From 45204f1bb3e138b145014ac003ab35106cb6bdd2 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 26 May 2026 14:01:43 +0300 Subject: [PATCH 1/6] [prf/dec]Implement prefill-decode for Llama Q8_0 --- .../InferenceCoreWithPrefillDecode.java | 10 +- ...adoVMMasterPlanWithBatchPrefillDecode.java | 175 ++++++++----- .../TornadoVMMasterPlanWithPrefillDecode.java | 53 ++-- .../TransformerBatchPrefillKernels.java | 226 +++++++++++++++++ .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 11 +- .../layers/type/q8_0/LogitsQ8_0Layer.java | 5 + .../q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 55 +++++ .../LlamaQ8_0FFNLayersPrefillDecode.java | 47 ++++ .../q8_0/decode/LogitsQ8_0LayerDecode.java | 35 +++ .../prefill/LlamaQ8_0LayersBatchPrefill.java | 230 ++++++++++++++++++ 10 files changed, 759 insertions(+), 88 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index 91bb6f79..dbd81297 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -153,9 +153,13 @@ public static void forwardTornadoVMPrefill(Model model, State state, int token, MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes, state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes); } - case Q8_0 -> throw new UnsupportedOperationException( - // TODO Phase 4: implement Q8_0 GPU batched prefill kernels - "GPU prefill/decode path not yet implemented for Q8_0 weights"); + case Q8_0 -> { + MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment(); + int blocksPerToken = (configuration.dim() + 31) / 32; + long bytesPerToken = (long) blocksPerToken * 34; + MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken, + state.embeddingX.getSegment(), 0, bytesPerToken); + } default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java index 2984508e..2fc310e2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java @@ -7,13 +7,20 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -102,7 +109,7 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── - /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */ + /** Graph 0 (FP16): B×dim FP16 embeddings → FP32 wrapXBatch. */ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { return new TaskGraph("prefillActivation") .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) @@ -112,8 +119,17 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { .persistOnDevice(state.wrapXBatch); } + /** Graph 0 (Q8_0): wrapXBatch pre-filled with FP32 by host; upload and persist. */ + private TaskGraph buildQ8_0BatchPrefillActivationGraph(KernelContext ctx) { + return new TaskGraph("prefillActivation") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) + .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough, + ctx, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + /** - * Graph N+1: single-token FP16 → FP32. + * Graph N+1: single-token embedding → FP32 wrapX, with KV-cache pass-through. * *

Receives the KV-cache device pointer from batch layer N via * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so @@ -121,73 +137,86 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { * Both halves of the chain are required; without the re-persist the pointer is * not forwarded in interpreter (non-CUDA-graph) mode.

*/ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) { - return new TaskGraph("decodeActivation") - .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", - TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX) - // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache - // re-persisted so updatePersistedObjectState() propagates the device - // pointer to decode layer 0's consumeFromDevice without CUDA graphs. - .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID, + GGMLType weightType) { + TaskGraph tg = new TaskGraph("decodeActivation") + .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); + if (weightType == GGMLType.Q8_0) { + tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, + ctx, (ByteArray) state.embeddingX, state.wrapX); + } else { + tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX); + } + return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); } - /** - * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*. - * - * TODO: support Q8_0 weights - * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} - */ @Override public TornadoExecutionPlan createExecutionPlan() { GGMLType weightType = model.weights().getWeightType(); - switch (weightType) { - case F16 -> { /* supported — continue below */ } - case Q8_0 -> throw new UnsupportedOperationException( - "Batched prefill/decode GPU path not yet implemented for Q8_0 weights"); - default -> throw new UnsupportedOperationException( + if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { + throw new UnsupportedOperationException( "Batched prefill/decode GPU path not supported for weight type: " + weightType); } LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); + List all = new ArrayList<>(2 * N + 3); - List all = new ArrayList<>(2 * N + 3); - - // [0] Batch prefill activation ──────────────────────────────────────────────── + // [0] Batch prefill activation ──────────────────────────────────────── KernelContext batchActCtx = new KernelContext(); - all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); - gridScheduler.addWorkerGrid("prefillActivation.updateX", - WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + if (weightType == GGMLType.F16) { + all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); + gridScheduler.addWorkerGrid("prefillActivation.updateX", + WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); + } else { + all.add(buildQ8_0BatchPrefillActivationGraph(batchActCtx).snapshot()); + gridScheduler.addWorkerGrid("prefillActivation.batchPassthrough", + WorkerGridFactory.genericWorker(1, 1)); + } - // [1..N] Batch prefill layer graphs ─────────────────────────────────────────── - LlamaFP16LayersBatchPrefill batchLayers = - new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); - all.addAll(batchLayers.getLayerImmutableTaskGraphs()); - batchLayers.updateGridScheduler(gridScheduler); + // [1..N] Batch prefill layer graphs ─────────────────────────────────── + String lastBatchLayerID; + if (weightType == GGMLType.F16) { + LlamaFP16LayersBatchPrefill batchLayers = + new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(gridScheduler); + lastBatchLayerID = batchLayers.getLastLayerTaskGraphID(); + } else { + LlamaQ8_0LayersBatchPrefill batchLayers = + new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(gridScheduler); + lastBatchLayerID = batchLayers.getLastLayerTaskGraphID(); + } // [N+1] Decode activation (with KV-cache pass-through) ──────────────── KernelContext decodeActCtx = new KernelContext(); - all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot()); + all.add(buildDecodeActivationGraph(decodeActCtx, lastBatchLayerID, weightType).snapshot()); gridScheduler.addWorkerGrid("decodeActivation.updateX", WorkerGridFactory.genericWorker(config.dim(), 128)); - // [N+2..2N+1] Decode layer graphs ──────────────────────────────────── - // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload). - LlamaFP16FFNLayersDecode decodeLayers = - new LlamaFP16FFNLayersDecode( - "decode", state, weights, config, schedulerType); + // [N+2..2N+1] Decode layer graphs ───────────────────────────────────── + AbstractFFNLayers decodeLayers; + if (weightType == GGMLType.F16) { + decodeLayers = new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } else { + decodeLayers = new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(gridScheduler); // [2N+2] Logits ─────────────────────────────────────────────────────── - // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache) - // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the - // KV-cache pointer survives the logits → decode-activation boundary across tokens. - LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + AbstractLogitsLayer logitsLayer; + if (weightType == GGMLType.F16) { + logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + } else { + logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + } all.add(logitsLayer.getImmutableTaskGraph()); logitsLayer.updateGridScheduler(gridScheduler); @@ -222,15 +251,32 @@ public void forceCopyInReadOnlyData() { */ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); - int bytes = Short.BYTES; - int dim = config.dim(); - - // Copy B embeddings into embeddingXBatch - for (int b = 0; b < chunkSize; b++) { - MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, - state.embeddingXBatch.getSegment(), (long) b * dim * bytes, - (long) dim * bytes); + int dim = config.dim(); + + if (weights.getWeightType() == GGMLType.Q8_0) { + // CPU dequantizes Q8_0 embeddings into wrapXBatch (FP32) + ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + for (int b = 0; b < chunkSize; b++) { + int tokenId = tokenIds[b]; + for (int j = 0; j < dim; j++) { + int blockByteOffset = (tokenId * blocksPerRow + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = embTable.getHalfFloat(blockByteOffset).getFloat32(); + float quant = embTable.get(blockByteOffset + 2 + j % blockSize); + state.wrapXBatch.set(b * dim + j, quant * scale); + } + } + } else { + // FP16: copy raw half-float bytes into embeddingXBatch + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int bytes = Short.BYTES; + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } } state.batchStartPosHolder.set(0, startPos); @@ -258,12 +304,19 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod */ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); - int bytes = Short.BYTES; - int dim = config.dim(); - - MemorySegment.copy(embTable, (long) token * dim * bytes, - state.embeddingX.getSegment(), 0L, (long) dim * bytes); + int dim = config.dim(); + + if (weights.getWeightType() == GGMLType.Q8_0) { + ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); + int blocksPerRow = (dim + 31) / 32; + long bytesPerToken = (long) blocksPerRow * 34; + MemorySegment.copy(embTable.getSegment(), (long) token * bytesPerToken, + state.embeddingX.getSegment(), 0L, bytesPerToken); + } else { + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + MemorySegment.copy(embTable, (long) token * dim * Short.BYTES, + state.embeddingX.getSegment(), 0L, (long) dim * Short.BYTES); + } state.positionHolder.set(0, position); state.temp.clear(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java index 517024f9..500882ea 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java @@ -11,8 +11,13 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode; import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -98,7 +103,14 @@ public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan * The KV cache is not managed here — it is allocated on the first forward pass * by decode layer 0 via {@code FIRST_EXECUTION}.

*/ - private TaskGraph buildActivationGraph(KernelContext ctx) { + private TaskGraph buildActivationGraph(KernelContext ctx, GGMLType weightType) { + if (weightType == GGMLType.Q8_0) { + return new TaskGraph("decodeActivation") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, + ctx, (ByteArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + } return new TaskGraph("decodeActivation") .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) .task("updateX", TransformerComputeKernels::convertFP16toFP32, @@ -107,21 +119,11 @@ private TaskGraph buildActivationGraph(KernelContext ctx) { } // ── Plan construction ───────────────────────────────────────────────────── - /** - * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*. - * Prefill is token-by-token but does not compute logits. - * - * TODO: support Q8_0 weights - * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory} - */ @Override public TornadoExecutionPlan createExecutionPlan() { GGMLType weightType = model.weights().getWeightType(); - switch (weightType) { - case F16 -> { /* supported — continue below */ } - case Q8_0 -> throw new UnsupportedOperationException( - "Prefill/decode GPU path not yet implemented for Q8_0 weights"); - default -> throw new UnsupportedOperationException( + if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { + throw new UnsupportedOperationException( "Prefill/decode GPU path not supported for weight type: " + weightType); } @@ -132,24 +134,29 @@ public TornadoExecutionPlan createExecutionPlan() { // [0] Activation ────────────────────────────────────────────────────── KernelContext actCtx = new KernelContext(); - all.add(buildActivationGraph(actCtx).snapshot()); + all.add(buildActivationGraph(actCtx, weightType).snapshot()); gridScheduler.addWorkerGrid("decodeActivation.updateX", WorkerGridFactory.genericWorker(config.dim(), 128)); // [1..N] Decode layer graphs ────────────────────────────────────────── - // Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX). - // Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode. - LlamaFP16FFNLayersPrefillDecode decodeLayers = - new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + AbstractFFNLayers decodeLayers; + if (weightType == GGMLType.F16) { + decodeLayers = new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } else { + decodeLayers = new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(gridScheduler); // [N+1] Logits ──────────────────────────────────────────────────────── - // LogitsFP16LayerDecode re-persists the KV cache so the pointer survives - // the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens. - LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode( - "logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + AbstractLogitsLayer logitsLayer; + if (weightType == GGMLType.F16) { + logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + } else { + logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config, + decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); + } all.add(logitsLayer.getImmutableTaskGraph()); logitsLayer.updateGridScheduler(gridScheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java index 9bba3860..6c688649 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java @@ -5,6 +5,7 @@ import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** @@ -458,4 +459,229 @@ public static void batchedFusedRmsNormFFNGateUp(KernelContext context, wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); } } + + // ── Q8_0 Batch Kernels ─────────────────────────────────────────────────── + + /** + * No-op kernel for Q8_0 batch activation graph. + * The host fills wrapXBatch with dequantized FP32 embeddings before execution. + * Worker: 1 global thread. + */ + public static void batchPassthrough(KernelContext context, FloatArray wrapXBatch) { + if (context.globalIdx == 0) { + wrapXBatch.set(0, wrapXBatch.get(0)); + } + } + + /** + * Applies RMS normalization to FP32 — Q8_0 variant. + * Writes normalized FP32 to wrapXbBatch (reused as xb intermediate before QKV). + * Worker: B*dim global threads, localSize=256. + */ + public static void batchedRmsApplyFP32(KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapXBatch, + FloatArray rmsWeights, + FloatArray attnScaleBatch, + int dim) { + int gid = context.globalIdx; + int b = gid / dim; + int i = gid % dim; + wrapXbBatch.set(gid, rmsWeights.get(i) * attnScaleBatch.get(b) * wrapXBatch.get(gid)); + } + + /** + * Fused batched QKV projection with Q8_0 weight dequantization. + * Input wrapXbBatch is FP32 (written by batchedRmsApplyFP32). + * groupIdx = batchIdx * (dim + 2*kvDim) + rowIdx. + * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulQ8(KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + ByteArray wq, + ByteArray wk, + ByteArray wv, + int dim, int kvDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int totalRows = dim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * dim; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < dim) { + int rowBlockOffset = rowIdx * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wq.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wq.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]); + + } else if (rowIdx < dim + kvDim) { + int kRow = rowIdx - dim; + int rowBlockOffset = kRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wk.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wk.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - dim - kvDim; + int rowBlockOffset = vRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = wv.getHalfFloat(blockByteOffset).getFloat32(); + float quant = wv.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched matrix-vector multiply with residual add (Q8_0 weights). + * Used for attention output (Wo) and FFN down (W2) projections. + * groupIdx = batchIdx * d + rowIdx. + * Worker: B*d workgroups × localWorkGroupSize threads. + */ + public static void batchedMatVecWithResidualQ8(KernelContext context, + FloatArray inputBatch, + FloatArray outputBatch, + ByteArray w, + int n, int d, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / d; + int rowIdx = groupId % d; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (n + blockSize - 1) / blockSize; + int rowBlockOffset = rowIdx * blocksPerRow; + int inputOff = batchIdx * n; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + float partial = 0.0f; + for (int j = localId; j < n; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float scale = w.getHalfFloat(blockByteOffset).getFloat32(); + float quant = w.get(blockByteOffset + 2 + j % blockSize); + partial += quant * scale * inputBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) { + int outIdx = batchIdx * d + rowIdx; + outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]); + } + } + + /** + * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU (Q8_0 weights). + * groupIdx = batchIdx * hiddenDim + rowIdx. + * Worker: B*hiddenDim workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedRmsNormFFNGateUpQ8(KernelContext context, + FloatArray wrapXBatch, + FloatArray wrapHbBatch, + FloatArray rmsFFNWeights, + FloatArray ffnScaleBatch, + ByteArray w1, + ByteArray w3, + int dim, int hiddenDim, + int localWorkGroupSize) { + int groupId = context.groupIdx; + int localId = context.localIdx; + int batchIdx = groupId / hiddenDim; + int rowIdx = groupId % hiddenDim; + + float scale = ffnScaleBatch.get(batchIdx); + int inputOff = batchIdx * dim; + + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; + int blocksPerRow = (dim + blockSize - 1) / blockSize; + int rowBlockOffset = rowIdx * blocksPerRow; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float w1Scale = w1.getHalfFloat(blockByteOffset).getFloat32(); + float w1Quant = w1.get(blockByteOffset + 2 + j % blockSize); + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum1 += w1Quant * w1Scale * normed; + } + localSum[localId] = sum1; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result1 = localSum[0]; + + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES; + float w3Scale = w3.getHalfFloat(blockByteOffset).getFloat32(); + float w3Quant = w3.get(blockByteOffset + 2 + j % blockSize); + float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j); + sum3 += w3Quant * w3Scale * normed; + } + localSum[localId] = sum3; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + float result3 = localSum[0]; + + if (localId == 0) { + float silu = result1 / (1.0f + TornadoMath.exp(-result1)); + wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 1f33f090..2cfacdc0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -113,7 +113,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Copy-in weights per layer for batched-layered layout (Q8 format) weights.rms_att_weightLayered[layerIndex].asFloatArray(), @@ -224,6 +229,10 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { return unifiedLayer; } + protected String predecessorGraphName(int layerIndex) { + return null; + } + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { // First layer: Transfer initial data to device (one-time transfer) if (layerIndex == 0) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index ed0d6cfc..1a1313e2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -22,12 +22,16 @@ public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration super(name, state, weights, config, lastTaskGraphID, schedulerType); } + protected void configureAdditionalConsumes(TaskGraph logits) {} + protected void configureAdditionalPersists(TaskGraph logits) {} + // @formatter:off @Override protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { var logits = new TaskGraph("logits"); // === Data Setup === + configureAdditionalConsumes(logits); logits.consumeFromDevice(lastTaskGraphID, state.wrapX); logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, @@ -74,6 +78,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + configureAdditionalPersists(logits); return logits; } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java new file mode 100644 index 00000000..e159c99b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java @@ -0,0 +1,55 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode FFN layers for the unified batched prefill-decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * + *

Layer 0 consumes the KV cache from device (passed through by the decode activation + * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation + * for the KV cache — it was already allocated in the batch prefill phase.

+ */ +public class LlamaQ8_0FFNLayersDecode extends LlamaQ8_0FFNLayers { + + public LlamaQ8_0FFNLayersDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, state.temp, state.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapAtt, state.wrapHb); + layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java new file mode 100644 index 00000000..a1f02ada --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java @@ -0,0 +1,47 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode FFN layers for the single-token prefill/decode plan + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * + *

Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which + * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache}, + * allocating the KV cache on the very first forward pass. Layers 1+ use explicit + * predecessor names for all consumed objects, required by TornadoVM's interpreter mode.

+ */ +public class LlamaQ8_0FFNLayersPrefillDecode extends LlamaQ8_0FFNLayers { + + public LlamaQ8_0FFNLayersPrefillDecode(String taskGraph, LlamaState state, + LlamaTornadoWeights weights, LlamaConfiguration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXb, state.wrapXb2, + state.wrapQ, state.wrapK, state.wrapV, + state.wrapKeyCache, state.wrapValueCache, + state.wrapAtt, state.wrapHb, + state.positionHolder); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java new file mode 100644 index 00000000..054b7f7f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java @@ -0,0 +1,35 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Logits layer for the unified batched prefill-decode plan (Q8_0). + * + *

Extends {@link LogitsQ8_0Layer} with KV-cache pass-through so the device pointers for + * {@code wrapKeyCache} and {@code wrapValueCache} survive the logits → decode-activation + * boundary between decode tokens. Without the pass-through, the KV-cache pointer is absent + * from the logits persisted set, cleared to null, and the first decode layer crashes with + * an NPE in {@code executeAlloc}.

+ */ +public class LogitsQ8_0LayerDecode extends LogitsQ8_0Layer { + + public LogitsQ8_0LayerDecode(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + @Override + protected void configureAdditionalConsumes(TaskGraph logits) { + logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache); + } + + @Override + protected void configureAdditionalPersists(TaskGraph logits) { + logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java new file mode 100644 index 00000000..79ea9bc2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java @@ -0,0 +1,230 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0). + * + *

Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill} + * but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:

+ *
    + *
  • {@code wrapXBatch} is filled with dequantized FP32 embeddings by the host before + * the activation graph runs (no on-device FP16→FP32 conversion).
  • + *
  • {@code wrapXbBatch} (FP32) is reused as the normalized xb intermediate: written + * by {@code batchedRmsApplyFP32}, read by {@code batchedFusedQKVMatmulQ8}, then + * overwritten by flash attention output.
  • + *
  • {@code wrapXbFP16Batch} is not used.
  • + *
  • Weight matrices are {@code ByteArray} (Q8_0 format).
  • + *
+ */ +public class LlamaQ8_0LayersBatchPrefill { + + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public LlamaQ8_0LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights, + LlamaConfiguration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + // Allocate GPU-side batch intermediates once. + // wrapXBatch is filled with dequantized FP32 by the host, persisted by prefillActivation. + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + String pred = "batchPrefillLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once (Q8_0 format) + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + // Writes FP32 normalized xb into wrapXbBatch (reused later by flash attention) + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP32, + context, state.wrapXbBatch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + TransformerBatchPrefillKernels::batchedFusedQKVMatmulQ8, + context, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + dim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_rope_kv", + TransformerBatchPrefillKernels::batchedRopeWithKVCache, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, config.headSize(), layerIndex, config.contextLength(), dim); + + // Overwrites wrapXbBatch with attention output + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), config.headSize(), + kvDim, config.kvMul(), layerIndex, config.contextLength(), dim); + + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asByteArray(), + dim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUpQ8, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asByteArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int kvDim = config.kvDim(); + int hidDim = config.hiddenDim(); + int nHeads = config.numberOfHeads(); + int headSz = config.headSize(); + + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + int qkvRows = dim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + int ropeGlobal = batchSize * (dim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + int optLocal = findOptimalLocalSize(headSz); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * nHeads * optLocal, optLocal); + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} From 8ebf91f2dd437d1dd6466e3b2c4bb4b37ba7fbfe Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 28 May 2026 15:14:01 +0300 Subject: [PATCH 2/6] Reorganize TornadoVM execution planning and improve naming conventions Reorganize TornadoVM execution planning around forward modes, model families, and quantization-specific components. --- .../gpullama3/inference/InferenceCore.java | 2 +- .../InferenceCoreBatchPrefillDecode.java | 14 +- .../InferenceCoreWithPrefillDecode.java | 6 +- ...InferenceEngineWithBatchPrefillDecode.java | 6 +- .../InferenceEngineWithPrefillDecode.java | 6 +- .../tornadovm/TornadoVMMasterPlan.java | 24 +- ...TornadoVMMasterPlanBatchPrefillDecode.java | 165 ++++++++ .../TornadoVMMasterPlanPrefillDecode.java | 166 ++++++++ ...va => TornadoVMMasterPlanSingleToken.java} | 59 +-- ...adoVMMasterPlanWithBatchPrefillDecode.java | 360 ------------------ .../TornadoVMMasterPlanWithPrefillDecode.java | 244 ------------ .../TransformerBatchPrefillKernels.java | 2 +- .../layerplanner/GenericLayerPlanner.java | 14 - .../QuantizationPlannerFactory.java | 97 ----- .../layerplanner/QuantizedLayerPlanner.java | 103 ----- .../model/fp16/DevstralFP16LayerPlanner.java | 20 - .../model/fp16/FP16LayerPlanner.java | 25 -- .../model/fp16/GraniteFP16LayerPlanner.java | 20 - .../model/fp16/LlamaFP16LayerPlanner.java | 20 - .../model/fp16/MistralFP16LayerPlanner.java | 20 - .../model/fp16/Phi3FP16LayerPlanner.java | 27 -- .../model/fp16/Qwen2FP16LayerPlanner.java | 27 -- .../model/fp16/Qwen3FP16LayerPlanner.java | 27 -- .../model/q8_0/DevstralQ8_0LayerPlanner.java | 20 - .../model/q8_0/GraniteQ8_0LayerPlanner.java | 20 - .../model/q8_0/LlamaQ8_0LayerPlanner.java | 20 - .../model/q8_0/MistralQ8_0LayerPlanner.java | 20 - .../model/q8_0/Phi3Q8_0LayerPlanner.java | 28 -- .../model/q8_0/Q8_0LayerPlanner.java | 26 -- .../model/q8_0/Qwen2Q8_0LayerPlanner.java | 28 -- .../model/q8_0/Qwen3Q8_0LayerPlanner.java | 28 -- .../layerplanner/strategy/SchedulerType.java | 5 - .../tornadovm/layers/AbstractFFNLayers.java | 4 +- .../tornadovm/layers/AbstractLogitsLayer.java | 2 +- .../tornadovm/layers/Activation.java | 4 +- .../tornadovm/layers/ActivationGraph.java | 15 + ...atchPrefillTransformerLayerTaskGraphs.java | 17 + .../layers/TransformerLayerTaskGraphs.java | 17 + .../type/fp16/DevstralFP16FFNLayers.java | 4 +- .../type/fp16/GraniteFP16FFNLayers.java | 4 +- .../layers/type/fp16/LlamaFP16FFNLayers.java | 4 +- .../layers/type/fp16/LogitsFP16Layer.java | 4 +- .../type/fp16/LogitsGraniteFP16Layer.java | 2 +- .../type/fp16/MistralFP16FFNLayers.java | 4 +- .../layers/type/fp16/Phi3FP16FFNLayers.java | 4 +- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 4 +- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 4 +- .../fp16/decode/LlamaFP16FFNLayersDecode.java | 4 +- .../LlamaFP16FFNLayersPrefillDecode.java | 4 +- .../fp16/decode/LogitsFP16LayerDecode.java | 4 +- .../prefill/LlamaFP16LayersBatchPrefill.java | 7 +- .../type/q8_0/DevstralQ8_0FFNLayers.java | 4 +- .../type/q8_0/GraniteQ8_0FFNLayers.java | 4 +- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 4 +- .../type/q8_0/LogitsGraniteQ8_0Layer.java | 2 +- .../layers/type/q8_0/LogitsQ8_0Layer.java | 4 +- .../type/q8_0/MistralQ8_0FFNLayers.java | 4 +- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 4 +- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 4 +- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 4 +- .../q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 4 +- .../LlamaQ8_0FFNLayersPrefillDecode.java | 4 +- .../q8_0/decode/LogitsQ8_0LayerDecode.java | 2 +- .../prefill/LlamaQ8_0LayersBatchPrefill.java | 5 +- .../plan/BatchPrefillDecodeForwardPlan.java | 75 ++++ .../tornadovm/plan/ExecutionMode.java | 16 + .../gpullama3/tornadovm/plan/ForwardPlan.java | 32 ++ .../tornadovm/plan/ForwardPlanFactory.java | 207 ++++++++++ .../plan/PrefillDecodeForwardPlan.java | 57 +++ .../plan/SingleTokenForwardPlan.java | 54 +++ ...tchPrefillDecodeForwardPlanComponents.java | 25 ++ .../plan/components/EmbeddingPreparer.java | 18 + .../PrefillDecodeForwardPlanComponents.java | 20 + .../SingleTokenForwardPlanComponents.java | 21 + .../activation/BatchDecodeActivation.java | 62 +++ .../activation/BatchPrefillActivation.java | 72 ++++ .../fp16/DevstralFP16PlanComponents.java | 42 ++ .../fp16/GraniteFP16PlanComponents.java | 42 ++ .../fp16/LlamaFP16PlanComponents.java | 120 ++++++ .../fp16/MistralFP16PlanComponents.java | 42 ++ .../fp16/Phi3FP16PlanComponents.java | 42 ++ .../fp16/Qwen2FP16PlanComponents.java | 42 ++ .../fp16/Qwen3FP16PlanComponents.java | 42 ++ .../q8_0/DevstralQ8_0PlanComponents.java | 42 ++ .../q8_0/GraniteQ8_0PlanComponents.java | 42 ++ .../q8_0/LlamaQ8_0PlanComponents.java | 131 +++++++ .../q8_0/MistralQ8_0PlanComponents.java | 42 ++ .../q8_0/Phi3Q8_0PlanComponents.java | 42 ++ .../q8_0/Qwen2Q8_0PlanComponents.java | 42 ++ .../q8_0/Qwen3Q8_0PlanComponents.java | 42 ++ ...chPrefillDecodeForwardTaskGraphLayout.java | 21 + .../PrefillDecodeForwardTaskGraphLayout.java | 17 + .../SingleTokenForwardTaskGraphLayout.java | 17 + .../SchedulerDetectionService.java | 5 +- .../tornadovm/scheduling/SchedulerType.java | 5 + .../WorkerGridFactory.java | 3 +- 96 files changed, 1965 insertions(+), 1327 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java rename src/main/java/org/beehive/gpullama3/tornadovm/{TornadoVMMasterPlanStandard.java => TornadoVMMasterPlanSingleToken.java} (51%) delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java rename src/main/java/org/beehive/gpullama3/tornadovm/{layerplanner/strategy => scheduling}/SchedulerDetectionService.java (93%) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java rename src/main/java/org/beehive/gpullama3/tornadovm/{layerplanner => scheduling}/WorkerGridFactory.java (98%) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 9beade35..67e862ce 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -814,7 +814,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType()); } - return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); + return tornadoVMMasterPlan.tornadoVMExecuteForward(position); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java index d3c74599..0e89592d 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; import org.beehive.gpullama3.tensor.standard.FloatTensor; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; /** @@ -21,9 +21,9 @@ * prompt tokens in one pass using batch matmul, avoiding redundant weight * traversals. Only the KV cache is populated; logits are intentionally omitted. *
  • {@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk - * to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.
  • + * to {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill}. *
  • {@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which + * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode}, which * handles the embedding copy and runs the full decode + logits graphs.
  • * */ @@ -165,7 +165,7 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token * GPU batched prefill forward pass (Phase 4). * *

    Delegates the full chunk to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}, + * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill}, * which handles embedding lookup and GPU execution internally.

    * * @param model the LLaMA model @@ -175,7 +175,7 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token * @param plan the batched prefill/decode GPU plan */ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { + TornadoVMMasterPlanBatchPrefillDecode plan) { plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize); } @@ -183,7 +183,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s * GPU decode forward pass (Phase 4). * *

    Delegates a single-token decode step to - * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, + * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode}, * which copies the token embedding and runs the decode + logits graphs.

    * * @param model the LLaMA model @@ -193,7 +193,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s * @return logits array for token sampling */ public static FloatArray forwardTornadoVMDecode(Model model, int token, int position, - TornadoVMMasterPlanWithBatchPrefillDecode plan) { + TornadoVMMasterPlanBatchPrefillDecode plan) { return plan.tornadoVMForwardDecode(token, position, model); } } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java index dbd81297..9c812b10 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.standard.FloatTensor; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode; import java.lang.foreign.MemorySegment; @@ -131,7 +131,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * *

    Copies the token embedding into {@code state.embeddingX} (same as * {@link InferenceCore#forwardTornadoVM}) then delegates to - * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill}, + * {@link TornadoVMMasterPlanPrefillDecode#tornadoVMForwardPrefill}, * which executes preprocessing + layer graphs but skips the logits graph.

    * * @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only) @@ -142,7 +142,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p * @throws UnsupportedOperationException if the model uses Q8_0 weights */ public static void forwardTornadoVMPrefill(Model model, State state, int token, int position, - TornadoVMMasterPlanWithPrefillDecode prefillPlan) { + TornadoVMMasterPlanPrefillDecode prefillPlan) { final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java index 06dbe256..5f617bda 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode; import java.util.ArrayList; import java.util.Arrays; @@ -163,8 +163,8 @@ public static List generateTokensGPULlama( ? config.contextLength() : maxTokens; final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE; - TornadoVMMasterPlanWithBatchPrefillDecode plan = - (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan; + TornadoVMMasterPlanBatchPrefillDecode plan = + (TornadoVMMasterPlanBatchPrefillDecode) tornadoVMPlan; List generatedTokens = new ArrayList<>(); diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java index ae9f6e9c..7239ab8a 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode; import java.util.ArrayList; import java.util.List; @@ -135,8 +135,8 @@ public static List generateTokensGPULlama( int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens) ? config.contextLength() : maxTokens; - TornadoVMMasterPlanWithPrefillDecode prefillPlan = - (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan; + TornadoVMMasterPlanPrefillDecode prefillPlan = + (TornadoVMMasterPlanPrefillDecode) tornadoVMPlan; List generatedTokens = new ArrayList<>(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 3bf4bc21..962be21d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -10,20 +10,20 @@ * *

    Three concrete implementations exist:

    *
      - *
    • {@link TornadoVMMasterPlanStandard} — baseline single-token forward pass + *
    • {@link TornadoVMMasterPlanSingleToken} — baseline single-token forward pass * (preprocessing + N layers + logits).
    • - *
    • {@link TornadoVMMasterPlanWithPrefillDecode} — sequential prefill/decode separation; + *
    • {@link TornadoVMMasterPlanPrefillDecode} — sequential prefill/decode separation; * reuses the same N layer graphs for both phases, skipping logits during prefill.
    • - *
    • {@link TornadoVMMasterPlanWithBatchPrefillDecode} — batched prefill + single-token + *
    • {@link TornadoVMMasterPlanBatchPrefillDecode} — batched prefill + single-token * decode; holds 2N+3 graphs in one plan to keep the KV cache on device across phases.
    • *
    * *

    The {@link #initializeTornadoVMPlan} factory selects the implementation based on * {@code llama.withPrefillDecode} and {@code llama.prefillBatchSize}:

    *
      - *
    • {@code withPrefillDecode=false} → {@link TornadoVMMasterPlanStandard}
    • - *
    • {@code withPrefillDecode=true}, {@code prefillBatchSize=1} → {@link TornadoVMMasterPlanWithPrefillDecode}
    • - *
    • {@code withPrefillDecode=true}, {@code prefillBatchSize>1} → {@link TornadoVMMasterPlanWithBatchPrefillDecode}
    • + *
    • {@code withPrefillDecode=false} → {@link TornadoVMMasterPlanSingleToken}
    • + *
    • {@code withPrefillDecode=true}, {@code prefillBatchSize=1} → {@link TornadoVMMasterPlanPrefillDecode}
    • + *
    • {@code withPrefillDecode=true}, {@code prefillBatchSize>1} → {@link TornadoVMMasterPlanBatchPrefillDecode}
    • *
    */ public interface TornadoVMMasterPlan { @@ -43,8 +43,8 @@ public interface TornadoVMMasterPlan { * Factory: creates, JIT-compiles, and warms up the appropriate TornadoVMMasterPlan. * *

    When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1}, - * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned. - * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline + * a {@link TornadoVMMasterPlanBatchPrefillDecode} is returned. + * Otherwise a {@link TornadoVMMasterPlanSingleToken} is returned (used for the baseline * path and the sequential prefill/decode path when batch size is 1).

    * * @param state the model state @@ -56,13 +56,13 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) { // GPU path with batched prefill/decode - plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model); + plan = new TornadoVMMasterPlanBatchPrefillDecode(state, model); } else if (WITH_PREFILL_DECODE) { // GPU path with simple prefill/decode - plan = new TornadoVMMasterPlanWithPrefillDecode(state, model); + plan = new TornadoVMMasterPlanPrefillDecode(state, model); } else { // GPU path with no prefill/decode - plan = new TornadoVMMasterPlanStandard(state, model); + plan = new TornadoVMMasterPlanSingleToken(state, model); } model.setTornadoVMPlan(plan); return plan; @@ -76,7 +76,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) { void forceCopyInReadOnlyData(); - FloatArray tornadoVMForwardExecuteLayered(int position); + FloatArray tornadoVMExecuteForward(int position); /** Releases all device memory held by this plan. */ void freeTornadoExecutionPlan(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java new file mode 100644 index 00000000..968e7b19 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java @@ -0,0 +1,165 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.BatchPrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for batched prefill + single-token decode. + * + *

    A single {@link TornadoExecutionPlan} holds all task graphs for + * batched prefill and single-token decode phases:

    + * + *

    TaskGraph layout (2N+3 TaskGraphs total):

    + *
    + *   [0]         batchPrefillActivation  B×dim embeddings → FP32 wrapXBatch
    + *   [1..N]      batch-prefill layers    B tokens, all transformer ops
    + *   [N+1]       decodeActivation        single-token embedding → FP32 + KV-cache pass-through
    + *   [N+2..2N+1] decode layers           single-token, standard kernels
    + *   [2N+2]      logits
    + * 
    + */ +public class TornadoVMMasterPlanBatchPrefillDecode implements TornadoVMMasterPlan { + + private final State state; + private final Model model; + private final Configuration config; + + BatchPrefillDecodeForwardPlan batchPrefillDecodeForwardPlan; + BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout; + public TornadoExecutionPlan executionPlan; + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanBatchPrefillDecode(State initialState, Model model) { + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = initialState; + this.model = model; + this.config = model.configuration(); + + long startTime = System.nanoTime(); + this.executionPlan = createExecutionPlan(); + long planCreationTime = System.nanoTime(); + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + long warmupTime = System.nanoTime(); + + forceCopyInReadOnlyData(); + long copyTime = System.nanoTime(); + + RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); + } + + // ── Plan construction ───────────────────────────────────────────────────── + + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.batchPrefillDecodeForwardPlan = + ForwardPlanFactory.createBatchPrefillDecode(weightType, state, model); + this.taskGraphLayout = batchPrefillDecodeForwardPlan.getTaskGraphLayout(); + var taskGraphs = batchPrefillDecodeForwardPlan.getImmutableTaskGraphs(); + return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0])); + } + + // ── Initialisation ──────────────────────────────────────────────────────── + + @Override + public void forceCopyInReadOnlyData() { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().initBatchState(); + state.wrapX.clear(); + state.positionHolder.init(0); + + for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) { + var g = executionPlan.withGraph(i) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. + * + * @param tokenIds token IDs for this chunk + * @param startPos sequence position of tokenIds[0] + * @param model model (unused — kept for API compatibility) + * @param chunkSize actual number of tokens in this chunk (≤ batchSize) + */ + public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyBatchEmbeddings(tokenIds, startPos, chunkSize); + + var batchAct = executionPlan.withGraph(taskGraphLayout.batchActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchAct.withCUDAGraph(); + batchAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var batchLayer = executionPlan.withGraph(taskGraphLayout.batchLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); + batchLayer.execute(); + } + } + + /** + * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). + * + * @param token token ID to process + * @param position sequence position + * @param model model (unused — kept for API compatibility) + * @return logits array for sampling + */ + public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { + batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyDecodeEmbedding(token); + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + var decodeAct = executionPlan.withGraph(taskGraphLayout.decodeActivationIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); + decodeAct.execute(); + + for (int l = 0; l < config.numberOfLayers(); l++) { + var decodeLayer = executionPlan.withGraph(taskGraphLayout.decodeLayerIdx(l)) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); + decodeLayer.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + + var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx()) + .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public FloatArray tornadoVMExecuteForward(int position) { + throw new UnsupportedOperationException( + "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java new file mode 100644 index 00000000..a1f56549 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java @@ -0,0 +1,166 @@ +package org.beehive.gpullama3.tornadovm; + +import org.beehive.gpullama3.auxiliary.RunMetrics; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.PrefillDecodeForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +/** + * GPU execution plan for sequential (single-token) prefill/decode separation. + * + *

    A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache + * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on + * device across both phases. Prefill and decode reuse the same N layer graphs; + * only the logits graph is skipped during prefill.

    + * + *

    Graph layout (N+2 graphs total):

    + *
    + *   [0]      decodeActivation    single-token FP16 → FP32; KV-cache allocated on first execution
    + *   [1..N]   layer_0..layer_N-1  transformer layers (attention + FFN)
    + *   [N+1]    logits              final RMSNorm + wcls matmul
    + * 
    + * + *

    Two forward passes:

    + *
      + *
    • {@link #tornadoVMForwardPrefill} — graphs 0..N (activation + layers), logits skipped. + * Called once per prompt token; populates the KV cache.
    • + *
    • {@link #tornadoVMForwardDecode} — full pass including logits. + * Called once per generated token; returns logits for sampling.
    • + *
    + */ +public class TornadoVMMasterPlanPrefillDecode implements TornadoVMMasterPlan { + + private final State state; + private final Model model; + private final Configuration config; + + PrefillDecodeForwardPlan prefillDecodeForwardPlan; + PrefillDecodeForwardTaskGraphLayout taskGraphLayout; + public TornadoExecutionPlan executionPlan; + + // ── Construction ───────────────────────────────────────────────────────── + TornadoVMMasterPlanPrefillDecode(State state, Model model) { + if (ENABLE_TORNADOVM_INIT_TIME) { + System.err.println("\nStarting TornadoVM initialization..."); + } + + this.state = state; + this.model = model; + this.config = model.configuration(); + + long startTime = System.nanoTime(); + this.executionPlan = createExecutionPlan(); + long planCreationTime = System.nanoTime(); + + if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); + executionPlan.withPreCompilation(); + long warmupTime = System.nanoTime(); + + forceCopyInReadOnlyData(); + long copyTime = System.nanoTime(); + + RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); + } + + // ── Plan construction ───────────────────────────────────────────────────── + + @Override + public TornadoExecutionPlan createExecutionPlan() { + GGMLType weightType = model.weights().getWeightType(); + this.prefillDecodeForwardPlan = ForwardPlanFactory.createPrefillDecode(weightType, state, model); + this.taskGraphLayout = prefillDecodeForwardPlan.getTaskGraphLayout(); + var taskGraphs = prefillDecodeForwardPlan.getImmutableTaskGraphs(); + return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0])); + } + + // ── Initialisation ──────────────────────────────────────────────────────── + + /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ + @Override + public void forceCopyInReadOnlyData() { + state.wrapX.clear(); + state.positionHolder.init(0); + + for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) { + var g = executionPlan.withGraph(i) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) g.withCUDAGraph(); + g.execute(); + } + } + + // ── Forward passes ──────────────────────────────────────────────────────── + + /** + * GPU prefill forward: activation + all transformer layers, logits skipped. + * + * @param position sequence position being processed + */ + public void tornadoVMForwardPrefill(int position) { + var act = executionPlan.withGraph(taskGraphLayout.activationIdx()) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) act.withCUDAGraph(); + act.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + var l = executionPlan.withGraph(taskGraphLayout.layerIdx(layer)) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) l.withCUDAGraph(); + l.execute(); + } + } + + /** + * GPU decode forward: full execution including logits. + * + * @param position sequence position being processed + * @return logits array for token sampling + */ + public FloatArray tornadoVMForwardDecode(int position) { + return tornadoVMExecuteForward(position); + } + + @Override + public FloatArray tornadoVMExecuteForward(int position) { + var act = executionPlan.withGraph(taskGraphLayout.activationIdx()) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) act.withCUDAGraph(); + act.execute(); + + state.positionHolder.set(0, position); + state.temp.clear(); + state.tempFFN.clear(); + + for (int layer = 0; layer < config.numberOfLayers(); layer++) { + var l = executionPlan.withGraph(taskGraphLayout.layerIdx(layer)) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) l.withCUDAGraph(); + l.execute(); + } + + state.tempLogits.clear(); + state.wrapLogits.clear(); + var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx()) + .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler()); + if (CUDA_GRAPHS) logits.withCUDAGraph(); + logits.execute(); + + return state.wrapLogits; + } + + @Override + public void freeTornadoExecutionPlan() { + executionPlan.freeDeviceMemory(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java similarity index 51% rename from src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java rename to src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java index 14aaed10..ae3cdd03 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java @@ -5,8 +5,9 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory; +import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory; +import org.beehive.gpullama3.tornadovm.plan.SingleTokenForwardPlan; +import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; @@ -18,16 +19,17 @@ * logits projection. *

    */ -public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan { +public class TornadoVMMasterPlanSingleToken implements TornadoVMMasterPlan { private final State state; private final Model model; private final Configuration config; - GenericLayerPlanner tornadoVMLayerPlanner; + SingleTokenForwardPlan tornadoVMForwardPlan; + SingleTokenForwardTaskGraphLayout taskGraphLayout; public TornadoExecutionPlan executionPlan; - public TornadoVMMasterPlanStandard(State state, Model model) { + public TornadoVMMasterPlanSingleToken(State state, Model model) { if (ENABLE_TORNADOVM_INIT_TIME) { System.err.println("\nStarting TornadoVM initialization..."); } @@ -50,23 +52,20 @@ public TornadoVMMasterPlanStandard(State state, Model model) { RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); } - /** - * Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass. - */ @Override public TornadoExecutionPlan createExecutionPlan() { GGMLType weightType = model.weights().getWeightType(); - this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model); - var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); - var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); - return new TornadoExecutionPlan(taskGraphArray); + this.tornadoVMForwardPlan = ForwardPlanFactory.createSingleToken(weightType, state, model); + this.taskGraphLayout = tornadoVMForwardPlan.getTaskGraphLayout(); + var taskGraphs = tornadoVMForwardPlan.getImmutableTaskGraphs(); + return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0])); } @Override - public FloatArray tornadoVMForwardExecuteLayered(int position) { + public FloatArray tornadoVMExecuteForward(int position) { // @formatter:off - var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + var preGraph = executionPlan.withGraph(taskGraphLayout.activationIdx()) + .withGridScheduler(tornadoVMForwardPlan.getGridScheduler()); if (CUDA_GRAPHS) preGraph.withCUDAGraph(); preGraph.execute(); @@ -75,48 +74,32 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { state.tempFFN.clear(); for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) - //.withCUDAGraph() + executionPlan.withGraph(taskGraphLayout.layerIdx(layer)) + .withGridScheduler(tornadoVMForwardPlan.getGridScheduler()) .execute(); } state.tempLogits.clear(); state.wrapLogits.clear(); - var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()); + var logitsGraph = executionPlan.withGraph(taskGraphLayout.logitsIdx()) + .withGridScheduler(tornadoVMForwardPlan.getGridScheduler()); if (CUDA_GRAPHS) logitsGraph.withCUDAGraph(); logitsGraph.execute(); // @formatter:on return state.wrapLogits; } - private int getPreprocessingGraphIndex() { - return 0; - } - - private int getLayerGraphIndex(int layerIndex) { - return 1 + layerIndex; - } - - private int getFinalLogitsGraphIndex() { - return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; - } - @Override public void forceCopyInReadOnlyData() { state.wrapX.clear(); state.positionHolder.init(0); - //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(taskGraphLayout.activationIdx()).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute(); for (int layer = 0; layer < config.numberOfLayers(); layer++) { - //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(taskGraphLayout.layerIdx(layer)).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute(); } - //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute(); - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); + executionPlan.withGraph(taskGraphLayout.logitsIdx()).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute(); } @Override diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java deleted file mode 100644 index 2fc310e2..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java +++ /dev/null @@ -1,360 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.RunMetrics; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; -import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.TornadoExecutionPlan; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; - -import java.lang.foreign.MemorySegment; -import java.util.ArrayList; -import java.util.List; - -/** - * GPU execution plan for batched prefill + single-token decode. - * - *

    A single {@link TornadoExecutionPlan} holds all {@link TaskGraph} for - * batched prefill and single-token decode phases with the following structure:

    . - * - *

    TaskGraph layout (2N+3 TaskGraphs total):

    - *
    - *   [0]         prefill batch activation       B×dim FP16 → FP32
    - *   [1..N]      prefill batch layer graphs     B tokens, all transformer ops
    - *   [N+1]       decode activation              single-token FP16 → FP32 + KV-cache pass-through
    - *   [N+2..2N+1] decode layer graphs            single-token, standard kernels
    - *   [2N+2]      logits graph
    - * 
    - * - *

    - * Incorporating cross-phase {@link TaskGraph}s withing a single {@link TornadoExecutionPlan} - * is necessary to enable KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) sharing - * across prefill and decode phases. The KV cache pointers are chained across {@link TaskGraph}s - * via the {@code persistOnDevice}/{@code consumeFromDevice} API within the {@link TornadoExecutionPlan}. - *

    - * - *

    KV cache pointer chain across phases:

    - *
    - *   batchLayer[N-1]  --persistOnDevice(wrapKeyCache)-→
    - *   decodeActivation --consumeFromDevice(wrapKeyCache)-→  (pass-through)
    - *   decodeLayer[0]   --consumeFromDevice(wrapKeyCache)-→  (used by attention)
    - * 
    - */ -public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan { - - private final LlamaState state; - private final Model model; - private final LlamaConfiguration config; - private final int batchSize; - private final int N; // numberOfLayers - private final TornadoExecutionPlan executionPlan; - private final GridScheduler gridScheduler; - - // ── Graph-index helpers ─────────────────────────────────────────────────── - private int batchActivationIdx() { return 0; } - private int batchLayerIdx(int i) { return 1 + i; } - private int decodeActivationIdx() { return N + 1; } - private int decodeLayerIdx(int i) { return N + 2 + i; } - private int logitsIdx() { return 2 * N + 2; } - - // ── Construction ───────────────────────────────────────────────────────── - TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) { - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now - this.model = model; - this.config = (LlamaConfiguration) model.configuration(); - this.batchSize = PREFILL_BATCH_SIZE; - this.N = config.numberOfLayers(); - this.gridScheduler = new GridScheduler(); - - long startTime = System.nanoTime(); - this.executionPlan = createExecutionPlan(); - long planCreationTime = System.nanoTime(); - - if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); - executionPlan.withPreCompilation(); - long warmupTime = System.nanoTime(); - - forceCopyInReadOnlyData(); - long copyTime = System.nanoTime(); - - RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); - } - - // ── Batch Prefill Activation graphs ───────────────────────────────────────────────────── - - /** Graph 0 (FP16): B×dim FP16 embeddings → FP32 wrapXBatch. */ - private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("prefillActivation") - .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, state.embeddingXBatch, state.wrapXBatch) - .persistOnDevice(state.wrapXBatch); - } - - /** Graph 0 (Q8_0): wrapXBatch pre-filled with FP32 by host; upload and persist. */ - private TaskGraph buildQ8_0BatchPrefillActivationGraph(KernelContext ctx) { - return new TaskGraph("prefillActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) - .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough, - ctx, state.wrapXBatch) - .persistOnDevice(state.wrapXBatch); - } - - /** - * Graph N+1: single-token embedding → FP32 wrapX, with KV-cache pass-through. - * - *

    Receives the KV-cache device pointer from batch layer N via - * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so - * that {@code updatePersistedObjectState()} can propagate it to decode layer 0. - * Both halves of the chain are required; without the re-persist the pointer is - * not forwarded in interpreter (non-CUDA-graph) mode.

    - */ - private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID, - GGMLType weightType) { - TaskGraph tg = new TaskGraph("decodeActivation") - .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); - if (weightType == GGMLType.Q8_0) { - tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, - ctx, (ByteArray) state.embeddingX, state.wrapX); - } else { - tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX); - } - return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); - } - - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "Batched prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - List all = new ArrayList<>(2 * N + 3); - - // [0] Batch prefill activation ──────────────────────────────────────── - KernelContext batchActCtx = new KernelContext(); - if (weightType == GGMLType.F16) { - all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot()); - gridScheduler.addWorkerGrid("prefillActivation.updateX", - WorkerGridFactory.genericWorker(batchSize * config.dim(), 128)); - } else { - all.add(buildQ8_0BatchPrefillActivationGraph(batchActCtx).snapshot()); - gridScheduler.addWorkerGrid("prefillActivation.batchPassthrough", - WorkerGridFactory.genericWorker(1, 1)); - } - - // [1..N] Batch prefill layer graphs ─────────────────────────────────── - String lastBatchLayerID; - if (weightType == GGMLType.F16) { - LlamaFP16LayersBatchPrefill batchLayers = - new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); - all.addAll(batchLayers.getLayerImmutableTaskGraphs()); - batchLayers.updateGridScheduler(gridScheduler); - lastBatchLayerID = batchLayers.getLastLayerTaskGraphID(); - } else { - LlamaQ8_0LayersBatchPrefill batchLayers = - new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); - all.addAll(batchLayers.getLayerImmutableTaskGraphs()); - batchLayers.updateGridScheduler(gridScheduler); - lastBatchLayerID = batchLayers.getLastLayerTaskGraphID(); - } - - // [N+1] Decode activation (with KV-cache pass-through) ──────────────── - KernelContext decodeActCtx = new KernelContext(); - all.add(buildDecodeActivationGraph(decodeActCtx, lastBatchLayerID, weightType).snapshot()); - gridScheduler.addWorkerGrid("decodeActivation.updateX", - WorkerGridFactory.genericWorker(config.dim(), 128)); - - // [N+2..2N+1] Decode layer graphs ───────────────────────────────────── - AbstractFFNLayers decodeLayers; - if (weightType == GGMLType.F16) { - decodeLayers = new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); - } else { - decodeLayers = new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); - } - all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); - decodeLayers.updateGridScheduler(gridScheduler); - - // [2N+2] Logits ─────────────────────────────────────────────────────── - AbstractLogitsLayer logitsLayer; - if (weightType == GGMLType.F16) { - logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); - } else { - logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); - } - all.add(logitsLayer.getImmutableTaskGraph()); - logitsLayer.updateGridScheduler(gridScheduler); - - return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); - } - - - /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ - @Override - public void forceCopyInReadOnlyData() { - state.wrapXBatch.clear(); - state.wrapX.clear(); - state.positionHolder.init(0); - state.batchStartPosHolder.init(0); - - for (int i = 0; i <= logitsIdx(); i++) { - var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) g.withCUDAGraph(); - g.execute(); - } - } - - // ── Forward passes ──────────────────────────────────────────────────────── - - /** - * Batch prefill: runs graphs 0..N (activation + N layers), skips logits. - * - * @param tokenIds token IDs for this chunk (length == batchSize, or tail) - * @param startPos sequence position of tokenIds[0] - * @param model model (for embedding table) - * @param chunkSize actual number of tokens in this chunk (≤ batchSize) - */ - public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) { - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - int dim = config.dim(); - - if (weights.getWeightType() == GGMLType.Q8_0) { - // CPU dequantizes Q8_0 embeddings into wrapXBatch (FP32) - ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); - int blockSize = 32; - int Q8_0_BLOCK_BYTES = 34; - int blocksPerRow = (dim + blockSize - 1) / blockSize; - for (int b = 0; b < chunkSize; b++) { - int tokenId = tokenIds[b]; - for (int j = 0; j < dim; j++) { - int blockByteOffset = (tokenId * blocksPerRow + j / blockSize) * Q8_0_BLOCK_BYTES; - float scale = embTable.getHalfFloat(blockByteOffset).getFloat32(); - float quant = embTable.get(blockByteOffset + 2 + j % blockSize); - state.wrapXBatch.set(b * dim + j, quant * scale); - } - } - } else { - // FP16: copy raw half-float bytes into embeddingXBatch - MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); - int bytes = Short.BYTES; - for (int b = 0; b < chunkSize; b++) { - MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, - state.embeddingXBatch.getSegment(), (long) b * dim * bytes, - (long) dim * bytes); - } - } - state.batchStartPosHolder.set(0, startPos); - - // Graph 0: batch activation - var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) batchAct.withCUDAGraph(); - batchAct.execute(); - - // Graphs 1..N: batch transformer layers - for (int l = 0; l < N; l++) { - var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) batchLayer.withCUDAGraph(); - batchLayer.execute(); - } - // Logits skipped — not needed for prefill positions. - } - - /** - * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits). - * - * @param token token ID to process - * @param position sequence position - * @param model model (for embedding table) - * @return logits array for sampling - */ - public FloatArray tornadoVMForwardDecode(int token, int position, Model model) { - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - int dim = config.dim(); - - if (weights.getWeightType() == GGMLType.Q8_0) { - ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); - int blocksPerRow = (dim + 31) / 32; - long bytesPerToken = (long) blocksPerRow * 34; - MemorySegment.copy(embTable.getSegment(), (long) token * bytesPerToken, - state.embeddingX.getSegment(), 0L, bytesPerToken); - } else { - MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); - MemorySegment.copy(embTable, (long) token * dim * Short.BYTES, - state.embeddingX.getSegment(), 0L, (long) dim * Short.BYTES); - } - - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - // Graph N+1: decode activation - var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) decodeAct.withCUDAGraph(); - decodeAct.execute(); - - // Graphs N+2..2N+1: decode transformer layers - for (int l = 0; l < N; l++) { - var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) decodeLayer.withCUDAGraph(); - //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)"); - decodeLayer.execute(); - } - - state.tempLogits.clear(); - state.wrapLogits.clear(); - - // Graph 2N+2: logits - var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) logits.withCUDAGraph(); - logits.execute(); - - return state.wrapLogits; - } - - @Override - public FloatArray tornadoVMForwardExecuteLayered(int position) { - throw new UnsupportedOperationException( - "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan"); - } - - @Override - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); - } - -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java deleted file mode 100644 index 500882ea..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java +++ /dev/null @@ -1,244 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.RunMetrics; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.TornadoExecutionPlan; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; - -import java.util.ArrayList; -import java.util.List; - -/** - * GPU execution plan for sequential (single-token) prefill/decode separation. - * - *

    A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache - * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on - * device across both phases. Prefill and decode reuse the same N layer graphs; - * only the logits graph is skipped during prefill.

    - * - *

    Graph layout (N+2 graphs total):

    - *
    - *   [0]      decodeActivation    single-token FP16 → FP32; KV-cache allocated on first execution
    - *   [1..N]   layer_0..layer_N-1  transformer layers (attention + FFN)
    - *   [N+1]    logits              final RMSNorm + wcls matmul
    - * 
    - * - *

    Two forward passes:

    - *
      - *
    • {@link #tornadoVMForwardPrefill} — graphs 0..N (activation + layers), logits skipped. - * Called once per prompt token; populates the KV cache.
    • - *
    • {@link #tornadoVMForwardDecode} — full pass including logits. - * Called once per generated token; returns logits for sampling.
    • - *
    - */ -public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan { - - private final LlamaState state; - private final Model model; - private final LlamaConfiguration config; - private final int N; // numberOfLayers - private final TornadoExecutionPlan executionPlan; - private final GridScheduler gridScheduler; - - // ── Graph-index helpers ─────────────────────────────────────────────────── - private int activationIdx() { return 0; } - private int layerIdx(int i) { return 1 + i; } - private int logitsIdx() { return N + 1; } - - // ── Construction ───────────────────────────────────────────────────────── - TornadoVMMasterPlanWithPrefillDecode(State initialState, Model model) { - if (ENABLE_TORNADOVM_INIT_TIME) { - System.err.println("\nStarting TornadoVM initialization..."); - } - - this.state = (LlamaState) initialState; - this.model = model; - this.config = (LlamaConfiguration) model.configuration(); - this.N = config.numberOfLayers(); - this.gridScheduler = new GridScheduler(); - - long startTime = System.nanoTime(); - this.executionPlan = createExecutionPlan(); - long planCreationTime = System.nanoTime(); - - if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph(); - executionPlan.withPreCompilation(); - long warmupTime = System.nanoTime(); - - forceCopyInReadOnlyData(); - long copyTime = System.nanoTime(); - - RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime); - } - - // ── Activation graph ───────────────────────────────────────────────────── - - /** - * Graph 0: single-token FP16 → FP32. - * - *

    Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that - * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}. - * The KV cache is not managed here — it is allocated on the first forward pass - * by decode layer 0 via {@code FIRST_EXECUTION}.

    - */ - private TaskGraph buildActivationGraph(KernelContext ctx, GGMLType weightType) { - if (weightType == GGMLType.Q8_0) { - return new TaskGraph("decodeActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, - ctx, (ByteArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - return new TaskGraph("decodeActivation") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertFP16toFP32, - ctx, (HalfFloatArray) state.embeddingX, state.wrapX) - .persistOnDevice(state.wrapX); - } - - // ── Plan construction ───────────────────────────────────────────────────── - @Override - public TornadoExecutionPlan createExecutionPlan() { - GGMLType weightType = model.weights().getWeightType(); - if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) { - throw new UnsupportedOperationException( - "Prefill/decode GPU path not supported for weight type: " + weightType); - } - - LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights(); - SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model); - - List all = new ArrayList<>(N + 2); - - // [0] Activation ────────────────────────────────────────────────────── - KernelContext actCtx = new KernelContext(); - all.add(buildActivationGraph(actCtx, weightType).snapshot()); - gridScheduler.addWorkerGrid("decodeActivation.updateX", - WorkerGridFactory.genericWorker(config.dim(), 128)); - - // [1..N] Decode layer graphs ────────────────────────────────────────── - AbstractFFNLayers decodeLayers; - if (weightType == GGMLType.F16) { - decodeLayers = new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); - } else { - decodeLayers = new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); - } - all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); - decodeLayers.updateGridScheduler(gridScheduler); - - // [N+1] Logits ──────────────────────────────────────────────────────── - AbstractLogitsLayer logitsLayer; - if (weightType == GGMLType.F16) { - logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); - } else { - logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config, - decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType); - } - all.add(logitsLayer.getImmutableTaskGraph()); - logitsLayer.updateGridScheduler(gridScheduler); - - return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0])); - } - - // ── Initialisation ──────────────────────────────────────────────────────── - - /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */ - @Override - public void forceCopyInReadOnlyData() { - state.wrapX.clear(); - state.positionHolder.init(0); - - for (int i = 0; i <= logitsIdx(); i++) { - var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) g.withCUDAGraph(); - g.execute(); - } - } - - // ── Forward passes ──────────────────────────────────────────────────────── - - /** - * GPU prefill forward: activation + all transformer layers, logits skipped. - * KV cache is populated for each prompt token. - * - * @param position sequence position being processed - */ - public void tornadoVMForwardPrefill(int position) { - var prefillActivation = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) prefillActivation.withCUDAGraph(); - prefillActivation.execute(); - - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - for (int layer = 0; layer < N; layer++) { - var prefillLayer = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) prefillLayer.withCUDAGraph(); - prefillLayer.execute(); - } - } - - /** - * GPU decode forward: full execution including logits. - * - * @param position sequence position being processed - * @return logits array for token sampling - */ - public FloatArray tornadoVMForwardDecode(int position) { - return tornadoVMForwardExecuteLayered(position); - } - - @Override - public FloatArray tornadoVMForwardExecuteLayered(int position) { - var act = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) act.withCUDAGraph(); - act.execute(); - - state.positionHolder.set(0, position); - state.temp.clear(); - state.tempFFN.clear(); - - for (int layer = 0; layer < N; layer++) { - var l = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) l.withCUDAGraph(); - l.execute(); - } - - state.tempLogits.clear(); - state.wrapLogits.clear(); - var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler); - if (CUDA_GRAPHS) logits.withCUDAGraph(); - logits.execute(); - - return state.wrapLogits; - } - - @Override - public void freeTornadoExecutionPlan() { - executionPlan.freeDeviceMemory(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java index 6c688649..ecc81296 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java @@ -15,7 +15,7 @@ * Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}. * Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.

    * - *

    These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode} + *

    These kernels are meant to be registered in {@link TornadoVMMasterPlanBatchPrefillDecode} * batch task graphs; they are NOT invoked directly.

    */ public final class TransformerBatchPrefillKernels { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java deleted file mode 100644 index 9e211051..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java +++ /dev/null @@ -1,14 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; - -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; - -import java.util.List; - -public interface GenericLayerPlanner { - - List getImmutableTaskGraphs(); - - GridScheduler getGridScheduler(); - -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java deleted file mode 100644 index 42d2dc0c..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java +++ /dev/null @@ -1,97 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; - -import org.beehive.gpullama3.inference.state.DevstralState; -import org.beehive.gpullama3.inference.state.GraniteState; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.DevstralFP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.DevstralQ8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner; - -/** - * Factory class responsible for creating appropriate layer planners based on model type and quantization. - *

    - * The factory follows a routing logic: - *

      - *
    1. Determine quantization type from {@link GGMLType}
    2. - *
    3. Determine model type from {@link Model}
    4. - *
    5. Instantiate appropriate planner implementation
    6. - *
    - *

    - * Examples: - *

      - *
    • {@code QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner}
    • - *
    • {@code QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner}
    • - *
    - */ -public class QuantizationPlannerFactory { - - /** - * Main factory method: create planner for given model + quantization - */ - public static GenericLayerPlanner create(GGMLType quantization, State state, Model model) { - return switch (quantization) { - case F32 -> createFP32Planner(state, model); - case F16 -> createFP16Planner(state, model); - case Q8_0 -> createQ8_0Planner(state, model); - case Q4_0 -> createQ4_0Planner(state, model); - default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization); - }; - } - - // ============ FP16 Planners ============ - private static GenericLayerPlanner createFP16Planner(State state, Model model) { - return switch (model.getModelType()) { - case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model); - case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model); - case DEVSTRAL_2 -> new DevstralFP16LayerPlanner((DevstralState) state, model); - case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); - case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); - case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); - case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model); - case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); - default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType()); - }; - } - - // ============ Q8_0 Planners ============ - private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { - return switch (model.getModelType()) { - case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); - case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model); - case DEVSTRAL_2 -> new DevstralQ8_0LayerPlanner((DevstralState) state, model); - case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); - case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); - case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); - case GRANITE -> new GraniteQ8_0LayerPlanner((GraniteState) state, model); - case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); - default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); - }; - } - - // ============ FP32 Planners (FUTURE) ============ - private static GenericLayerPlanner createFP32Planner(State state, Model model) { - throw new UnsupportedOperationException("FP32 planners not yet implemented"); - } - - private static GenericLayerPlanner createQ4_0Planner(State state, Model model) { - throw new UnsupportedOperationException("Q4 planners not yet implemented"); - } - -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java deleted file mode 100644 index 1b7c1953..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java +++ /dev/null @@ -1,103 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; - -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; - -import java.util.ArrayList; -import java.util.List; - -/** - * Abstract base for all quantization-specific planners. - * - * Extracts common state from the model, detects the hardware scheduler type, - * and assembles the full execution plan via createTornadoInferencePlan(). - * Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation. - */ -public abstract class QuantizedLayerPlanner - implements GenericLayerPlanner { - - protected final S state; - protected final C config; - protected final W weights; - protected final KernelContext context; - protected final Model model; - protected final SchedulerType schedulerType; - - protected Activation activationLayer; - protected AbstractFFNLayers ffnLayers; - protected AbstractLogitsLayer logitsLayer; - - private List immutableTaskGraphs; - private GridScheduler gridScheduler; - - @SuppressWarnings("unchecked") - protected QuantizedLayerPlanner(S state, Model model) { - this.state = state; - this.model = model; - this.config = (C) model.configuration(); - this.weights = (W) model.weights(); - this.context = new KernelContext(); - this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); - validateQuantizationType(); - } - - /** Validates that the model weights match the expected quantization type. */ - protected abstract void validateQuantizationType(); - - /** - * Creates the TornadoVM inference execution pipeline. - * It represents the entire Feed-Forward Network (FFN) and consists of: - *
      - *
    • Activation layer
    • - *
    • FFN layers (N transformer layers, model-specific)
    • - *
    • Logits layer
    • - *
    - *

    - * Each component is represented as an {@link ImmutableTaskGraph}, along with a - * corresponding {@link GridScheduler} configuration that defines how tasks are - * mapped on the GPU. - *

    - * This method assembles all components into a unified execution pipeline and - * caches the resulting task graphs and scheduler for reuse across inference runs. - */ - protected final void createTornadoInferencePlan() { - List allTaskGraphs = new ArrayList<>(); - GridScheduler masterScheduler = new GridScheduler(); - - // 1. Activation layer (common to all models) - allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); - activationLayer.updateGridScheduler(masterScheduler); - - // 2. FFN layers (N transformer layers - model-specific) - allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs()); - ffnLayers.updateGridScheduler(masterScheduler); - - // 3. Logits layer (common to all models) - allTaskGraphs.add(logitsLayer.getImmutableTaskGraph()); - logitsLayer.updateGridScheduler(masterScheduler); - - // Cache for future retrievals - this.immutableTaskGraphs = allTaskGraphs; - this.gridScheduler = masterScheduler; - } - - @Override - public final List getImmutableTaskGraphs() { - return this.immutableTaskGraphs; - } - - @Override - public final GridScheduler getGridScheduler() { - return this.gridScheduler; - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java deleted file mode 100644 index f72cfeb7..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.DevstralState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.devstral.DevstralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; - -public class DevstralFP16LayerPlanner extends FP16LayerPlanner { - - public DevstralFP16LayerPlanner(DevstralState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java deleted file mode 100644 index 3b33d98c..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner; - -/** - * Base for all FP16-quantized layer planners. - */ -public abstract class FP16LayerPlanner extends QuantizedLayerPlanner { - - protected FP16LayerPlanner(S state, Model model) { - super(state, model); - } - - @Override - protected void validateQuantizationType() { - if (this.weights.getWeightType() != GGMLType.F16) { - throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType()); - } - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java deleted file mode 100644 index 488e4b39..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.GraniteState; -import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; - -public class GraniteFP16LayerPlanner extends FP16LayerPlanner { - - public GraniteFP16LayerPlanner(GraniteState state, Model model) { - super(state, model); - this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config); - this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java deleted file mode 100644 index d6042c39..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; - -public class LlamaFP16LayerPlanner extends FP16LayerPlanner { - - public LlamaFP16LayerPlanner(LlamaState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java deleted file mode 100644 index 78e1bffb..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; - -public class MistralFP16LayerPlanner extends FP16LayerPlanner { - - public MistralFP16LayerPlanner(LlamaState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java deleted file mode 100644 index 5b50529a..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; - -/** - * Phi3FP16LayerPlanner: Phi3 model with FP16 weights. - * - * Follows the same pattern as Qwen3FP16LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights - Phi3Configuration - * - * Inherits from FP16LayerPlanner - */ -public class Phi3FP16LayerPlanner extends FP16LayerPlanner { - - public Phi3FP16LayerPlanner(Phi3State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java deleted file mode 100644 index a3dcc5e2..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; - -/** - * Qwen2FP16LayerPlanner: Qwen2 model with FP16 weights. - * - * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights - Qwen2Configuration - * - * Inherits from FP16LayerPlanner - */ -public class Qwen2FP16LayerPlanner extends FP16LayerPlanner { - - public Qwen2FP16LayerPlanner(Qwen2State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java deleted file mode 100644 index 76239ae6..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; - -import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; - -/** - * Qwen3FP16LayerPlanner: Qwen3 model with FP16 weights. - * - * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights - Qwen3Configuration - * - * Inherits from FP16LayerPlanner - */ -public class Qwen3FP16LayerPlanner extends FP16LayerPlanner { - - public Qwen3FP16LayerPlanner(Qwen3State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java deleted file mode 100644 index b6bf9634..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.DevstralState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.devstral.DevstralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; - -public class DevstralQ8_0LayerPlanner extends Q8_0LayerPlanner { - - public DevstralQ8_0LayerPlanner(DevstralState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java deleted file mode 100644 index f7735dc0..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.GraniteState; -import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; - -public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner { - - public GraniteQ8_0LayerPlanner(GraniteState state, Model model) { - super(state, model); - this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config); - this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsGraniteQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java deleted file mode 100644 index 827ae538..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; - -public class LlamaQ8_0LayerPlanner extends Q8_0LayerPlanner { - - public LlamaQ8_0LayerPlanner(LlamaState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java deleted file mode 100644 index 65e0150c..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; - -public class MistralQ8_0LayerPlanner extends Q8_0LayerPlanner { - - public MistralQ8_0LayerPlanner(LlamaState state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java deleted file mode 100644 index 6f088d36..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; - -/** - * Phi3Q8_0LayerPlanner: Phi3 model with Q8_0-quantized weights. - * - * Follows the same pattern as Qwen3Q8_0LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights (8-bit integer quantization) - Phi3Configuration - 2x - * memory compression vs FP16 - * - * Inherits from Q8_0LayerPlanner - */ -public class Phi3Q8_0LayerPlanner extends Q8_0LayerPlanner { - - public Phi3Q8_0LayerPlanner(Phi3State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java deleted file mode 100644 index b525d2a3..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner; - -/** - * Base for all Q8_0-quantized layer planners. - */ -public abstract class Q8_0LayerPlanner - extends QuantizedLayerPlanner { - - protected Q8_0LayerPlanner(S state, Model model) { - super(state, model); - } - - @Override - protected void validateQuantizationType() { - if (this.weights.getWeightType() != GGMLType.Q8_0) { - throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType()); - } - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java deleted file mode 100644 index 090812e7..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; - -/** - * Qwen2Q8_0LayerPlanner: Qwen2 model with Q8_0-quantized weights. - * - * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights (8-bit integer quantization) - Qwen2Configuration - - * 2x memory compression vs FP16 - * - * Inherits from Q8_0LayerPlanner - */ -public class Qwen2Q8_0LayerPlanner extends Q8_0LayerPlanner { - - public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java deleted file mode 100644 index b5a19be2..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; - -import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; -import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; - -/** - * Qwen3Q8_0LayerPlanner: Qwen3 model with Q8_0-quantized weights. - * - * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights (8-bit integer quantization) - Qwen3Configuration - 2x memory - * compression vs FP16 - * - * Inherits from Q8_0LayerPlanner - */ -public class Qwen3Q8_0LayerPlanner extends Q8_0LayerPlanner { - - public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) { - super(state, model); - this.activationLayer = new Activation("activationUpdate", state, weights, config); - this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); - this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); - createTornadoInferencePlan(); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java deleted file mode 100644 index 28b568a2..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java +++ /dev/null @@ -1,5 +0,0 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.strategy; - -public enum SchedulerType { - NVIDIA, NON_NVIDIA -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java index 1a73897e..d0b3656e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -3,7 +3,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; @@ -14,7 +14,7 @@ * Abstract base class for all FFN (Feed-Forward Network) layer implementations. * Extended by model and quantization-specific subclasses that provide specific implementations. */ -public abstract class AbstractFFNLayers extends AbstractLayer { +public abstract class AbstractFFNLayers extends AbstractLayer implements TransformerLayerTaskGraphs { /** * List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java index 37288e5f..afead80d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 6ba36f39..d35f8252 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; @@ -13,7 +13,7 @@ import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -public class Activation extends AbstractLayer { +public class Activation extends AbstractLayer implements ActivationGraph { private final TaskGraph activationTaskGraph; public Activation(String name, State state, Weights weights, Configuration config) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java new file mode 100644 index 00000000..c46479bc --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java @@ -0,0 +1,15 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +/** + * Common interface for single activation graphs (standard, decode, batch-prefill variants). + * + *

    Implemented by {@link Activation} and custom activation wrappers used by + * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.

    + */ +public interface ActivationGraph { + ImmutableTaskGraph getImmutableTaskGraph(); + GridScheduler updateGridScheduler(GridScheduler scheduler); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java new file mode 100644 index 00000000..0e7a762f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Interface for a group of N batched-prefill transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph}. + * + *

    Implemented by {@code LlamaFP16LayersBatchPrefill} and {@code LlamaQ8_0LayersBatchPrefill}.

    + */ +public interface BatchPrefillTransformerLayerTaskGraphs { + List getLayerImmutableTaskGraphs(); + void updateGridScheduler(GridScheduler scheduler); + String getLastLayerTaskGraphID(); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java new file mode 100644 index 00000000..245b3f40 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Interface for a group of N transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph} (standard or prefill-decode variants). + * + *

    Implemented by {@link AbstractFFNLayers} and its subclasses.

    + */ +public interface TransformerLayerTaskGraphs { + List getFFNLayerImmutableTaskGraphs(); + GridScheduler updateGridScheduler(GridScheduler scheduler); + String getLastFFNLayerTaskGraphID(); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java index 7fcc8717..e18c2a3a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 80f4a6ea..46aad290 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -7,8 +7,8 @@ import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index a7b1e87b..928954e9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index 1858408e..e3974ade 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -7,8 +7,8 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java index 54ec9641..2e3a6fe6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java index 499fc176..aa26bb69 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 065c6802..1d443ce3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index d6ac0eee..16e35a39 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -6,8 +6,8 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 60b3cc3e..0530e7c0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index f20d5bed..ea0220f6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -3,14 +3,14 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** * Decode FFN layers of the unified batched prefill-decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

    Overrides data-transfer declarations so that all cross-graph boundaries use * the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java index 2f5bac64..c1167796 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -3,14 +3,14 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** * Decode FFN layers for the single-token prefill/decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *

    Combines two concerns:

    *
      diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java index 350e6760..7ee529d0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java @@ -3,13 +3,13 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import uk.ac.manchester.tornado.api.TaskGraph; /** * Logits layer of the unified batched prefill-decode plan - * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device * pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index a44425ef..44ce4b35 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -17,7 +18,7 @@ /** * Prefill FFN layers with batching for the unified batched prefill-decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      One {@link ImmutableTaskGraph} per transformer layer, each processing * {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.

      @@ -25,7 +26,7 @@ *

      KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device * after every layer so the subsequent single-token decode layers can consume it.

      */ -public class LlamaFP16LayersBatchPrefill { +public class LlamaFP16LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { // Matches the local workgroup size used by the single-token kernels. static final int LOCAL_WORK_GROUP_SIZE = 32; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java index 0c5aedd6..0d717d2e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index 8a91c75a..1aca4714 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index 2cfacdc0..f5683904 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java index c583bb00..dca19c83 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index 1a1313e2..bf66118a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -7,8 +7,8 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java index 64864114..4a6f3151 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java @@ -4,8 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 8f693adc..0ffe25d6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 6cdf32db..936c3573 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -6,8 +6,8 @@ import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 43b88fb1..1696644d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -5,8 +5,8 @@ import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java index e159c99b..0b2f13cb 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java @@ -3,14 +3,14 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** * Decode FFN layers for the unified batched prefill-decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      Layer 0 consumes the KV cache from device (passed through by the decode activation * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java index a1f02ada..388853f7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java @@ -3,13 +3,13 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; import uk.ac.manchester.tornado.api.TaskGraph; /** * Decode FFN layers for the single-token prefill/decode plan - * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}). + * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *

      Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache}, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java index 054b7f7f..2a6714f2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java @@ -3,7 +3,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import uk.ac.manchester.tornado.api.TaskGraph; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java index 79ea9bc2..14c30462 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java @@ -4,7 +4,8 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; -import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -30,7 +31,7 @@ *

    1. Weight matrices are {@code ByteArray} (Q8_0 format).
    2. * */ -public class LlamaQ8_0LayersBatchPrefill { +public class LlamaQ8_0LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { static final int LOCAL_WORK_GROUP_SIZE = 32; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java new file mode 100644 index 00000000..1246265c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java @@ -0,0 +1,75 @@ +package org.beehive.gpullama3.tornadovm.plan; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer; +import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Topology plan for the 2N+3 batch-prefill/decode forward pass. + * + *

      Graph layout:

      + *
      + *   [0]         batchPrefillActivation
      + *   [1..N]      batch-prefill transformer layers
      + *   [N+1]       decodeActivation  (consumes + re-persists KV cache)
      + *   [N+2..2N+1] decode transformer layers
      + *   [2N+2]      logits
      + * 
      + * + *

      During batch prefill, the master plan executes graphs 0..N. + * During decode, graphs N+1..2N+2 run.

      + */ +public class BatchPrefillDecodeForwardPlan extends ForwardPlan { + + private final BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout; + private final EmbeddingPreparer embeddingPreparer; + + public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanComponents components, int batchSize) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new BatchPrefillDecodeForwardTaskGraphLayout(N); + this.embeddingPreparer = components.embeddingPreparer(); + + List all = new ArrayList<>(2 * N + 3); + GridScheduler scheduler = new GridScheduler(); + + ActivationGraph batchAct = components.batchPrefillActivation(batchSize); + all.add(batchAct.getImmutableTaskGraph()); + batchAct.updateGridScheduler(scheduler); + + BatchPrefillTransformerLayerTaskGraphs batchLayers = components.batchPrefillLayers(batchSize); + all.addAll(batchLayers.getLayerImmutableTaskGraphs()); + batchLayers.updateGridScheduler(scheduler); + + ActivationGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID()); + all.add(decodeAct.getImmutableTaskGraph()); + decodeAct.updateGridScheduler(scheduler); + + TransformerLayerTaskGraphs decodeLayers = components.batchDecodeLayers(); + all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); + decodeLayers.updateGridScheduler(scheduler); + + AbstractLogitsLayer logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID()); + all.add(logits.getImmutableTaskGraph()); + logits.updateGridScheduler(scheduler); + + setGraphs(all, scheduler); + } + + public BatchPrefillDecodeForwardTaskGraphLayout getTaskGraphLayout() { + return taskGraphLayout; + } + + public EmbeddingPreparer getEmbeddingPreparer() { + return embeddingPreparer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java new file mode 100644 index 00000000..0e597af2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java @@ -0,0 +1,16 @@ +package org.beehive.gpullama3.tornadovm.plan; + +/** + * Selects the GPU execution topology for a TornadoVM inference plan. + * + *
        + *
      • {@link #STANDARD} — single forward pass (activation + N layers + logits).
      • + *
      • {@link #PREFILL_DECODE} — shared N+2 graph plan; prefill skips logits, decode runs all.
      • + *
      • {@link #BATCH_PREFILL_DECODE} — 2N+3 graph plan; N batch-prefill graphs + N decode graphs + logits.
      • + *
      + */ +public enum ExecutionMode { + STANDARD, + PREFILL_DECODE, + BATCH_PREFILL_DECODE +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java new file mode 100644 index 00000000..5ccdb4c0 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java @@ -0,0 +1,32 @@ +package org.beehive.gpullama3.tornadovm.plan; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Abstract base for GPU forward-pass execution plans. + * + *

      Subclasses assemble the {@link ImmutableTaskGraph} list and {@link GridScheduler} + * in their constructors by calling {@link #setGraphs}, then expose the results + * via {@link #getImmutableTaskGraphs} and {@link #getGridScheduler}.

      + */ +public abstract class ForwardPlan { + + private List immutableTaskGraphs; + private GridScheduler gridScheduler; + + protected final void setGraphs(List itgs, GridScheduler scheduler) { + this.immutableTaskGraphs = itgs; + this.gridScheduler = scheduler; + } + + public final List getImmutableTaskGraphs() { + return immutableTaskGraphs; + } + + public final GridScheduler getGridScheduler() { + return gridScheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java new file mode 100644 index 00000000..4eac09b6 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java @@ -0,0 +1,207 @@ +package org.beehive.gpullama3.tornadovm.plan; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.DevstralFP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.GraniteFP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.LlamaFP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.MistralFP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.Phi3FP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.Qwen2FP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.fp16.Qwen3FP16PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.DevstralQ8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.GraniteQ8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.LlamaQ8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.MistralQ8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Phi3Q8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Qwen2Q8_0PlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Qwen3Q8_0PlanComponents; + +/** + * Factory for {@link ForwardPlan} instances. + * + *

      Dispatches across three axes in order: + *

        + *
      1. Quantization ({@link GGMLType})
      2. + *
      3. Model family ({@link org.beehive.gpullama3.model.ModelType})
      4. + *
      5. Execution mode ({@link ExecutionMode})
      6. + *
      + * + *

      Use the typed convenience methods when the execution mode is known at the call site:

      + *
        + *
      • {@link #createSingleToken} — returns {@link SingleTokenForwardPlan}
      • + *
      • {@link #createPrefillDecode} — returns {@link PrefillDecodeForwardPlan}
      • + *
      • {@link #createBatchPrefillDecode} — returns {@link BatchPrefillDecodeForwardPlan}
      • + *
      + */ +public class ForwardPlanFactory { + + private ForwardPlanFactory() {} + + // ── Typed public API ────────────────────────────────────────────────────── + + public static SingleTokenForwardPlan createSingleToken(GGMLType quantization, State state, Model model) { + ForwardPlan plan = create(quantization, ExecutionMode.STANDARD, state, model); + if (plan instanceof SingleTokenForwardPlan singleToken) return singleToken; + throw new IllegalStateException("Expected SingleTokenForwardPlan for STANDARD mode but got " + plan.getClass().getSimpleName()); + } + + public static PrefillDecodeForwardPlan createPrefillDecode(GGMLType quantization, State state, Model model) { + ForwardPlan plan = create(quantization, ExecutionMode.PREFILL_DECODE, state, model); + if (plan instanceof PrefillDecodeForwardPlan prefillDecode) return prefillDecode; + throw new IllegalStateException("Expected PrefillDecodeForwardPlan for PREFILL_DECODE mode but got " + plan.getClass().getSimpleName()); + } + + public static BatchPrefillDecodeForwardPlan createBatchPrefillDecode(GGMLType quantization, State state, Model model) { + ForwardPlan plan = create(quantization, ExecutionMode.BATCH_PREFILL_DECODE, state, model); + if (plan instanceof BatchPrefillDecodeForwardPlan batchPrefillDecode) return batchPrefillDecode; + throw new IllegalStateException("Expected BatchPrefillDecodeForwardPlan for BATCH_PREFILL_DECODE mode but got " + plan.getClass().getSimpleName()); + } + + // ── Generic dispatch ────────────────────────────────────────────────────── + + static ForwardPlan create(GGMLType quantization, ExecutionMode mode, State state, Model model) { + return switch (quantization) { + case F16 -> createFP16Plan(mode, state, model); + case Q8_0 -> createQ8_0Plan(mode, state, model); + case F32 -> throw new UnsupportedOperationException("F32 plans not yet implemented"); + case Q4_0 -> throw new UnsupportedOperationException("Q4_0 plans not yet implemented"); + default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization); + }; + } + + // ── FP16 branch ─────────────────────────────────────────────────────────── + + private static ForwardPlan createFP16Plan(ExecutionMode mode, State state, Model model) { + return switch (model.getModelType()) { + case LLAMA_3 -> createLlamaFP16Plan(mode, (LlamaState) state, model); + case MISTRAL -> createMistralFP16Plan(mode, (LlamaState) state, model); + case DEVSTRAL_2 -> createDevstralFP16Plan(mode, (DevstralState) state, model); + case QWEN_2 -> createQwen2FP16Plan(mode, (Qwen2State) state, model); + case QWEN_3 -> createQwen3FP16Plan(mode, (Qwen3State) state, model); + case PHI_3 -> createPhi3FP16Plan(mode, (Phi3State) state, model); + case GRANITE -> createGraniteFP16Plan(mode, (GraniteState) state, model); + case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2FP16Plan(mode, (Qwen2State) state, model); + default -> throw new UnsupportedOperationException("F16 not supported for model: " + model.getModelType()); + }; + } + + // ── Q8_0 branch ─────────────────────────────────────────────────────────── + + private static ForwardPlan createQ8_0Plan(ExecutionMode mode, State state, Model model) { + return switch (model.getModelType()) { + case LLAMA_3 -> createLlamaQ8_0Plan(mode, (LlamaState) state, model); + case MISTRAL -> createMistralQ8_0Plan(mode, (LlamaState) state, model); + case DEVSTRAL_2 -> createDevstralQ8_0Plan(mode, (DevstralState) state, model); + case QWEN_2 -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model); + case QWEN_3 -> createQwen3Q8_0Plan(mode, (Qwen3State) state, model); + case PHI_3 -> createPhi3Q8_0Plan(mode, (Phi3State) state, model); + case GRANITE -> createGraniteQ8_0Plan(mode, (GraniteState) state, model); + case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model); + default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); + }; + } + + // ── Model+quant helpers — Llama (all 3 modes supported) ────────────────── + + private static ForwardPlan createLlamaFP16Plan(ExecutionMode mode, LlamaState state, Model model) { + BatchPrefillDecodeForwardPlanComponents components = new LlamaFP16PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; + } + + private static ForwardPlan createLlamaQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) { + BatchPrefillDecodeForwardPlanComponents components = new LlamaQ8_0PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; + } + + // ── Model+quant helpers — STANDARD only ────────────────────────────────── + + private static ForwardPlan createMistralFP16Plan(ExecutionMode mode, LlamaState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + F16"); + return new SingleTokenForwardPlan(model, new MistralFP16PlanComponents(state, model)); + } + + private static ForwardPlan createMistralQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + Q8_0"); + return new SingleTokenForwardPlan(model, new MistralQ8_0PlanComponents(state, model)); + } + + private static ForwardPlan createDevstralFP16Plan(ExecutionMode mode, DevstralState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + F16"); + return new SingleTokenForwardPlan(model, new DevstralFP16PlanComponents(state, model)); + } + + private static ForwardPlan createDevstralQ8_0Plan(ExecutionMode mode, DevstralState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + Q8_0"); + return new SingleTokenForwardPlan(model, new DevstralQ8_0PlanComponents(state, model)); + } + + private static ForwardPlan createQwen2FP16Plan(ExecutionMode mode, Qwen2State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + F16"); + return new SingleTokenForwardPlan(model, new Qwen2FP16PlanComponents(state, model)); + } + + private static ForwardPlan createQwen2Q8_0Plan(ExecutionMode mode, Qwen2State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + Q8_0"); + return new SingleTokenForwardPlan(model, new Qwen2Q8_0PlanComponents(state, model)); + } + + private static ForwardPlan createQwen3FP16Plan(ExecutionMode mode, Qwen3State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + F16"); + return new SingleTokenForwardPlan(model, new Qwen3FP16PlanComponents(state, model)); + } + + private static ForwardPlan createQwen3Q8_0Plan(ExecutionMode mode, Qwen3State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + Q8_0"); + return new SingleTokenForwardPlan(model, new Qwen3Q8_0PlanComponents(state, model)); + } + + private static ForwardPlan createPhi3FP16Plan(ExecutionMode mode, Phi3State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + F16"); + return new SingleTokenForwardPlan(model, new Phi3FP16PlanComponents(state, model)); + } + + private static ForwardPlan createPhi3Q8_0Plan(ExecutionMode mode, Phi3State state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + Q8_0"); + return new SingleTokenForwardPlan(model, new Phi3Q8_0PlanComponents(state, model)); + } + + private static ForwardPlan createGraniteFP16Plan(ExecutionMode mode, GraniteState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + F16"); + return new SingleTokenForwardPlan(model, new GraniteFP16PlanComponents(state, model)); + } + + private static ForwardPlan createGraniteQ8_0Plan(ExecutionMode mode, GraniteState state, Model model) { + if (mode != ExecutionMode.STANDARD) + throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + Q8_0"); + return new SingleTokenForwardPlan(model, new GraniteQ8_0PlanComponents(state, model)); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java new file mode 100644 index 00000000..241f542b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java @@ -0,0 +1,57 @@ +package org.beehive.gpullama3.tornadovm.plan; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Topology plan for the N+2 prefill/decode forward pass. + * + *

      Graph layout:

      + *
      + *   [0]      decodeActivation
      + *   [1..N]   decode transformer layers
      + *   [N+1]    logits
      + * 
      + * + *

      During prefill, the master plan executes graphs 0..N (skipping logits). + * During decode, all N+2 graphs run.

      + */ +public class PrefillDecodeForwardPlan extends ForwardPlan { + + private final PrefillDecodeForwardTaskGraphLayout taskGraphLayout; + + public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents components) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new PrefillDecodeForwardTaskGraphLayout(N); + + List all = new ArrayList<>(N + 2); + GridScheduler scheduler = new GridScheduler(); + + ActivationGraph act = components.decodeActivation(); + all.add(act.getImmutableTaskGraph()); + act.updateGridScheduler(scheduler); + + TransformerLayerTaskGraphs layers = components.prefillDecodeLayers(); + all.addAll(layers.getFFNLayerImmutableTaskGraphs()); + layers.updateGridScheduler(scheduler); + + AbstractLogitsLayer logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID()); + all.add(logits.getImmutableTaskGraph()); + logits.updateGridScheduler(scheduler); + + setGraphs(all, scheduler); + } + + public PrefillDecodeForwardTaskGraphLayout getTaskGraphLayout() { + return taskGraphLayout; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java new file mode 100644 index 00000000..bdc20a97 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java @@ -0,0 +1,54 @@ +package org.beehive.gpullama3.tornadovm.plan; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Topology plan for the N+2 single-token forward pass. + * + *

      Graph layout:

      + *
      + *   [0]      activation
      + *   [1..N]   transformer layers
      + *   [N+1]    logits
      + * 
      + */ +public class SingleTokenForwardPlan extends ForwardPlan { + + private final SingleTokenForwardTaskGraphLayout taskGraphLayout; + + public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents components) { + int N = model.configuration().numberOfLayers(); + this.taskGraphLayout = new SingleTokenForwardTaskGraphLayout(N); + + List all = new ArrayList<>(N + 2); + GridScheduler scheduler = new GridScheduler(); + + ActivationGraph act = components.standardActivation(); + all.add(act.getImmutableTaskGraph()); + act.updateGridScheduler(scheduler); + + TransformerLayerTaskGraphs layers = components.standardLayers(); + all.addAll(layers.getFFNLayerImmutableTaskGraphs()); + layers.updateGridScheduler(scheduler); + + AbstractLogitsLayer logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID()); + all.add(logits.getImmutableTaskGraph()); + logits.updateGridScheduler(scheduler); + + setGraphs(all, scheduler); + } + + public SingleTokenForwardTaskGraphLayout getTaskGraphLayout() { + return taskGraphLayout; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java new file mode 100644 index 00000000..10595b27 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java @@ -0,0 +1,25 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; + +/** + * Components for the 2N+3 batch-prefill/decode forward plan. + * + *

      Extends {@link PrefillDecodeForwardPlanComponents} with the batch-prefill + * activation, batch layer group, KV-cache decode activation, and the + * host-side embedding preparer.

      + */ +public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents { + + ActivationGraph batchPrefillActivation(int batchSize); + + ActivationGraph batchDecodeActivation(String lastBatchLayerId); + + TransformerLayerTaskGraphs batchDecodeLayers(); + + BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize); + + EmbeddingPreparer embeddingPreparer(); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java new file mode 100644 index 00000000..074bc4cd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java @@ -0,0 +1,18 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +/** + * Prepares embedding vectors in host memory before GPU activation graphs run. + * + *

      Concrete implementations are format-specific (FP16 byte copy vs Q8_0 CPU dequantization) + * and live in the component-provider classes.

      + */ +public interface EmbeddingPreparer { + /** Clears the batch input buffer and resets the batch-start position holder to zero. */ + void initBatchState(); + + /** Copies or dequantizes embeddings for {@code chunkSize} tokens into the batch buffer and sets the batch-start position holder. */ + void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize); + + /** Copies or dequantizes the embedding for a single decode token into the single-token buffer. */ + void copyDecodeEmbedding(int token); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java new file mode 100644 index 00000000..8f0b6cdd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; + +/** + * Components for the N+2 prefill/decode forward plan. + * + *

      Extends {@link SingleTokenForwardPlanComponents} with the decode-phase + * activation, KV-cache-aware layer group, and decode logits.

      + */ +public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents { + + ActivationGraph decodeActivation(); + + TransformerLayerTaskGraphs prefillDecodeLayers(); + + AbstractLogitsLayer decodeLogits(String previousGraphId); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java new file mode 100644 index 00000000..6ecc2077 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.plan.components; + +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; + +/** + * Components for the single-token forward pass. + * + *

      All model+quantization combinations implement this interface. + * Models that support prefill/decode modes implement the richer + * {@link PrefillDecodeForwardPlanComponents} or {@link BatchPrefillDecodeForwardPlanComponents}.

      + */ +public interface SingleTokenForwardPlanComponents { + + ActivationGraph standardActivation(); + + TransformerLayerTaskGraphs standardLayers(); + + AbstractLogitsLayer standardLogits(String previousGraphId); +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java new file mode 100644 index 00000000..d3617a6c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java @@ -0,0 +1,62 @@ +package org.beehive.gpullama3.tornadovm.plan.components.activation; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +/** + * Decode activation graph with KV-cache pass-through ("decodeActivation"). + * + *

      Used in the 2N+3 batch-prefill/decode plan. Consumes + * {@code wrapKeyCache}/{@code wrapValueCache} from the last batch-prefill layer, + * converts the single-token embedding to FP32, then re-persists the KV cache so + * that decode layer 0 can consume it.

      + */ +public class BatchDecodeActivation implements ActivationGraph { + + private final ImmutableTaskGraph itg; + private final int dim; + + public BatchDecodeActivation(LlamaState state, LlamaConfiguration config, + String lastBatchLayerId, boolean isQ8) { + this.dim = config.dim(); + KernelContext ctx = new KernelContext(); + this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot(); + } + + private TaskGraph buildGraph(KernelContext ctx, LlamaState state, + String lastBatchLayerId, boolean isQ8) { + TaskGraph tg = new TaskGraph("decodeActivation") + .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX); + if (isQ8) { + tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32, + ctx, (ByteArray) state.embeddingX, state.wrapX); + } else { + tg.task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, (HalfFloatArray) state.embeddingX, state.wrapX); + } + return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return itg; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + scheduler.addWorkerGrid("decodeActivation.updateX", + WorkerGridFactory.genericWorker(dim, 128)); + return scheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java new file mode 100644 index 00000000..73eec954 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -0,0 +1,72 @@ +package org.beehive.gpullama3.tornadovm.plan.components.activation; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Batch-prefill activation graph ("prefillActivation"). + * + *
        + *
      • FP16: uploads {@code embeddingXBatch} (EVERY_EXECUTION) and converts to + * FP32 {@code wrapXBatch}.
      • + *
      • Q8_0: uploads FP32 {@code wrapXBatch} pre-filled by host (EVERY_EXECUTION) + * and runs a no-op {@code batchPassthrough} so TornadoVM has at least one task.
      • + *
      + */ +public class BatchPrefillActivation implements ActivationGraph { + + private final ImmutableTaskGraph itg; + private final boolean isQ8; + private final int batchSize; + private final int dim; + + public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int batchSize, boolean isQ8) { + this.isQ8 = isQ8; + this.batchSize = batchSize; + this.dim = config.dim(); + KernelContext ctx = new KernelContext(); + this.itg = buildGraph(ctx, state).snapshot(); + } + + private TaskGraph buildGraph(KernelContext ctx, LlamaState state) { + if (isQ8) { + return new TaskGraph("prefillActivation") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) + .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough, + ctx, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + return new TaskGraph("prefillActivation") + .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, + ctx, state.embeddingXBatch, state.wrapXBatch) + .persistOnDevice(state.wrapXBatch); + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return itg; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + if (isQ8) { + scheduler.addWorkerGrid("prefillActivation.batchPassthrough", + WorkerGridFactory.genericWorker(1, 1)); + } else { + scheduler.addWorkerGrid("prefillActivation.updateX", + WorkerGridFactory.genericWorker(batchSize * dim, 128)); + } + return scheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java new file mode 100644 index 00000000..66f3b9ad --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class DevstralFP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final DevstralState state; + private final LlamaTornadoWeights weights; + private final DevstralConfiguration config; + private final SchedulerType schedulerType; + + public DevstralFP16PlanComponents(DevstralState state, Model model) { + this.state = state; + this.config = (DevstralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java new file mode 100644 index 00000000..5c8a6cfd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class GraniteFP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final GraniteState state; + private final GraniteTornadoWeights weights; + private final GraniteConfiguration config; + private final SchedulerType schedulerType; + + public GraniteFP16PlanComponents(GraniteState state, Model model) { + this.state = state; + this.config = (GraniteConfiguration) model.configuration(); + this.weights = (GraniteTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new ActivationGranite("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsGraniteFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java new file mode 100644 index 00000000..3fc093c3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java @@ -0,0 +1,120 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +import java.lang.foreign.MemorySegment; + +/** + * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + FP16. + * + *

      Provides all layer objects and the FP16 embedding preparer (raw half-float byte copy).

      + */ +public class LlamaFP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationGraph decodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, false); + } + + @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, false); + } + + // ── FFN layer groups ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + return new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + return new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + return new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + + // ── Embedding preparation ───────────────────────────────────────────────── + + @Override public EmbeddingPreparer embeddingPreparer() { + MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(); + int dim = config.dim(); + int bytes = Short.BYTES; + return new EmbeddingPreparer() { + @Override + public void initBatchState() { + state.wrapXBatch.clear(); + state.batchStartPosHolder.init(0); + } + @Override + public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) { + state.batchStartPosHolder.set(0, startPos); + for (int b = 0; b < chunkSize; b++) { + MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes, + state.embeddingXBatch.getSegment(), (long) b * dim * bytes, + (long) dim * bytes); + } + } + @Override + public void copyDecodeEmbedding(int token) { + MemorySegment.copy(embTable, (long) token * dim * bytes, + state.embeddingX.getSegment(), 0L, (long) dim * bytes); + } + }; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java new file mode 100644 index 00000000..5cc01685 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralFP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralFP16PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java new file mode 100644 index 00000000..b98eb47d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3FP16PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java new file mode 100644 index 00000000..2eaae7a2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2FP16PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java new file mode 100644 index 00000000..2ed13ab1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.fp16; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3FP16PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java new file mode 100644 index 00000000..8302f7cb --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.devstral.DevstralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class DevstralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final DevstralState state; + private final LlamaTornadoWeights weights; + private final DevstralConfiguration config; + private final SchedulerType schedulerType; + + public DevstralQ8_0PlanComponents(DevstralState state, Model model) { + this.state = state; + this.config = (DevstralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java new file mode 100644 index 00000000..93f01f10 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class GraniteQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final GraniteState state; + private final GraniteTornadoWeights weights; + private final GraniteConfiguration config; + private final SchedulerType schedulerType; + + public GraniteQ8_0PlanComponents(GraniteState state, Model model) { + this.state = state; + this.config = (GraniteConfiguration) model.configuration(); + this.weights = (GraniteTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new ActivationGranite("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java new file mode 100644 index 00000000..bd82a767 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -0,0 +1,131 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; + +import java.lang.foreign.MemorySegment; + +/** + * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + Q8_0. + * + *

      Batch embedding prep: CPU dequantizes Q8_0 embeddings into {@code wrapXBatch} (FP32). + * Decode embedding prep: raw Q8_0 block copy into {@code embeddingX} for on-device conversion.

      + */ +public class LlamaQ8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents { + + private static final int BLOCK_SIZE = 32; + private static final int Q8_0_BLOCK_BYTES = 34; + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final LlamaConfiguration config; + private final SchedulerType schedulerType; + + public LlamaQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (LlamaConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + // ── Activations ─────────────────────────────────────────────────────────── + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public ActivationGraph decodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override public ActivationGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, true); + } + + @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, true); + } + + // ── FFN layer groups ────────────────────────────────────────────────────── + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs prefillDecodeLayers() { + return new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override public TransformerLayerTaskGraphs batchDecodeLayers() { + return new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) { + return new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } + + @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } + + // ── Embedding preparation ───────────────────────────────────────────────── + + @Override public EmbeddingPreparer embeddingPreparer() { + ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray(); + int dim = config.dim(); + int blocksPerRow = (dim + BLOCK_SIZE - 1) / BLOCK_SIZE; + long bytesPerToken = (long) blocksPerRow * Q8_0_BLOCK_BYTES; + + return new EmbeddingPreparer() { + @Override + public void initBatchState() { + state.wrapXBatch.clear(); + state.batchStartPosHolder.init(0); + } + @Override + public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) { + state.batchStartPosHolder.set(0, startPos); + for (int b = 0; b < chunkSize; b++) { + int tokenId = tokenIds[b]; + for (int j = 0; j < dim; j++) { + int blockByteOffset = (tokenId * blocksPerRow + j / BLOCK_SIZE) * Q8_0_BLOCK_BYTES; + float scale = embTable.getHalfFloat(blockByteOffset).getFloat32(); + float quant = embTable.get(blockByteOffset + 2 + j % BLOCK_SIZE); + state.wrapXBatch.set(b * dim + j, quant * scale); + } + } + } + @Override + public void copyDecodeEmbedding(int token) { + MemorySegment.copy(embTable.getSegment(), token * bytesPerToken, + state.embeddingX.getSegment(), 0L, bytesPerToken); + } + }; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java new file mode 100644 index 00000000..4b8a78ec --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.mistral.MistralConfiguration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class MistralQ8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final LlamaState state; + private final LlamaTornadoWeights weights; + private final MistralConfiguration config; + private final SchedulerType schedulerType; + + public MistralQ8_0PlanComponents(LlamaState state, Model model) { + this.state = state; + this.config = (MistralConfiguration) model.configuration(); + this.weights = (LlamaTornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java new file mode 100644 index 00000000..57cd9508 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Phi3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Phi3State state; + private final Phi3TornadoWeights weights; + private final Phi3Configuration config; + private final SchedulerType schedulerType; + + public Phi3Q8_0PlanComponents(Phi3State state, Model model) { + this.state = state; + this.config = (Phi3Configuration) model.configuration(); + this.weights = (Phi3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java new file mode 100644 index 00000000..534eedc5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen2Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen2State state; + private final Qwen2TornadoWeights weights; + private final Qwen2Configuration config; + private final SchedulerType schedulerType; + + public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { + this.state = state; + this.config = (Qwen2Configuration) model.configuration(); + this.weights = (Qwen2TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java new file mode 100644 index 00000000..3c615b39 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -0,0 +1,42 @@ +package org.beehive.gpullama3.tornadovm.plan.components.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; + +public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final SchedulerType schedulerType; + + public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { + this.state = state; + this.config = (Qwen3Configuration) model.configuration(); + this.weights = (Qwen3TornadoWeights) model.weights(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + } + + @Override public ActivationGraph standardActivation() { + return new Activation("activationUpdate", state, weights, config); + } + + @Override public TransformerLayerTaskGraphs standardLayers() { + return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); + } + + @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java new file mode 100644 index 00000000..aec212bd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java @@ -0,0 +1,21 @@ +package org.beehive.gpullama3.tornadovm.plan.layout; + +/** + * Graph-index arithmetic for the 2N+3 batch-prefill/decode forward plan. + * + *
      + *   [0]         batchPrefillActivation
      + *   [1..N]      batchPrefillLayer_0 .. batchPrefillLayer_{N-1}
      + *   [N+1]       decodeActivation    (consumes + re-persists KV cache)
      + *   [N+2..2N+1] decodeLayer_0 .. decodeLayer_{N-1}
      + *   [2N+2]      logits
      + * 
      + */ +public record BatchPrefillDecodeForwardTaskGraphLayout(int N) { + public int batchActivationIdx() { return 0; } + public int batchLayerIdx(int i) { return 1 + i; } + public int decodeActivationIdx() { return N + 1; } + public int decodeLayerIdx(int i) { return N + 2 + i; } + public int logitsIdx() { return 2 * N + 2; } + public int totalGraphs() { return 2 * N + 3; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java new file mode 100644 index 00000000..c252c92d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.plan.layout; + +/** + * Graph-index arithmetic for the N+2 prefill/decode forward plan. + * + *
      + *   [0]      decodeActivation
      + *   [1..N]   layer_0 .. layer_{N-1}
      + *   [N+1]    logits
      + * 
      + */ +public record PrefillDecodeForwardTaskGraphLayout(int N) { + public int activationIdx() { return 0; } + public int layerIdx(int i) { return 1 + i; } + public int logitsIdx() { return N + 1; } + public int totalGraphs() { return N + 2; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java new file mode 100644 index 00000000..d8fe8d13 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tornadovm.plan.layout; + +/** + * Graph-index arithmetic for the N+2 single-token forward plan. + * + *
      + *   [0]      activation
      + *   [1..N]   layer_0 .. layer_{N-1}
      + *   [N+1]    logits
      + * 
      + */ +public record SingleTokenForwardTaskGraphLayout(int N) { + public int activationIdx() { return 0; } + public int layerIdx(int i) { return 1 + i; } + public int logitsIdx() { return N + 1; } + public int totalGraphs() { return N + 2; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java similarity index 93% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java index 5a81caa8..5e29c528 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layerplanner.strategy; +package org.beehive.gpullama3.tornadovm.scheduling; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.ModelType; @@ -9,7 +9,6 @@ public class SchedulerDetectionService { - public static SchedulerType determineSchedulerType(Model model) { TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime(); String platformName = tornadoRuntime.getBackend(0) @@ -24,4 +23,4 @@ public static SchedulerType determineSchedulerType(Model model) { return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA; } -} \ No newline at end of file +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java new file mode 100644 index 00000000..bbd27169 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java @@ -0,0 +1,5 @@ +package org.beehive.gpullama3.tornadovm.scheduling; + +public enum SchedulerType { + NVIDIA, NON_NVIDIA +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java index af39c133..4d5a55a5 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm.layerplanner; +package org.beehive.gpullama3.tornadovm.scheduling; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.WorkerGrid1D; @@ -98,5 +98,4 @@ private static int findOptimalLocalSize(int size) { } return optimal; } - } From 4e4478af24a77b80111618f7e76050f65d44b33f Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 28 May 2026 15:19:25 +0300 Subject: [PATCH 3/6] Update naming from `ActivationGraph` to `ActivationTaskGraph` across TornadoVM components --- .../beehive/gpullama3/tornadovm/layers/Activation.java | 2 +- .../{ActivationGraph.java => ActivationTaskGraph.java} | 4 ++-- .../tornadovm/plan/BatchPrefillDecodeForwardPlan.java | 6 +++--- .../tornadovm/plan/PrefillDecodeForwardPlan.java | 4 ++-- .../tornadovm/plan/SingleTokenForwardPlan.java | 4 ++-- .../BatchPrefillDecodeForwardPlanComponents.java | 6 +++--- .../components/PrefillDecodeForwardPlanComponents.java | 4 ++-- .../components/SingleTokenForwardPlanComponents.java | 4 ++-- .../components/activation/BatchDecodeActivation.java | 4 ++-- .../components/activation/BatchPrefillActivation.java | 4 ++-- .../components/fp16/DevstralFP16PlanComponents.java | 4 ++-- .../components/fp16/GraniteFP16PlanComponents.java | 4 ++-- .../plan/components/fp16/LlamaFP16PlanComponents.java | 10 +++++----- .../components/fp16/MistralFP16PlanComponents.java | 4 ++-- .../plan/components/fp16/Phi3FP16PlanComponents.java | 4 ++-- .../plan/components/fp16/Qwen2FP16PlanComponents.java | 4 ++-- .../plan/components/fp16/Qwen3FP16PlanComponents.java | 4 ++-- .../components/q8_0/DevstralQ8_0PlanComponents.java | 4 ++-- .../components/q8_0/GraniteQ8_0PlanComponents.java | 4 ++-- .../plan/components/q8_0/LlamaQ8_0PlanComponents.java | 10 +++++----- .../components/q8_0/MistralQ8_0PlanComponents.java | 4 ++-- .../plan/components/q8_0/Phi3Q8_0PlanComponents.java | 4 ++-- .../plan/components/q8_0/Qwen2Q8_0PlanComponents.java | 4 ++-- .../plan/components/q8_0/Qwen3Q8_0PlanComponents.java | 4 ++-- 24 files changed, 55 insertions(+), 55 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{ActivationGraph.java => ActivationTaskGraph.java} (76%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index d35f8252..b25a613a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -13,7 +13,7 @@ import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -public class Activation extends AbstractLayer implements ActivationGraph { +public class Activation extends AbstractLayer implements ActivationTaskGraph { private final TaskGraph activationTaskGraph; public Activation(String name, State state, Weights weights, Configuration config) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java similarity index 76% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java index c46479bc..8a751c18 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java @@ -4,12 +4,12 @@ import uk.ac.manchester.tornado.api.ImmutableTaskGraph; /** - * Common interface for single activation graphs (standard, decode, batch-prefill variants). + * Common interface for a single activation task graph (standard, decode, and batch-prefill variants). * *

      Implemented by {@link Activation} and custom activation wrappers used by * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.

      */ -public interface ActivationGraph { +public interface ActivationTaskGraph { ImmutableTaskGraph getImmutableTaskGraph(); GridScheduler updateGridScheduler(GridScheduler scheduler); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java index 1246265c..55202738 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; @@ -42,7 +42,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC List all = new ArrayList<>(2 * N + 3); GridScheduler scheduler = new GridScheduler(); - ActivationGraph batchAct = components.batchPrefillActivation(batchSize); + ActivationTaskGraph batchAct = components.batchPrefillActivation(batchSize); all.add(batchAct.getImmutableTaskGraph()); batchAct.updateGridScheduler(scheduler); @@ -50,7 +50,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC all.addAll(batchLayers.getLayerImmutableTaskGraphs()); batchLayers.updateGridScheduler(scheduler); - ActivationGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID()); + ActivationTaskGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID()); all.add(decodeAct.getImmutableTaskGraph()); decodeAct.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java index 241f542b..4c91b7fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents; import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout; @@ -36,7 +36,7 @@ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents List all = new ArrayList<>(N + 2); GridScheduler scheduler = new GridScheduler(); - ActivationGraph act = components.decodeActivation(); + ActivationTaskGraph act = components.decodeActivation(); all.add(act.getImmutableTaskGraph()); act.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java index bdc20a97..4db7ad86 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout; @@ -33,7 +33,7 @@ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents comp List all = new ArrayList<>(N + 2); GridScheduler scheduler = new GridScheduler(); - ActivationGraph act = components.standardActivation(); + ActivationTaskGraph act = components.standardActivation(); all.add(act.getImmutableTaskGraph()); act.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java index 10595b27..38b03d30 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.tornadovm.plan.components; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -13,9 +13,9 @@ */ public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents { - ActivationGraph batchPrefillActivation(int batchSize); + ActivationTaskGraph batchPrefillActivation(int batchSize); - ActivationGraph batchDecodeActivation(String lastBatchLayerId); + ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId); TransformerLayerTaskGraphs batchDecodeLayers(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java index 8f0b6cdd..89f53051 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; /** @@ -12,7 +12,7 @@ */ public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents { - ActivationGraph decodeActivation(); + ActivationTaskGraph decodeActivation(); TransformerLayerTaskGraphs prefillDecodeLayers(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java index 6ecc2077..37ef3dc7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; /** @@ -13,7 +13,7 @@ */ public interface SingleTokenForwardPlanComponents { - ActivationGraph standardActivation(); + ActivationTaskGraph standardActivation(); TransformerLayerTaskGraphs standardLayers(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java index d3617a6c..be6f8d02 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java @@ -3,7 +3,7 @@ import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -21,7 +21,7 @@ * converts the single-token embedding to FP32, then re-persists the KV cache so * that decode layer 0 can consume it.

      */ -public class BatchDecodeActivation implements ActivationGraph { +public class BatchDecodeActivation implements ActivationTaskGraph { private final ImmutableTaskGraph itg; private final int dim; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java index 73eec954..17b80cfa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; @@ -22,7 +22,7 @@ * and runs a no-op {@code batchPassthrough} so TornadoVM has at least one task. * */ -public class BatchPrefillActivation implements ActivationGraph { +public class BatchPrefillActivation implements ActivationTaskGraph { private final ImmutableTaskGraph itg; private final boolean isQ8; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java index 66f3b9ad..862b8a92 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; @@ -28,7 +28,7 @@ public DevstralFP16PlanComponents(DevstralState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java index 5c8a6cfd..13f0b0f9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; @@ -28,7 +28,7 @@ public GraniteFP16PlanComponents(GraniteState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new ActivationGranite("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java index 3fc093c3..23b63c8e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; @@ -45,19 +45,19 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) { // ── Activations ─────────────────────────────────────────────────────────── - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public ActivationGraph decodeActivation() { + @Override public ActivationTaskGraph decodeActivation() { return new Activation("decodeActivation", state, weights, config); } - @Override public ActivationGraph batchPrefillActivation(int batchSize) { + @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) { return new BatchPrefillActivation(state, config, batchSize, false); } - @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { return new BatchDecodeActivation(state, config, lastBatchLayerId, false); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java index 5cc01685..542b9827 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers; @@ -28,7 +28,7 @@ public MistralFP16PlanComponents(LlamaState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java index b98eb47d..4a619211 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; @@ -28,7 +28,7 @@ public Phi3FP16PlanComponents(Phi3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java index 2eaae7a2..0a359579 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; @@ -28,7 +28,7 @@ public Qwen2FP16PlanComponents(Qwen2State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java index 2ed13ab1..d18c67e6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; @@ -28,7 +28,7 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java index 8302f7cb..2019beae 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; @@ -28,7 +28,7 @@ public DevstralQ8_0PlanComponents(DevstralState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java index 93f01f10..30c6c28b 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; @@ -28,7 +28,7 @@ public GraniteQ8_0PlanComponents(GraniteState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new ActivationGranite("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java index bd82a767..aa16bd52 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; @@ -50,19 +50,19 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) { // ── Activations ─────────────────────────────────────────────────────────── - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } - @Override public ActivationGraph decodeActivation() { + @Override public ActivationTaskGraph decodeActivation() { return new Activation("decodeActivation", state, weights, config); } - @Override public ActivationGraph batchPrefillActivation(int batchSize) { + @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) { return new BatchPrefillActivation(state, config, batchSize, true); } - @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) { + @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { return new BatchDecodeActivation(state, config, lastBatchLayerId, true); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java index 4b8a78ec..b377a5f1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers; @@ -28,7 +28,7 @@ public MistralQ8_0PlanComponents(LlamaState state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java index 57cd9508..9c9735ac 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; @@ -28,7 +28,7 @@ public Phi3Q8_0PlanComponents(Phi3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java index 534eedc5..d4a22184 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; @@ -28,7 +28,7 @@ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java index 3c615b39..a033849e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; import org.beehive.gpullama3.tornadovm.layers.Activation; -import org.beehive.gpullama3.tornadovm.layers.ActivationGraph; +import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; @@ -28,7 +28,7 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } - @Override public ActivationGraph standardActivation() { + @Override public ActivationTaskGraph standardActivation() { return new Activation("activationUpdate", state, weights, config); } From ea478f871985e5e150aa0a060480d08e2221cf5f Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 28 May 2026 15:25:40 +0300 Subject: [PATCH 4/6] Rename `AbstractFFNLayers` to `AbstractTransformerLayerTaskGraphs` and `AbstractLogitsLayer` to `AbstractLogitsTaskGraph`, updating all references to improve clarity and align with naming conventions. --- ...ayer.java => AbstractLogitsTaskGraph.java} | 6 ++--- ...> AbstractTransformerLayerTaskGraphs.java} | 24 +++++++++---------- .../layers/TransformerLayerTaskGraphs.java | 2 +- .../type/fp16/DevstralFP16FFNLayers.java | 4 ++-- .../type/fp16/GraniteFP16FFNLayers.java | 4 ++-- .../layers/type/fp16/LlamaFP16FFNLayers.java | 4 ++-- .../layers/type/fp16/LogitsFP16Layer.java | 4 ++-- .../type/fp16/MistralFP16FFNLayers.java | 4 ++-- .../layers/type/fp16/Phi3FP16FFNLayers.java | 4 ++-- .../layers/type/fp16/Qwen2FP16FFNLayers.java | 4 ++-- .../layers/type/fp16/Qwen3FP16FFNLayers.java | 4 ++-- .../type/q8_0/DevstralQ8_0FFNLayers.java | 4 ++-- .../type/q8_0/GraniteQ8_0FFNLayers.java | 4 ++-- .../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 4 ++-- .../layers/type/q8_0/LogitsQ8_0Layer.java | 4 ++-- .../type/q8_0/MistralQ8_0FFNLayers.java | 4 ++-- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 4 ++-- .../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 4 ++-- .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 4 ++-- .../plan/BatchPrefillDecodeForwardPlan.java | 4 ++-- .../plan/PrefillDecodeForwardPlan.java | 4 ++-- .../plan/SingleTokenForwardPlan.java | 4 ++-- .../PrefillDecodeForwardPlanComponents.java | 4 ++-- .../SingleTokenForwardPlanComponents.java | 4 ++-- .../fp16/DevstralFP16PlanComponents.java | 4 ++-- .../fp16/GraniteFP16PlanComponents.java | 4 ++-- .../fp16/LlamaFP16PlanComponents.java | 6 ++--- .../fp16/MistralFP16PlanComponents.java | 4 ++-- .../fp16/Phi3FP16PlanComponents.java | 4 ++-- .../fp16/Qwen2FP16PlanComponents.java | 4 ++-- .../fp16/Qwen3FP16PlanComponents.java | 4 ++-- .../q8_0/DevstralQ8_0PlanComponents.java | 4 ++-- .../q8_0/GraniteQ8_0PlanComponents.java | 4 ++-- .../q8_0/LlamaQ8_0PlanComponents.java | 6 ++--- .../q8_0/MistralQ8_0PlanComponents.java | 4 ++-- .../q8_0/Phi3Q8_0PlanComponents.java | 4 ++-- .../q8_0/Qwen2Q8_0PlanComponents.java | 4 ++-- .../q8_0/Qwen3Q8_0PlanComponents.java | 4 ++-- 38 files changed, 88 insertions(+), 88 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{AbstractLogitsLayer.java => AbstractLogitsTaskGraph.java} (86%) rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{AbstractFFNLayers.java => AbstractTransformerLayerTaskGraphs.java} (70%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java similarity index 86% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java index afead80d..026c3d90 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java @@ -9,20 +9,20 @@ import uk.ac.manchester.tornado.api.TaskGraph; /** - * Abstract base for all logits layers (final vocabulary projection step). + * Abstract base for all logits task graphs (final vocabulary projection step). * * Holds the shared fields and calls the protected buildLogitsTaskGraph() hook once * during construction. Subclasses implement buildLogitsTaskGraph() to define the * quantization-specific task sequence; Granite variants override it to swap in * their scaled kernel. */ -public abstract class AbstractLogitsLayer extends AbstractLayer { +public abstract class AbstractLogitsTaskGraph extends AbstractLayer { protected final String lastTaskGraphID; protected final SchedulerType schedulerType; private final TaskGraph logitsTaskGraph; - protected AbstractLogitsLayer(String name, State state, Weights weights, Configuration config, + protected AbstractLogitsTaskGraph(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { super(name, state, weights, config); this.lastTaskGraphID = lastTaskGraphID; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java similarity index 70% rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java index d0b3656e..14a71e88 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java @@ -11,14 +11,14 @@ import java.util.stream.IntStream; /** - * Abstract base class for all FFN (Feed-Forward Network) layer implementations. + * Abstract base class for all transformer-layer task graph implementations. * Extended by model and quantization-specific subclasses that provide specific implementations. */ -public abstract class AbstractFFNLayers extends AbstractLayer implements TransformerLayerTaskGraphs { +public abstract class AbstractTransformerLayerTaskGraphs extends AbstractLayer implements TransformerLayerTaskGraphs { /** - * List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer. - * Build by {@link #setupFFNLayers()}. + * List of TornadoVM {@link ImmutableTaskGraph}s, one per transformer layer. + * Built by {@link #setupFFNLayers()}. */ private List ffnLayerITGs; protected final W weights; @@ -27,7 +27,7 @@ public abstract class AbstractFFNLayers getFFNLayerImmutableTaskGraphs() { } /** - * Returns the TaskGraph ID of the last FFN layer. - * Used by the logits layer to chain its consumeFromDevice call. + * Returns the task graph ID of the last transformer layer. + * Used by the logits task graph to chain its consumeFromDevice call. */ public String getLastFFNLayerTaskGraphID() { return lastFFNLayerTaskGraphID; @@ -91,4 +91,4 @@ public String getLastFFNLayerTaskGraphID() { protected boolean shouldUseFinalNormalization() { return schedulerType == SchedulerType.NON_NVIDIA; } -} \ No newline at end of file +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java index 245b3f40..e815574d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java @@ -8,7 +8,7 @@ /** * Interface for a group of N transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph} (standard or prefill-decode variants). * - *

      Implemented by {@link AbstractFFNLayers} and its subclasses.

      + *

      Implemented by {@link AbstractTransformerLayerTaskGraphs} and its subclasses.

      */ public interface TransformerLayerTaskGraphs { List getFFNLayerImmutableTaskGraphs(); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java index e18c2a3a..0597a1c7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -17,7 +17,7 @@ * FP16 FFN layers for Devstral 2 models. * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. */ -public class DevstralFP16FFNLayers extends AbstractFFNLayers { +public class DevstralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs { public DevstralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 46aad290..2d7234b8 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -9,13 +9,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class GraniteFP16FFNLayers extends AbstractFFNLayers { +public class GraniteFP16FFNLayers extends AbstractTransformerLayerTaskGraphs { public GraniteFP16FFNLayers(String taskGraph, State state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 928954e9..21bae5a6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -7,13 +7,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LlamaFP16FFNLayers extends AbstractFFNLayers { +public class LlamaFP16FFNLayers extends AbstractTransformerLayerTaskGraphs { public LlamaFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java index e3974ade..1295db09 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -9,13 +9,13 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsFP16Layer extends AbstractLogitsLayer { +public class LogitsFP16Layer extends AbstractLogitsTaskGraph { public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java index aa26bb69..cd03b2ba 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java @@ -7,13 +7,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class MistralFP16FFNLayers extends AbstractFFNLayers { +public class MistralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs { public MistralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 1d443ce3..e8ca1e54 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -22,7 +22,7 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3FP16FFNLayers extends AbstractFFNLayers { +public class Phi3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed references to Phi3-specific state and config private final Phi3State phi3State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 16e35a39..747b7213 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -25,7 +25,7 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2FP16FFNLayers extends AbstractFFNLayers { +public class Qwen2FP16FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 0530e7c0..3b298249 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -21,7 +21,7 @@ * * Works directly with Qwen3State to access and mutate Qwen3-specific state fields like tempQcur and tempKcur. */ -public class Qwen3FP16FFNLayers extends AbstractFFNLayers { +public class Qwen3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java index 0d717d2e..26686738 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -16,7 +16,7 @@ * Q8_0 FFN layers for Devstral 2 models. * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. */ -public class DevstralQ8_0FFNLayers extends AbstractFFNLayers { +public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { public DevstralQ8_0FFNLayers(String taskGraphName, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index 1aca4714..78f834f4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -7,13 +7,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { +public class GraniteQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index f5683904..8cab58ee 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -6,13 +6,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { +public class LlamaQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java index bf66118a..725dfd24 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -9,13 +9,13 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsQ8_0Layer extends AbstractLogitsLayer { +public class LogitsQ8_0Layer extends AbstractLogitsTaskGraph { public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java index 4a6f3151..68073c63 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java @@ -6,13 +6,13 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class MistralQ8_0FFNLayers extends AbstractFFNLayers { +public class MistralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { public MistralQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) { super(taskGraphName, state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 0ffe25d6..9739c14f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -21,7 +21,7 @@ * * Works directly with Phi3State to access and mutate Phi3-specific state fields. */ -public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { +public class Phi3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Phi3-specific state private final Phi3State phi3State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index 936c3573..ec96364e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -8,7 +8,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -28,7 +28,7 @@ * * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. */ -public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { +public class Qwen2Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen2-specific state private final Qwen2State qwen2State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 1696644d..9397dc2d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -7,7 +7,7 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; @@ -25,7 +25,7 @@ * Works directly with Qwen3State to access and mutate Qwen3-specific state fields * like tempQcur and tempKcur. */ -public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { +public class Qwen3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state private final Qwen3State qwen3State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java index 55202738..0126c7ef 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -58,7 +58,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs()); decodeLayers.updateGridScheduler(scheduler); - AbstractLogitsLayer logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID()); + AbstractLogitsTaskGraph logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID()); all.add(logits.getImmutableTaskGraph()); logits.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java index 4c91b7fe..d526080d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents; @@ -44,7 +44,7 @@ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents all.addAll(layers.getFFNLayerImmutableTaskGraphs()); layers.updateGridScheduler(scheduler); - AbstractLogitsLayer logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID()); + AbstractLogitsTaskGraph logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID()); all.add(logits.getImmutableTaskGraph()); logits.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java index 4db7ad86..2eca71bd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; @@ -41,7 +41,7 @@ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents comp all.addAll(layers.getFFNLayerImmutableTaskGraphs()); layers.updateGridScheduler(scheduler); - AbstractLogitsLayer logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID()); + AbstractLogitsTaskGraph logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID()); all.add(logits.getImmutableTaskGraph()); logits.updateGridScheduler(scheduler); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java index 89f53051..dad27f94 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.tornadovm.plan.components; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -16,5 +16,5 @@ public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPl TransformerLayerTaskGraphs prefillDecodeLayers(); - AbstractLogitsLayer decodeLogits(String previousGraphId); + AbstractLogitsTaskGraph decodeLogits(String previousGraphId); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java index 37ef3dc7..596156f4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.tornadovm.plan.components; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -17,5 +17,5 @@ public interface SingleTokenForwardPlanComponents { TransformerLayerTaskGraphs standardLayers(); - AbstractLogitsLayer standardLogits(String previousGraphId); + AbstractLogitsTaskGraph standardLogits(String previousGraphId); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java index 862b8a92..00f49f28 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.devstral.DevstralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public DevstralFP16PlanComponents(DevstralState state, Model model) { return new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java index 13f0b0f9..5f47ec94 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public GraniteFP16PlanComponents(GraniteState state, Model model) { return new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsGraniteFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java index 23b63c8e..350d1fd0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; @@ -81,11 +81,11 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) { // ── Logits layers ───────────────────────────────────────────────────────── - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } - @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java index 542b9827..fdd4d43f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public MistralFP16PlanComponents(LlamaState state, Model model) { return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java index 4a619211..fb21c4c2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Phi3FP16PlanComponents(Phi3State state, Model model) { return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java index 0a359579..03fc3c88 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Qwen2FP16PlanComponents(Qwen2State state, Model model) { return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java index d18c67e6..a11270bd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) { return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java index 2019beae..8589ee05 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.devstral.DevstralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public DevstralQ8_0PlanComponents(DevstralState state, Model model) { return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java index 30c6c28b..03e0ccb9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public GraniteQ8_0PlanComponents(GraniteState state, Model model) { return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java index aa16bd52..f7b8c0c2 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; @@ -86,11 +86,11 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) { // ── Logits layers ───────────────────────────────────────────────────────── - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } - @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java index b377a5f1..d4d91993 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public MistralQ8_0PlanComponents(LlamaState state, Model model) { return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java index 9c9735ac..7974a3fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Phi3Q8_0PlanComponents(Phi3State state, Model model) { return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java index d4a22184..fb739835 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) { return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java index a033849e..8ef1881d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -4,7 +4,7 @@ import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer; +import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; @@ -36,7 +36,7 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); } - @Override public AbstractLogitsLayer standardLogits(String previousGraphId) { + @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } } From e20ebc5c35214980d90281cdca2f0b3f20b4d9ab Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 28 May 2026 15:28:42 +0300 Subject: [PATCH 5/6] Refactor FFN layer comments to `transformer-layer task graphs`, aligning with updated naming conventions. --- .../org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java | 2 +- .../tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java | 2 +- .../gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 2 +- .../tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java | 2 +- .../tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java | 2 +- .../layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java | 2 +- .../type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java | 2 +- .../layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java | 2 +- .../tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java | 2 +- .../gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java | 2 +- .../tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 2 +- .../tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 2 +- .../layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 2 +- .../type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java | 2 +- .../layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java | 2 +- .../tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java | 2 +- .../tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java | 2 +- 17 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java index f34f5777..f73a55a3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -8,7 +8,7 @@ import uk.ac.manchester.tornado.api.TaskGraph; /** - * Abstract base class for Activations, FFN Layers, and Logits. + * Abstract base class for activation, transformer-layer, and logits task graphs. */ public abstract class AbstractLayer { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java index 0597a1c7..7b709797 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * FP16 FFN layers for Devstral 2 models. + * FP16 transformer-layer task graphs for Devstral 2 models. * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. */ public class DevstralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index e8ca1e54..b06fb942 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. + * Phi3FP16FFNLayers: FP16 transformer-layer task graphs for Phi3 with Group Query Attention (GQA) support. * * Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for * position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index 747b7213..4b3bbca3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -17,7 +17,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen2FP16FFNLayers: FP16 FFN layers for Qwen2 with Group Query Attention (GQA) support. + * Qwen2FP16FFNLayers: FP16 transformer-layer task graphs for Qwen2 with Group Query Attention (GQA) support. * * Key Differences from Qwen3: - No tempQcur/tempKcur fields in Qwen2State - Includes bias terms for Q, K, V projections - Standard GQA (no parallel offset RMSNorm) - Uses * Qwen2Kernels::processHeadsFlashAttention for attention computation - Uses Qwen3Kernels::ropeRotation for position embeddings - Simpler matrix dimensions (uses config.dim() and config.kvDim() diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 3b298249..d0bee6a9 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. + * Qwen3FP16FFNLayers: FP16 transformer-layer task graphs for Qwen3 with Group Query Attention (GQA) support. * * Key Differences from Llama: - Supports GQA with separate KV heads (nHeadKv) - Uses Qwen3Kernels for RMSNorm with parallel offset - Custom RoPE rotation for Qwen3 - Different attention computation * due to GQA structure diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java index ea0220f6..28d8a1f6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode FFN layers of the unified batched prefill-decode plan + * Decode transformer-layer task graphs of the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      Overrides data-transfer declarations so that all cross-graph boundaries use diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java index c1167796..f3f628b3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode FFN layers for the single-token prefill/decode plan + * Decode transformer-layer task graphs for the single-token prefill/decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *

      Combines two concerns:

      diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java index 44ce4b35..dcfe626e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java @@ -17,7 +17,7 @@ import java.util.stream.IntStream; /** - * Prefill FFN layers with batching for the unified batched prefill-decode plan + * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      One {@link ImmutableTaskGraph} per transformer layer, each processing diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java index 26686738..b8a9fafa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java @@ -13,7 +13,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Q8_0 FFN layers for Devstral 2 models. + * Q8_0 transformer-layer task graphs for Devstral 2 models. * Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation. */ public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 9739c14f..9831486d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Phi3Q8_0FFNLayers: Q8_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support. + * Phi3Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Phi3 with Group Query Attention (GQA) support. * * Key Differences from Phi3FP16FFNLayers: - Uses Q8_0-quantized weights (getQuants() and getScales()) - Same attention and RoPE kernels as FP16 version - 8-bit integer computations with * dequantization - 2x memory compression vs FP16 - Same combined QKV and gate/up FFN structure diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java index ec96364e..9ba6b363 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -17,7 +17,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen2Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support. + * Qwen2Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Qwen2 with Group Query Attention (GQA) support. * * Key Differences from Qwen2FP16FFNLayers: * - Uses Q8_0-quantized weights (getQuants() and getScales()) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 9397dc2d..b27224bd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -14,7 +14,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Qwen3Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support. + * Qwen3Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Qwen3 with Group Query Attention (GQA) support. * * Key Differences from Qwen3FP16FFNLayers: * - Uses Q8_0-quantized weights (getQuants() and getScales()) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java index 0b2f13cb..d1ee5d14 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java @@ -9,7 +9,7 @@ import uk.ac.manchester.tornado.api.enums.DataTransferMode; /** - * Decode FFN layers for the unified batched prefill-decode plan + * Decode transformer-layer task graphs for the unified batched prefill-decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}). * *

      Layer 0 consumes the KV cache from device (passed through by the decode activation diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java index 388853f7..6c4c22b6 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java @@ -8,7 +8,7 @@ import uk.ac.manchester.tornado.api.TaskGraph; /** - * Decode FFN layers for the single-token prefill/decode plan + * Decode transformer-layer task graphs for the single-token prefill/decode plan * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}). * *

      Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java index 14c30462..c0904ef3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java @@ -17,7 +17,7 @@ import java.util.stream.IntStream; /** - * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0). + * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan (Q8_0). * *

      Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill} * but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:

      diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java index 350d1fd0..67392338 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java @@ -61,7 +61,7 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) { return new BatchDecodeActivation(state, config, lastBatchLayerId, false); } - // ── FFN layer groups ────────────────────────────────────────────────────── + // ── Transformer layer task graphs ────────────────────────────────────────────────────── @Override public TransformerLayerTaskGraphs standardLayers() { return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java index f7b8c0c2..bf9a5f29 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java @@ -66,7 +66,7 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) { return new BatchDecodeActivation(state, config, lastBatchLayerId, true); } - // ── FFN layer groups ────────────────────────────────────────────────────── + // ── Transformer layer task graphs ────────────────────────────────────────────────────── @Override public TransformerLayerTaskGraphs standardLayers() { return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType); From c7522d16af0e0baf2f3d35fa28917a7aad32ca17 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 28 May 2026 15:42:24 +0300 Subject: [PATCH 6/6] [ci] Add workflows for Llama-3.2-1B-Instruct Q8_0 inference with prefill-decode and CUDA-graph variants --- .github/workflows/build-and-run.yml | 101 ++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index a17c9228..46dc3bf4 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -385,6 +385,107 @@ jobs: flags="" \ prompt="Say hello" + - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.meta.json" \ + backend="${{ matrix.backend.name }}" \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=prefill-decode \ + "flags=--with-prefill-decode" \ + prompt="Say hello" + + - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.meta.json" \ + backend="${{ matrix.backend.name }}" \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=batch-prefill-decode \ + "flags=--with-prefill-decode --batch-prefill-size 32" \ + prompt="Say hello" + + # ── PTX-only: CUDA-graph variants ──────────────────────────────────────── + - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode \ + --cuda-graphs + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.meta.json" \ + backend=ptx \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=prefill-decode-cuda-graphs \ + "flags=--with-prefill-decode --cuda-graphs" \ + prompt="Say hello" + + - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + env: + JAVA_TOOL_OPTIONS: >- + -Dllama.metrics.format=json + -Dllama.metrics.output=file + -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.json + run: | + cd ${{ github.workspace }} + export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --ptx \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" \ + --with-prefill-decode --batch-prefill-size 32 \ + --cuda-graphs + python3 scripts/write_metrics_sidecar.py \ + --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.meta.json" \ + backend=ptx \ + task=llama-inference \ + model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \ + model=Llama-3.2-1B-Instruct \ + quantization=Q8_0 \ + configuration=batch-prefill-decode-cuda-graphs \ + "flags=--with-prefill-decode --batch-prefill-size 32 --cuda-graphs" \ + prompt="Say hello" + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf env: JAVA_TOOL_OPTIONS: >-