Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class LlamaRunner:
[
"-cp",
self._find_llama_jar(),
"org.beehive.gpullama3.LlamaApp",
args.main_class,
]
)
cmd.extend(module_config)
Expand Down Expand Up @@ -246,6 +246,9 @@ class LlamaRunner:
elif args.instruct:
llama_args.append("--instruct")

if args.tool_demo:
llama_args.append("--tool-demo")

Comment on lines +250 to +251
return cmd + llama_args

def run(self, args: argparse.Namespace) -> int:
Expand Down Expand Up @@ -527,6 +530,16 @@ def create_parser() -> argparse.ArgumentParser:

# Advanced options
advanced_group = parser.add_argument_group("Advanced Options")
advanced_group.add_argument(
"--main-class",
default="org.beehive.gpullama3.cli.LlamaTornadoCli",
help="Java main class to run (default: LlamaTornadoCli)",
)
advanced_group.add_argument(
"--tool-demo",
action="store_true",
help="Run the tool calling demo (requires a LLaMA 3.1 or Qwen3 model)",
)
advanced_group.add_argument(
"--opencl-flags",
default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only",
Expand Down
15 changes: 12 additions & 3 deletions llamaTornado
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ record Config(
boolean printBytecodes, boolean threads, boolean printKernel,
boolean fullDump, boolean verboseInit,
boolean showCommand, boolean executeAfterShow,
String openclFlags, int maxWaitEvents, boolean verbose
String openclFlags, int maxWaitEvents, boolean verbose,
boolean toolDemo
) {}

Config parseArgs(String[] args) {
Expand Down Expand Up @@ -50,6 +51,7 @@ Config parseArgs(String[] args) {
String openclFlags = "-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only";
int maxWaitEvents = 32000;
boolean verbose = false;
boolean toolDemo = false;

for (int i = 0; i < args.length; i++) {
switch (args[i]) {
Expand Down Expand Up @@ -84,6 +86,7 @@ Config parseArgs(String[] args) {
case "--opencl-flags" -> openclFlags = args[++i];
case "--max-wait-events" -> maxWaitEvents = Integer.parseInt(args[++i]);
case "--verbose", "-v" -> verbose = true;
case "--tool-demo" -> toolDemo = true;
default -> {
System.err.println("Unknown option: " + args[i]);
System.exit(1);
Expand All @@ -104,7 +107,8 @@ Config parseArgs(String[] args) {
return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens,
stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax,
debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump,
verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose);
verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose,
toolDemo);
}

void printUsage() {
Expand Down Expand Up @@ -151,6 +155,7 @@ void printUsage() {
--show-command Display the full Java command
--execute-after-show Execute after showing command
--verbose, -v Verbose output
--tool-demo Run the tool calling demo

-help Show this help
-version Show version
Expand Down Expand Up @@ -262,7 +267,10 @@ List<String> buildCommand(Config cfg, String javaHome, String tornadoSdk, String
}
}

cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp"));
var mainClass = cfg.toolDemo()
? "org.beehive.gpullama3.ToolCallingDemo"
: "org.beehive.gpullama3.LlamaApp";
cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), mainClass));

// LLaMA arguments
cmd.addAll(List.of(
Expand All @@ -279,6 +287,7 @@ List<String> buildCommand(Config cfg, String javaHome, String tornadoSdk, String
if (cfg.systemPrompt() != null) cmd.addAll(List.of("-sp", cfg.systemPrompt()));
if (cfg.interactive()) cmd.add("--interactive");
else if (cfg.instruct()) cmd.add("--instruct");
// --tool-demo is handled by main class selection above, not passed to Java

return cmd;
}
Expand Down
102 changes: 102 additions & 0 deletions src/main/java/org/beehive/gpullama3/ToolCallingApp.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package org.beehive.gpullama3;

import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.format.ToolCallExtract;
import org.beehive.gpullama3.model.format.ToolCallParserUtils;
import org.beehive.gpullama3.tools.*;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler;
import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel;

/**
* Standalone tool-calling entry point for GPULlama3.java.
*
* Uses the same command-line flags as {@link LlamaApp} plus the tool-calling loop
* provided by {@link ToolCallingSession}. A {@code listDirectory} tool is registered
* as a built-in demo; extend {@link #buildRegistry()} to add more tools.
*
* <pre>
* java @$TORNADOVM_HOME/tornado-argfile \
* --add-modules jdk.incubator.vector --enable-preview \
* -cp gpu-llama3.jar org.beehive.gpullama3.ToolCallingApp \
* --model /path/to/model.gguf \
* --prompt "Show me what is inside /tmp" \
* --use-tornadovm true
* </pre>
*/
public class ToolCallingApp {

public static void main(String[] args) throws IOException {
Options options = Options.parseOptions(args);
Model model = loadModel(options);
Sampler sampler = createSampler(model, options);

ToolRegistry registry = buildRegistry();
ToolCallingOptions tcOptions = ToolCallingOptions.from(options);
ToolCallingSession session = new ToolCallingSession(model, sampler, registry, tcOptions);

ToolCallingResult result = session.run(options.systemPrompt(), options.prompt());

if (!tcOptions.verbose()) {
// verbose=false means ToolCallingSession did not stream tokens — print the answer now
System.out.println(result.finalAnswer());
}
}

// ── Tool registry ─────────────────────────────────────────────────────────

private static ToolRegistry buildRegistry() {
ToolRegistry registry = new ToolRegistry();
registry.register(listDirectoryDefinition(), ToolCallingApp::listDirectory);
return registry;
}

private static ToolDefinition listDirectoryDefinition() {
return new ToolDefinition(
"listDirectory",
"Lists the contents of a directory on the local filesystem. " +
"Returns file names and metadata. " +
"Use this when the user asks what is inside a directory or folder.",
"""
{"type":"object","properties":{"path":{"type":"string",\
"description":"Absolute path of the directory to list, e.g. /tmp or /home/orion"}},\
"required":["path"]}""");
}

private static ToolResult listDirectory(ToolCallExtract call) {
String path = ToolCallParserUtils.extractStringValue(call.argumentsJson(), "path");
if (path == null || path.isBlank()) {
return ToolResult.failure("listDirectory",
"Could not extract 'path' from arguments: " + call.argumentsJson());
}
Path dir = Path.of(path);
if (!dir.isAbsolute())
return ToolResult.failure("listDirectory", "Path must be absolute. Got: " + path);
if (!Files.exists(dir))
return ToolResult.failure("listDirectory", "Path does not exist: " + path);
if (!Files.isDirectory(dir))
return ToolResult.failure("listDirectory", "Not a directory: " + path);
try {
StringBuilder sb = new StringBuilder("Contents of ").append(path).append(":\n");
try (var stream = Files.list(dir)) {
stream.sorted().forEach(entry -> {
boolean isDir = Files.isDirectory(entry);
long size = 0;
try { if (!isDir) size = Files.size(entry); } catch (IOException ignored) {}
sb.append(isDir ? "dir: " : "file: ")
.append(entry.getFileName())
.append(isDir ? "" : ", " + size + " bytes")
.append("\n");
});
}
return ToolResult.success("listDirectory", sb.toString());
} catch (IOException e) {
return ToolResult.failure("listDirectory", e.getMessage());
}
}
}
103 changes: 103 additions & 0 deletions src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer;

import java.util.List;
import java.util.Optional;
import java.util.Set;

public interface ChatFormat {
Expand Down Expand Up @@ -36,6 +37,108 @@ default ChatTokens chatTokens() {

Set<Integer> getStopTokens();

/**
* Returns plain text to append to the system message content when tools are available.
* Used by formats that inject tool definitions into the <em>system</em> message.
*
* <p>Formats that inject tools into the <em>user</em> message instead should override
* {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and
* {@link #toolFirstUserMessagePrefix(String)} rather than this method.
*
* @param toolsJson JSON array of tool definitions
*/
default String toolSystemPromptSuffix(String toolsJson) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Returns {@code true} when this format injects tool definitions into the
* <em>first user message</em> instead of the system message.
*
* <p>When this returns {@code true}, callers should:
* <ol>
* <li>Prepend {@link #toolSystemMessagePrefix()} to the system message content.</li>
* <li>Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.</li>
* </ol>
* When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to
* the system message as before.
*/
default boolean injectsToolsInUserMessage() {
return false;
}

/**
* Returns text to <em>prepend</em> to the system message content when tools are active
* and {@link #injectsToolsInUserMessage()} is {@code true}.
* Default: empty string (no prefix).
*/
default String toolSystemMessagePrefix() {
return "";
}

/**
* Returns the preamble to <em>prepend</em> to the first user message when
* {@link #injectsToolsInUserMessage()} is {@code true}.
* The preamble should include the tool definitions and usage instructions.
*
* @param toolsJson JSON array of tool definitions
*/
default String toolFirstUserMessagePrefix(String toolsJson) {
return "";
}

/**
* Re-encodes a prior assistant tool-call turn into the conversation token stream.
* Used when replaying multi-turn history that contains a previous tool call.
*
* @param toolCall the tool call to encode (name + raw arguments JSON)
*/
default List<Integer> encodeToolCallAssistantTurn(ToolCallExtract toolCall) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Encodes a tool execution result message in the model-native format.
*
* @param toolCallId the ID of the originating tool call (may be ignored by some formats)
* @param toolName the name of the tool that was called
* @param result the result content string
*/
default List<Integer> encodeToolResultTurn(String toolCallId, String toolName, String result) {
throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName());
}

/**
* Detects and extracts a tool call from fully decoded model response text.
* Returns {@link Optional#empty()} when the response is a plain text answer.
*
* @param responseText the fully decoded response from the model
*/
default Optional<ToolCallExtract> extractToolCall(String responseText) {
return Optional.empty();
}

/**
* Extracts ALL tool calls from a response. Models may emit multiple
* {@code <tool_call>} blocks in a single turn (batch tool calls).
* The default delegates to {@link #extractToolCall} for formats that
* do not support batch calls.
*
* @param responseText the fully decoded response from the model
*/
default List<ToolCallExtract> extractAllToolCalls(String responseText) {
return extractToolCall(responseText).map(List::of).orElse(List.of());
}

/**
* Stop tokens to use when tool calling is enabled.
* Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>})
* when emitting a tool call instead of a regular response.
*/
default Set<Integer> getToolAwareStopTokens() {
return getStopTokens();
}

record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) {
}

Expand Down
Loading