From 45204f1bb3e138b145014ac003ab35106cb6bdd2 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Tue, 26 May 2026 14:01:43 +0300
Subject: [PATCH 1/6] [prf/dec]Implement prefill-decode for Llama Q8_0
---
.../InferenceCoreWithPrefillDecode.java | 10 +-
...adoVMMasterPlanWithBatchPrefillDecode.java | 175 ++++++++-----
.../TornadoVMMasterPlanWithPrefillDecode.java | 53 ++--
.../TransformerBatchPrefillKernels.java | 226 +++++++++++++++++
.../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 11 +-
.../layers/type/q8_0/LogitsQ8_0Layer.java | 5 +
.../q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 55 +++++
.../LlamaQ8_0FFNLayersPrefillDecode.java | 47 ++++
.../q8_0/decode/LogitsQ8_0LayerDecode.java | 35 +++
.../prefill/LlamaQ8_0LayersBatchPrefill.java | 230 ++++++++++++++++++
10 files changed, 759 insertions(+), 88 deletions(-)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
index 91bb6f79..dbd81297 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
@@ -153,9 +153,13 @@ public static void forwardTornadoVMPrefill(Model model, State state, int token,
MemorySegment.copy(tokenEmbeddings, (long) token * configuration.dim() * bytes,
state.embeddingX.getSegment(), 0, (long) configuration.dim() * bytes);
}
- case Q8_0 -> throw new UnsupportedOperationException(
- // TODO Phase 4: implement Q8_0 GPU batched prefill kernels
- "GPU prefill/decode path not yet implemented for Q8_0 weights");
+ case Q8_0 -> {
+ MemorySegment tokenEmbeddings = weights.getTokenEmbeddingTable().asByteArray().getSegment();
+ int blocksPerToken = (configuration.dim() + 31) / 32;
+ long bytesPerToken = (long) blocksPerToken * 34;
+ MemorySegment.copy(tokenEmbeddings, (long) token * bytesPerToken,
+ state.embeddingX.getSegment(), 0, bytesPerToken);
+ }
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
index 2984508e..2fc310e2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
@@ -7,13 +7,20 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -102,7 +109,7 @@ public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMaste
// ── Batch Prefill Activation graphs ─────────────────────────────────────────────────────
- /** Graph 0: B×dim FP16 embeddings → FP32 wrapXBatch. */
+ /** Graph 0 (FP16): B×dim FP16 embeddings → FP32 wrapXBatch. */
private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
return new TaskGraph("prefillActivation")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
@@ -112,8 +119,17 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
.persistOnDevice(state.wrapXBatch);
}
+ /** Graph 0 (Q8_0): wrapXBatch pre-filled with FP32 by host; upload and persist. */
+ private TaskGraph buildQ8_0BatchPrefillActivationGraph(KernelContext ctx) {
+ return new TaskGraph("prefillActivation")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch)
+ .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough,
+ ctx, state.wrapXBatch)
+ .persistOnDevice(state.wrapXBatch);
+ }
+
/**
- * Graph N+1: single-token FP16 → FP32.
+ * Graph N+1: single-token embedding → FP32 wrapX, with KV-cache pass-through.
*
* Receives the KV-cache device pointer from batch layer N via
* {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so
@@ -121,73 +137,86 @@ private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
* Both halves of the chain are required; without the re-persist the pointer is
* not forwarded in interpreter (non-CUDA-graph) mode.
*/
- private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID) {
- return new TaskGraph("decodeActivation")
- .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache) // KV pass-through
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
- .task("updateX",
- TransformerComputeKernels::convertFP16toFP32,
- ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
- // wrapX persisted for decode layer 0; wrapKeyCache/wrapValueCache
- // re-persisted so updatePersistedObjectState() propagates the device
- // pointer to decode layer 0's consumeFromDevice without CUDA graphs.
- .persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
+ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID,
+ GGMLType weightType) {
+ TaskGraph tg = new TaskGraph("decodeActivation")
+ .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX);
+ if (weightType == GGMLType.Q8_0) {
+ tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32,
+ ctx, (ByteArray) state.embeddingX, state.wrapX);
+ } else {
+ tg.task("updateX", TransformerComputeKernels::convertFP16toFP32,
+ ctx, (HalfFloatArray) state.embeddingX, state.wrapX);
+ }
+ return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
}
- /**
- * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*.
- *
- * TODO: support Q8_0 weights
- * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory}
- */
@Override
public TornadoExecutionPlan createExecutionPlan() {
GGMLType weightType = model.weights().getWeightType();
- switch (weightType) {
- case F16 -> { /* supported — continue below */ }
- case Q8_0 -> throw new UnsupportedOperationException(
- "Batched prefill/decode GPU path not yet implemented for Q8_0 weights");
- default -> throw new UnsupportedOperationException(
+ if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) {
+ throw new UnsupportedOperationException(
"Batched prefill/decode GPU path not supported for weight type: " + weightType);
}
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ List all = new ArrayList<>(2 * N + 3);
- List all = new ArrayList<>(2 * N + 3);
-
- // [0] Batch prefill activation ────────────────────────────────────────────────
+ // [0] Batch prefill activation ────────────────────────────────────────
KernelContext batchActCtx = new KernelContext();
- all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
- gridScheduler.addWorkerGrid("prefillActivation.updateX",
- WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
+ if (weightType == GGMLType.F16) {
+ all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
+ gridScheduler.addWorkerGrid("prefillActivation.updateX",
+ WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
+ } else {
+ all.add(buildQ8_0BatchPrefillActivationGraph(batchActCtx).snapshot());
+ gridScheduler.addWorkerGrid("prefillActivation.batchPassthrough",
+ WorkerGridFactory.genericWorker(1, 1));
+ }
- // [1..N] Batch prefill layer graphs ───────────────────────────────────────────
- LlamaFP16LayersBatchPrefill batchLayers =
- new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
- all.addAll(batchLayers.getLayerImmutableTaskGraphs());
- batchLayers.updateGridScheduler(gridScheduler);
+ // [1..N] Batch prefill layer graphs ───────────────────────────────────
+ String lastBatchLayerID;
+ if (weightType == GGMLType.F16) {
+ LlamaFP16LayersBatchPrefill batchLayers =
+ new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
+ all.addAll(batchLayers.getLayerImmutableTaskGraphs());
+ batchLayers.updateGridScheduler(gridScheduler);
+ lastBatchLayerID = batchLayers.getLastLayerTaskGraphID();
+ } else {
+ LlamaQ8_0LayersBatchPrefill batchLayers =
+ new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize);
+ all.addAll(batchLayers.getLayerImmutableTaskGraphs());
+ batchLayers.updateGridScheduler(gridScheduler);
+ lastBatchLayerID = batchLayers.getLastLayerTaskGraphID();
+ }
// [N+1] Decode activation (with KV-cache pass-through) ────────────────
KernelContext decodeActCtx = new KernelContext();
- all.add(buildDecodeActivationGraph(decodeActCtx, batchLayers.getLastLayerTaskGraphID()).snapshot());
+ all.add(buildDecodeActivationGraph(decodeActCtx, lastBatchLayerID, weightType).snapshot());
gridScheduler.addWorkerGrid("decodeActivation.updateX",
WorkerGridFactory.genericWorker(config.dim(), 128));
- // [N+2..2N+1] Decode layer graphs ────────────────────────────────────
- // Layer 0 uses consumeFromDevice for KV cache (no FIRST_EXECUTION upload).
- LlamaFP16FFNLayersDecode decodeLayers =
- new LlamaFP16FFNLayersDecode(
- "decode", state, weights, config, schedulerType);
+ // [N+2..2N+1] Decode layer graphs ─────────────────────────────────────
+ AbstractFFNLayers decodeLayers;
+ if (weightType == GGMLType.F16) {
+ decodeLayers = new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType);
+ } else {
+ decodeLayers = new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType);
+ }
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
decodeLayers.updateGridScheduler(gridScheduler);
// [2N+2] Logits ───────────────────────────────────────────────────────
- // LogitsFP16LayerDecode extends LogitsFP16Layer: adds consumeFromDevice(wrapKeyCache)
- // at the start of the graph and persistOnDevice(wrapKeyCache) at the end, so the
- // KV-cache pointer survives the logits → decode-activation boundary across tokens.
- LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ AbstractLogitsLayer logitsLayer;
+ if (weightType == GGMLType.F16) {
+ logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
+ decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ } else {
+ logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config,
+ decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ }
all.add(logitsLayer.getImmutableTaskGraph());
logitsLayer.updateGridScheduler(gridScheduler);
@@ -222,15 +251,32 @@ public void forceCopyInReadOnlyData() {
*/
public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) {
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
- int bytes = Short.BYTES;
- int dim = config.dim();
-
- // Copy B embeddings into embeddingXBatch
- for (int b = 0; b < chunkSize; b++) {
- MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes,
- state.embeddingXBatch.getSegment(), (long) b * dim * bytes,
- (long) dim * bytes);
+ int dim = config.dim();
+
+ if (weights.getWeightType() == GGMLType.Q8_0) {
+ // CPU dequantizes Q8_0 embeddings into wrapXBatch (FP32)
+ ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray();
+ int blockSize = 32;
+ int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (dim + blockSize - 1) / blockSize;
+ for (int b = 0; b < chunkSize; b++) {
+ int tokenId = tokenIds[b];
+ for (int j = 0; j < dim; j++) {
+ int blockByteOffset = (tokenId * blocksPerRow + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float scale = embTable.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = embTable.get(blockByteOffset + 2 + j % blockSize);
+ state.wrapXBatch.set(b * dim + j, quant * scale);
+ }
+ }
+ } else {
+ // FP16: copy raw half-float bytes into embeddingXBatch
+ MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ int bytes = Short.BYTES;
+ for (int b = 0; b < chunkSize; b++) {
+ MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes,
+ state.embeddingXBatch.getSegment(), (long) b * dim * bytes,
+ (long) dim * bytes);
+ }
}
state.batchStartPosHolder.set(0, startPos);
@@ -258,12 +304,19 @@ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model mod
*/
public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
- int bytes = Short.BYTES;
- int dim = config.dim();
-
- MemorySegment.copy(embTable, (long) token * dim * bytes,
- state.embeddingX.getSegment(), 0L, (long) dim * bytes);
+ int dim = config.dim();
+
+ if (weights.getWeightType() == GGMLType.Q8_0) {
+ ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray();
+ int blocksPerRow = (dim + 31) / 32;
+ long bytesPerToken = (long) blocksPerRow * 34;
+ MemorySegment.copy(embTable.getSegment(), (long) token * bytesPerToken,
+ state.embeddingX.getSegment(), 0L, bytesPerToken);
+ } else {
+ MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ MemorySegment.copy(embTable, (long) token * dim * Short.BYTES,
+ state.embeddingX.getSegment(), 0L, (long) dim * Short.BYTES);
+ }
state.positionHolder.set(0, position);
state.temp.clear();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
index 517024f9..500882ea 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
@@ -11,8 +11,13 @@
import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -98,7 +103,14 @@ public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan
* The KV cache is not managed here — it is allocated on the first forward pass
* by decode layer 0 via {@code FIRST_EXECUTION}.
*/
- private TaskGraph buildActivationGraph(KernelContext ctx) {
+ private TaskGraph buildActivationGraph(KernelContext ctx, GGMLType weightType) {
+ if (weightType == GGMLType.Q8_0) {
+ return new TaskGraph("decodeActivation")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
+ .task("updateX", TransformerComputeKernels::convertQ8_0toFP32,
+ ctx, (ByteArray) state.embeddingX, state.wrapX)
+ .persistOnDevice(state.wrapX);
+ }
return new TaskGraph("decodeActivation")
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
.task("updateX", TransformerComputeKernels::convertFP16toFP32,
@@ -107,21 +119,11 @@ private TaskGraph buildActivationGraph(KernelContext ctx) {
}
// ── Plan construction ─────────────────────────────────────────────────────
- /**
- * Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*.
- * Prefill is token-by-token but does not compute logits.
- *
- * TODO: support Q8_0 weights
- * To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory}
- */
@Override
public TornadoExecutionPlan createExecutionPlan() {
GGMLType weightType = model.weights().getWeightType();
- switch (weightType) {
- case F16 -> { /* supported — continue below */ }
- case Q8_0 -> throw new UnsupportedOperationException(
- "Prefill/decode GPU path not yet implemented for Q8_0 weights");
- default -> throw new UnsupportedOperationException(
+ if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) {
+ throw new UnsupportedOperationException(
"Prefill/decode GPU path not supported for weight type: " + weightType);
}
@@ -132,24 +134,29 @@ public TornadoExecutionPlan createExecutionPlan() {
// [0] Activation ──────────────────────────────────────────────────────
KernelContext actCtx = new KernelContext();
- all.add(buildActivationGraph(actCtx).snapshot());
+ all.add(buildActivationGraph(actCtx, weightType).snapshot());
gridScheduler.addWorkerGrid("decodeActivation.updateX",
WorkerGridFactory.genericWorker(config.dim(), 128));
// [1..N] Decode layer graphs ──────────────────────────────────────────
- // Layer 0: FIRST_EXECUTION for KV cache + consumeFromDevice("decodeActivation", wrapX).
- // Layers 1+: consumeFromDevice with explicit predecessor names for interpreter mode.
- LlamaFP16FFNLayersPrefillDecode decodeLayers =
- new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
+ AbstractFFNLayers decodeLayers;
+ if (weightType == GGMLType.F16) {
+ decodeLayers = new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
+ } else {
+ decodeLayers = new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
+ }
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
decodeLayers.updateGridScheduler(gridScheduler);
// [N+1] Logits ────────────────────────────────────────────────────────
- // LogitsFP16LayerDecode re-persists the KV cache so the pointer survives
- // the logits → layer_0 KV-cache FIRST_EXECUTION boundary across decode tokens.
- LogitsFP16LayerDecode logitsLayer = new LogitsFP16LayerDecode(
- "logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ AbstractLogitsLayer logitsLayer;
+ if (weightType == GGMLType.F16) {
+ logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
+ decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ } else {
+ logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config,
+ decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ }
all.add(logitsLayer.getImmutableTaskGraph());
logitsLayer.updateGridScheduler(gridScheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
index 9bba3860..6c688649 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
@@ -5,6 +5,7 @@
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
/**
@@ -458,4 +459,229 @@ public static void batchedFusedRmsNormFFNGateUp(KernelContext context,
wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3);
}
}
+
+ // ── Q8_0 Batch Kernels ───────────────────────────────────────────────────
+
+ /**
+ * No-op kernel for Q8_0 batch activation graph.
+ * The host fills wrapXBatch with dequantized FP32 embeddings before execution.
+ * Worker: 1 global thread.
+ */
+ public static void batchPassthrough(KernelContext context, FloatArray wrapXBatch) {
+ if (context.globalIdx == 0) {
+ wrapXBatch.set(0, wrapXBatch.get(0));
+ }
+ }
+
+ /**
+ * Applies RMS normalization to FP32 — Q8_0 variant.
+ * Writes normalized FP32 to wrapXbBatch (reused as xb intermediate before QKV).
+ * Worker: B*dim global threads, localSize=256.
+ */
+ public static void batchedRmsApplyFP32(KernelContext context,
+ FloatArray wrapXbBatch,
+ FloatArray wrapXBatch,
+ FloatArray rmsWeights,
+ FloatArray attnScaleBatch,
+ int dim) {
+ int gid = context.globalIdx;
+ int b = gid / dim;
+ int i = gid % dim;
+ wrapXbBatch.set(gid, rmsWeights.get(i) * attnScaleBatch.get(b) * wrapXBatch.get(gid));
+ }
+
+ /**
+ * Fused batched QKV projection with Q8_0 weight dequantization.
+ * Input wrapXbBatch is FP32 (written by batchedRmsApplyFP32).
+ * groupIdx = batchIdx * (dim + 2*kvDim) + rowIdx.
+ * Worker: B*(dim+2*kvDim) workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedFusedQKVMatmulQ8(KernelContext context,
+ FloatArray wrapXbBatch,
+ FloatArray wrapQBatch,
+ FloatArray wrapKBatch,
+ FloatArray wrapVBatch,
+ ByteArray wq,
+ ByteArray wk,
+ ByteArray wv,
+ int dim, int kvDim,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int totalRows = dim + 2 * kvDim;
+ int batchIdx = groupId / totalRows;
+ int rowIdx = groupId % totalRows;
+ int inputOff = batchIdx * dim;
+
+ int blockSize = 32;
+ int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (dim + blockSize - 1) / blockSize;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ if (rowIdx < dim) {
+ int rowBlockOffset = rowIdx * blocksPerRow;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float scale = wq.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = wq.get(blockByteOffset + 2 + j % blockSize);
+ partial += quant * scale * wrapXbBatch.get(inputOff + j);
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapQBatch.set(batchIdx * dim + rowIdx, localSum[0]);
+
+ } else if (rowIdx < dim + kvDim) {
+ int kRow = rowIdx - dim;
+ int rowBlockOffset = kRow * blocksPerRow;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float scale = wk.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = wk.get(blockByteOffset + 2 + j % blockSize);
+ partial += quant * scale * wrapXbBatch.get(inputOff + j);
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]);
+
+ } else {
+ int vRow = rowIdx - dim - kvDim;
+ int rowBlockOffset = vRow * blocksPerRow;
+ float partial = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float scale = wv.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = wv.get(blockByteOffset + 2 + j % blockSize);
+ partial += quant * scale * wrapXbBatch.get(inputOff + j);
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]);
+ }
+ }
+
+ /**
+ * Batched matrix-vector multiply with residual add (Q8_0 weights).
+ * Used for attention output (Wo) and FFN down (W2) projections.
+ * groupIdx = batchIdx * d + rowIdx.
+ * Worker: B*d workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedMatVecWithResidualQ8(KernelContext context,
+ FloatArray inputBatch,
+ FloatArray outputBatch,
+ ByteArray w,
+ int n, int d,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int batchIdx = groupId / d;
+ int rowIdx = groupId % d;
+
+ int blockSize = 32;
+ int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (n + blockSize - 1) / blockSize;
+ int rowBlockOffset = rowIdx * blocksPerRow;
+ int inputOff = batchIdx * n;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ float partial = 0.0f;
+ for (int j = localId; j < n; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float scale = w.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = w.get(blockByteOffset + 2 + j % blockSize);
+ partial += quant * scale * inputBatch.get(inputOff + j);
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ if (localId == 0) {
+ int outIdx = batchIdx * d + rowIdx;
+ outputBatch.set(outIdx, outputBatch.get(outIdx) + localSum[0]);
+ }
+ }
+
+ /**
+ * Batched fused RMS-apply + W1/W3 gate-up projections + SiLU + GLU (Q8_0 weights).
+ * groupIdx = batchIdx * hiddenDim + rowIdx.
+ * Worker: B*hiddenDim workgroups × localWorkGroupSize threads.
+ */
+ public static void batchedFusedRmsNormFFNGateUpQ8(KernelContext context,
+ FloatArray wrapXBatch,
+ FloatArray wrapHbBatch,
+ FloatArray rmsFFNWeights,
+ FloatArray ffnScaleBatch,
+ ByteArray w1,
+ ByteArray w3,
+ int dim, int hiddenDim,
+ int localWorkGroupSize) {
+ int groupId = context.groupIdx;
+ int localId = context.localIdx;
+ int batchIdx = groupId / hiddenDim;
+ int rowIdx = groupId % hiddenDim;
+
+ float scale = ffnScaleBatch.get(batchIdx);
+ int inputOff = batchIdx * dim;
+
+ int blockSize = 32;
+ int Q8_0_BLOCK_BYTES = 34;
+ int blocksPerRow = (dim + blockSize - 1) / blockSize;
+ int rowBlockOffset = rowIdx * blocksPerRow;
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+
+ float sum1 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float w1Scale = w1.getHalfFloat(blockByteOffset).getFloat32();
+ float w1Quant = w1.get(blockByteOffset + 2 + j % blockSize);
+ float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j);
+ sum1 += w1Quant * w1Scale * normed;
+ }
+ localSum[localId] = sum1;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ float result1 = localSum[0];
+
+ float sum3 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ int blockByteOffset = (rowBlockOffset + j / blockSize) * Q8_0_BLOCK_BYTES;
+ float w3Scale = w3.getHalfFloat(blockByteOffset).getFloat32();
+ float w3Quant = w3.get(blockByteOffset + 2 + j % blockSize);
+ float normed = rmsFFNWeights.get(j) * scale * wrapXBatch.get(inputOff + j);
+ sum3 += w3Quant * w3Scale * normed;
+ }
+ localSum[localId] = sum3;
+ context.localBarrier();
+ for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) {
+ if (localId < s) localSum[localId] += localSum[localId + s];
+ context.localBarrier();
+ }
+ float result3 = localSum[0];
+
+ if (localId == 0) {
+ float silu = result1 / (1.0f + TornadoMath.exp(-result1));
+ wrapHbBatch.set(batchIdx * hiddenDim + rowIdx, silu * result3);
+ }
+ }
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
index 1f33f090..2cfacdc0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
@@ -113,7 +113,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
// === Data Setup ===
- unifiedLayer.consumeFromDevice(state.wrapX);
+ String wrapXSrc = predecessorGraphName(layerIndex);
+ if (wrapXSrc != null) {
+ unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX);
+ } else {
+ unifiedLayer.consumeFromDevice(state.wrapX);
+ }
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
// Copy-in weights per layer for batched-layered layout (Q8 format)
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
@@ -224,6 +229,10 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
return unifiedLayer;
}
+ protected String predecessorGraphName(int layerIndex) {
+ return null;
+ }
+
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
// First layer: Transfer initial data to device (one-time transfer)
if (layerIndex == 0) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
index ed0d6cfc..1a1313e2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
@@ -22,12 +22,16 @@ public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration
super(name, state, weights, config, lastTaskGraphID, schedulerType);
}
+ protected void configureAdditionalConsumes(TaskGraph logits) {}
+ protected void configureAdditionalPersists(TaskGraph logits) {}
+
// @formatter:off
@Override
protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
var logits = new TaskGraph("logits");
// === Data Setup ===
+ configureAdditionalConsumes(logits);
logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
@@ -74,6 +78,7 @@ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration c
LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ configureAdditionalPersists(logits);
return logits;
}
// @formatter:on
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
new file mode 100644
index 00000000..e159c99b
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
@@ -0,0 +1,55 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+/**
+ * Decode FFN layers for the unified batched prefill-decode plan
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ *
+ * Layer 0 consumes the KV cache from device (passed through by the decode activation
+ * graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation
+ * for the KV cache — it was already allocated in the batch prefill phase.
+ */
+public class LlamaQ8_0FFNLayersDecode extends LlamaQ8_0FFNLayers {
+
+ public LlamaQ8_0FFNLayersDecode(String taskGraph, LlamaState state,
+ LlamaTornadoWeights weights, LlamaConfiguration config,
+ SchedulerType schedulerType) {
+ super(taskGraph, state, weights, config, schedulerType);
+ }
+
+ @Override
+ protected String predecessorGraphName(int layerIndex) {
+ return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1);
+ }
+
+ @Override
+ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
+ if (layerIndex == 0) {
+ layer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder, state.temp, state.tempFFN);
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapAtt, state.wrapHb);
+ layer.consumeFromDevice("decodeActivation", state.wrapKeyCache, state.wrapValueCache);
+ } else {
+ String pred = "layer_" + (layerIndex - 1);
+ layer.consumeFromDevice(pred,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapKeyCache, state.wrapValueCache,
+ state.wrapAtt, state.wrapHb,
+ state.positionHolder);
+ }
+ return layer;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
new file mode 100644
index 00000000..a1f02ada
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
@@ -0,0 +1,47 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
+import uk.ac.manchester.tornado.api.TaskGraph;
+
+/**
+ * Decode FFN layers for the single-token prefill/decode plan
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}).
+ *
+ * Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
+ * includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache},
+ * allocating the KV cache on the very first forward pass. Layers 1+ use explicit
+ * predecessor names for all consumed objects, required by TornadoVM's interpreter mode.
+ */
+public class LlamaQ8_0FFNLayersPrefillDecode extends LlamaQ8_0FFNLayers {
+
+ public LlamaQ8_0FFNLayersPrefillDecode(String taskGraph, LlamaState state,
+ LlamaTornadoWeights weights, LlamaConfiguration config,
+ SchedulerType schedulerType) {
+ super(taskGraph, state, weights, config, schedulerType);
+ }
+
+ @Override
+ protected String predecessorGraphName(int layerIndex) {
+ return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1);
+ }
+
+ @Override
+ protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) {
+ if (layerIndex == 0) {
+ return super.configureLayerDataTransfers(layer, 0);
+ }
+ String pred = "layer_" + (layerIndex - 1);
+ layer.consumeFromDevice(pred,
+ context,
+ state.wrapXb, state.wrapXb2,
+ state.wrapQ, state.wrapK, state.wrapV,
+ state.wrapKeyCache, state.wrapValueCache,
+ state.wrapAtt, state.wrapHb,
+ state.positionHolder);
+ return layer;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
new file mode 100644
index 00000000..054b7f7f
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
@@ -0,0 +1,35 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import uk.ac.manchester.tornado.api.TaskGraph;
+
+/**
+ * Logits layer for the unified batched prefill-decode plan (Q8_0).
+ *
+ * Extends {@link LogitsQ8_0Layer} with KV-cache pass-through so the device pointers for
+ * {@code wrapKeyCache} and {@code wrapValueCache} survive the logits → decode-activation
+ * boundary between decode tokens. Without the pass-through, the KV-cache pointer is absent
+ * from the logits persisted set, cleared to null, and the first decode layer crashes with
+ * an NPE in {@code executeAlloc}.
+ */
+public class LogitsQ8_0LayerDecode extends LogitsQ8_0Layer {
+
+ public LogitsQ8_0LayerDecode(String name, State state, Weights weights, Configuration config,
+ String lastTaskGraphID, SchedulerType schedulerType) {
+ super(name, state, weights, config, lastTaskGraphID, schedulerType);
+ }
+
+ @Override
+ protected void configureAdditionalConsumes(TaskGraph logits) {
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapKeyCache, state.wrapValueCache);
+ }
+
+ @Override
+ protected void configureAdditionalPersists(TaskGraph logits) {
+ logits.persistOnDevice(state.wrapKeyCache, state.wrapValueCache);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
new file mode 100644
index 00000000..79ea9bc2
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
@@ -0,0 +1,230 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+/**
+ * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0).
+ *
+ * Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill}
+ * but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:
+ *
+ * - {@code wrapXBatch} is filled with dequantized FP32 embeddings by the host before
+ * the activation graph runs (no on-device FP16→FP32 conversion).
+ * - {@code wrapXbBatch} (FP32) is reused as the normalized xb intermediate: written
+ * by {@code batchedRmsApplyFP32}, read by {@code batchedFusedQKVMatmulQ8}, then
+ * overwritten by flash attention output.
+ * - {@code wrapXbFP16Batch} is not used.
+ * - Weight matrices are {@code ByteArray} (Q8_0 format).
+ *
+ */
+public class LlamaQ8_0LayersBatchPrefill {
+
+ static final int LOCAL_WORK_GROUP_SIZE = 32;
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final LlamaConfiguration config;
+ private final KernelContext context = new KernelContext();
+ private final int batchSize;
+ private final List layerITGs;
+ private String lastLayerTaskGraphID;
+
+ public LlamaQ8_0LayersBatchPrefill(LlamaState state, LlamaTornadoWeights weights,
+ LlamaConfiguration config, int batchSize) {
+ this.state = state;
+ this.weights = weights;
+ this.config = config;
+ this.batchSize = batchSize;
+ this.layerITGs = IntStream.range(0, config.numberOfLayers())
+ .mapToObj(this::createBatchPrefillLayerTaskGraph)
+ .map(TaskGraph::snapshot)
+ .toList();
+ }
+
+ // @formatter:off
+ private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) {
+ String graphName = "batchPrefillLayer_" + layerIndex;
+ if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName;
+
+ TaskGraph layer = new TaskGraph(graphName);
+
+ // ── Data Transfers ─────────────────────────────────────────────────────
+ if (layerIndex == 0) {
+ // batchStartPosHolder is set by host before each chunk → EVERY_EXECUTION
+ layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder);
+ // Allocate GPU-side batch intermediates once.
+ // wrapXBatch is filled with dequantized FP32 by the host, persisted by prefillActivation.
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.attnScaleBatch, state.ffnScaleBatch,
+ state.wrapXbBatch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapHbBatch,
+ state.wrapKeyCache, state.wrapValueCache);
+ layer.consumeFromDevice("prefillActivation", state.wrapXBatch);
+ } else {
+ String pred = "batchPrefillLayer_" + (layerIndex - 1);
+ layer.consumeFromDevice(pred,
+ context,
+ state.wrapXBatch,
+ state.wrapXbBatch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapHbBatch,
+ state.wrapKeyCache, state.wrapValueCache,
+ state.batchStartPosHolder,
+ state.attnScaleBatch, state.ffnScaleBatch);
+ }
+
+ // Per-layer weights: upload once (Q8_0 format)
+ layer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asByteArray(),
+ weights.wkLayered[layerIndex].asByteArray(),
+ weights.wvLayered[layerIndex].asByteArray(),
+ weights.woLayered[layerIndex].asByteArray(),
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asByteArray(),
+ weights.w2Layered[layerIndex].asByteArray(),
+ weights.w3Layered[layerIndex].asByteArray());
+
+ int dim = config.dim();
+ int kvDim = config.kvDim();
+ int hidDim = config.hiddenDim();
+
+ // ── Attention Block ────────────────────────────────────────────────────
+ layer.task("batch_attn_rms",
+ TransformerBatchPrefillKernels::batchedRmsReduce,
+ context, state.wrapXBatch, state.attnScaleBatch,
+ dim, config.rmsNormEps());
+
+ // Writes FP32 normalized xb into wrapXbBatch (reused later by flash attention)
+ layer.task("batch_attn_rms_apply",
+ TransformerBatchPrefillKernels::batchedRmsApplyFP32,
+ context, state.wrapXbBatch, state.wrapXBatch,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ state.attnScaleBatch, dim);
+
+ layer.task("batch_qkv",
+ TransformerBatchPrefillKernels::batchedFusedQKVMatmulQ8,
+ context,
+ state.wrapXbBatch,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ weights.wqLayered[layerIndex].asByteArray(),
+ weights.wkLayered[layerIndex].asByteArray(),
+ weights.wvLayered[layerIndex].asByteArray(),
+ dim, kvDim, LOCAL_WORK_GROUP_SIZE);
+
+ layer.task("batch_rope_kv",
+ TransformerBatchPrefillKernels::batchedRopeWithKVCache,
+ context, state.batchStartPosHolder,
+ state.wrapQBatch, state.wrapKBatch, state.wrapVBatch,
+ state.wrapKeyCache, state.wrapValueCache,
+ kvDim, config.headSize(), layerIndex, config.contextLength(), dim);
+
+ // Overwrites wrapXbBatch with attention output
+ layer.task("batch_attention",
+ TransformerBatchPrefillKernels::batchedFlashAttention,
+ context, state.batchStartPosHolder,
+ state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache,
+ state.wrapXbBatch,
+ config.numberOfHeads(), config.headSize(),
+ kvDim, config.kvMul(), layerIndex, config.contextLength(), dim);
+
+ layer.task("batch_attn_out",
+ TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8,
+ context, state.wrapXbBatch, state.wrapXBatch,
+ weights.woLayered[layerIndex].asByteArray(),
+ dim, dim, LOCAL_WORK_GROUP_SIZE);
+
+ // ── FFN Block ──────────────────────────────────────────────────────────
+ layer.task("batch_ffn_rms",
+ TransformerBatchPrefillKernels::batchedFFNRmsReduce,
+ context, state.wrapXBatch, state.ffnScaleBatch,
+ dim, config.rmsNormEps());
+
+ layer.task("batch_ffn_gate_up",
+ TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUpQ8,
+ context, state.wrapXBatch, state.wrapHbBatch,
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ state.ffnScaleBatch,
+ weights.w1Layered[layerIndex].asByteArray(),
+ weights.w3Layered[layerIndex].asByteArray(),
+ dim, hidDim, LOCAL_WORK_GROUP_SIZE);
+
+ layer.task("batch_ffn_down",
+ TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8,
+ context, state.wrapHbBatch, state.wrapXBatch,
+ weights.w2Layered[layerIndex].asByteArray(),
+ hidDim, dim, LOCAL_WORK_GROUP_SIZE);
+
+ layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache);
+
+ return layer;
+ }
+ // @formatter:on
+
+ public void updateGridScheduler(GridScheduler scheduler) {
+ int dim = config.dim();
+ int kvDim = config.kvDim();
+ int hidDim = config.hiddenDim();
+ int nHeads = config.numberOfHeads();
+ int headSz = config.headSize();
+
+ WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1);
+ WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256);
+ int qkvRows = dim + 2 * kvDim;
+ WorkerGrid qkvWorker = WorkerGridFactory.genericWorker(
+ batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+ int ropeGlobal = batchSize * (dim / 2);
+ int ropeLocal = Math.min(512, ropeGlobal);
+ while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--;
+ WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal);
+ int optLocal = findOptimalLocalSize(headSz);
+ WorkerGrid attnWorker = WorkerGridFactory.genericWorker(
+ batchSize * nHeads * optLocal, optLocal);
+ WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker(
+ batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+ WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker(
+ batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE);
+
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String p = "batchPrefillLayer_" + i + ".";
+ scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker);
+ scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker);
+ scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker);
+ scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker);
+ scheduler.addWorkerGrid(p + "batch_attention", attnWorker);
+ scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker);
+ scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker);
+ }
+ }
+
+ private static int findOptimalLocalSize(int size) {
+ int optimal = Math.min(size, 64);
+ if (size % optimal != 0) {
+ for (int s = 64; s >= 1; s--) {
+ if (size % s == 0) { optimal = s; break; }
+ }
+ }
+ return optimal;
+ }
+
+ public List getLayerImmutableTaskGraphs() { return layerITGs; }
+ public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; }
+ public KernelContext getContext() { return context; }
+}
From 8ebf91f2dd437d1dd6466e3b2c4bb4b37ba7fbfe Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 28 May 2026 15:14:01 +0300
Subject: [PATCH 2/6] Reorganize TornadoVM execution planning and improve
naming conventions
Reorganize TornadoVM execution planning around forward modes, model families, and quantization-specific components.
---
.../gpullama3/inference/InferenceCore.java | 2 +-
.../InferenceCoreBatchPrefillDecode.java | 14 +-
.../InferenceCoreWithPrefillDecode.java | 6 +-
...InferenceEngineWithBatchPrefillDecode.java | 6 +-
.../InferenceEngineWithPrefillDecode.java | 6 +-
.../tornadovm/TornadoVMMasterPlan.java | 24 +-
...TornadoVMMasterPlanBatchPrefillDecode.java | 165 ++++++++
.../TornadoVMMasterPlanPrefillDecode.java | 166 ++++++++
...va => TornadoVMMasterPlanSingleToken.java} | 59 +--
...adoVMMasterPlanWithBatchPrefillDecode.java | 360 ------------------
.../TornadoVMMasterPlanWithPrefillDecode.java | 244 ------------
.../TransformerBatchPrefillKernels.java | 2 +-
.../layerplanner/GenericLayerPlanner.java | 14 -
.../QuantizationPlannerFactory.java | 97 -----
.../layerplanner/QuantizedLayerPlanner.java | 103 -----
.../model/fp16/DevstralFP16LayerPlanner.java | 20 -
.../model/fp16/FP16LayerPlanner.java | 25 --
.../model/fp16/GraniteFP16LayerPlanner.java | 20 -
.../model/fp16/LlamaFP16LayerPlanner.java | 20 -
.../model/fp16/MistralFP16LayerPlanner.java | 20 -
.../model/fp16/Phi3FP16LayerPlanner.java | 27 --
.../model/fp16/Qwen2FP16LayerPlanner.java | 27 --
.../model/fp16/Qwen3FP16LayerPlanner.java | 27 --
.../model/q8_0/DevstralQ8_0LayerPlanner.java | 20 -
.../model/q8_0/GraniteQ8_0LayerPlanner.java | 20 -
.../model/q8_0/LlamaQ8_0LayerPlanner.java | 20 -
.../model/q8_0/MistralQ8_0LayerPlanner.java | 20 -
.../model/q8_0/Phi3Q8_0LayerPlanner.java | 28 --
.../model/q8_0/Q8_0LayerPlanner.java | 26 --
.../model/q8_0/Qwen2Q8_0LayerPlanner.java | 28 --
.../model/q8_0/Qwen3Q8_0LayerPlanner.java | 28 --
.../layerplanner/strategy/SchedulerType.java | 5 -
.../tornadovm/layers/AbstractFFNLayers.java | 4 +-
.../tornadovm/layers/AbstractLogitsLayer.java | 2 +-
.../tornadovm/layers/Activation.java | 4 +-
.../tornadovm/layers/ActivationGraph.java | 15 +
...atchPrefillTransformerLayerTaskGraphs.java | 17 +
.../layers/TransformerLayerTaskGraphs.java | 17 +
.../type/fp16/DevstralFP16FFNLayers.java | 4 +-
.../type/fp16/GraniteFP16FFNLayers.java | 4 +-
.../layers/type/fp16/LlamaFP16FFNLayers.java | 4 +-
.../layers/type/fp16/LogitsFP16Layer.java | 4 +-
.../type/fp16/LogitsGraniteFP16Layer.java | 2 +-
.../type/fp16/MistralFP16FFNLayers.java | 4 +-
.../layers/type/fp16/Phi3FP16FFNLayers.java | 4 +-
.../layers/type/fp16/Qwen2FP16FFNLayers.java | 4 +-
.../layers/type/fp16/Qwen3FP16FFNLayers.java | 4 +-
.../fp16/decode/LlamaFP16FFNLayersDecode.java | 4 +-
.../LlamaFP16FFNLayersPrefillDecode.java | 4 +-
.../fp16/decode/LogitsFP16LayerDecode.java | 4 +-
.../prefill/LlamaFP16LayersBatchPrefill.java | 7 +-
.../type/q8_0/DevstralQ8_0FFNLayers.java | 4 +-
.../type/q8_0/GraniteQ8_0FFNLayers.java | 4 +-
.../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 4 +-
.../type/q8_0/LogitsGraniteQ8_0Layer.java | 2 +-
.../layers/type/q8_0/LogitsQ8_0Layer.java | 4 +-
.../type/q8_0/MistralQ8_0FFNLayers.java | 4 +-
.../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 4 +-
.../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 4 +-
.../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 4 +-
.../q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 4 +-
.../LlamaQ8_0FFNLayersPrefillDecode.java | 4 +-
.../q8_0/decode/LogitsQ8_0LayerDecode.java | 2 +-
.../prefill/LlamaQ8_0LayersBatchPrefill.java | 5 +-
.../plan/BatchPrefillDecodeForwardPlan.java | 75 ++++
.../tornadovm/plan/ExecutionMode.java | 16 +
.../gpullama3/tornadovm/plan/ForwardPlan.java | 32 ++
.../tornadovm/plan/ForwardPlanFactory.java | 207 ++++++++++
.../plan/PrefillDecodeForwardPlan.java | 57 +++
.../plan/SingleTokenForwardPlan.java | 54 +++
...tchPrefillDecodeForwardPlanComponents.java | 25 ++
.../plan/components/EmbeddingPreparer.java | 18 +
.../PrefillDecodeForwardPlanComponents.java | 20 +
.../SingleTokenForwardPlanComponents.java | 21 +
.../activation/BatchDecodeActivation.java | 62 +++
.../activation/BatchPrefillActivation.java | 72 ++++
.../fp16/DevstralFP16PlanComponents.java | 42 ++
.../fp16/GraniteFP16PlanComponents.java | 42 ++
.../fp16/LlamaFP16PlanComponents.java | 120 ++++++
.../fp16/MistralFP16PlanComponents.java | 42 ++
.../fp16/Phi3FP16PlanComponents.java | 42 ++
.../fp16/Qwen2FP16PlanComponents.java | 42 ++
.../fp16/Qwen3FP16PlanComponents.java | 42 ++
.../q8_0/DevstralQ8_0PlanComponents.java | 42 ++
.../q8_0/GraniteQ8_0PlanComponents.java | 42 ++
.../q8_0/LlamaQ8_0PlanComponents.java | 131 +++++++
.../q8_0/MistralQ8_0PlanComponents.java | 42 ++
.../q8_0/Phi3Q8_0PlanComponents.java | 42 ++
.../q8_0/Qwen2Q8_0PlanComponents.java | 42 ++
.../q8_0/Qwen3Q8_0PlanComponents.java | 42 ++
...chPrefillDecodeForwardTaskGraphLayout.java | 21 +
.../PrefillDecodeForwardTaskGraphLayout.java | 17 +
.../SingleTokenForwardTaskGraphLayout.java | 17 +
.../SchedulerDetectionService.java | 5 +-
.../tornadovm/scheduling/SchedulerType.java | 5 +
.../WorkerGridFactory.java | 3 +-
96 files changed, 1965 insertions(+), 1327 deletions(-)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java
rename src/main/java/org/beehive/gpullama3/tornadovm/{TornadoVMMasterPlanStandard.java => TornadoVMMasterPlanSingleToken.java} (51%)
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java
delete mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
rename src/main/java/org/beehive/gpullama3/tornadovm/{layerplanner/strategy => scheduling}/SchedulerDetectionService.java (93%)
create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
rename src/main/java/org/beehive/gpullama3/tornadovm/{layerplanner => scheduling}/WorkerGridFactory.java (98%)
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
index 9beade35..67e862ce 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
@@ -814,7 +814,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
}
- return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
+ return tornadoVMMasterPlan.tornadoVMExecuteForward(position);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java
index d3c74599..0e89592d 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreBatchPrefillDecode.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
/**
@@ -21,9 +21,9 @@
* prompt tokens in one pass using batch matmul, avoiding redundant weight
* traversals. Only the KV cache is populated; logits are intentionally omitted.
* {@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk
- * to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.
+ * to {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill}.
* {@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to
- * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which
+ * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode}, which
* handles the embedding copy and runs the full decode + logits graphs.
*
*/
@@ -165,7 +165,7 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token
* GPU batched prefill forward pass (Phase 4).
*
* Delegates the full chunk to
- * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill},
+ * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardBatchPrefill},
* which handles embedding lookup and GPU execution internally.
*
* @param model the LLaMA model
@@ -175,7 +175,7 @@ public static void batchForwardJavaPrefill(Model model, State state, int[] token
* @param plan the batched prefill/decode GPU plan
*/
public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize,
- TornadoVMMasterPlanWithBatchPrefillDecode plan) {
+ TornadoVMMasterPlanBatchPrefillDecode plan) {
plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize);
}
@@ -183,7 +183,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s
* GPU decode forward pass (Phase 4).
*
* Delegates a single-token decode step to
- * {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode},
+ * {@link TornadoVMMasterPlanBatchPrefillDecode#tornadoVMForwardDecode},
* which copies the token embedding and runs the decode + logits graphs.
*
* @param model the LLaMA model
@@ -193,7 +193,7 @@ public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int s
* @return logits array for token sampling
*/
public static FloatArray forwardTornadoVMDecode(Model model, int token, int position,
- TornadoVMMasterPlanWithBatchPrefillDecode plan) {
+ TornadoVMMasterPlanBatchPrefillDecode plan) {
return plan.tornadoVMForwardDecode(token, position, model);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
index dbd81297..9c812b10 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode;
import java.lang.foreign.MemorySegment;
@@ -131,7 +131,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
*
* Copies the token embedding into {@code state.embeddingX} (same as
* {@link InferenceCore#forwardTornadoVM}) then delegates to
- * {@link TornadoVMMasterPlanWithPrefillDecode#tornadoVMForwardPrefill},
+ * {@link TornadoVMMasterPlanPrefillDecode#tornadoVMForwardPrefill},
* which executes preprocessing + layer graphs but skips the logits graph.
*
* @param model the LLaMA model (must carry {@link TornadoWeights}, FP16 only)
@@ -142,7 +142,7 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
* @throws UnsupportedOperationException if the model uses Q8_0 weights
*/
public static void forwardTornadoVMPrefill(Model model, State state, int token, int position,
- TornadoVMMasterPlanWithPrefillDecode prefillPlan) {
+ TornadoVMMasterPlanPrefillDecode prefillPlan) {
final Configuration configuration = model.configuration();
final TornadoWeights weights = (TornadoWeights) model.weights();
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java
index 06dbe256..5f617bda 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode;
import java.util.ArrayList;
import java.util.Arrays;
@@ -163,8 +163,8 @@ public static List generateTokensGPULlama(
? config.contextLength() : maxTokens;
final int batchSize = TornadoVMMasterPlan.PREFILL_BATCH_SIZE;
- TornadoVMMasterPlanWithBatchPrefillDecode plan =
- (TornadoVMMasterPlanWithBatchPrefillDecode) tornadoVMPlan;
+ TornadoVMMasterPlanBatchPrefillDecode plan =
+ (TornadoVMMasterPlanBatchPrefillDecode) tornadoVMPlan;
List generatedTokens = new ArrayList<>();
diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
index ae9f6e9c..7239ab8a 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode;
import java.util.ArrayList;
import java.util.List;
@@ -135,8 +135,8 @@ public static List generateTokensGPULlama(
int actualMaxTokens = (maxTokens < 0 || config.contextLength() < maxTokens)
? config.contextLength() : maxTokens;
- TornadoVMMasterPlanWithPrefillDecode prefillPlan =
- (TornadoVMMasterPlanWithPrefillDecode) tornadoVMPlan;
+ TornadoVMMasterPlanPrefillDecode prefillPlan =
+ (TornadoVMMasterPlanPrefillDecode) tornadoVMPlan;
List generatedTokens = new ArrayList<>();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index 3bf4bc21..962be21d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -10,20 +10,20 @@
*
* Three concrete implementations exist:
*
- * - {@link TornadoVMMasterPlanStandard} — baseline single-token forward pass
+ *
- {@link TornadoVMMasterPlanSingleToken} — baseline single-token forward pass
* (preprocessing + N layers + logits).
- * - {@link TornadoVMMasterPlanWithPrefillDecode} — sequential prefill/decode separation;
+ *
- {@link TornadoVMMasterPlanPrefillDecode} — sequential prefill/decode separation;
* reuses the same N layer graphs for both phases, skipping logits during prefill.
- * - {@link TornadoVMMasterPlanWithBatchPrefillDecode} — batched prefill + single-token
+ *
- {@link TornadoVMMasterPlanBatchPrefillDecode} — batched prefill + single-token
* decode; holds 2N+3 graphs in one plan to keep the KV cache on device across phases.
*
*
* The {@link #initializeTornadoVMPlan} factory selects the implementation based on
* {@code llama.withPrefillDecode} and {@code llama.prefillBatchSize}:
*
- * - {@code withPrefillDecode=false} → {@link TornadoVMMasterPlanStandard}
- * - {@code withPrefillDecode=true}, {@code prefillBatchSize=1} → {@link TornadoVMMasterPlanWithPrefillDecode}
- * - {@code withPrefillDecode=true}, {@code prefillBatchSize>1} → {@link TornadoVMMasterPlanWithBatchPrefillDecode}
+ * - {@code withPrefillDecode=false} → {@link TornadoVMMasterPlanSingleToken}
+ * - {@code withPrefillDecode=true}, {@code prefillBatchSize=1} → {@link TornadoVMMasterPlanPrefillDecode}
+ * - {@code withPrefillDecode=true}, {@code prefillBatchSize>1} → {@link TornadoVMMasterPlanBatchPrefillDecode}
*
*/
public interface TornadoVMMasterPlan {
@@ -43,8 +43,8 @@ public interface TornadoVMMasterPlan {
* Factory: creates, JIT-compiles, and warms up the appropriate TornadoVMMasterPlan.
*
* When {@code llama.withPrefillDecode=true} and {@code llama.prefillBatchSize > 1},
- * a {@link TornadoVMMasterPlanWithBatchPrefillDecode} is returned.
- * Otherwise a {@link TornadoVMMasterPlanStandard} is returned (used for the baseline
+ * a {@link TornadoVMMasterPlanBatchPrefillDecode} is returned.
+ * Otherwise a {@link TornadoVMMasterPlanSingleToken} is returned (used for the baseline
* path and the sequential prefill/decode path when batch size is 1).
*
* @param state the model state
@@ -56,13 +56,13 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
if (WITH_PREFILL_DECODE && PREFILL_BATCH_SIZE > 1) {
// GPU path with batched prefill/decode
- plan = new TornadoVMMasterPlanWithBatchPrefillDecode(state, model);
+ plan = new TornadoVMMasterPlanBatchPrefillDecode(state, model);
} else if (WITH_PREFILL_DECODE) {
// GPU path with simple prefill/decode
- plan = new TornadoVMMasterPlanWithPrefillDecode(state, model);
+ plan = new TornadoVMMasterPlanPrefillDecode(state, model);
} else {
// GPU path with no prefill/decode
- plan = new TornadoVMMasterPlanStandard(state, model);
+ plan = new TornadoVMMasterPlanSingleToken(state, model);
}
model.setTornadoVMPlan(plan);
return plan;
@@ -76,7 +76,7 @@ static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
void forceCopyInReadOnlyData();
- FloatArray tornadoVMForwardExecuteLayered(int position);
+ FloatArray tornadoVMExecuteForward(int position);
/** Releases all device memory held by this plan. */
void freeTornadoExecutionPlan();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java
new file mode 100644
index 00000000..968e7b19
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanBatchPrefillDecode.java
@@ -0,0 +1,165 @@
+package org.beehive.gpullama3.tornadovm;
+
+import org.beehive.gpullama3.auxiliary.RunMetrics;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tornadovm.plan.BatchPrefillDecodeForwardPlan;
+import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory;
+import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+
+/**
+ * GPU execution plan for batched prefill + single-token decode.
+ *
+ * A single {@link TornadoExecutionPlan} holds all task graphs for
+ * batched prefill and single-token decode phases:
+ *
+ * TaskGraph layout (2N+3 TaskGraphs total):
+ *
+ * [0] batchPrefillActivation B×dim embeddings → FP32 wrapXBatch
+ * [1..N] batch-prefill layers B tokens, all transformer ops
+ * [N+1] decodeActivation single-token embedding → FP32 + KV-cache pass-through
+ * [N+2..2N+1] decode layers single-token, standard kernels
+ * [2N+2] logits
+ *
+ */
+public class TornadoVMMasterPlanBatchPrefillDecode implements TornadoVMMasterPlan {
+
+ private final State state;
+ private final Model model;
+ private final Configuration config;
+
+ BatchPrefillDecodeForwardPlan batchPrefillDecodeForwardPlan;
+ BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout;
+ public TornadoExecutionPlan executionPlan;
+
+ // ── Construction ─────────────────────────────────────────────────────────
+ TornadoVMMasterPlanBatchPrefillDecode(State initialState, Model model) {
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ System.err.println("\nStarting TornadoVM initialization...");
+ }
+
+ this.state = initialState;
+ this.model = model;
+ this.config = model.configuration();
+
+ long startTime = System.nanoTime();
+ this.executionPlan = createExecutionPlan();
+ long planCreationTime = System.nanoTime();
+
+ if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
+ executionPlan.withPreCompilation();
+ long warmupTime = System.nanoTime();
+
+ forceCopyInReadOnlyData();
+ long copyTime = System.nanoTime();
+
+ RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime);
+ }
+
+ // ── Plan construction ─────────────────────────────────────────────────────
+
+ @Override
+ public TornadoExecutionPlan createExecutionPlan() {
+ GGMLType weightType = model.weights().getWeightType();
+ this.batchPrefillDecodeForwardPlan =
+ ForwardPlanFactory.createBatchPrefillDecode(weightType, state, model);
+ this.taskGraphLayout = batchPrefillDecodeForwardPlan.getTaskGraphLayout();
+ var taskGraphs = batchPrefillDecodeForwardPlan.getImmutableTaskGraphs();
+ return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0]));
+ }
+
+ // ── Initialisation ────────────────────────────────────────────────────────
+
+ @Override
+ public void forceCopyInReadOnlyData() {
+ batchPrefillDecodeForwardPlan.getEmbeddingPreparer().initBatchState();
+ state.wrapX.clear();
+ state.positionHolder.init(0);
+
+ for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) {
+ var g = executionPlan.withGraph(i)
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) g.withCUDAGraph();
+ g.execute();
+ }
+ }
+
+ // ── Forward passes ────────────────────────────────────────────────────────
+
+ /**
+ * Batch prefill: runs graphs 0..N (activation + N layers), skips logits.
+ *
+ * @param tokenIds token IDs for this chunk
+ * @param startPos sequence position of tokenIds[0]
+ * @param model model (unused — kept for API compatibility)
+ * @param chunkSize actual number of tokens in this chunk (≤ batchSize)
+ */
+ public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) {
+ batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyBatchEmbeddings(tokenIds, startPos, chunkSize);
+
+ var batchAct = executionPlan.withGraph(taskGraphLayout.batchActivationIdx())
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) batchAct.withCUDAGraph();
+ batchAct.execute();
+
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ var batchLayer = executionPlan.withGraph(taskGraphLayout.batchLayerIdx(l))
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) batchLayer.withCUDAGraph();
+ batchLayer.execute();
+ }
+ }
+
+ /**
+ * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits).
+ *
+ * @param token token ID to process
+ * @param position sequence position
+ * @param model model (unused — kept for API compatibility)
+ * @return logits array for sampling
+ */
+ public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
+ batchPrefillDecodeForwardPlan.getEmbeddingPreparer().copyDecodeEmbedding(token);
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ var decodeAct = executionPlan.withGraph(taskGraphLayout.decodeActivationIdx())
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
+ decodeAct.execute();
+
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ var decodeLayer = executionPlan.withGraph(taskGraphLayout.decodeLayerIdx(l))
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) decodeLayer.withCUDAGraph();
+ decodeLayer.execute();
+ }
+
+ state.tempLogits.clear();
+ state.wrapLogits.clear();
+
+ var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx())
+ .withGridScheduler(batchPrefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) logits.withCUDAGraph();
+ logits.execute();
+
+ return state.wrapLogits;
+ }
+
+ @Override
+ public FloatArray tornadoVMExecuteForward(int position) {
+ throw new UnsupportedOperationException(
+ "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan");
+ }
+
+ @Override
+ public void freeTornadoExecutionPlan() {
+ executionPlan.freeDeviceMemory();
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java
new file mode 100644
index 00000000..a1f56549
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanPrefillDecode.java
@@ -0,0 +1,166 @@
+package org.beehive.gpullama3.tornadovm;
+
+import org.beehive.gpullama3.auxiliary.RunMetrics;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory;
+import org.beehive.gpullama3.tornadovm.plan.PrefillDecodeForwardPlan;
+import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+
+/**
+ * GPU execution plan for sequential (single-token) prefill/decode separation.
+ *
+ * A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
+ * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on
+ * device across both phases. Prefill and decode reuse the same N layer graphs;
+ * only the logits graph is skipped during prefill.
+ *
+ * Graph layout (N+2 graphs total):
+ *
+ * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
+ * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
+ * [N+1] logits final RMSNorm + wcls matmul
+ *
+ *
+ * Two forward passes:
+ *
+ * - {@link #tornadoVMForwardPrefill} — graphs 0..N (activation + layers), logits skipped.
+ * Called once per prompt token; populates the KV cache.
+ * - {@link #tornadoVMForwardDecode} — full pass including logits.
+ * Called once per generated token; returns logits for sampling.
+ *
+ */
+public class TornadoVMMasterPlanPrefillDecode implements TornadoVMMasterPlan {
+
+ private final State state;
+ private final Model model;
+ private final Configuration config;
+
+ PrefillDecodeForwardPlan prefillDecodeForwardPlan;
+ PrefillDecodeForwardTaskGraphLayout taskGraphLayout;
+ public TornadoExecutionPlan executionPlan;
+
+ // ── Construction ─────────────────────────────────────────────────────────
+ TornadoVMMasterPlanPrefillDecode(State state, Model model) {
+ if (ENABLE_TORNADOVM_INIT_TIME) {
+ System.err.println("\nStarting TornadoVM initialization...");
+ }
+
+ this.state = state;
+ this.model = model;
+ this.config = model.configuration();
+
+ long startTime = System.nanoTime();
+ this.executionPlan = createExecutionPlan();
+ long planCreationTime = System.nanoTime();
+
+ if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
+ executionPlan.withPreCompilation();
+ long warmupTime = System.nanoTime();
+
+ forceCopyInReadOnlyData();
+ long copyTime = System.nanoTime();
+
+ RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime);
+ }
+
+ // ── Plan construction ─────────────────────────────────────────────────────
+
+ @Override
+ public TornadoExecutionPlan createExecutionPlan() {
+ GGMLType weightType = model.weights().getWeightType();
+ this.prefillDecodeForwardPlan = ForwardPlanFactory.createPrefillDecode(weightType, state, model);
+ this.taskGraphLayout = prefillDecodeForwardPlan.getTaskGraphLayout();
+ var taskGraphs = prefillDecodeForwardPlan.getImmutableTaskGraphs();
+ return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0]));
+ }
+
+ // ── Initialisation ────────────────────────────────────────────────────────
+
+ /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
+ @Override
+ public void forceCopyInReadOnlyData() {
+ state.wrapX.clear();
+ state.positionHolder.init(0);
+
+ for (int i = 0; i <= taskGraphLayout.logitsIdx(); i++) {
+ var g = executionPlan.withGraph(i)
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) g.withCUDAGraph();
+ g.execute();
+ }
+ }
+
+ // ── Forward passes ────────────────────────────────────────────────────────
+
+ /**
+ * GPU prefill forward: activation + all transformer layers, logits skipped.
+ *
+ * @param position sequence position being processed
+ */
+ public void tornadoVMForwardPrefill(int position) {
+ var act = executionPlan.withGraph(taskGraphLayout.activationIdx())
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) act.withCUDAGraph();
+ act.execute();
+
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ for (int layer = 0; layer < config.numberOfLayers(); layer++) {
+ var l = executionPlan.withGraph(taskGraphLayout.layerIdx(layer))
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) l.withCUDAGraph();
+ l.execute();
+ }
+ }
+
+ /**
+ * GPU decode forward: full execution including logits.
+ *
+ * @param position sequence position being processed
+ * @return logits array for token sampling
+ */
+ public FloatArray tornadoVMForwardDecode(int position) {
+ return tornadoVMExecuteForward(position);
+ }
+
+ @Override
+ public FloatArray tornadoVMExecuteForward(int position) {
+ var act = executionPlan.withGraph(taskGraphLayout.activationIdx())
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) act.withCUDAGraph();
+ act.execute();
+
+ state.positionHolder.set(0, position);
+ state.temp.clear();
+ state.tempFFN.clear();
+
+ for (int layer = 0; layer < config.numberOfLayers(); layer++) {
+ var l = executionPlan.withGraph(taskGraphLayout.layerIdx(layer))
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) l.withCUDAGraph();
+ l.execute();
+ }
+
+ state.tempLogits.clear();
+ state.wrapLogits.clear();
+ var logits = executionPlan.withGraph(taskGraphLayout.logitsIdx())
+ .withGridScheduler(prefillDecodeForwardPlan.getGridScheduler());
+ if (CUDA_GRAPHS) logits.withCUDAGraph();
+ logits.execute();
+
+ return state.wrapLogits;
+ }
+
+ @Override
+ public void freeTornadoExecutionPlan() {
+ executionPlan.freeDeviceMemory();
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java
similarity index 51%
rename from src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java
index 14aaed10..ae3cdd03 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanStandard.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanSingleToken.java
@@ -5,8 +5,9 @@
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.tornadovm.layerplanner.GenericLayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.QuantizationPlannerFactory;
+import org.beehive.gpullama3.tornadovm.plan.ForwardPlanFactory;
+import org.beehive.gpullama3.tornadovm.plan.SingleTokenForwardPlan;
+import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
@@ -18,16 +19,17 @@
* logits projection.
*
*/
-public class TornadoVMMasterPlanStandard implements TornadoVMMasterPlan {
+public class TornadoVMMasterPlanSingleToken implements TornadoVMMasterPlan {
private final State state;
private final Model model;
private final Configuration config;
- GenericLayerPlanner tornadoVMLayerPlanner;
+ SingleTokenForwardPlan tornadoVMForwardPlan;
+ SingleTokenForwardTaskGraphLayout taskGraphLayout;
public TornadoExecutionPlan executionPlan;
- public TornadoVMMasterPlanStandard(State state, Model model) {
+ public TornadoVMMasterPlanSingleToken(State state, Model model) {
if (ENABLE_TORNADOVM_INIT_TIME) {
System.err.println("\nStarting TornadoVM initialization...");
}
@@ -50,23 +52,20 @@ public TornadoVMMasterPlanStandard(State state, Model model) {
RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime);
}
- /**
- * Creates the {@link TornadoExecutionPlan} for *simple/standard* single-token forward pass.
- */
@Override
public TornadoExecutionPlan createExecutionPlan() {
GGMLType weightType = model.weights().getWeightType();
- this.tornadoVMLayerPlanner = QuantizationPlannerFactory.create(weightType, state, model);
- var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs();
- var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]);
- return new TornadoExecutionPlan(taskGraphArray);
+ this.tornadoVMForwardPlan = ForwardPlanFactory.createSingleToken(weightType, state, model);
+ this.taskGraphLayout = tornadoVMForwardPlan.getTaskGraphLayout();
+ var taskGraphs = tornadoVMForwardPlan.getImmutableTaskGraphs();
+ return new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[0]));
}
@Override
- public FloatArray tornadoVMForwardExecuteLayered(int position) {
+ public FloatArray tornadoVMExecuteForward(int position) {
// @formatter:off
- var preGraph = executionPlan.withGraph(getPreprocessingGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
+ var preGraph = executionPlan.withGraph(taskGraphLayout.activationIdx())
+ .withGridScheduler(tornadoVMForwardPlan.getGridScheduler());
if (CUDA_GRAPHS) preGraph.withCUDAGraph();
preGraph.execute();
@@ -75,48 +74,32 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
state.tempFFN.clear();
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
- executionPlan.withGraph(getLayerGraphIndex(layer))
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
- //.withCUDAGraph()
+ executionPlan.withGraph(taskGraphLayout.layerIdx(layer))
+ .withGridScheduler(tornadoVMForwardPlan.getGridScheduler())
.execute();
}
state.tempLogits.clear();
state.wrapLogits.clear();
- var logitsGraph = executionPlan.withGraph(getFinalLogitsGraphIndex())
- .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler());
+ var logitsGraph = executionPlan.withGraph(taskGraphLayout.logitsIdx())
+ .withGridScheduler(tornadoVMForwardPlan.getGridScheduler());
if (CUDA_GRAPHS) logitsGraph.withCUDAGraph();
logitsGraph.execute();
// @formatter:on
return state.wrapLogits;
}
- private int getPreprocessingGraphIndex() {
- return 0;
- }
-
- private int getLayerGraphIndex(int layerIndex) {
- return 1 + layerIndex;
- }
-
- private int getFinalLogitsGraphIndex() {
- return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1;
- }
-
@Override
public void forceCopyInReadOnlyData() {
state.wrapX.clear();
state.positionHolder.init(0);
- //executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
- executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(taskGraphLayout.activationIdx()).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute();
for (int layer = 0; layer < config.numberOfLayers(); layer++) {
- //executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
- executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(taskGraphLayout.layerIdx(layer)).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute();
}
- //executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).withCUDAGraph().execute();
- executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute();
+ executionPlan.withGraph(taskGraphLayout.logitsIdx()).withGridScheduler(tornadoVMForwardPlan.getGridScheduler()).execute();
}
@Override
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
deleted file mode 100644
index 2fc310e2..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java
+++ /dev/null
@@ -1,360 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.RunMetrics;
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
-import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill;
-import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.KernelContext;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-
-import java.lang.foreign.MemorySegment;
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * GPU execution plan for batched prefill + single-token decode.
- *
- * A single {@link TornadoExecutionPlan} holds all {@link TaskGraph} for
- * batched prefill and single-token decode phases with the following structure:
.
- *
- * TaskGraph layout (2N+3 TaskGraphs total):
- *
- * [0] prefill batch activation B×dim FP16 → FP32
- * [1..N] prefill batch layer graphs B tokens, all transformer ops
- * [N+1] decode activation single-token FP16 → FP32 + KV-cache pass-through
- * [N+2..2N+1] decode layer graphs single-token, standard kernels
- * [2N+2] logits graph
- *
- *
- *
- * Incorporating cross-phase {@link TaskGraph}s withing a single {@link TornadoExecutionPlan}
- * is necessary to enable KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) sharing
- * across prefill and decode phases. The KV cache pointers are chained across {@link TaskGraph}s
- * via the {@code persistOnDevice}/{@code consumeFromDevice} API within the {@link TornadoExecutionPlan}.
- *
- *
- * KV cache pointer chain across phases:
- *
- * batchLayer[N-1] --persistOnDevice(wrapKeyCache)-→
- * decodeActivation --consumeFromDevice(wrapKeyCache)-→ (pass-through)
- * decodeLayer[0] --consumeFromDevice(wrapKeyCache)-→ (used by attention)
- *
- */
-public class TornadoVMMasterPlanWithBatchPrefillDecode implements TornadoVMMasterPlan {
-
- private final LlamaState state;
- private final Model model;
- private final LlamaConfiguration config;
- private final int batchSize;
- private final int N; // numberOfLayers
- private final TornadoExecutionPlan executionPlan;
- private final GridScheduler gridScheduler;
-
- // ── Graph-index helpers ───────────────────────────────────────────────────
- private int batchActivationIdx() { return 0; }
- private int batchLayerIdx(int i) { return 1 + i; }
- private int decodeActivationIdx() { return N + 1; }
- private int decodeLayerIdx(int i) { return N + 2 + i; }
- private int logitsIdx() { return 2 * N + 2; }
-
- // ── Construction ─────────────────────────────────────────────────────────
- TornadoVMMasterPlanWithBatchPrefillDecode(State initialState, Model model) {
- if (ENABLE_TORNADOVM_INIT_TIME) {
- System.err.println("\nStarting TornadoVM initialization...");
- }
-
- this.state = (LlamaState) initialState; // only LlamaFP16 supports batched prefill for now
- this.model = model;
- this.config = (LlamaConfiguration) model.configuration();
- this.batchSize = PREFILL_BATCH_SIZE;
- this.N = config.numberOfLayers();
- this.gridScheduler = new GridScheduler();
-
- long startTime = System.nanoTime();
- this.executionPlan = createExecutionPlan();
- long planCreationTime = System.nanoTime();
-
- if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
- executionPlan.withPreCompilation();
- long warmupTime = System.nanoTime();
-
- forceCopyInReadOnlyData();
- long copyTime = System.nanoTime();
-
- RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime);
- }
-
- // ── Batch Prefill Activation graphs ─────────────────────────────────────────────────────
-
- /** Graph 0 (FP16): B×dim FP16 embeddings → FP32 wrapXBatch. */
- private TaskGraph buildBatchPrefillActivationGraph(KernelContext ctx) {
- return new TaskGraph("prefillActivation")
- .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
- .task("updateX", TransformerComputeKernels::convertFP16toFP32,
- ctx, state.embeddingXBatch, state.wrapXBatch)
- .persistOnDevice(state.wrapXBatch);
- }
-
- /** Graph 0 (Q8_0): wrapXBatch pre-filled with FP32 by host; upload and persist. */
- private TaskGraph buildQ8_0BatchPrefillActivationGraph(KernelContext ctx) {
- return new TaskGraph("prefillActivation")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch)
- .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough,
- ctx, state.wrapXBatch)
- .persistOnDevice(state.wrapXBatch);
- }
-
- /**
- * Graph N+1: single-token embedding → FP32 wrapX, with KV-cache pass-through.
- *
- * Receives the KV-cache device pointer from batch layer N via
- * {@code consumeFromDevice}, then re-emits it via {@code persistOnDevice} so
- * that {@code updatePersistedObjectState()} can propagate it to decode layer 0.
- * Both halves of the chain are required; without the re-persist the pointer is
- * not forwarded in interpreter (non-CUDA-graph) mode.
- */
- private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatchLayerID,
- GGMLType weightType) {
- TaskGraph tg = new TaskGraph("decodeActivation")
- .consumeFromDevice(lastBatchLayerID, state.wrapKeyCache, state.wrapValueCache)
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX);
- if (weightType == GGMLType.Q8_0) {
- tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32,
- ctx, (ByteArray) state.embeddingX, state.wrapX);
- } else {
- tg.task("updateX", TransformerComputeKernels::convertFP16toFP32,
- ctx, (HalfFloatArray) state.embeddingX, state.wrapX);
- }
- return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
- }
-
- @Override
- public TornadoExecutionPlan createExecutionPlan() {
- GGMLType weightType = model.weights().getWeightType();
- if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) {
- throw new UnsupportedOperationException(
- "Batched prefill/decode GPU path not supported for weight type: " + weightType);
- }
-
- LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
- List all = new ArrayList<>(2 * N + 3);
-
- // [0] Batch prefill activation ────────────────────────────────────────
- KernelContext batchActCtx = new KernelContext();
- if (weightType == GGMLType.F16) {
- all.add(buildBatchPrefillActivationGraph(batchActCtx).snapshot());
- gridScheduler.addWorkerGrid("prefillActivation.updateX",
- WorkerGridFactory.genericWorker(batchSize * config.dim(), 128));
- } else {
- all.add(buildQ8_0BatchPrefillActivationGraph(batchActCtx).snapshot());
- gridScheduler.addWorkerGrid("prefillActivation.batchPassthrough",
- WorkerGridFactory.genericWorker(1, 1));
- }
-
- // [1..N] Batch prefill layer graphs ───────────────────────────────────
- String lastBatchLayerID;
- if (weightType == GGMLType.F16) {
- LlamaFP16LayersBatchPrefill batchLayers =
- new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
- all.addAll(batchLayers.getLayerImmutableTaskGraphs());
- batchLayers.updateGridScheduler(gridScheduler);
- lastBatchLayerID = batchLayers.getLastLayerTaskGraphID();
- } else {
- LlamaQ8_0LayersBatchPrefill batchLayers =
- new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize);
- all.addAll(batchLayers.getLayerImmutableTaskGraphs());
- batchLayers.updateGridScheduler(gridScheduler);
- lastBatchLayerID = batchLayers.getLastLayerTaskGraphID();
- }
-
- // [N+1] Decode activation (with KV-cache pass-through) ────────────────
- KernelContext decodeActCtx = new KernelContext();
- all.add(buildDecodeActivationGraph(decodeActCtx, lastBatchLayerID, weightType).snapshot());
- gridScheduler.addWorkerGrid("decodeActivation.updateX",
- WorkerGridFactory.genericWorker(config.dim(), 128));
-
- // [N+2..2N+1] Decode layer graphs ─────────────────────────────────────
- AbstractFFNLayers decodeLayers;
- if (weightType == GGMLType.F16) {
- decodeLayers = new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType);
- } else {
- decodeLayers = new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType);
- }
- all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
- decodeLayers.updateGridScheduler(gridScheduler);
-
- // [2N+2] Logits ───────────────────────────────────────────────────────
- AbstractLogitsLayer logitsLayer;
- if (weightType == GGMLType.F16) {
- logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- } else {
- logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- }
- all.add(logitsLayer.getImmutableTaskGraph());
- logitsLayer.updateGridScheduler(gridScheduler);
-
- return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
- }
-
-
- /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
- @Override
- public void forceCopyInReadOnlyData() {
- state.wrapXBatch.clear();
- state.wrapX.clear();
- state.positionHolder.init(0);
- state.batchStartPosHolder.init(0);
-
- for (int i = 0; i <= logitsIdx(); i++) {
- var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) g.withCUDAGraph();
- g.execute();
- }
- }
-
- // ── Forward passes ────────────────────────────────────────────────────────
-
- /**
- * Batch prefill: runs graphs 0..N (activation + N layers), skips logits.
- *
- * @param tokenIds token IDs for this chunk (length == batchSize, or tail)
- * @param startPos sequence position of tokenIds[0]
- * @param model model (for embedding table)
- * @param chunkSize actual number of tokens in this chunk (≤ batchSize)
- */
- public void tornadoVMForwardBatchPrefill(int[] tokenIds, int startPos, Model model, int chunkSize) {
- LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- int dim = config.dim();
-
- if (weights.getWeightType() == GGMLType.Q8_0) {
- // CPU dequantizes Q8_0 embeddings into wrapXBatch (FP32)
- ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray();
- int blockSize = 32;
- int Q8_0_BLOCK_BYTES = 34;
- int blocksPerRow = (dim + blockSize - 1) / blockSize;
- for (int b = 0; b < chunkSize; b++) {
- int tokenId = tokenIds[b];
- for (int j = 0; j < dim; j++) {
- int blockByteOffset = (tokenId * blocksPerRow + j / blockSize) * Q8_0_BLOCK_BYTES;
- float scale = embTable.getHalfFloat(blockByteOffset).getFloat32();
- float quant = embTable.get(blockByteOffset + 2 + j % blockSize);
- state.wrapXBatch.set(b * dim + j, quant * scale);
- }
- }
- } else {
- // FP16: copy raw half-float bytes into embeddingXBatch
- MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
- int bytes = Short.BYTES;
- for (int b = 0; b < chunkSize; b++) {
- MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes,
- state.embeddingXBatch.getSegment(), (long) b * dim * bytes,
- (long) dim * bytes);
- }
- }
- state.batchStartPosHolder.set(0, startPos);
-
- // Graph 0: batch activation
- var batchAct = executionPlan.withGraph(batchActivationIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) batchAct.withCUDAGraph();
- batchAct.execute();
-
- // Graphs 1..N: batch transformer layers
- for (int l = 0; l < N; l++) {
- var batchLayer = executionPlan.withGraph(batchLayerIdx(l)).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) batchLayer.withCUDAGraph();
- batchLayer.execute();
- }
- // Logits skipped — not needed for prefill positions.
- }
-
- /**
- * Single-token decode: runs graphs N+1..2N+2 (activation + N layers + logits).
- *
- * @param token token ID to process
- * @param position sequence position
- * @param model model (for embedding table)
- * @return logits array for sampling
- */
- public FloatArray tornadoVMForwardDecode(int token, int position, Model model) {
- LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- int dim = config.dim();
-
- if (weights.getWeightType() == GGMLType.Q8_0) {
- ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray();
- int blocksPerRow = (dim + 31) / 32;
- long bytesPerToken = (long) blocksPerRow * 34;
- MemorySegment.copy(embTable.getSegment(), (long) token * bytesPerToken,
- state.embeddingX.getSegment(), 0L, bytesPerToken);
- } else {
- MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
- MemorySegment.copy(embTable, (long) token * dim * Short.BYTES,
- state.embeddingX.getSegment(), 0L, (long) dim * Short.BYTES);
- }
-
- state.positionHolder.set(0, position);
- state.temp.clear();
- state.tempFFN.clear();
-
- // Graph N+1: decode activation
- var decodeAct = executionPlan.withGraph(decodeActivationIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) decodeAct.withCUDAGraph();
- decodeAct.execute();
-
- // Graphs N+2..2N+1: decode transformer layers
- for (int l = 0; l < N; l++) {
- var decodeLayer = executionPlan.withGraph(decodeLayerIdx(l)).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) decodeLayer.withCUDAGraph();
- //System.err.println("[DEBUG] about to execute decode transformer layer (graph " + decodeLayerIdx(l) + "--)");
- decodeLayer.execute();
- }
-
- state.tempLogits.clear();
- state.wrapLogits.clear();
-
- // Graph 2N+2: logits
- var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) logits.withCUDAGraph();
- logits.execute();
-
- return state.wrapLogits;
- }
-
- @Override
- public FloatArray tornadoVMForwardExecuteLayered(int position) {
- throw new UnsupportedOperationException(
- "Use tornadoVMForwardBatchPrefill / tornadoVMForwardDecode for batch plan");
- }
-
- @Override
- public void freeTornadoExecutionPlan() {
- executionPlan.freeDeviceMemory();
- }
-
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
deleted file mode 100644
index 500882ea..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java
+++ /dev/null
@@ -1,244 +0,0 @@
-package org.beehive.gpullama3.tornadovm;
-
-import org.beehive.gpullama3.auxiliary.RunMetrics;
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode;
-import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.KernelContext;
-import uk.ac.manchester.tornado.api.TaskGraph;
-import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
-import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * GPU execution plan for sequential (single-token) prefill/decode separation.
- *
- * A single {@link TornadoExecutionPlan} holds all graphs so that the KV cache
- * ({@code wrapKeyCache}, {@code wrapValueCache}) is allocated once and remains on
- * device across both phases. Prefill and decode reuse the same N layer graphs;
- * only the logits graph is skipped during prefill.
- *
- * Graph layout (N+2 graphs total):
- *
- * [0] decodeActivation single-token FP16 → FP32; KV-cache allocated on first execution
- * [1..N] layer_0..layer_N-1 transformer layers (attention + FFN)
- * [N+1] logits final RMSNorm + wcls matmul
- *
- *
- * Two forward passes:
- *
- * - {@link #tornadoVMForwardPrefill} — graphs 0..N (activation + layers), logits skipped.
- * Called once per prompt token; populates the KV cache.
- * - {@link #tornadoVMForwardDecode} — full pass including logits.
- * Called once per generated token; returns logits for sampling.
- *
- */
-public class TornadoVMMasterPlanWithPrefillDecode implements TornadoVMMasterPlan {
-
- private final LlamaState state;
- private final Model model;
- private final LlamaConfiguration config;
- private final int N; // numberOfLayers
- private final TornadoExecutionPlan executionPlan;
- private final GridScheduler gridScheduler;
-
- // ── Graph-index helpers ───────────────────────────────────────────────────
- private int activationIdx() { return 0; }
- private int layerIdx(int i) { return 1 + i; }
- private int logitsIdx() { return N + 1; }
-
- // ── Construction ─────────────────────────────────────────────────────────
- TornadoVMMasterPlanWithPrefillDecode(State initialState, Model model) {
- if (ENABLE_TORNADOVM_INIT_TIME) {
- System.err.println("\nStarting TornadoVM initialization...");
- }
-
- this.state = (LlamaState) initialState;
- this.model = model;
- this.config = (LlamaConfiguration) model.configuration();
- this.N = config.numberOfLayers();
- this.gridScheduler = new GridScheduler();
-
- long startTime = System.nanoTime();
- this.executionPlan = createExecutionPlan();
- long planCreationTime = System.nanoTime();
-
- if (CUDA_GRAPHS) executionPlan.withAllGraphs().withCUDAGraph();
- executionPlan.withPreCompilation();
- long warmupTime = System.nanoTime();
-
- forceCopyInReadOnlyData();
- long copyTime = System.nanoTime();
-
- RunMetrics.setTornadoMetrics(planCreationTime - startTime, warmupTime - planCreationTime, copyTime - warmupTime);
- }
-
- // ── Activation graph ─────────────────────────────────────────────────────
-
- /**
- * Graph 0: single-token FP16 → FP32.
- *
- * Outputs {@code wrapX} (FP32 hidden state) and persists it on device so that
- * decode layer 0 can pick it up via {@code consumeFromDevice("decodeActivation", wrapX)}.
- * The KV cache is not managed here — it is allocated on the first forward pass
- * by decode layer 0 via {@code FIRST_EXECUTION}.
- */
- private TaskGraph buildActivationGraph(KernelContext ctx, GGMLType weightType) {
- if (weightType == GGMLType.Q8_0) {
- return new TaskGraph("decodeActivation")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
- .task("updateX", TransformerComputeKernels::convertQ8_0toFP32,
- ctx, (ByteArray) state.embeddingX, state.wrapX)
- .persistOnDevice(state.wrapX);
- }
- return new TaskGraph("decodeActivation")
- .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
- .task("updateX", TransformerComputeKernels::convertFP16toFP32,
- ctx, (HalfFloatArray) state.embeddingX, state.wrapX)
- .persistOnDevice(state.wrapX);
- }
-
- // ── Plan construction ─────────────────────────────────────────────────────
- @Override
- public TornadoExecutionPlan createExecutionPlan() {
- GGMLType weightType = model.weights().getWeightType();
- if (weightType != GGMLType.F16 && weightType != GGMLType.Q8_0) {
- throw new UnsupportedOperationException(
- "Prefill/decode GPU path not supported for weight type: " + weightType);
- }
-
- LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
- SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
-
- List all = new ArrayList<>(N + 2);
-
- // [0] Activation ──────────────────────────────────────────────────────
- KernelContext actCtx = new KernelContext();
- all.add(buildActivationGraph(actCtx, weightType).snapshot());
- gridScheduler.addWorkerGrid("decodeActivation.updateX",
- WorkerGridFactory.genericWorker(config.dim(), 128));
-
- // [1..N] Decode layer graphs ──────────────────────────────────────────
- AbstractFFNLayers decodeLayers;
- if (weightType == GGMLType.F16) {
- decodeLayers = new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
- } else {
- decodeLayers = new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
- }
- all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
- decodeLayers.updateGridScheduler(gridScheduler);
-
- // [N+1] Logits ────────────────────────────────────────────────────────
- AbstractLogitsLayer logitsLayer;
- if (weightType == GGMLType.F16) {
- logitsLayer = new LogitsFP16LayerDecode("logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- } else {
- logitsLayer = new LogitsQ8_0LayerDecode("logits", state, weights, config,
- decodeLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- }
- all.add(logitsLayer.getImmutableTaskGraph());
- logitsLayer.updateGridScheduler(gridScheduler);
-
- return new TornadoExecutionPlan(all.toArray(new ImmutableTaskGraph[0]));
- }
-
- // ── Initialisation ────────────────────────────────────────────────────────
-
- /** Runs all graphs once to trigger FIRST_EXECUTION uploads and warm up CUDA graphs. */
- @Override
- public void forceCopyInReadOnlyData() {
- state.wrapX.clear();
- state.positionHolder.init(0);
-
- for (int i = 0; i <= logitsIdx(); i++) {
- var g = executionPlan.withGraph(i).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) g.withCUDAGraph();
- g.execute();
- }
- }
-
- // ── Forward passes ────────────────────────────────────────────────────────
-
- /**
- * GPU prefill forward: activation + all transformer layers, logits skipped.
- * KV cache is populated for each prompt token.
- *
- * @param position sequence position being processed
- */
- public void tornadoVMForwardPrefill(int position) {
- var prefillActivation = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) prefillActivation.withCUDAGraph();
- prefillActivation.execute();
-
- state.positionHolder.set(0, position);
- state.temp.clear();
- state.tempFFN.clear();
-
- for (int layer = 0; layer < N; layer++) {
- var prefillLayer = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) prefillLayer.withCUDAGraph();
- prefillLayer.execute();
- }
- }
-
- /**
- * GPU decode forward: full execution including logits.
- *
- * @param position sequence position being processed
- * @return logits array for token sampling
- */
- public FloatArray tornadoVMForwardDecode(int position) {
- return tornadoVMForwardExecuteLayered(position);
- }
-
- @Override
- public FloatArray tornadoVMForwardExecuteLayered(int position) {
- var act = executionPlan.withGraph(activationIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) act.withCUDAGraph();
- act.execute();
-
- state.positionHolder.set(0, position);
- state.temp.clear();
- state.tempFFN.clear();
-
- for (int layer = 0; layer < N; layer++) {
- var l = executionPlan.withGraph(layerIdx(layer)).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) l.withCUDAGraph();
- l.execute();
- }
-
- state.tempLogits.clear();
- state.wrapLogits.clear();
- var logits = executionPlan.withGraph(logitsIdx()).withGridScheduler(gridScheduler);
- if (CUDA_GRAPHS) logits.withCUDAGraph();
- logits.execute();
-
- return state.wrapLogits;
- }
-
- @Override
- public void freeTornadoExecutionPlan() {
- executionPlan.freeDeviceMemory();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
index 6c688649..ecc81296 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerBatchPrefillKernels.java
@@ -15,7 +15,7 @@
* Batch tensors are flat: element [b][i] lives at index {@code b*stride + i}.
* Worker-grid sizes are scaled by {@code batchSize} vs the single-token kernels.
*
- * These kernels are meant to be registered in {@link TornadoVMMasterPlanWithPrefillDecode}
+ *
These kernels are meant to be registered in {@link TornadoVMMasterPlanBatchPrefillDecode}
* batch task graphs; they are NOT invoked directly.
*/
public final class TransformerBatchPrefillKernels {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java
deleted file mode 100644
index 9e211051..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/GenericLayerPlanner.java
+++ /dev/null
@@ -1,14 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
-
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-
-import java.util.List;
-
-public interface GenericLayerPlanner {
-
- List getImmutableTaskGraphs();
-
- GridScheduler getGridScheduler();
-
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
deleted file mode 100644
index 42d2dc0c..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
+++ /dev/null
@@ -1,97 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
-
-import org.beehive.gpullama3.inference.state.DevstralState;
-import org.beehive.gpullama3.inference.state.GraniteState;
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.state.Qwen2State;
-import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.DevstralFP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.DevstralQ8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.MistralQ8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
-import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner;
-
-/**
- * Factory class responsible for creating appropriate layer planners based on model type and quantization.
- *
- * The factory follows a routing logic:
- *
- * - Determine quantization type from {@link GGMLType}
- * - Determine model type from {@link Model}
- * - Instantiate appropriate planner implementation
- *
- *
- * Examples:
- *
- * - {@code QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner}
- * - {@code QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner}
- *
- */
-public class QuantizationPlannerFactory {
-
- /**
- * Main factory method: create planner for given model + quantization
- */
- public static GenericLayerPlanner create(GGMLType quantization, State state, Model model) {
- return switch (quantization) {
- case F32 -> createFP32Planner(state, model);
- case F16 -> createFP16Planner(state, model);
- case Q8_0 -> createQ8_0Planner(state, model);
- case Q4_0 -> createQ4_0Planner(state, model);
- default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization);
- };
- }
-
- // ============ FP16 Planners ============
- private static GenericLayerPlanner createFP16Planner(State state, Model model) {
- return switch (model.getModelType()) {
- case LLAMA_3 -> new LlamaFP16LayerPlanner((LlamaState) state, model);
- case MISTRAL -> new MistralFP16LayerPlanner((LlamaState) state, model);
- case DEVSTRAL_2 -> new DevstralFP16LayerPlanner((DevstralState) state, model);
- case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
- case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
- case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
- case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model);
- case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
- default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType());
- };
- }
-
- // ============ Q8_0 Planners ============
- private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
- return switch (model.getModelType()) {
- case LLAMA_3 -> new LlamaQ8_0LayerPlanner((LlamaState) state, model);
- case MISTRAL -> new MistralQ8_0LayerPlanner((LlamaState) state, model);
- case DEVSTRAL_2 -> new DevstralQ8_0LayerPlanner((DevstralState) state, model);
- case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
- case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
- case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
- case GRANITE -> new GraniteQ8_0LayerPlanner((GraniteState) state, model);
- case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
- default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
- };
- }
-
- // ============ FP32 Planners (FUTURE) ============
- private static GenericLayerPlanner createFP32Planner(State state, Model model) {
- throw new UnsupportedOperationException("FP32 planners not yet implemented");
- }
-
- private static GenericLayerPlanner createQ4_0Planner(State state, Model model) {
- throw new UnsupportedOperationException("Q4 planners not yet implemented");
- }
-
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java
deleted file mode 100644
index 1b7c1953..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizedLayerPlanner.java
+++ /dev/null
@@ -1,103 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
-
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.inference.weights.Weights;
-import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import uk.ac.manchester.tornado.api.GridScheduler;
-import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
-import uk.ac.manchester.tornado.api.KernelContext;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- * Abstract base for all quantization-specific planners.
- *
- * Extracts common state from the model, detects the hardware scheduler type,
- * and assembles the full execution plan via createTornadoInferencePlan().
- * Subclasses (FP16LayerPlanner, Q8_0LayerPlanner) only provide quantization validation.
- */
-public abstract class QuantizedLayerPlanner
- implements GenericLayerPlanner {
-
- protected final S state;
- protected final C config;
- protected final W weights;
- protected final KernelContext context;
- protected final Model model;
- protected final SchedulerType schedulerType;
-
- protected Activation activationLayer;
- protected AbstractFFNLayers ffnLayers;
- protected AbstractLogitsLayer logitsLayer;
-
- private List immutableTaskGraphs;
- private GridScheduler gridScheduler;
-
- @SuppressWarnings("unchecked")
- protected QuantizedLayerPlanner(S state, Model model) {
- this.state = state;
- this.model = model;
- this.config = (C) model.configuration();
- this.weights = (W) model.weights();
- this.context = new KernelContext();
- this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
- validateQuantizationType();
- }
-
- /** Validates that the model weights match the expected quantization type. */
- protected abstract void validateQuantizationType();
-
- /**
- * Creates the TornadoVM inference execution pipeline.
- * It represents the entire Feed-Forward Network (FFN) and consists of:
- *
- * - Activation layer
- * - FFN layers (N transformer layers, model-specific)
- * - Logits layer
- *
- *
- * Each component is represented as an {@link ImmutableTaskGraph}, along with a
- * corresponding {@link GridScheduler} configuration that defines how tasks are
- * mapped on the GPU.
- *
- * This method assembles all components into a unified execution pipeline and
- * caches the resulting task graphs and scheduler for reuse across inference runs.
- */
- protected final void createTornadoInferencePlan() {
- List allTaskGraphs = new ArrayList<>();
- GridScheduler masterScheduler = new GridScheduler();
-
- // 1. Activation layer (common to all models)
- allTaskGraphs.add(activationLayer.getImmutableTaskGraph());
- activationLayer.updateGridScheduler(masterScheduler);
-
- // 2. FFN layers (N transformer layers - model-specific)
- allTaskGraphs.addAll(ffnLayers.getFFNLayerImmutableTaskGraphs());
- ffnLayers.updateGridScheduler(masterScheduler);
-
- // 3. Logits layer (common to all models)
- allTaskGraphs.add(logitsLayer.getImmutableTaskGraph());
- logitsLayer.updateGridScheduler(masterScheduler);
-
- // Cache for future retrievals
- this.immutableTaskGraphs = allTaskGraphs;
- this.gridScheduler = masterScheduler;
- }
-
- @Override
- public final List getImmutableTaskGraphs() {
- return this.immutableTaskGraphs;
- }
-
- @Override
- public final GridScheduler getGridScheduler() {
- return this.gridScheduler;
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java
deleted file mode 100644
index f72cfeb7..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/DevstralFP16LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.DevstralState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-
-public class DevstralFP16LayerPlanner extends FP16LayerPlanner {
-
- public DevstralFP16LayerPlanner(DevstralState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java
deleted file mode 100644
index 3b33d98c..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/FP16LayerPlanner.java
+++ /dev/null
@@ -1,25 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner;
-
-/**
- * Base for all FP16-quantized layer planners.
- */
-public abstract class FP16LayerPlanner extends QuantizedLayerPlanner {
-
- protected FP16LayerPlanner(S state, Model model) {
- super(state, model);
- }
-
- @Override
- protected void validateQuantizationType() {
- if (this.weights.getWeightType() != GGMLType.F16) {
- throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType());
- }
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java
deleted file mode 100644
index 488e4b39..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.GraniteState;
-import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.granite.GraniteConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
-
-public class GraniteFP16LayerPlanner extends FP16LayerPlanner {
-
- public GraniteFP16LayerPlanner(GraniteState state, Model model) {
- super(state, model);
- this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
- this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsGraniteFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java
deleted file mode 100644
index d6042c39..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-
-public class LlamaFP16LayerPlanner extends FP16LayerPlanner {
-
- public LlamaFP16LayerPlanner(LlamaState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java
deleted file mode 100644
index 78e1bffb..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/MistralFP16LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.mistral.MistralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-
-public class MistralFP16LayerPlanner extends FP16LayerPlanner {
-
- public MistralFP16LayerPlanner(LlamaState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java
deleted file mode 100644
index 5b50529a..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers;
-
-/**
- * Phi3FP16LayerPlanner: Phi3 model with FP16 weights.
- *
- * Follows the same pattern as Qwen3FP16LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights - Phi3Configuration
- *
- * Inherits from FP16LayerPlanner
- */
-public class Phi3FP16LayerPlanner extends FP16LayerPlanner {
-
- public Phi3FP16LayerPlanner(Phi3State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java
deleted file mode 100644
index a3dcc5e2..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.Qwen2State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers;
-
-/**
- * Qwen2FP16LayerPlanner: Qwen2 model with FP16 weights.
- *
- * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights - Qwen2Configuration
- *
- * Inherits from FP16LayerPlanner
- */
-public class Qwen2FP16LayerPlanner extends FP16LayerPlanner {
-
- public Qwen2FP16LayerPlanner(Qwen2State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java
deleted file mode 100644
index 76239ae6..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
-
-import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers;
-
-/**
- * Qwen3FP16LayerPlanner: Qwen3 model with FP16 weights.
- *
- * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights - Qwen3Configuration
- *
- * Inherits from FP16LayerPlanner
- */
-public class Qwen3FP16LayerPlanner extends FP16LayerPlanner {
-
- public Qwen3FP16LayerPlanner(Qwen3State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java
deleted file mode 100644
index b6bf9634..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/DevstralQ8_0LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.DevstralState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-
-public class DevstralQ8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public DevstralQ8_0LayerPlanner(DevstralState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java
deleted file mode 100644
index f7735dc0..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.GraniteState;
-import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.granite.GraniteConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer;
-
-public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public GraniteQ8_0LayerPlanner(GraniteState state, Model model) {
- super(state, model);
- this.activationLayer = new ActivationGranite("activationUpdate", state, weights, config);
- this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsGraniteQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java
deleted file mode 100644
index 827ae538..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-
-public class LlamaQ8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public LlamaQ8_0LayerPlanner(LlamaState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java
deleted file mode 100644
index 65e0150c..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/MistralQ8_0LayerPlanner.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.LlamaState;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.mistral.MistralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-
-public class MistralQ8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public MistralQ8_0LayerPlanner(LlamaState state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java
deleted file mode 100644
index 6f088d36..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers;
-
-/**
- * Phi3Q8_0LayerPlanner: Phi3 model with Q8_0-quantized weights.
- *
- * Follows the same pattern as Qwen3Q8_0LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights (8-bit integer quantization) - Phi3Configuration - 2x
- * memory compression vs FP16
- *
- * Inherits from Q8_0LayerPlanner
- */
-public class Phi3Q8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public Phi3Q8_0LayerPlanner(Phi3State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java
deleted file mode 100644
index b525d2a3..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Q8_0LayerPlanner.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.tensor.GGMLType;
-import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layerplanner.QuantizedLayerPlanner;
-
-/**
- * Base for all Q8_0-quantized layer planners.
- */
-public abstract class Q8_0LayerPlanner
- extends QuantizedLayerPlanner {
-
- protected Q8_0LayerPlanner(S state, Model model) {
- super(state, model);
- }
-
- @Override
- protected void validateQuantizationType() {
- if (this.weights.getWeightType() != GGMLType.Q8_0) {
- throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType());
- }
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java
deleted file mode 100644
index 090812e7..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.Qwen2State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers;
-
-/**
- * Qwen2Q8_0LayerPlanner: Qwen2 model with Q8_0-quantized weights.
- *
- * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights (8-bit integer quantization) - Qwen2Configuration -
- * 2x memory compression vs FP16
- *
- * Inherits from Q8_0LayerPlanner
- */
-public class Qwen2Q8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java
deleted file mode 100644
index b5a19be2..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
-
-import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
-import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers;
-
-/**
- * Qwen3Q8_0LayerPlanner: Qwen3 model with Q8_0-quantized weights.
- *
- * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights (8-bit integer quantization) - Qwen3Configuration - 2x memory
- * compression vs FP16
- *
- * Inherits from Q8_0LayerPlanner
- */
-public class Qwen3Q8_0LayerPlanner extends Q8_0LayerPlanner {
-
- public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) {
- super(state, model);
- this.activationLayer = new Activation("activationUpdate", state, weights, config);
- this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType);
- this.logitsLayer = new LogitsQ8_0Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
- createTornadoInferencePlan();
- }
-}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java
deleted file mode 100644
index 28b568a2..00000000
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java
+++ /dev/null
@@ -1,5 +0,0 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.strategy;
-
-public enum SchedulerType {
- NVIDIA, NON_NVIDIA
-}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java
index 1a73897e..d0b3656e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java
@@ -3,7 +3,7 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
@@ -14,7 +14,7 @@
* Abstract base class for all FFN (Feed-Forward Network) layer implementations.
* Extended by model and quantization-specific subclasses that provide specific implementations.
*/
-public abstract class AbstractFFNLayers extends AbstractLayer {
+public abstract class AbstractFFNLayers extends AbstractLayer implements TransformerLayerTaskGraphs {
/**
* List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer.
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java
index 37288e5f..afead80d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
index 6ba36f39..d35f8252 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
@@ -13,7 +13,7 @@
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-public class Activation extends AbstractLayer {
+public class Activation extends AbstractLayer implements ActivationGraph {
private final TaskGraph activationTaskGraph;
public Activation(String name, State state, Weights weights, Configuration config) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java
new file mode 100644
index 00000000..c46479bc
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java
@@ -0,0 +1,15 @@
+package org.beehive.gpullama3.tornadovm.layers;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+/**
+ * Common interface for single activation graphs (standard, decode, batch-prefill variants).
+ *
+ * Implemented by {@link Activation} and custom activation wrappers used by
+ * {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.
+ */
+public interface ActivationGraph {
+ ImmutableTaskGraph getImmutableTaskGraph();
+ GridScheduler updateGridScheduler(GridScheduler scheduler);
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java
new file mode 100644
index 00000000..0e7a762f
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/BatchPrefillTransformerLayerTaskGraphs.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.layers;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.List;
+
+/**
+ * Interface for a group of N batched-prefill transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph}.
+ *
+ * Implemented by {@code LlamaFP16LayersBatchPrefill} and {@code LlamaQ8_0LayersBatchPrefill}.
+ */
+public interface BatchPrefillTransformerLayerTaskGraphs {
+ List getLayerImmutableTaskGraphs();
+ void updateGridScheduler(GridScheduler scheduler);
+ String getLastLayerTaskGraphID();
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
new file mode 100644
index 00000000..245b3f40
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.layers;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.List;
+
+/**
+ * Interface for a group of N transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph} (standard or prefill-decode variants).
+ *
+ * Implemented by {@link AbstractFFNLayers} and its subclasses.
+ */
+public interface TransformerLayerTaskGraphs {
+ List getFFNLayerImmutableTaskGraphs();
+ GridScheduler updateGridScheduler(GridScheduler scheduler);
+ String getLastFFNLayerTaskGraphID();
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
index 7fcc8717..e18c2a3a 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
index 80f4a6ea..46aad290 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
@@ -7,8 +7,8 @@
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
index a7b1e87b..928954e9 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
index 1858408e..e3974ade 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
@@ -7,8 +7,8 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
index 54ec9641..2e3a6fe6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
index 499fc176..aa26bb69 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
index 065c6802..1d443ce3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
index d6ac0eee..16e35a39 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
@@ -6,8 +6,8 @@
import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels;
import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
index 60b3cc3e..0530e7c0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
index f20d5bed..ea0220f6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
@@ -3,14 +3,14 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
* Decode FFN layers of the unified batched prefill-decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* Overrides data-transfer declarations so that all cross-graph boundaries use
* the explicit-source form of {@code consumeFromDevice}. The no-arg form (used by
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
index 2f5bac64..c1167796 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
@@ -3,14 +3,14 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
* Decode FFN layers for the single-token prefill/decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}).
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
*
Combines two concerns:
*
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
index 350e6760..7ee529d0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LogitsFP16LayerDecode.java
@@ -3,13 +3,13 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import uk.ac.manchester.tornado.api.TaskGraph;
/**
* Logits layer of the unified batched prefill-decode plan
- * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* Extends {@link LogitsFP16Layer} with KV-cache pass-through so the device
* pointers for {@code wrapKeyCache} and {@code wrapValueCache} survive the
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
index a44425ef..44ce4b35 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
@@ -4,7 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -17,7 +18,7 @@
/**
* Prefill FFN layers with batching for the unified batched prefill-decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
*
One {@link ImmutableTaskGraph} per transformer layer, each processing
* {@code batchSize} tokens simultaneously via {@link TransformerBatchPrefillKernels}.
@@ -25,7 +26,7 @@
* KV cache ({@code wrapKeyCache}, {@code wrapValueCache}) is persisted on device
* after every layer so the subsequent single-token decode layers can consume it.
*/
-public class LlamaFP16LayersBatchPrefill {
+public class LlamaFP16LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs {
// Matches the local workgroup size used by the single-token kernels.
static final int LOCAL_WORK_GROUP_SIZE = 32;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
index 0c5aedd6..0d717d2e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
@@ -4,8 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
index 8a91c75a..1aca4714 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
index 2cfacdc0..f5683904 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
@@ -4,8 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java
index c583bb00..dca19c83 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
index 1a1313e2..bf66118a 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
@@ -7,8 +7,8 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
index 64864114..4a6f3151 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
@@ -4,8 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
index 8f693adc..0ffe25d6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.tornadovm.kernels.Phi3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
index 6cdf32db..936c3573 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
@@ -6,8 +6,8 @@
import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels;
import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
index 43b88fb1..1696644d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
@@ -5,8 +5,8 @@
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
index e159c99b..0b2f13cb 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
@@ -3,14 +3,14 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
* Decode FFN layers for the unified batched prefill-decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode}).
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* Layer 0 consumes the KV cache from device (passed through by the decode activation
* graph, which relays it from the last batch prefill layer). No FIRST_EXECUTION allocation
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
index a1f02ada..388853f7 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
@@ -3,13 +3,13 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
import uk.ac.manchester.tornado.api.TaskGraph;
/**
* Decode FFN layers for the single-token prefill/decode plan
- * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode}).
+ * ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
*
Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
* includes {@code FIRST_EXECUTION} for {@code wrapKeyCache} and {@code wrapValueCache},
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
index 054b7f7f..2a6714f2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LogitsQ8_0LayerDecode.java
@@ -3,7 +3,7 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
import uk.ac.manchester.tornado.api.TaskGraph;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
index 79ea9bc2..14c30462 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
@@ -4,7 +4,8 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
-import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
@@ -30,7 +31,7 @@
*
- Weight matrices are {@code ByteArray} (Q8_0 format).
*
*/
-public class LlamaQ8_0LayersBatchPrefill {
+public class LlamaQ8_0LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs {
static final int LOCAL_WORK_GROUP_SIZE = 32;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
new file mode 100644
index 00000000..1246265c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
@@ -0,0 +1,75 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer;
+import org.beehive.gpullama3.tornadovm.plan.layout.BatchPrefillDecodeForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Topology plan for the 2N+3 batch-prefill/decode forward pass.
+ *
+ * Graph layout:
+ *
+ * [0] batchPrefillActivation
+ * [1..N] batch-prefill transformer layers
+ * [N+1] decodeActivation (consumes + re-persists KV cache)
+ * [N+2..2N+1] decode transformer layers
+ * [2N+2] logits
+ *
+ *
+ * During batch prefill, the master plan executes graphs 0..N.
+ * During decode, graphs N+1..2N+2 run.
+ */
+public class BatchPrefillDecodeForwardPlan extends ForwardPlan {
+
+ private final BatchPrefillDecodeForwardTaskGraphLayout taskGraphLayout;
+ private final EmbeddingPreparer embeddingPreparer;
+
+ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanComponents components, int batchSize) {
+ int N = model.configuration().numberOfLayers();
+ this.taskGraphLayout = new BatchPrefillDecodeForwardTaskGraphLayout(N);
+ this.embeddingPreparer = components.embeddingPreparer();
+
+ List all = new ArrayList<>(2 * N + 3);
+ GridScheduler scheduler = new GridScheduler();
+
+ ActivationGraph batchAct = components.batchPrefillActivation(batchSize);
+ all.add(batchAct.getImmutableTaskGraph());
+ batchAct.updateGridScheduler(scheduler);
+
+ BatchPrefillTransformerLayerTaskGraphs batchLayers = components.batchPrefillLayers(batchSize);
+ all.addAll(batchLayers.getLayerImmutableTaskGraphs());
+ batchLayers.updateGridScheduler(scheduler);
+
+ ActivationGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID());
+ all.add(decodeAct.getImmutableTaskGraph());
+ decodeAct.updateGridScheduler(scheduler);
+
+ TransformerLayerTaskGraphs decodeLayers = components.batchDecodeLayers();
+ all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
+ decodeLayers.updateGridScheduler(scheduler);
+
+ AbstractLogitsLayer logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID());
+ all.add(logits.getImmutableTaskGraph());
+ logits.updateGridScheduler(scheduler);
+
+ setGraphs(all, scheduler);
+ }
+
+ public BatchPrefillDecodeForwardTaskGraphLayout getTaskGraphLayout() {
+ return taskGraphLayout;
+ }
+
+ public EmbeddingPreparer getEmbeddingPreparer() {
+ return embeddingPreparer;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java
new file mode 100644
index 00000000..0e597af2
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ExecutionMode.java
@@ -0,0 +1,16 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+/**
+ * Selects the GPU execution topology for a TornadoVM inference plan.
+ *
+ *
+ * - {@link #STANDARD} — single forward pass (activation + N layers + logits).
+ * - {@link #PREFILL_DECODE} — shared N+2 graph plan; prefill skips logits, decode runs all.
+ * - {@link #BATCH_PREFILL_DECODE} — 2N+3 graph plan; N batch-prefill graphs + N decode graphs + logits.
+ *
+ */
+public enum ExecutionMode {
+ STANDARD,
+ PREFILL_DECODE,
+ BATCH_PREFILL_DECODE
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java
new file mode 100644
index 00000000..5ccdb4c0
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlan.java
@@ -0,0 +1,32 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.List;
+
+/**
+ * Abstract base for GPU forward-pass execution plans.
+ *
+ * Subclasses assemble the {@link ImmutableTaskGraph} list and {@link GridScheduler}
+ * in their constructors by calling {@link #setGraphs}, then expose the results
+ * via {@link #getImmutableTaskGraphs} and {@link #getGridScheduler}.
+ */
+public abstract class ForwardPlan {
+
+ private List immutableTaskGraphs;
+ private GridScheduler gridScheduler;
+
+ protected final void setGraphs(List itgs, GridScheduler scheduler) {
+ this.immutableTaskGraphs = itgs;
+ this.gridScheduler = scheduler;
+ }
+
+ public final List getImmutableTaskGraphs() {
+ return immutableTaskGraphs;
+ }
+
+ public final GridScheduler getGridScheduler() {
+ return gridScheduler;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java
new file mode 100644
index 00000000..4eac09b6
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java
@@ -0,0 +1,207 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import org.beehive.gpullama3.inference.state.DevstralState;
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.state.Phi3State;
+import org.beehive.gpullama3.inference.state.Qwen2State;
+import org.beehive.gpullama3.inference.state.Qwen3State;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.DevstralFP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.GraniteFP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.LlamaFP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.MistralFP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.Phi3FP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.Qwen2FP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.fp16.Qwen3FP16PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.DevstralQ8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.GraniteQ8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.LlamaQ8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.MistralQ8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Phi3Q8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Qwen2Q8_0PlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.q8_0.Qwen3Q8_0PlanComponents;
+
+/**
+ * Factory for {@link ForwardPlan} instances.
+ *
+ * Dispatches across three axes in order:
+ *
+ * - Quantization ({@link GGMLType})
+ * - Model family ({@link org.beehive.gpullama3.model.ModelType})
+ * - Execution mode ({@link ExecutionMode})
+ *
+ *
+ * Use the typed convenience methods when the execution mode is known at the call site:
+ *
+ * - {@link #createSingleToken} — returns {@link SingleTokenForwardPlan}
+ * - {@link #createPrefillDecode} — returns {@link PrefillDecodeForwardPlan}
+ * - {@link #createBatchPrefillDecode} — returns {@link BatchPrefillDecodeForwardPlan}
+ *
+ */
+public class ForwardPlanFactory {
+
+ private ForwardPlanFactory() {}
+
+ // ── Typed public API ──────────────────────────────────────────────────────
+
+ public static SingleTokenForwardPlan createSingleToken(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.STANDARD, state, model);
+ if (plan instanceof SingleTokenForwardPlan singleToken) return singleToken;
+ throw new IllegalStateException("Expected SingleTokenForwardPlan for STANDARD mode but got " + plan.getClass().getSimpleName());
+ }
+
+ public static PrefillDecodeForwardPlan createPrefillDecode(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.PREFILL_DECODE, state, model);
+ if (plan instanceof PrefillDecodeForwardPlan prefillDecode) return prefillDecode;
+ throw new IllegalStateException("Expected PrefillDecodeForwardPlan for PREFILL_DECODE mode but got " + plan.getClass().getSimpleName());
+ }
+
+ public static BatchPrefillDecodeForwardPlan createBatchPrefillDecode(GGMLType quantization, State state, Model model) {
+ ForwardPlan plan = create(quantization, ExecutionMode.BATCH_PREFILL_DECODE, state, model);
+ if (plan instanceof BatchPrefillDecodeForwardPlan batchPrefillDecode) return batchPrefillDecode;
+ throw new IllegalStateException("Expected BatchPrefillDecodeForwardPlan for BATCH_PREFILL_DECODE mode but got " + plan.getClass().getSimpleName());
+ }
+
+ // ── Generic dispatch ──────────────────────────────────────────────────────
+
+ static ForwardPlan create(GGMLType quantization, ExecutionMode mode, State state, Model model) {
+ return switch (quantization) {
+ case F16 -> createFP16Plan(mode, state, model);
+ case Q8_0 -> createQ8_0Plan(mode, state, model);
+ case F32 -> throw new UnsupportedOperationException("F32 plans not yet implemented");
+ case Q4_0 -> throw new UnsupportedOperationException("Q4_0 plans not yet implemented");
+ default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization);
+ };
+ }
+
+ // ── FP16 branch ───────────────────────────────────────────────────────────
+
+ private static ForwardPlan createFP16Plan(ExecutionMode mode, State state, Model model) {
+ return switch (model.getModelType()) {
+ case LLAMA_3 -> createLlamaFP16Plan(mode, (LlamaState) state, model);
+ case MISTRAL -> createMistralFP16Plan(mode, (LlamaState) state, model);
+ case DEVSTRAL_2 -> createDevstralFP16Plan(mode, (DevstralState) state, model);
+ case QWEN_2 -> createQwen2FP16Plan(mode, (Qwen2State) state, model);
+ case QWEN_3 -> createQwen3FP16Plan(mode, (Qwen3State) state, model);
+ case PHI_3 -> createPhi3FP16Plan(mode, (Phi3State) state, model);
+ case GRANITE -> createGraniteFP16Plan(mode, (GraniteState) state, model);
+ case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2FP16Plan(mode, (Qwen2State) state, model);
+ default -> throw new UnsupportedOperationException("F16 not supported for model: " + model.getModelType());
+ };
+ }
+
+ // ── Q8_0 branch ───────────────────────────────────────────────────────────
+
+ private static ForwardPlan createQ8_0Plan(ExecutionMode mode, State state, Model model) {
+ return switch (model.getModelType()) {
+ case LLAMA_3 -> createLlamaQ8_0Plan(mode, (LlamaState) state, model);
+ case MISTRAL -> createMistralQ8_0Plan(mode, (LlamaState) state, model);
+ case DEVSTRAL_2 -> createDevstralQ8_0Plan(mode, (DevstralState) state, model);
+ case QWEN_2 -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model);
+ case QWEN_3 -> createQwen3Q8_0Plan(mode, (Qwen3State) state, model);
+ case PHI_3 -> createPhi3Q8_0Plan(mode, (Phi3State) state, model);
+ case GRANITE -> createGraniteQ8_0Plan(mode, (GraniteState) state, model);
+ case DEEPSEEK_R1_DISTILL_QWEN -> createQwen2Q8_0Plan(mode, (Qwen2State) state, model);
+ default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
+ };
+ }
+
+ // ── Model+quant helpers — Llama (all 3 modes supported) ──────────────────
+
+ private static ForwardPlan createLlamaFP16Plan(ExecutionMode mode, LlamaState state, Model model) {
+ BatchPrefillDecodeForwardPlanComponents components = new LlamaFP16PlanComponents(state, model);
+ return switch (mode) {
+ case STANDARD -> new SingleTokenForwardPlan(model, components);
+ case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components);
+ case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE);
+ };
+ }
+
+ private static ForwardPlan createLlamaQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) {
+ BatchPrefillDecodeForwardPlanComponents components = new LlamaQ8_0PlanComponents(state, model);
+ return switch (mode) {
+ case STANDARD -> new SingleTokenForwardPlan(model, components);
+ case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components);
+ case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE);
+ };
+ }
+
+ // ── Model+quant helpers — STANDARD only ──────────────────────────────────
+
+ private static ForwardPlan createMistralFP16Plan(ExecutionMode mode, LlamaState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + F16");
+ return new SingleTokenForwardPlan(model, new MistralFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createMistralQ8_0Plan(ExecutionMode mode, LlamaState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for MISTRAL + Q8_0");
+ return new SingleTokenForwardPlan(model, new MistralQ8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createDevstralFP16Plan(ExecutionMode mode, DevstralState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + F16");
+ return new SingleTokenForwardPlan(model, new DevstralFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createDevstralQ8_0Plan(ExecutionMode mode, DevstralState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for DEVSTRAL_2 + Q8_0");
+ return new SingleTokenForwardPlan(model, new DevstralQ8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen2FP16Plan(ExecutionMode mode, Qwen2State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + F16");
+ return new SingleTokenForwardPlan(model, new Qwen2FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen2Q8_0Plan(ExecutionMode mode, Qwen2State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_2 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Qwen2Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen3FP16Plan(ExecutionMode mode, Qwen3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + F16");
+ return new SingleTokenForwardPlan(model, new Qwen3FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createQwen3Q8_0Plan(ExecutionMode mode, Qwen3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Qwen3Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createPhi3FP16Plan(ExecutionMode mode, Phi3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + F16");
+ return new SingleTokenForwardPlan(model, new Phi3FP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createPhi3Q8_0Plan(ExecutionMode mode, Phi3State state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for PHI_3 + Q8_0");
+ return new SingleTokenForwardPlan(model, new Phi3Q8_0PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createGraniteFP16Plan(ExecutionMode mode, GraniteState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + F16");
+ return new SingleTokenForwardPlan(model, new GraniteFP16PlanComponents(state, model));
+ }
+
+ private static ForwardPlan createGraniteQ8_0Plan(ExecutionMode mode, GraniteState state, Model model) {
+ if (mode != ExecutionMode.STANDARD)
+ throw new UnsupportedOperationException(mode + " not yet supported for GRANITE + Q8_0");
+ return new SingleTokenForwardPlan(model, new GraniteQ8_0PlanComponents(state, model));
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
new file mode 100644
index 00000000..241f542b
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
@@ -0,0 +1,57 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Topology plan for the N+2 prefill/decode forward pass.
+ *
+ * Graph layout:
+ *
+ * [0] decodeActivation
+ * [1..N] decode transformer layers
+ * [N+1] logits
+ *
+ *
+ * During prefill, the master plan executes graphs 0..N (skipping logits).
+ * During decode, all N+2 graphs run.
+ */
+public class PrefillDecodeForwardPlan extends ForwardPlan {
+
+ private final PrefillDecodeForwardTaskGraphLayout taskGraphLayout;
+
+ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents components) {
+ int N = model.configuration().numberOfLayers();
+ this.taskGraphLayout = new PrefillDecodeForwardTaskGraphLayout(N);
+
+ List all = new ArrayList<>(N + 2);
+ GridScheduler scheduler = new GridScheduler();
+
+ ActivationGraph act = components.decodeActivation();
+ all.add(act.getImmutableTaskGraph());
+ act.updateGridScheduler(scheduler);
+
+ TransformerLayerTaskGraphs layers = components.prefillDecodeLayers();
+ all.addAll(layers.getFFNLayerImmutableTaskGraphs());
+ layers.updateGridScheduler(scheduler);
+
+ AbstractLogitsLayer logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID());
+ all.add(logits.getImmutableTaskGraph());
+ logits.updateGridScheduler(scheduler);
+
+ setGraphs(all, scheduler);
+ }
+
+ public PrefillDecodeForwardTaskGraphLayout getTaskGraphLayout() {
+ return taskGraphLayout;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
new file mode 100644
index 00000000..bdc20a97
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
@@ -0,0 +1,54 @@
+package org.beehive.gpullama3.tornadovm.plan;
+
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Topology plan for the N+2 single-token forward pass.
+ *
+ * Graph layout:
+ *
+ * [0] activation
+ * [1..N] transformer layers
+ * [N+1] logits
+ *
+ */
+public class SingleTokenForwardPlan extends ForwardPlan {
+
+ private final SingleTokenForwardTaskGraphLayout taskGraphLayout;
+
+ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents components) {
+ int N = model.configuration().numberOfLayers();
+ this.taskGraphLayout = new SingleTokenForwardTaskGraphLayout(N);
+
+ List all = new ArrayList<>(N + 2);
+ GridScheduler scheduler = new GridScheduler();
+
+ ActivationGraph act = components.standardActivation();
+ all.add(act.getImmutableTaskGraph());
+ act.updateGridScheduler(scheduler);
+
+ TransformerLayerTaskGraphs layers = components.standardLayers();
+ all.addAll(layers.getFFNLayerImmutableTaskGraphs());
+ layers.updateGridScheduler(scheduler);
+
+ AbstractLogitsLayer logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID());
+ all.add(logits.getImmutableTaskGraph());
+ logits.updateGridScheduler(scheduler);
+
+ setGraphs(all, scheduler);
+ }
+
+ public SingleTokenForwardTaskGraphLayout getTaskGraphLayout() {
+ return taskGraphLayout;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
new file mode 100644
index 00000000..10595b27
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
@@ -0,0 +1,25 @@
+package org.beehive.gpullama3.tornadovm.plan.components;
+
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+
+/**
+ * Components for the 2N+3 batch-prefill/decode forward plan.
+ *
+ * Extends {@link PrefillDecodeForwardPlanComponents} with the batch-prefill
+ * activation, batch layer group, KV-cache decode activation, and the
+ * host-side embedding preparer.
+ */
+public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents {
+
+ ActivationGraph batchPrefillActivation(int batchSize);
+
+ ActivationGraph batchDecodeActivation(String lastBatchLayerId);
+
+ TransformerLayerTaskGraphs batchDecodeLayers();
+
+ BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize);
+
+ EmbeddingPreparer embeddingPreparer();
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java
new file mode 100644
index 00000000..074bc4cd
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/EmbeddingPreparer.java
@@ -0,0 +1,18 @@
+package org.beehive.gpullama3.tornadovm.plan.components;
+
+/**
+ * Prepares embedding vectors in host memory before GPU activation graphs run.
+ *
+ * Concrete implementations are format-specific (FP16 byte copy vs Q8_0 CPU dequantization)
+ * and live in the component-provider classes.
+ */
+public interface EmbeddingPreparer {
+ /** Clears the batch input buffer and resets the batch-start position holder to zero. */
+ void initBatchState();
+
+ /** Copies or dequantizes embeddings for {@code chunkSize} tokens into the batch buffer and sets the batch-start position holder. */
+ void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize);
+
+ /** Copies or dequantizes the embedding for a single decode token into the single-token buffer. */
+ void copyDecodeEmbedding(int token);
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
new file mode 100644
index 00000000..8f0b6cdd
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
@@ -0,0 +1,20 @@
+package org.beehive.gpullama3.tornadovm.plan.components;
+
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+
+/**
+ * Components for the N+2 prefill/decode forward plan.
+ *
+ * Extends {@link SingleTokenForwardPlanComponents} with the decode-phase
+ * activation, KV-cache-aware layer group, and decode logits.
+ */
+public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents {
+
+ ActivationGraph decodeActivation();
+
+ TransformerLayerTaskGraphs prefillDecodeLayers();
+
+ AbstractLogitsLayer decodeLogits(String previousGraphId);
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
new file mode 100644
index 00000000..6ecc2077
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
@@ -0,0 +1,21 @@
+package org.beehive.gpullama3.tornadovm.plan.components;
+
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+
+/**
+ * Components for the single-token forward pass.
+ *
+ * All model+quantization combinations implement this interface.
+ * Models that support prefill/decode modes implement the richer
+ * {@link PrefillDecodeForwardPlanComponents} or {@link BatchPrefillDecodeForwardPlanComponents}.
+ */
+public interface SingleTokenForwardPlanComponents {
+
+ ActivationGraph standardActivation();
+
+ TransformerLayerTaskGraphs standardLayers();
+
+ AbstractLogitsLayer standardLogits(String previousGraphId);
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
new file mode 100644
index 00000000..d3617a6c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
@@ -0,0 +1,62 @@
+package org.beehive.gpullama3.tornadovm.plan.components.activation;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+/**
+ * Decode activation graph with KV-cache pass-through ("decodeActivation").
+ *
+ * Used in the 2N+3 batch-prefill/decode plan. Consumes
+ * {@code wrapKeyCache}/{@code wrapValueCache} from the last batch-prefill layer,
+ * converts the single-token embedding to FP32, then re-persists the KV cache so
+ * that decode layer 0 can consume it.
+ */
+public class BatchDecodeActivation implements ActivationGraph {
+
+ private final ImmutableTaskGraph itg;
+ private final int dim;
+
+ public BatchDecodeActivation(LlamaState state, LlamaConfiguration config,
+ String lastBatchLayerId, boolean isQ8) {
+ this.dim = config.dim();
+ KernelContext ctx = new KernelContext();
+ this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot();
+ }
+
+ private TaskGraph buildGraph(KernelContext ctx, LlamaState state,
+ String lastBatchLayerId, boolean isQ8) {
+ TaskGraph tg = new TaskGraph("decodeActivation")
+ .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX);
+ if (isQ8) {
+ tg.task("updateX", TransformerComputeKernels::convertQ8_0toFP32,
+ ctx, (ByteArray) state.embeddingX, state.wrapX);
+ } else {
+ tg.task("updateX", TransformerComputeKernels::convertFP16toFP32,
+ ctx, (HalfFloatArray) state.embeddingX, state.wrapX);
+ }
+ return tg.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache);
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return itg;
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler scheduler) {
+ scheduler.addWorkerGrid("decodeActivation.updateX",
+ WorkerGridFactory.genericWorker(dim, 128));
+ return scheduler;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
new file mode 100644
index 00000000..73eec954
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
@@ -0,0 +1,72 @@
+package org.beehive.gpullama3.tornadovm.plan.components.activation;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+/**
+ * Batch-prefill activation graph ("prefillActivation").
+ *
+ *
+ * - FP16: uploads {@code embeddingXBatch} (EVERY_EXECUTION) and converts to
+ * FP32 {@code wrapXBatch}.
+ * - Q8_0: uploads FP32 {@code wrapXBatch} pre-filled by host (EVERY_EXECUTION)
+ * and runs a no-op {@code batchPassthrough} so TornadoVM has at least one task.
+ *
+ */
+public class BatchPrefillActivation implements ActivationGraph {
+
+ private final ImmutableTaskGraph itg;
+ private final boolean isQ8;
+ private final int batchSize;
+ private final int dim;
+
+ public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int batchSize, boolean isQ8) {
+ this.isQ8 = isQ8;
+ this.batchSize = batchSize;
+ this.dim = config.dim();
+ KernelContext ctx = new KernelContext();
+ this.itg = buildGraph(ctx, state).snapshot();
+ }
+
+ private TaskGraph buildGraph(KernelContext ctx, LlamaState state) {
+ if (isQ8) {
+ return new TaskGraph("prefillActivation")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch)
+ .task("batchPassthrough", TransformerBatchPrefillKernels::batchPassthrough,
+ ctx, state.wrapXBatch)
+ .persistOnDevice(state.wrapXBatch);
+ }
+ return new TaskGraph("prefillActivation")
+ .transferToDevice(DataTransferMode.FIRST_EXECUTION, ctx, state.wrapXBatch)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingXBatch)
+ .task("updateX", TransformerComputeKernels::convertFP16toFP32,
+ ctx, state.embeddingXBatch, state.wrapXBatch)
+ .persistOnDevice(state.wrapXBatch);
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return itg;
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler scheduler) {
+ if (isQ8) {
+ scheduler.addWorkerGrid("prefillActivation.batchPassthrough",
+ WorkerGridFactory.genericWorker(1, 1));
+ } else {
+ scheduler.addWorkerGrid("prefillActivation.updateX",
+ WorkerGridFactory.genericWorker(batchSize * dim, 128));
+ }
+ return scheduler;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
new file mode 100644
index 00000000..66f3b9ad
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.DevstralState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class DevstralFP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final DevstralState state;
+ private final LlamaTornadoWeights weights;
+ private final DevstralConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public DevstralFP16PlanComponents(DevstralState state, Model model) {
+ this.state = state;
+ this.config = (DevstralConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
new file mode 100644
index 00000000..5c8a6cfd
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class GraniteFP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final GraniteState state;
+ private final GraniteTornadoWeights weights;
+ private final GraniteConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public GraniteFP16PlanComponents(GraniteState state, Model model) {
+ this.state = state;
+ this.config = (GraniteConfiguration) model.configuration();
+ this.weights = (GraniteTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new ActivationGranite("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsGraniteFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
new file mode 100644
index 00000000..3fc093c3
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
@@ -0,0 +1,120 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LlamaFP16FFNLayersPrefillDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill;
+import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer;
+import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation;
+import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+import java.lang.foreign.MemorySegment;
+
+/**
+ * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + FP16.
+ *
+ * Provides all layer objects and the FP16 embedding preparer (raw half-float byte copy).
+ */
+public class LlamaFP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents {
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final LlamaConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public LlamaFP16PlanComponents(LlamaState state, Model model) {
+ this.state = state;
+ this.config = (LlamaConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ // ── Activations ───────────────────────────────────────────────────────────
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public ActivationGraph decodeActivation() {
+ return new Activation("decodeActivation", state, weights, config);
+ }
+
+ @Override public ActivationGraph batchPrefillActivation(int batchSize) {
+ return new BatchPrefillActivation(state, config, batchSize, false);
+ }
+
+ @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) {
+ return new BatchDecodeActivation(state, config, lastBatchLayerId, false);
+ }
+
+ // ── FFN layer groups ──────────────────────────────────────────────────────
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public TransformerLayerTaskGraphs prefillDecodeLayers() {
+ return new LlamaFP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
+ }
+
+ @Override public TransformerLayerTaskGraphs batchDecodeLayers() {
+ return new LlamaFP16FFNLayersDecode("decode", state, weights, config, schedulerType);
+ }
+
+ @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) {
+ return new LlamaFP16LayersBatchPrefill(state, weights, config, batchSize);
+ }
+
+ // ── Logits layers ─────────────────────────────────────────────────────────
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) {
+ return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+
+ // ── Embedding preparation ─────────────────────────────────────────────────
+
+ @Override public EmbeddingPreparer embeddingPreparer() {
+ MemorySegment embTable = weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment();
+ int dim = config.dim();
+ int bytes = Short.BYTES;
+ return new EmbeddingPreparer() {
+ @Override
+ public void initBatchState() {
+ state.wrapXBatch.clear();
+ state.batchStartPosHolder.init(0);
+ }
+ @Override
+ public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) {
+ state.batchStartPosHolder.set(0, startPos);
+ for (int b = 0; b < chunkSize; b++) {
+ MemorySegment.copy(embTable, (long) tokenIds[b] * dim * bytes,
+ state.embeddingXBatch.getSegment(), (long) b * dim * bytes,
+ (long) dim * bytes);
+ }
+ }
+ @Override
+ public void copyDecodeEmbedding(int token) {
+ MemorySegment.copy(embTable, (long) token * dim * bytes,
+ state.embeddingX.getSegment(), 0L, (long) dim * bytes);
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
new file mode 100644
index 00000000..5cc01685
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.mistral.MistralConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class MistralFP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final MistralConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public MistralFP16PlanComponents(LlamaState state, Model model) {
+ this.state = state;
+ this.config = (MistralConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
new file mode 100644
index 00000000..b98eb47d
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.Phi3State;
+import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.phi3.Phi3Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Phi3FP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Phi3State state;
+ private final Phi3TornadoWeights weights;
+ private final Phi3Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Phi3FP16PlanComponents(Phi3State state, Model model) {
+ this.state = state;
+ this.config = (Phi3Configuration) model.configuration();
+ this.weights = (Phi3TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
new file mode 100644
index 00000000..2eaae7a2
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.Qwen2State;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Qwen2FP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Qwen2State state;
+ private final Qwen2TornadoWeights weights;
+ private final Qwen2Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Qwen2FP16PlanComponents(Qwen2State state, Model model) {
+ this.state = state;
+ this.config = (Qwen2Configuration) model.configuration();
+ this.weights = (Qwen2TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
new file mode 100644
index 00000000..2ed13ab1
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.fp16;
+
+import org.beehive.gpullama3.inference.state.Qwen3State;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Qwen3State state;
+ private final Qwen3TornadoWeights weights;
+ private final Qwen3Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Qwen3FP16PlanComponents(Qwen3State state, Model model) {
+ this.state = state;
+ this.config = (Qwen3Configuration) model.configuration();
+ this.weights = (Qwen3TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
new file mode 100644
index 00000000..8302f7cb
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.DevstralState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class DevstralQ8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final DevstralState state;
+ private final LlamaTornadoWeights weights;
+ private final DevstralConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public DevstralQ8_0PlanComponents(DevstralState state, Model model) {
+ this.state = state;
+ this.config = (DevstralConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
new file mode 100644
index 00000000..93f01f10
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class GraniteQ8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final GraniteState state;
+ private final GraniteTornadoWeights weights;
+ private final GraniteConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public GraniteQ8_0PlanComponents(GraniteState state, Model model) {
+ this.state = state;
+ this.config = (GraniteConfiguration) model.configuration();
+ this.weights = (GraniteTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new ActivationGranite("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
new file mode 100644
index 00000000..bd82a767
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
@@ -0,0 +1,131 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LlamaQ8_0FFNLayersPrefillDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.LlamaQ8_0LayersBatchPrefill;
+import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.plan.components.EmbeddingPreparer;
+import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation;
+import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+
+import java.lang.foreign.MemorySegment;
+
+/**
+ * {@link BatchPrefillDecodeForwardPlanComponents} for Llama + Q8_0.
+ *
+ * Batch embedding prep: CPU dequantizes Q8_0 embeddings into {@code wrapXBatch} (FP32).
+ * Decode embedding prep: raw Q8_0 block copy into {@code embeddingX} for on-device conversion.
+ */
+public class LlamaQ8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents {
+
+ private static final int BLOCK_SIZE = 32;
+ private static final int Q8_0_BLOCK_BYTES = 34;
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final LlamaConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public LlamaQ8_0PlanComponents(LlamaState state, Model model) {
+ this.state = state;
+ this.config = (LlamaConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ // ── Activations ───────────────────────────────────────────────────────────
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public ActivationGraph decodeActivation() {
+ return new Activation("decodeActivation", state, weights, config);
+ }
+
+ @Override public ActivationGraph batchPrefillActivation(int batchSize) {
+ return new BatchPrefillActivation(state, config, batchSize, true);
+ }
+
+ @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) {
+ return new BatchDecodeActivation(state, config, lastBatchLayerId, true);
+ }
+
+ // ── FFN layer groups ──────────────────────────────────────────────────────
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public TransformerLayerTaskGraphs prefillDecodeLayers() {
+ return new LlamaQ8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType);
+ }
+
+ @Override public TransformerLayerTaskGraphs batchDecodeLayers() {
+ return new LlamaQ8_0FFNLayersDecode("decode", state, weights, config, schedulerType);
+ }
+
+ @Override public BatchPrefillTransformerLayerTaskGraphs batchPrefillLayers(int batchSize) {
+ return new LlamaQ8_0LayersBatchPrefill(state, weights, config, batchSize);
+ }
+
+ // ── Logits layers ─────────────────────────────────────────────────────────
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) {
+ return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+
+ // ── Embedding preparation ─────────────────────────────────────────────────
+
+ @Override public EmbeddingPreparer embeddingPreparer() {
+ ByteArray embTable = weights.getTokenEmbeddingTable().asByteArray();
+ int dim = config.dim();
+ int blocksPerRow = (dim + BLOCK_SIZE - 1) / BLOCK_SIZE;
+ long bytesPerToken = (long) blocksPerRow * Q8_0_BLOCK_BYTES;
+
+ return new EmbeddingPreparer() {
+ @Override
+ public void initBatchState() {
+ state.wrapXBatch.clear();
+ state.batchStartPosHolder.init(0);
+ }
+ @Override
+ public void copyBatchEmbeddings(int[] tokenIds, int startPos, int chunkSize) {
+ state.batchStartPosHolder.set(0, startPos);
+ for (int b = 0; b < chunkSize; b++) {
+ int tokenId = tokenIds[b];
+ for (int j = 0; j < dim; j++) {
+ int blockByteOffset = (tokenId * blocksPerRow + j / BLOCK_SIZE) * Q8_0_BLOCK_BYTES;
+ float scale = embTable.getHalfFloat(blockByteOffset).getFloat32();
+ float quant = embTable.get(blockByteOffset + 2 + j % BLOCK_SIZE);
+ state.wrapXBatch.set(b * dim + j, quant * scale);
+ }
+ }
+ }
+ @Override
+ public void copyDecodeEmbedding(int token) {
+ MemorySegment.copy(embTable.getSegment(), token * bytesPerToken,
+ state.embeddingX.getSegment(), 0L, bytesPerToken);
+ }
+ };
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
new file mode 100644
index 00000000..4b8a78ec
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.mistral.MistralConfiguration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class MistralQ8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final LlamaState state;
+ private final LlamaTornadoWeights weights;
+ private final MistralConfiguration config;
+ private final SchedulerType schedulerType;
+
+ public MistralQ8_0PlanComponents(LlamaState state, Model model) {
+ this.state = state;
+ this.config = (MistralConfiguration) model.configuration();
+ this.weights = (LlamaTornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
new file mode 100644
index 00000000..57cd9508
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.Phi3State;
+import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.phi3.Phi3Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Phi3Q8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Phi3State state;
+ private final Phi3TornadoWeights weights;
+ private final Phi3Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Phi3Q8_0PlanComponents(Phi3State state, Model model) {
+ this.state = state;
+ this.config = (Phi3Configuration) model.configuration();
+ this.weights = (Phi3TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
new file mode 100644
index 00000000..534eedc5
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.Qwen2State;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Qwen2Q8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Qwen2State state;
+ private final Qwen2TornadoWeights weights;
+ private final Qwen2Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) {
+ this.state = state;
+ this.config = (Qwen2Configuration) model.configuration();
+ this.weights = (Qwen2TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
new file mode 100644
index 00000000..3c615b39
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
@@ -0,0 +1,42 @@
+package org.beehive.gpullama3.tornadovm.plan.components.q8_0;
+
+import org.beehive.gpullama3.inference.state.Qwen3State;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService;
+import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
+
+public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents {
+
+ private final Qwen3State state;
+ private final Qwen3TornadoWeights weights;
+ private final Qwen3Configuration config;
+ private final SchedulerType schedulerType;
+
+ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) {
+ this.state = state;
+ this.config = (Qwen3Configuration) model.configuration();
+ this.weights = (Qwen3TornadoWeights) model.weights();
+ this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
+ }
+
+ @Override public ActivationGraph standardActivation() {
+ return new Activation("activationUpdate", state, weights, config);
+ }
+
+ @Override public TransformerLayerTaskGraphs standardLayers() {
+ return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType);
+ }
+
+ @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java
new file mode 100644
index 00000000..aec212bd
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/BatchPrefillDecodeForwardTaskGraphLayout.java
@@ -0,0 +1,21 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the 2N+3 batch-prefill/decode forward plan.
+ *
+ *
+ * [0] batchPrefillActivation
+ * [1..N] batchPrefillLayer_0 .. batchPrefillLayer_{N-1}
+ * [N+1] decodeActivation (consumes + re-persists KV cache)
+ * [N+2..2N+1] decodeLayer_0 .. decodeLayer_{N-1}
+ * [2N+2] logits
+ *
+ */
+public record BatchPrefillDecodeForwardTaskGraphLayout(int N) {
+ public int batchActivationIdx() { return 0; }
+ public int batchLayerIdx(int i) { return 1 + i; }
+ public int decodeActivationIdx() { return N + 1; }
+ public int decodeLayerIdx(int i) { return N + 2 + i; }
+ public int logitsIdx() { return 2 * N + 2; }
+ public int totalGraphs() { return 2 * N + 3; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
new file mode 100644
index 00000000..c252c92d
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/PrefillDecodeForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 prefill/decode forward plan.
+ *
+ *
+ * [0] decodeActivation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record PrefillDecodeForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
new file mode 100644
index 00000000..d8fe8d13
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/layout/SingleTokenForwardTaskGraphLayout.java
@@ -0,0 +1,17 @@
+package org.beehive.gpullama3.tornadovm.plan.layout;
+
+/**
+ * Graph-index arithmetic for the N+2 single-token forward plan.
+ *
+ *
+ * [0] activation
+ * [1..N] layer_0 .. layer_{N-1}
+ * [N+1] logits
+ *
+ */
+public record SingleTokenForwardTaskGraphLayout(int N) {
+ public int activationIdx() { return 0; }
+ public int layerIdx(int i) { return 1 + i; }
+ public int logitsIdx() { return N + 1; }
+ public int totalGraphs() { return N + 2; }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
similarity index 93%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
index 5a81caa8..5e29c528 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerDetectionService.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner.strategy;
+package org.beehive.gpullama3.tornadovm.scheduling;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
@@ -9,7 +9,6 @@
public class SchedulerDetectionService {
-
public static SchedulerType determineSchedulerType(Model model) {
TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime();
String platformName = tornadoRuntime.getBackend(0)
@@ -24,4 +23,4 @@ public static SchedulerType determineSchedulerType(Model model) {
return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
new file mode 100644
index 00000000..bbd27169
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/SchedulerType.java
@@ -0,0 +1,5 @@
+package org.beehive.gpullama3.tornadovm.scheduling;
+
+public enum SchedulerType {
+ NVIDIA, NON_NVIDIA
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
similarity index 98%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
index af39c133..4d5a55a5 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/scheduling/WorkerGridFactory.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.tornadovm.layerplanner;
+package org.beehive.gpullama3.tornadovm.scheduling;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
@@ -98,5 +98,4 @@ private static int findOptimalLocalSize(int size) {
}
return optimal;
}
-
}
From 4e4478af24a77b80111618f7e76050f65d44b33f Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 28 May 2026 15:19:25 +0300
Subject: [PATCH 3/6] Update naming from `ActivationGraph` to
`ActivationTaskGraph` across TornadoVM components
---
.../beehive/gpullama3/tornadovm/layers/Activation.java | 2 +-
.../{ActivationGraph.java => ActivationTaskGraph.java} | 4 ++--
.../tornadovm/plan/BatchPrefillDecodeForwardPlan.java | 6 +++---
.../tornadovm/plan/PrefillDecodeForwardPlan.java | 4 ++--
.../tornadovm/plan/SingleTokenForwardPlan.java | 4 ++--
.../BatchPrefillDecodeForwardPlanComponents.java | 6 +++---
.../components/PrefillDecodeForwardPlanComponents.java | 4 ++--
.../components/SingleTokenForwardPlanComponents.java | 4 ++--
.../components/activation/BatchDecodeActivation.java | 4 ++--
.../components/activation/BatchPrefillActivation.java | 4 ++--
.../components/fp16/DevstralFP16PlanComponents.java | 4 ++--
.../components/fp16/GraniteFP16PlanComponents.java | 4 ++--
.../plan/components/fp16/LlamaFP16PlanComponents.java | 10 +++++-----
.../components/fp16/MistralFP16PlanComponents.java | 4 ++--
.../plan/components/fp16/Phi3FP16PlanComponents.java | 4 ++--
.../plan/components/fp16/Qwen2FP16PlanComponents.java | 4 ++--
.../plan/components/fp16/Qwen3FP16PlanComponents.java | 4 ++--
.../components/q8_0/DevstralQ8_0PlanComponents.java | 4 ++--
.../components/q8_0/GraniteQ8_0PlanComponents.java | 4 ++--
.../plan/components/q8_0/LlamaQ8_0PlanComponents.java | 10 +++++-----
.../components/q8_0/MistralQ8_0PlanComponents.java | 4 ++--
.../plan/components/q8_0/Phi3Q8_0PlanComponents.java | 4 ++--
.../plan/components/q8_0/Qwen2Q8_0PlanComponents.java | 4 ++--
.../plan/components/q8_0/Qwen3Q8_0PlanComponents.java | 4 ++--
24 files changed, 55 insertions(+), 55 deletions(-)
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{ActivationGraph.java => ActivationTaskGraph.java} (76%)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
index d35f8252..b25a613a 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java
@@ -13,7 +13,7 @@
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-public class Activation extends AbstractLayer implements ActivationGraph {
+public class Activation extends AbstractLayer implements ActivationTaskGraph {
private final TaskGraph activationTaskGraph;
public Activation(String name, State state, Weights weights, Configuration config) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java
similarity index 76%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java
index c46479bc..8a751c18 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGraph.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationTaskGraph.java
@@ -4,12 +4,12 @@
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
/**
- * Common interface for single activation graphs (standard, decode, batch-prefill variants).
+ * Common interface for a single activation task graph (standard, decode, and batch-prefill variants).
*
* Implemented by {@link Activation} and custom activation wrappers used by
* {@link org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents}.
*/
-public interface ActivationGraph {
+public interface ActivationTaskGraph {
ImmutableTaskGraph getImmutableTaskGraph();
GridScheduler updateGridScheduler(GridScheduler scheduler);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
index 1246265c..55202738 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents;
@@ -42,7 +42,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC
List all = new ArrayList<>(2 * N + 3);
GridScheduler scheduler = new GridScheduler();
- ActivationGraph batchAct = components.batchPrefillActivation(batchSize);
+ ActivationTaskGraph batchAct = components.batchPrefillActivation(batchSize);
all.add(batchAct.getImmutableTaskGraph());
batchAct.updateGridScheduler(scheduler);
@@ -50,7 +50,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC
all.addAll(batchLayers.getLayerImmutableTaskGraphs());
batchLayers.updateGridScheduler(scheduler);
- ActivationGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID());
+ ActivationTaskGraph decodeAct = components.batchDecodeActivation(batchLayers.getLastLayerTaskGraphID());
all.add(decodeAct.getImmutableTaskGraph());
decodeAct.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
index 241f542b..4c91b7fe 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents;
import org.beehive.gpullama3.tornadovm.plan.layout.PrefillDecodeForwardTaskGraphLayout;
@@ -36,7 +36,7 @@ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents
List all = new ArrayList<>(N + 2);
GridScheduler scheduler = new GridScheduler();
- ActivationGraph act = components.decodeActivation();
+ ActivationTaskGraph act = components.decodeActivation();
all.add(act.getImmutableTaskGraph());
act.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
index bdc20a97..4db7ad86 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
import org.beehive.gpullama3.tornadovm.plan.layout.SingleTokenForwardTaskGraphLayout;
@@ -33,7 +33,7 @@ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents comp
List all = new ArrayList<>(N + 2);
GridScheduler scheduler = new GridScheduler();
- ActivationGraph act = components.standardActivation();
+ ActivationTaskGraph act = components.standardActivation();
all.add(act.getImmutableTaskGraph());
act.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
index 10595b27..38b03d30 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/BatchPrefillDecodeForwardPlanComponents.java
@@ -1,6 +1,6 @@
package org.beehive.gpullama3.tornadovm.plan.components;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -13,9 +13,9 @@
*/
public interface BatchPrefillDecodeForwardPlanComponents extends PrefillDecodeForwardPlanComponents {
- ActivationGraph batchPrefillActivation(int batchSize);
+ ActivationTaskGraph batchPrefillActivation(int batchSize);
- ActivationGraph batchDecodeActivation(String lastBatchLayerId);
+ ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId);
TransformerLayerTaskGraphs batchDecodeLayers();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
index 8f0b6cdd..89f53051 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm.plan.components;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
/**
@@ -12,7 +12,7 @@
*/
public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPlanComponents {
- ActivationGraph decodeActivation();
+ ActivationTaskGraph decodeActivation();
TransformerLayerTaskGraphs prefillDecodeLayers();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
index 6ecc2077..37ef3dc7 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm.plan.components;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
/**
@@ -13,7 +13,7 @@
*/
public interface SingleTokenForwardPlanComponents {
- ActivationGraph standardActivation();
+ ActivationTaskGraph standardActivation();
TransformerLayerTaskGraphs standardLayers();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
index d3617a6c..be6f8d02 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java
@@ -3,7 +3,7 @@
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -21,7 +21,7 @@
* converts the single-token embedding to FP32, then re-persists the KV cache so
* that decode layer 0 can consume it.
*/
-public class BatchDecodeActivation implements ActivationGraph {
+public class BatchDecodeActivation implements ActivationTaskGraph {
private final ImmutableTaskGraph itg;
private final int dim;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
index 73eec954..17b80cfa 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels;
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -22,7 +22,7 @@
* and runs a no-op {@code batchPassthrough} so TornadoVM has at least one task.
*
*/
-public class BatchPrefillActivation implements ActivationGraph {
+public class BatchPrefillActivation implements ActivationTaskGraph {
private final ImmutableTaskGraph itg;
private final boolean isQ8;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
index 66f3b9ad..862b8a92 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.DevstralFP16FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
@@ -28,7 +28,7 @@ public DevstralFP16PlanComponents(DevstralState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
index 5c8a6cfd..13f0b0f9 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
@@ -28,7 +28,7 @@ public GraniteFP16PlanComponents(GraniteState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new ActivationGranite("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
index 3fc093c3..23b63c8e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers;
@@ -45,19 +45,19 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) {
// ── Activations ───────────────────────────────────────────────────────────
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
- @Override public ActivationGraph decodeActivation() {
+ @Override public ActivationTaskGraph decodeActivation() {
return new Activation("decodeActivation", state, weights, config);
}
- @Override public ActivationGraph batchPrefillActivation(int batchSize) {
+ @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) {
return new BatchPrefillActivation(state, config, batchSize, false);
}
- @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) {
+ @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) {
return new BatchDecodeActivation(state, config, lastBatchLayerId, false);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
index 5cc01685..542b9827 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.MistralFP16FFNLayers;
@@ -28,7 +28,7 @@ public MistralFP16PlanComponents(LlamaState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
index b98eb47d..4a619211 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers;
@@ -28,7 +28,7 @@ public Phi3FP16PlanComponents(Phi3State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
index 2eaae7a2..0a359579 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers;
@@ -28,7 +28,7 @@ public Qwen2FP16PlanComponents(Qwen2State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
index 2ed13ab1..d18c67e6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer;
import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers;
@@ -28,7 +28,7 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
index 8302f7cb..2019beae 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.DevstralQ8_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
@@ -28,7 +28,7 @@ public DevstralQ8_0PlanComponents(DevstralState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
index 93f01f10..30c6c28b 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer;
@@ -28,7 +28,7 @@ public GraniteQ8_0PlanComponents(GraniteState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new ActivationGranite("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
index bd82a767..aa16bd52 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers;
@@ -50,19 +50,19 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) {
// ── Activations ───────────────────────────────────────────────────────────
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
- @Override public ActivationGraph decodeActivation() {
+ @Override public ActivationTaskGraph decodeActivation() {
return new Activation("decodeActivation", state, weights, config);
}
- @Override public ActivationGraph batchPrefillActivation(int batchSize) {
+ @Override public ActivationTaskGraph batchPrefillActivation(int batchSize) {
return new BatchPrefillActivation(state, config, batchSize, true);
}
- @Override public ActivationGraph batchDecodeActivation(String lastBatchLayerId) {
+ @Override public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) {
return new BatchDecodeActivation(state, config, lastBatchLayerId, true);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
index 4b8a78ec..b377a5f1 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.MistralQ8_0FFNLayers;
@@ -28,7 +28,7 @@ public MistralQ8_0PlanComponents(LlamaState state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
index 57cd9508..9c9735ac 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers;
@@ -28,7 +28,7 @@ public Phi3Q8_0PlanComponents(Phi3State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
index 534eedc5..d4a22184 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers;
@@ -28,7 +28,7 @@ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
index 3c615b39..a033849e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
import org.beehive.gpullama3.tornadovm.layers.Activation;
-import org.beehive.gpullama3.tornadovm.layers.ActivationGraph;
+import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer;
import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers;
@@ -28,7 +28,7 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) {
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
}
- @Override public ActivationGraph standardActivation() {
+ @Override public ActivationTaskGraph standardActivation() {
return new Activation("activationUpdate", state, weights, config);
}
From ea478f871985e5e150aa0a060480d08e2221cf5f Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 28 May 2026 15:25:40 +0300
Subject: [PATCH 4/6] Rename `AbstractFFNLayers` to
`AbstractTransformerLayerTaskGraphs` and `AbstractLogitsLayer` to
`AbstractLogitsTaskGraph`, updating all references to improve clarity and
align with naming conventions.
---
...ayer.java => AbstractLogitsTaskGraph.java} | 6 ++---
...> AbstractTransformerLayerTaskGraphs.java} | 24 +++++++++----------
.../layers/TransformerLayerTaskGraphs.java | 2 +-
.../type/fp16/DevstralFP16FFNLayers.java | 4 ++--
.../type/fp16/GraniteFP16FFNLayers.java | 4 ++--
.../layers/type/fp16/LlamaFP16FFNLayers.java | 4 ++--
.../layers/type/fp16/LogitsFP16Layer.java | 4 ++--
.../type/fp16/MistralFP16FFNLayers.java | 4 ++--
.../layers/type/fp16/Phi3FP16FFNLayers.java | 4 ++--
.../layers/type/fp16/Qwen2FP16FFNLayers.java | 4 ++--
.../layers/type/fp16/Qwen3FP16FFNLayers.java | 4 ++--
.../type/q8_0/DevstralQ8_0FFNLayers.java | 4 ++--
.../type/q8_0/GraniteQ8_0FFNLayers.java | 4 ++--
.../layers/type/q8_0/LlamaQ8_0FFNLayers.java | 4 ++--
.../layers/type/q8_0/LogitsQ8_0Layer.java | 4 ++--
.../type/q8_0/MistralQ8_0FFNLayers.java | 4 ++--
.../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 4 ++--
.../layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 4 ++--
.../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 4 ++--
.../plan/BatchPrefillDecodeForwardPlan.java | 4 ++--
.../plan/PrefillDecodeForwardPlan.java | 4 ++--
.../plan/SingleTokenForwardPlan.java | 4 ++--
.../PrefillDecodeForwardPlanComponents.java | 4 ++--
.../SingleTokenForwardPlanComponents.java | 4 ++--
.../fp16/DevstralFP16PlanComponents.java | 4 ++--
.../fp16/GraniteFP16PlanComponents.java | 4 ++--
.../fp16/LlamaFP16PlanComponents.java | 6 ++---
.../fp16/MistralFP16PlanComponents.java | 4 ++--
.../fp16/Phi3FP16PlanComponents.java | 4 ++--
.../fp16/Qwen2FP16PlanComponents.java | 4 ++--
.../fp16/Qwen3FP16PlanComponents.java | 4 ++--
.../q8_0/DevstralQ8_0PlanComponents.java | 4 ++--
.../q8_0/GraniteQ8_0PlanComponents.java | 4 ++--
.../q8_0/LlamaQ8_0PlanComponents.java | 6 ++---
.../q8_0/MistralQ8_0PlanComponents.java | 4 ++--
.../q8_0/Phi3Q8_0PlanComponents.java | 4 ++--
.../q8_0/Qwen2Q8_0PlanComponents.java | 4 ++--
.../q8_0/Qwen3Q8_0PlanComponents.java | 4 ++--
38 files changed, 88 insertions(+), 88 deletions(-)
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{AbstractLogitsLayer.java => AbstractLogitsTaskGraph.java} (86%)
rename src/main/java/org/beehive/gpullama3/tornadovm/layers/{AbstractFFNLayers.java => AbstractTransformerLayerTaskGraphs.java} (70%)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java
similarity index 86%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java
index afead80d..026c3d90 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsLayer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLogitsTaskGraph.java
@@ -9,20 +9,20 @@
import uk.ac.manchester.tornado.api.TaskGraph;
/**
- * Abstract base for all logits layers (final vocabulary projection step).
+ * Abstract base for all logits task graphs (final vocabulary projection step).
*
* Holds the shared fields and calls the protected buildLogitsTaskGraph() hook once
* during construction. Subclasses implement buildLogitsTaskGraph() to define the
* quantization-specific task sequence; Granite variants override it to swap in
* their scaled kernel.
*/
-public abstract class AbstractLogitsLayer extends AbstractLayer {
+public abstract class AbstractLogitsTaskGraph extends AbstractLayer {
protected final String lastTaskGraphID;
protected final SchedulerType schedulerType;
private final TaskGraph logitsTaskGraph;
- protected AbstractLogitsLayer(String name, State state, Weights weights, Configuration config,
+ protected AbstractLogitsTaskGraph(String name, State state, Weights weights, Configuration config,
String lastTaskGraphID, SchedulerType schedulerType) {
super(name, state, weights, config);
this.lastTaskGraphID = lastTaskGraphID;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java
similarity index 70%
rename from src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java
index d0b3656e..14a71e88 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractTransformerLayerTaskGraphs.java
@@ -11,14 +11,14 @@
import java.util.stream.IntStream;
/**
- * Abstract base class for all FFN (Feed-Forward Network) layer implementations.
+ * Abstract base class for all transformer-layer task graph implementations.
* Extended by model and quantization-specific subclasses that provide specific implementations.
*/
-public abstract class AbstractFFNLayers extends AbstractLayer implements TransformerLayerTaskGraphs {
+public abstract class AbstractTransformerLayerTaskGraphs extends AbstractLayer implements TransformerLayerTaskGraphs {
/**
- * List of TornadoVM {@link ImmutableTaskGraph}s, one per FFN layer.
- * Build by {@link #setupFFNLayers()}.
+ * List of TornadoVM {@link ImmutableTaskGraph}s, one per transformer layer.
+ * Built by {@link #setupFFNLayers()}.
*/
private List ffnLayerITGs;
protected final W weights;
@@ -27,7 +27,7 @@ public abstract class AbstractFFNLayers getFFNLayerImmutableTaskGraphs() {
}
/**
- * Returns the TaskGraph ID of the last FFN layer.
- * Used by the logits layer to chain its consumeFromDevice call.
+ * Returns the task graph ID of the last transformer layer.
+ * Used by the logits task graph to chain its consumeFromDevice call.
*/
public String getLastFFNLayerTaskGraphID() {
return lastFFNLayerTaskGraphID;
@@ -91,4 +91,4 @@ public String getLastFFNLayerTaskGraphID() {
protected boolean shouldUseFinalNormalization() {
return schedulerType == SchedulerType.NON_NVIDIA;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
index 245b3f40..e815574d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/TransformerLayerTaskGraphs.java
@@ -8,7 +8,7 @@
/**
* Interface for a group of N transformer-layer {@link uk.ac.manchester.tornado.api.TaskGraph} (standard or prefill-decode variants).
*
- * Implemented by {@link AbstractFFNLayers} and its subclasses.
+ * Implemented by {@link AbstractTransformerLayerTaskGraphs} and its subclasses.
*/
public interface TransformerLayerTaskGraphs {
List getFFNLayerImmutableTaskGraphs();
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
index e18c2a3a..0597a1c7 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -17,7 +17,7 @@
* FP16 FFN layers for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
-public class DevstralFP16FFNLayers extends AbstractFFNLayers {
+public class DevstralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
public DevstralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
index 46aad290..2d7234b8 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
@@ -9,13 +9,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class GraniteFP16FFNLayers extends AbstractFFNLayers {
+public class GraniteFP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
public GraniteFP16FFNLayers(String taskGraph, State state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
index 928954e9..21bae5a6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java
@@ -7,13 +7,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class LlamaFP16FFNLayers extends AbstractFFNLayers {
+public class LlamaFP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
public LlamaFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
index e3974ade..1295db09 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java
@@ -9,13 +9,13 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class LogitsFP16Layer extends AbstractLogitsLayer {
+public class LogitsFP16Layer extends AbstractLogitsTaskGraph {
public LogitsFP16Layer(String name, State state, Weights weights, Configuration config,
String lastTaskGraphID, SchedulerType schedulerType) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
index aa26bb69..cd03b2ba 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/MistralFP16FFNLayers.java
@@ -7,13 +7,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class MistralFP16FFNLayers extends AbstractFFNLayers {
+public class MistralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
public MistralFP16FFNLayers(String taskGraph, State state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) {
super(taskGraph, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
index 1d443ce3..e8ca1e54 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -22,7 +22,7 @@
*
* Works directly with Phi3State to access and mutate Phi3-specific state fields.
*/
-public class Phi3FP16FFNLayers extends AbstractFFNLayers {
+public class Phi3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed references to Phi3-specific state and config
private final Phi3State phi3State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
index 16e35a39..747b7213 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -25,7 +25,7 @@
*
* Works directly with Qwen2State to access and mutate Qwen2-specific state fields.
*/
-public class Qwen2FP16FFNLayers extends AbstractFFNLayers {
+public class Qwen2FP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed reference to Qwen2-specific state
private final Qwen2State qwen2State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
index 0530e7c0..3b298249 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -21,7 +21,7 @@
*
* Works directly with Qwen3State to access and mutate Qwen3-specific state fields like tempQcur and tempKcur.
*/
-public class Qwen3FP16FFNLayers extends AbstractFFNLayers {
+public class Qwen3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed reference to Qwen3-specific state
private final Qwen3State qwen3State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
index 0d717d2e..26686738 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
@@ -6,7 +6,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -16,7 +16,7 @@
* Q8_0 FFN layers for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
-public class DevstralQ8_0FFNLayers extends AbstractFFNLayers {
+public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
public DevstralQ8_0FFNLayers(String taskGraphName, State state, LlamaTornadoWeights weights, DevstralConfiguration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
index 1aca4714..78f834f4 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
@@ -7,13 +7,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class GraniteQ8_0FFNLayers extends AbstractFFNLayers {
+public class GraniteQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
index f5683904..8cab58ee 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java
@@ -6,13 +6,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class LlamaQ8_0FFNLayers extends AbstractFFNLayers {
+public class LlamaQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, LlamaConfiguration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
index bf66118a..725dfd24 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java
@@ -9,13 +9,13 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class LogitsQ8_0Layer extends AbstractLogitsLayer {
+public class LogitsQ8_0Layer extends AbstractLogitsTaskGraph {
public LogitsQ8_0Layer(String name, State state, Weights weights, Configuration config,
String lastTaskGraphID, SchedulerType schedulerType) {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
index 4a6f3151..68073c63 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/MistralQ8_0FFNLayers.java
@@ -6,13 +6,13 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
-public class MistralQ8_0FFNLayers extends AbstractFFNLayers {
+public class MistralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
public MistralQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, MistralConfiguration config, SchedulerType schedulerType) {
super(taskGraphName, state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
index 0ffe25d6..9739c14f 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -21,7 +21,7 @@
*
* Works directly with Phi3State to access and mutate Phi3-specific state fields.
*/
-public class Phi3Q8_0FFNLayers extends AbstractFFNLayers {
+public class Phi3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed reference to Phi3-specific state
private final Phi3State phi3State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
index 936c3573..ec96364e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
@@ -8,7 +8,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -28,7 +28,7 @@
*
* Works directly with Qwen2State to access and mutate Qwen2-specific state fields.
*/
-public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers {
+public class Qwen2Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed reference to Qwen2-specific state
private final Qwen2State qwen2State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
index 1696644d..9397dc2d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
@@ -7,7 +7,7 @@
import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory;
import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType;
-import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.AbstractTransformerLayerTaskGraphs;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.WorkerGrid;
@@ -25,7 +25,7 @@
* Works directly with Qwen3State to access and mutate Qwen3-specific state fields
* like tempQcur and tempKcur.
*/
-public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers {
+public class Qwen3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
// Typed reference to Qwen3-specific state
private final Qwen3State qwen3State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
index 55202738..0126c7ef 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/BatchPrefillDecodeForwardPlan.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm.plan;
import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -58,7 +58,7 @@ public BatchPrefillDecodeForwardPlan(Model model, BatchPrefillDecodeForwardPlanC
all.addAll(decodeLayers.getFFNLayerImmutableTaskGraphs());
decodeLayers.updateGridScheduler(scheduler);
- AbstractLogitsLayer logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID());
+ AbstractLogitsTaskGraph logits = components.decodeLogits(decodeLayers.getLastFFNLayerTaskGraphID());
all.add(logits.getImmutableTaskGraph());
logits.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
index 4c91b7fe..d526080d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/PrefillDecodeForwardPlan.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm.plan;
import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.plan.components.PrefillDecodeForwardPlanComponents;
@@ -44,7 +44,7 @@ public PrefillDecodeForwardPlan(Model model, PrefillDecodeForwardPlanComponents
all.addAll(layers.getFFNLayerImmutableTaskGraphs());
layers.updateGridScheduler(scheduler);
- AbstractLogitsLayer logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID());
+ AbstractLogitsTaskGraph logits = components.decodeLogits(layers.getLastFFNLayerTaskGraphID());
all.add(logits.getImmutableTaskGraph());
logits.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
index 4db7ad86..2eca71bd 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/SingleTokenForwardPlan.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm.plan;
import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents;
@@ -41,7 +41,7 @@ public SingleTokenForwardPlan(Model model, SingleTokenForwardPlanComponents comp
all.addAll(layers.getFFNLayerImmutableTaskGraphs());
layers.updateGridScheduler(scheduler);
- AbstractLogitsLayer logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID());
+ AbstractLogitsTaskGraph logits = components.standardLogits(layers.getLastFFNLayerTaskGraphID());
all.add(logits.getImmutableTaskGraph());
logits.updateGridScheduler(scheduler);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
index 89f53051..dad27f94 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/PrefillDecodeForwardPlanComponents.java
@@ -1,6 +1,6 @@
package org.beehive.gpullama3.tornadovm.plan.components;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -16,5 +16,5 @@ public interface PrefillDecodeForwardPlanComponents extends SingleTokenForwardPl
TransformerLayerTaskGraphs prefillDecodeLayers();
- AbstractLogitsLayer decodeLogits(String previousGraphId);
+ AbstractLogitsTaskGraph decodeLogits(String previousGraphId);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
index 37ef3dc7..596156f4 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/SingleTokenForwardPlanComponents.java
@@ -1,6 +1,6 @@
package org.beehive.gpullama3.tornadovm.plan.components;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -17,5 +17,5 @@ public interface SingleTokenForwardPlanComponents {
TransformerLayerTaskGraphs standardLayers();
- AbstractLogitsLayer standardLogits(String previousGraphId);
+ AbstractLogitsTaskGraph standardLogits(String previousGraphId);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
index 862b8a92..00f49f28 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/DevstralFP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public DevstralFP16PlanComponents(DevstralState state, Model model) {
return new DevstralFP16FFNLayers("devstralFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
index 13f0b0f9..5f47ec94 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/GraniteFP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public GraniteFP16PlanComponents(GraniteState state, Model model) {
return new GraniteFP16FFNLayers("graniteFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsGraniteFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
index 23b63c8e..350d1fd0 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
@@ -81,11 +81,11 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) {
// ── Logits layers ─────────────────────────────────────────────────────────
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
- @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) {
return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
index 542b9827..fdd4d43f 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/MistralFP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public MistralFP16PlanComponents(LlamaState state, Model model) {
return new MistralFP16FFNLayers("mistralFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
index 4a619211..fb21c4c2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Phi3FP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Phi3FP16PlanComponents(Phi3State state, Model model) {
return new Phi3FP16FFNLayers("phi3FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
index 0a359579..03fc3c88 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen2FP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Qwen2FP16PlanComponents(Qwen2State state, Model model) {
return new Qwen2FP16FFNLayers("qwen2FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
index d18c67e6..a11270bd 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) {
return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
index 2019beae..8589ee05 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/DevstralQ8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public DevstralQ8_0PlanComponents(DevstralState state, Model model) {
return new DevstralQ8_0FFNLayers("devstralFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
index 30c6c28b..03e0ccb9 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/GraniteQ8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public GraniteQ8_0PlanComponents(GraniteState state, Model model) {
return new GraniteQ8_0FFNLayers("graniteFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsGraniteQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
index aa16bd52..f7b8c0c2 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs;
@@ -86,11 +86,11 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) {
// ── Logits layers ─────────────────────────────────────────────────────────
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
- @Override public AbstractLogitsLayer decodeLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) {
return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType);
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
index b377a5f1..d4d91993 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/MistralQ8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public MistralQ8_0PlanComponents(LlamaState state, Model model) {
return new MistralQ8_0FFNLayers("mistralFFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
index 9c9735ac..7974a3fe 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Phi3Q8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Phi3Q8_0PlanComponents(Phi3State state, Model model) {
return new Phi3Q8_0FFNLayers("phi3FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
index d4a22184..fb739835 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen2Q8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Qwen2Q8_0PlanComponents(Qwen2State state, Model model) {
return new Qwen2Q8_0FFNLayers("qwen2FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
index a033849e..8ef1881d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java
@@ -4,7 +4,7 @@
import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
-import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsLayer;
+import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.Activation;
import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph;
import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs;
@@ -36,7 +36,7 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) {
return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType);
}
- @Override public AbstractLogitsLayer standardLogits(String previousGraphId) {
+ @Override public AbstractLogitsTaskGraph standardLogits(String previousGraphId) {
return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType);
}
}
From e20ebc5c35214980d90281cdca2f0b3f20b4d9ab Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 28 May 2026 15:28:42 +0300
Subject: [PATCH 5/6] Refactor FFN layer comments to `transformer-layer task
graphs`, aligning with updated naming conventions.
---
.../org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java | 2 +-
.../tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java | 2 +-
.../gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 2 +-
.../tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java | 2 +-
.../tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java | 2 +-
.../layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java | 2 +-
.../type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java | 2 +-
.../layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java | 2 +-
.../tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java | 2 +-
.../gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java | 2 +-
.../tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java | 2 +-
.../tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 2 +-
.../layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java | 2 +-
.../type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java | 2 +-
.../layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java | 2 +-
.../tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java | 2 +-
.../tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java | 2 +-
17 files changed, 17 insertions(+), 17 deletions(-)
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java
index f34f5777..f73a55a3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java
@@ -8,7 +8,7 @@
import uk.ac.manchester.tornado.api.TaskGraph;
/**
- * Abstract base class for Activations, FFN Layers, and Logits.
+ * Abstract base class for activation, transformer-layer, and logits task graphs.
*/
public abstract class AbstractLayer {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
index 0597a1c7..7b709797 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/DevstralFP16FFNLayers.java
@@ -14,7 +14,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * FP16 FFN layers for Devstral 2 models.
+ * FP16 transformer-layer task graphs for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
public class DevstralFP16FFNLayers extends AbstractTransformerLayerTaskGraphs {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
index e8ca1e54..b06fb942 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java
@@ -14,7 +14,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support.
+ * Phi3FP16FFNLayers: FP16 transformer-layer task graphs for Phi3 with Group Query Attention (GQA) support.
*
* Key Differences from Qwen2/Qwen3: - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices - Includes splitQKV task to separate combined buffer - Uses ropeRotationPhi3 kernel for
* position embeddings - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) - Includes splitGateUpAndSiLU task for FFN activation - Uses wDown for final FFN projection - No Q, K,
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
index 747b7213..4b3bbca3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java
@@ -17,7 +17,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Qwen2FP16FFNLayers: FP16 FFN layers for Qwen2 with Group Query Attention (GQA) support.
+ * Qwen2FP16FFNLayers: FP16 transformer-layer task graphs for Qwen2 with Group Query Attention (GQA) support.
*
* Key Differences from Qwen3: - No tempQcur/tempKcur fields in Qwen2State - Includes bias terms for Q, K, V projections - Standard GQA (no parallel offset RMSNorm) - Uses
* Qwen2Kernels::processHeadsFlashAttention for attention computation - Uses Qwen3Kernels::ropeRotation for position embeddings - Simpler matrix dimensions (uses config.dim() and config.kvDim()
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
index 3b298249..d0bee6a9 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java
@@ -14,7 +14,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support.
+ * Qwen3FP16FFNLayers: FP16 transformer-layer task graphs for Qwen3 with Group Query Attention (GQA) support.
*
* Key Differences from Llama: - Supports GQA with separate KV heads (nHeadKv) - Uses Qwen3Kernels for RMSNorm with parallel offset - Custom RoPE rotation for Qwen3 - Different attention computation
* due to GQA structure
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
index ea0220f6..28d8a1f6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersDecode.java
@@ -9,7 +9,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Decode FFN layers of the unified batched prefill-decode plan
+ * Decode transformer-layer task graphs of the unified batched prefill-decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* Overrides data-transfer declarations so that all cross-graph boundaries use
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
index c1167796..f3f628b3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/LlamaFP16FFNLayersPrefillDecode.java
@@ -9,7 +9,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Decode FFN layers for the single-token prefill/decode plan
+ * Decode transformer-layer task graphs for the single-token prefill/decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
*
Combines two concerns:
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
index 44ce4b35..dcfe626e 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/LlamaFP16LayersBatchPrefill.java
@@ -17,7 +17,7 @@
import java.util.stream.IntStream;
/**
- * Prefill FFN layers with batching for the unified batched prefill-decode plan
+ * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* One {@link ImmutableTaskGraph} per transformer layer, each processing
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
index 26686738..b8a9fafa 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/DevstralQ8_0FFNLayers.java
@@ -13,7 +13,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Q8_0 FFN layers for Devstral 2 models.
+ * Q8_0 transformer-layer task graphs for Devstral 2 models.
* Uses precomputed RoPE frequencies (YaRN scaling) instead of on-the-fly computation.
*/
public class DevstralQ8_0FFNLayers extends AbstractTransformerLayerTaskGraphs {
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
index 9739c14f..9831486d 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java
@@ -14,7 +14,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Phi3Q8_0FFNLayers: Q8_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support.
+ * Phi3Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Phi3 with Group Query Attention (GQA) support.
*
* Key Differences from Phi3FP16FFNLayers: - Uses Q8_0-quantized weights (getQuants() and getScales()) - Same attention and RoPE kernels as FP16 version - 8-bit integer computations with
* dequantization - 2x memory compression vs FP16 - Same combined QKV and gate/up FFN structure
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
index ec96364e..9ba6b363 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java
@@ -17,7 +17,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Qwen2Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support.
+ * Qwen2Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Qwen2 with Group Query Attention (GQA) support.
*
* Key Differences from Qwen2FP16FFNLayers:
* - Uses Q8_0-quantized weights (getQuants() and getScales())
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
index 9397dc2d..b27224bd 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java
@@ -14,7 +14,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Qwen3Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support.
+ * Qwen3Q8_0FFNLayers: Q8_0 transformer-layer task graphs for Qwen3 with Group Query Attention (GQA) support.
*
* Key Differences from Qwen3FP16FFNLayers:
* - Uses Q8_0-quantized weights (getQuants() and getScales())
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
index 0b2f13cb..d1ee5d14 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersDecode.java
@@ -9,7 +9,7 @@
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
/**
- * Decode FFN layers for the unified batched prefill-decode plan
+ * Decode transformer-layer task graphs for the unified batched prefill-decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanBatchPrefillDecode}).
*
* Layer 0 consumes the KV cache from device (passed through by the decode activation
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
index 388853f7..6c4c22b6 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/LlamaQ8_0FFNLayersPrefillDecode.java
@@ -8,7 +8,7 @@
import uk.ac.manchester.tornado.api.TaskGraph;
/**
- * Decode FFN layers for the single-token prefill/decode plan
+ * Decode transformer-layer task graphs for the single-token prefill/decode plan
* ({@link org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanPrefillDecode}).
*
*
Layer 0 delegates to {@link LlamaQ8_0FFNLayers#configureLayerDataTransfers} which
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
index 14c30462..c0904ef3 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/LlamaQ8_0LayersBatchPrefill.java
@@ -17,7 +17,7 @@
import java.util.stream.IntStream;
/**
- * Prefill FFN layers with batching for the unified batched prefill-decode plan (Q8_0).
+ * Batched-prefill transformer-layer task graphs for the unified batched prefill-decode plan (Q8_0).
*
*
Mirrors {@link org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.LlamaFP16LayersBatchPrefill}
* but uses Q8_0 kernels with inline dequantization. Key differences from the FP16 path:
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
index 350d1fd0..67392338 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/LlamaFP16PlanComponents.java
@@ -61,7 +61,7 @@ public LlamaFP16PlanComponents(LlamaState state, Model model) {
return new BatchDecodeActivation(state, config, lastBatchLayerId, false);
}
- // ── FFN layer groups ──────────────────────────────────────────────────────
+ // ── Transformer layer task graphs ──────────────────────────────────────────────────────
@Override public TransformerLayerTaskGraphs standardLayers() {
return new LlamaFP16FFNLayers("llamaFFN", state, weights, config, schedulerType);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
index f7b8c0c2..bf9a5f29 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/LlamaQ8_0PlanComponents.java
@@ -66,7 +66,7 @@ public LlamaQ8_0PlanComponents(LlamaState state, Model model) {
return new BatchDecodeActivation(state, config, lastBatchLayerId, true);
}
- // ── FFN layer groups ──────────────────────────────────────────────────────
+ // ── Transformer layer task graphs ──────────────────────────────────────────────────────
@Override public TransformerLayerTaskGraphs standardLayers() {
return new LlamaQ8_0FFNLayers("llamaFFN", state, weights, config, schedulerType);
From c7522d16af0e0baf2f3d35fa28917a7aad32ca17 Mon Sep 17 00:00:00 2001
From: Orion Papadakis
Date: Thu, 28 May 2026 15:42:24 +0300
Subject: [PATCH 6/6] [ci] Add workflows for Llama-3.2-1B-Instruct Q8_0
inference with prefill-decode and CUDA-graph variants
---
.github/workflows/build-and-run.yml | 101 ++++++++++++++++++++++++++++
1 file changed, 101 insertions(+)
diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml
index a17c9228..46dc3bf4 100644
--- a/.github/workflows/build-and-run.yml
+++ b/.github/workflows/build-and-run.yml
@@ -385,6 +385,107 @@ jobs:
flags="" \
prompt="Say hello"
+ - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode
+ env:
+ JAVA_TOOL_OPTIONS: >-
+ -Dllama.metrics.format=json
+ -Dllama.metrics.output=file
+ -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.json
+ run: |
+ cd ${{ github.workspace }}
+ export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
+ ./llama-tornado --gpu --${{ matrix.backend.name }} \
+ --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
+ --prompt "Say hello" \
+ --with-prefill-decode
+ python3 scripts/write_metrics_sidecar.py \
+ --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.meta.json" \
+ backend="${{ matrix.backend.name }}" \
+ task=llama-inference \
+ model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
+ model=Llama-3.2-1B-Instruct \
+ quantization=Q8_0 \
+ configuration=prefill-decode \
+ "flags=--with-prefill-decode" \
+ prompt="Say hello"
+
+ - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode
+ env:
+ JAVA_TOOL_OPTIONS: >-
+ -Dllama.metrics.format=json
+ -Dllama.metrics.output=file
+ -Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.json
+ run: |
+ cd ${{ github.workspace }}
+ export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
+ ./llama-tornado --gpu --${{ matrix.backend.name }} \
+ --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
+ --prompt "Say hello" \
+ --with-prefill-decode --batch-prefill-size 32
+ python3 scripts/write_metrics_sidecar.py \
+ --out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.meta.json" \
+ backend="${{ matrix.backend.name }}" \
+ task=llama-inference \
+ model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
+ model=Llama-3.2-1B-Instruct \
+ quantization=Q8_0 \
+ configuration=batch-prefill-decode \
+ "flags=--with-prefill-decode --batch-prefill-size 32" \
+ prompt="Say hello"
+
+ # ── PTX-only: CUDA-graph variants ────────────────────────────────────────
+ - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode-CUDA-Graphs
+ if: matrix.backend.name == 'ptx'
+ env:
+ JAVA_TOOL_OPTIONS: >-
+ -Dllama.metrics.format=json
+ -Dllama.metrics.output=file
+ -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.json
+ run: |
+ cd ${{ github.workspace }}
+ export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
+ ./llama-tornado --gpu --ptx \
+ --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
+ --prompt "Say hello" \
+ --with-prefill-decode \
+ --cuda-graphs
+ python3 scripts/write_metrics_sidecar.py \
+ --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.meta.json" \
+ backend=ptx \
+ task=llama-inference \
+ model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
+ model=Llama-3.2-1B-Instruct \
+ quantization=Q8_0 \
+ configuration=prefill-decode-cuda-graphs \
+ "flags=--with-prefill-decode --cuda-graphs" \
+ prompt="Say hello"
+
+ - name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs
+ if: matrix.backend.name == 'ptx'
+ env:
+ JAVA_TOOL_OPTIONS: >-
+ -Dllama.metrics.format=json
+ -Dllama.metrics.output=file
+ -Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.json
+ run: |
+ cd ${{ github.workspace }}
+ export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
+ ./llama-tornado --gpu --ptx \
+ --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
+ --prompt "Say hello" \
+ --with-prefill-decode --batch-prefill-size 32 \
+ --cuda-graphs
+ python3 scripts/write_metrics_sidecar.py \
+ --out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.meta.json" \
+ backend=ptx \
+ task=llama-inference \
+ model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
+ model=Llama-3.2-1B-Instruct \
+ quantization=Q8_0 \
+ configuration=batch-prefill-decode-cuda-graphs \
+ "flags=--with-prefill-decode --batch-prefill-size 32 --cuda-graphs" \
+ prompt="Say hello"
+
- name: Q8 - Run Qwen3-0.6B-Q8_0.gguf
env:
JAVA_TOOL_OPTIONS: >-