diff --git a/llama-tornado b/llama-tornado index 1d6c3d23..61439dad 100755 --- a/llama-tornado +++ b/llama-tornado @@ -191,7 +191,7 @@ class LlamaRunner: [ "-cp", self._find_llama_jar(), - "org.beehive.gpullama3.LlamaApp", + args.main_class, ] ) cmd.extend(module_config) @@ -246,6 +246,9 @@ class LlamaRunner: elif args.instruct: llama_args.append("--instruct") + if args.tool_demo: + llama_args.append("--tool-demo") + return cmd + llama_args def run(self, args: argparse.Namespace) -> int: @@ -527,6 +530,16 @@ def create_parser() -> argparse.ArgumentParser: # Advanced options advanced_group = parser.add_argument_group("Advanced Options") + advanced_group.add_argument( + "--main-class", + default="org.beehive.gpullama3.cli.LlamaTornadoCli", + help="Java main class to run (default: LlamaTornadoCli)", + ) + advanced_group.add_argument( + "--tool-demo", + action="store_true", + help="Run the tool calling demo (requires a LLaMA 3.1 or Qwen3 model)", + ) advanced_group.add_argument( "--opencl-flags", default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", diff --git a/llamaTornado b/llamaTornado index 6cc303d5..6f5328f5 100755 --- a/llamaTornado +++ b/llamaTornado @@ -17,7 +17,8 @@ record Config( boolean printBytecodes, boolean threads, boolean printKernel, boolean fullDump, boolean verboseInit, boolean showCommand, boolean executeAfterShow, - String openclFlags, int maxWaitEvents, boolean verbose + String openclFlags, int maxWaitEvents, boolean verbose, + boolean toolDemo ) {} Config parseArgs(String[] args) { @@ -50,6 +51,7 @@ Config parseArgs(String[] args) { String openclFlags = "-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only"; int maxWaitEvents = 32000; boolean verbose = false; + boolean toolDemo = false; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -84,6 +86,7 @@ Config parseArgs(String[] args) { case "--opencl-flags" -> openclFlags = args[++i]; case "--max-wait-events" -> maxWaitEvents = Integer.parseInt(args[++i]); case "--verbose", "-v" -> verbose = true; + case "--tool-demo" -> toolDemo = true; default -> { System.err.println("Unknown option: " + args[i]); System.exit(1); @@ -104,7 +107,8 @@ Config parseArgs(String[] args) { return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens, stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump, - verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose); + verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose, + toolDemo); } void printUsage() { @@ -151,6 +155,7 @@ void printUsage() { --show-command Display the full Java command --execute-after-show Execute after showing command --verbose, -v Verbose output + --tool-demo Run the tool calling demo -help Show this help -version Show version @@ -262,7 +267,10 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String } } - cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp")); + var mainClass = cfg.toolDemo() + ? "org.beehive.gpullama3.ToolCallingDemo" + : "org.beehive.gpullama3.LlamaApp"; + cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), mainClass)); // LLaMA arguments cmd.addAll(List.of( @@ -279,6 +287,7 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String if (cfg.systemPrompt() != null) cmd.addAll(List.of("-sp", cfg.systemPrompt())); if (cfg.interactive()) cmd.add("--interactive"); else if (cfg.instruct()) cmd.add("--instruct"); + // --tool-demo is handled by main class selection above, not passed to Java return cmd; } diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingApp.java b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java new file mode 100644 index 00000000..7729d695 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java @@ -0,0 +1,102 @@ +package org.beehive.gpullama3; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.model.format.ToolCallParserUtils; +import org.beehive.gpullama3.tools.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +/** + * Standalone tool-calling entry point for GPULlama3.java. + * + * Uses the same command-line flags as {@link LlamaApp} plus the tool-calling loop + * provided by {@link ToolCallingSession}. A {@code listDirectory} tool is registered + * as a built-in demo; extend {@link #buildRegistry()} to add more tools. + * + *
+ * java @$TORNADOVM_HOME/tornado-argfile \
+ *     --add-modules jdk.incubator.vector --enable-preview \
+ *     -cp gpu-llama3.jar org.beehive.gpullama3.ToolCallingApp \
+ *     --model /path/to/model.gguf \
+ *     --prompt "Show me what is inside /tmp" \
+ *     --use-tornadovm true
+ * 
+ */ +public class ToolCallingApp { + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(args); + Model model = loadModel(options); + Sampler sampler = createSampler(model, options); + + ToolRegistry registry = buildRegistry(); + ToolCallingOptions tcOptions = ToolCallingOptions.from(options); + ToolCallingSession session = new ToolCallingSession(model, sampler, registry, tcOptions); + + ToolCallingResult result = session.run(options.systemPrompt(), options.prompt()); + + if (!tcOptions.verbose()) { + // verbose=false means ToolCallingSession did not stream tokens — print the answer now + System.out.println(result.finalAnswer()); + } + } + + // ── Tool registry ───────────────────────────────────────────────────────── + + private static ToolRegistry buildRegistry() { + ToolRegistry registry = new ToolRegistry(); + registry.register(listDirectoryDefinition(), ToolCallingApp::listDirectory); + return registry; + } + + private static ToolDefinition listDirectoryDefinition() { + return new ToolDefinition( + "listDirectory", + "Lists the contents of a directory on the local filesystem. " + + "Returns file names and metadata. " + + "Use this when the user asks what is inside a directory or folder.", + """ + {"type":"object","properties":{"path":{"type":"string",\ + "description":"Absolute path of the directory to list, e.g. /tmp or /home/orion"}},\ + "required":["path"]}"""); + } + + private static ToolResult listDirectory(ToolCallExtract call) { + String path = ToolCallParserUtils.extractStringValue(call.argumentsJson(), "path"); + if (path == null || path.isBlank()) { + return ToolResult.failure("listDirectory", + "Could not extract 'path' from arguments: " + call.argumentsJson()); + } + Path dir = Path.of(path); + if (!dir.isAbsolute()) + return ToolResult.failure("listDirectory", "Path must be absolute. Got: " + path); + if (!Files.exists(dir)) + return ToolResult.failure("listDirectory", "Path does not exist: " + path); + if (!Files.isDirectory(dir)) + return ToolResult.failure("listDirectory", "Not a directory: " + path); + try { + StringBuilder sb = new StringBuilder("Contents of ").append(path).append(":\n"); + try (var stream = Files.list(dir)) { + stream.sorted().forEach(entry -> { + boolean isDir = Files.isDirectory(entry); + long size = 0; + try { if (!isDir) size = Files.size(entry); } catch (IOException ignored) {} + sb.append(isDir ? "dir: " : "file: ") + .append(entry.getFileName()) + .append(isDir ? "" : ", " + size + " bytes") + .append("\n"); + }); + } + return ToolResult.success("listDirectory", sb.toString()); + } catch (IOException e) { + return ToolResult.failure("listDirectory", e.getMessage()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..c1ceb267 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; import java.util.List; +import java.util.Optional; import java.util.Set; public interface ChatFormat { @@ -36,6 +37,108 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns plain text to append to the system message content when tools are available. + * Used by formats that inject tool definitions into the system message. + * + *

Formats that inject tools into the user message instead should override + * {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and + * {@link #toolFirstUserMessagePrefix(String)} rather than this method. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolSystemPromptSuffix(String toolsJson) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Returns {@code true} when this format injects tool definitions into the + * first user message instead of the system message. + * + *

When this returns {@code true}, callers should: + *

    + *
  1. Prepend {@link #toolSystemMessagePrefix()} to the system message content.
  2. + *
  3. Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.
  4. + *
+ * When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to + * the system message as before. + */ + default boolean injectsToolsInUserMessage() { + return false; + } + + /** + * Returns text to prepend to the system message content when tools are active + * and {@link #injectsToolsInUserMessage()} is {@code true}. + * Default: empty string (no prefix). + */ + default String toolSystemMessagePrefix() { + return ""; + } + + /** + * Returns the preamble to prepend to the first user message when + * {@link #injectsToolsInUserMessage()} is {@code true}. + * The preamble should include the tool definitions and usage instructions. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolFirstUserMessagePrefix(String toolsJson) { + return ""; + } + + /** + * Re-encodes a prior assistant tool-call turn into the conversation token stream. + * Used when replaying multi-turn history that contains a previous tool call. + * + * @param toolCall the tool call to encode (name + raw arguments JSON) + */ + default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Encodes a tool execution result message in the model-native format. + * + * @param toolCallId the ID of the originating tool call (may be ignored by some formats) + * @param toolName the name of the tool that was called + * @param result the result content string + */ + default List encodeToolResultTurn(String toolCallId, String toolName, String result) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Detects and extracts a tool call from fully decoded model response text. + * Returns {@link Optional#empty()} when the response is a plain text answer. + * + * @param responseText the fully decoded response from the model + */ + default Optional extractToolCall(String responseText) { + return Optional.empty(); + } + + /** + * Extracts ALL tool calls from a response. Models may emit multiple + * {@code } blocks in a single turn (batch tool calls). + * The default delegates to {@link #extractToolCall} for formats that + * do not support batch calls. + * + * @param responseText the fully decoded response from the model + */ + default List extractAllToolCalls(String responseText) { + return extractToolCall(responseText).map(List::of).orElse(List.of()); + } + + /** + * Stop tokens to use when tool calling is enabled. + * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) + * when emitting a tool call instead of a regular response. + */ + default Set getToolAwareStopTokens() { + return getStopTokens(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index c98a72c9..421a5513 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class LlamaChatFormat implements ChatFormat { @@ -17,6 +18,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfTurn; protected final int endOfText; protected final int endOfMessage; + protected final int pythonTag; protected final Set stopTokens; public LlamaChatFormat(Tokenizer tokenizer) { @@ -28,6 +30,7 @@ public LlamaChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.pythonTag = specialTokens.getOrDefault("<|python_tag|>", -1); // only in 3.1 this.stopTokens = Set.of(endOfText, endOfTurn); } @@ -71,4 +74,115 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message + * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). + * The system message receives only an environment prefix; the tools and usage instructions + * go in the user turn. + */ + @Override + public boolean injectsToolsInUserMessage() { + return true; + } + + /** + * System-message prefix that signals tool availability to Llama 3.2. + * Matches the template's {@code "Environment: ipython\n"} line. + */ + @Override + public String toolSystemMessagePrefix() { + return "Environment: ipython\n\n"; + } + + /** + * Prepends tool definitions and usage instructions to the first user message, + * matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}). + * + *

Format mirrors: + *

+     * Given the following functions, please respond with a JSON for a function call
+     * with its proper arguments that best answers the given prompt.
+     *
+     * Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
+     * Do not use variables.
+     *
+     * {toolsJson}
+     *
+     * 
+ */ + @Override + public String toolFirstUserMessagePrefix(String toolsJson) { + return "Given the following functions, please respond with a JSON for a function call " + + "with its proper arguments that best answers the given prompt.\n\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of " + + "argument name and its value}. Do not use variables.\n\n" + + toolsJson + "\n\n"; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history using the + * Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}. + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + // Preserve the <|python_tag|> prefix used by LLaMA 3.1/3.2 for tool calls so that + // replayed history looks identical to what the model originally generated. + if (pythonTag != -1) { + tokens.add(pythonTag); + } + String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList(json)); + // LLaMA 3.1 ends tool-call turns with <|eom_id|>; fall back to <|eot_id|> for 3.2. + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + + /** + * Encodes a tool result using the LLaMA "ipython" role. + * Format: {@code <|start_header_id|>ipython<|end_header_id|>\nresult<|eot_id|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList("ipython")); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + tokens.addAll(tokenizer.encodeAsList(result)); + tokens.add(endOfTurn); + return tokens; + } + + /** + * Detects a tool call in the decoded response text. + * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), + * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback + * for smaller models. Delegates to {@link ToolCallParserUtils#parseLlamaResponse}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseLlamaResponse(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } + + /** + * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. + * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. + */ + @Override + public Set getToolAwareStopTokens() { + if (endOfMessage != -1) { + return Set.of(endOfText, endOfTurn, endOfMessage); + } + return stopTokens; + } + } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index b6d2e798..35930ff1 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -129,4 +129,73 @@ public Set getStopTokens() { return stopTokens; } + + // ── Tool calling ────────────────────────────────────────────────────────── + + /** + * Qwen3 tool calling system prompt suffix. + * Appended to the system message; instructs the model to wrap tool calls in + * {@code } XML tags. + */ + @Override + public String toolSystemPromptSuffix(String toolsJson) { + return "\n\n# Tools\n\n" + + "You may call one or more functions to assist with the user query.\n\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + toolsJson + + "\n\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history. + * Format: {@code <|im_start|>assistant\n\nJSON\n<|im_end|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + String json = "{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes a tool result using the Qwen3 "tool" role. + * Format: {@code <|im_start|>tool\nresult<|im_end|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); + tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Detects a tool call enclosed in {@code } tags. + * Delegates to {@link ToolCallParserUtils#parseQwen3Response}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseQwen3Response(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java new file mode 100644 index 00000000..335c2176 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -0,0 +1,11 @@ +package org.beehive.gpullama3.model.format; + +/** + * Represents a single tool call extracted from a model response. + * Contains the raw strings — JSON parsing of arguments is left to the caller. + * + * @param name the tool/function name to invoke + * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + */ +public record ToolCallExtract(String name, String argumentsJson) { +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java new file mode 100644 index 00000000..146148fc --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -0,0 +1,215 @@ +package org.beehive.gpullama3.model.format; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Pure-string tool-call extraction for Llama and Qwen3 response formats. + * + * All methods are stateless and do not require any model or tokenizer instance, + * making them directly unit-testable. + */ +public final class ToolCallParserUtils { + + private ToolCallParserUtils() {} + + // ── Llama ───────────────────────────────────────────────────────────────── + + /** + * Extracts a tool call from a LLaMA 3.1 or 3.2 model response. + * + * Recognised formats: + * 1. {@code <|python_tag|>{"name":…,"parameters":{…}}} — LLaMA 3.1 native, also accepted by 3.2 + * 2. Raw JSON with {@code "arguments"} key instead of {@code "parameters"} — LLaMA 3.2 instruction format + * 3. Raw JSON object optionally inside markdown code fences — fallback for models that + * follow system-prompt instructions but omit the special-token prefix + * + * Both {@code "parameters"} and {@code "arguments"} are tried so a single implementation + * handles the 3.1 and 3.2 variants transparently. + */ + public static Optional parseLlamaResponse(String responseText) { + // 1. Native LLaMA 3.1 format: <|python_tag|>{...} + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseLlamaJson(json); + } + + // 2. LLaMA 3.2 format: ... + int tcStart = responseText.indexOf(""); + int tcEnd = responseText.lastIndexOf(""); + if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { + String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); + return parseLlamaJson(json); + } + // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag + if (tcStart != -1 && tcEnd == -1) { + String json = responseText.substring(tcStart + "".length()).strip(); + return parseLlamaJson(json); + } + + // 3. Fallback: raw JSON, possibly inside markdown code fences + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseLlamaJson(stripped); + } + + return Optional.empty(); + } + + /** + * Parses a LLaMA-style tool call JSON object. + * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, + * and {@code {"name":…,"arguments":{…}}}. + */ + private static Optional parseLlamaJson(String json) { + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Qwen3 ───────────────────────────────────────────────────────────────── + + /** + * Extracts ALL tool calls from a response that may contain multiple + * {@code } blocks (Llama 3.2 and Qwen3 batch calls). + * + * Falls back to the raw-JSON single-call path if no tags are found. + * Returns an empty list when the response contains no tool calls. + */ + public static List parseAllToolCalls(String responseText) { + List calls = new java.util.ArrayList<>(); + + // <|python_tag|> (Llama 3.1) — single call by definition + int pythonIdx = responseText.indexOf("<|python_tag|>"); + if (pythonIdx != -1) { + parseLlamaJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + .ifPresent(calls::add); + return calls; + } + + // Scan for all blocks + int searchFrom = 0; + while (true) { + int start = responseText.indexOf("", searchFrom); + if (start == -1) break; + int end = responseText.indexOf("", start); + String json; + if (end != -1) { + json = responseText.substring(start + "".length(), end).strip(); + searchFrom = end + "".length(); + } else { + // Unclosed tag — model stopped before writing the closing tag + json = responseText.substring(start + "".length()).strip(); + searchFrom = responseText.length(); + } + parseLlamaJson(json).ifPresent(calls::add); + if (end == -1) break; + } + + // Raw JSON fallback (no tags at all) + if (calls.isEmpty()) { + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + parseLlamaJson(stripped).ifPresent(calls::add); + } + } + + return calls; + } + + /** + * Extracts a tool call enclosed in {@code } tags + * as produced by Qwen3 models. + */ + public static Optional parseQwen3Response(String responseText) { + int start = responseText.indexOf(""); + int end = responseText.lastIndexOf(""); + if (start == -1) return Optional.empty(); + + String json = (end != -1 && end > start) + ? responseText.substring(start + "".length(), end).strip() + : responseText.substring(start + "".length()).strip(); + + String name = extractStringValue(json, "name"); + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Shared helpers ──────────────────────────────────────────────────────── + + /** Strips surrounding markdown code fences (```…```) if present. */ + public static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) return text; + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) return text; + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) body = body.substring(0, body.length() - 3).stripTrailing(); + return body.strip(); + } + + /** + * Extracts the string value for {@code "key": ""} from a JSON object. + * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) + * inside the value, so multi-line code strings with embedded {@code "} are returned intact. + */ + public static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int quoteStart = json.indexOf('"', colonIdx + 1); + if (quoteStart == -1) return null; + // Scan for the closing quote, honouring backslash escapes + int i = quoteStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; // skip escape sequence (e.g. \", \\, \n) + } else if (c == '"') { + break; + } else { + i++; + } + } + if (i >= json.length()) return null; + return json.substring(quoteStart + 1, i); + } + + /** + * Extracts the JSON object value for {@code "key": {…}} using brace-counting. + * Handles nested objects and tolerates whitespace around {@code :}. + */ + public static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int braceStart = json.indexOf('{', colonIdx + 1); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + if (--depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java new file mode 100644 index 00000000..50ad5c9d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java @@ -0,0 +1,31 @@ +package org.beehive.gpullama3.tools; + +/** + * Tuning parameters for a {@link ToolCallingSession}. + * + * @param maxTokens max tokens per inference call + * @param maxRoundTrips max tool → result → re-inference cycles (default 1) + * @param maxToolResultChars tool output is truncated to this length before feeding back + * @param verbose print step-by-step output to stdout + * @param useGpu use TornadoVM GPU path for inference + */ +public record ToolCallingOptions( + int maxTokens, + int maxRoundTrips, + int maxToolResultChars, + boolean verbose, + boolean useGpu) { + + public static ToolCallingOptions defaults() { + return new ToolCallingOptions(1024, 1, 2000, true, false); + } + + public static ToolCallingOptions from(org.beehive.gpullama3.Options options) { + return new ToolCallingOptions( + options.maxTokens(), + 1, + 2000, + true, + options.useTornadovm()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java new file mode 100644 index 00000000..9d1c539f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java @@ -0,0 +1,29 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.List; + +/** + * The outcome of a complete tool-calling session (prompt → [tool round-trips] → answer). + * + * @param finalAnswer the model's final plain-text answer + * @param callsMade tool calls that were extracted and executed (may be empty) + * @param results corresponding tool results (same order as callsMade) + * @param reachedMaxRoundTrips true when the session stopped because maxRoundTrips was hit + */ +public record ToolCallingResult( + String finalAnswer, + List callsMade, + List results, + boolean reachedMaxRoundTrips) { + + public boolean hadToolCalls() { + return !callsMade.isEmpty(); + } + + /** Returns a result representing a plain-text (no-tool) response. */ + public static ToolCallingResult plainText(String answer) { + return new ToolCallingResult(answer, List.of(), List.of(), false); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java new file mode 100644 index 00000000..565294db --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java @@ -0,0 +1,208 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Framework-agnostic orchestrator for the tool-calling loop: + *
+ *   prompt → first generation → extract tool call? → execute → feed result → final answer
+ * 
+ * + * Supports Llama 3.1, Llama 3.2, and Qwen3 via the {@link ChatFormat} abstraction. + * The session is single-use; create a new instance per request. + */ +public class ToolCallingSession { + + private final Model model; + private final Sampler sampler; + private final ToolRegistry registry; + private final ToolCallingOptions options; + + public ToolCallingSession(Model model, Sampler sampler, ToolRegistry registry, ToolCallingOptions options) { + this.model = model; + this.sampler = sampler; + this.registry = registry; + this.options = options; + } + + /** Run with no custom system prompt (the tool definitions become the system message). */ + public ToolCallingResult run(String userPrompt) { + return run(null, userPrompt); + } + + /** + * Run with an optional system prompt prefix. Tool definitions are appended to it. + * + * @param systemPrompt base system prompt, or {@code null} for tools-only + * @param userPrompt the user's request + */ + public ToolCallingResult run(String systemPrompt, String userPrompt) { + ChatFormat chatFormat = model.chatFormat(); + String toolsJson = registry.toToolsJson(); + + // Build effective system and user messages according to the model's tool injection strategy. + // Formats like Llama 3.2 inject tool definitions into the first user message + // (injectsToolsInUserMessage() == true); others (Qwen3, Mistral) append to the system message. + String effectiveSystem; + String effectiveUser; + if (chatFormat.injectsToolsInUserMessage()) { + String prefix = chatFormat.toolSystemMessagePrefix(); + effectiveSystem = systemPrompt == null + ? prefix.stripLeading() + : prefix + systemPrompt; + effectiveUser = chatFormat.toolFirstUserMessagePrefix(toolsJson) + userPrompt; + } else { + String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); + effectiveSystem = systemPrompt == null + ? toolSuffix.stripLeading() + : systemPrompt + toolSuffix; + effectiveUser = userPrompt; + } + + log("\n[DEBUG] model: %s", model.getClass().getSimpleName()); + log("[DEBUG] chatFormat: %s", chatFormat.getClass().getSimpleName()); + log("[DEBUG] shouldAddSystemPrompt: %s", model.shouldAddSystemPrompt()); + log("[DEBUG] toolAwareStopTokens: %s", chatFormat.getToolAwareStopTokens()); + log("[DEBUG] ── effective system prompt ──────────────────────────"); + log("%s", effectiveSystem); + log("[DEBUG] ── end system prompt ─────────────────────────────────"); + + // ── Build initial prompt tokens ─────────────────────────────────────── + List promptTokens = new ArrayList<>(); + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + if (model.shouldAddSystemPrompt() && !effectiveSystem.isBlank()) { + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, effectiveUser))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + Set toolStopTokens = chatFormat.getToolAwareStopTokens(); + + List callsMade = new ArrayList<>(); + List toolResults = new ArrayList<>(); + + State state = model.createNewState(); + TornadoVMMasterPlan plan = options.useGpu() + ? TornadoVMMasterPlan.initializeTornadoVMPlan(state, model) + : null; + + try { + // ── Tool round-trip loop ────────────────────────────────────────────── + for (int round = 0; round < options.maxRoundTrips(); round++) { + log("\n--- First generation (round %d) ---", round + 1); + List responseTokens = generateTokens(state, plan, promptTokens, toolStopTokens); + + // strip trailing stop token + if (!responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + String rawResponse = model.tokenizer().decode(responseTokens); + + log("[DEBUG] raw response (round %d): >>>%s<<<", round + 1, rawResponse); + log("[DEBUG] response tokens: %d (stop token stripped: %s)", + responseTokens.size(), + !responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast()) ? "yes" : "no — stop token was removed before decoding"); + + List batchCalls = chatFormat.extractAllToolCalls(rawResponse); + log("[DEBUG] extractAllToolCalls found: %d call(s)", batchCalls.size()); + if (batchCalls.isEmpty()) { + log("\n--- No tool call detected; returning plain text response ---"); + return new ToolCallingResult(rawResponse, callsMade, toolResults, false); + } + + // ── Execute all tool calls found in this response ───────────────── + promptTokens = new ArrayList<>(promptTokens); + for (ToolCallExtract call : batchCalls) { + callsMade.add(call); + log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); + + ToolResult result = registry.execute(call); + toolResults.add(result); + log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); + + String feedbackContent = result.isError() + ? "Tool '" + call.name() + "' failed: " + result.error() + : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); + + promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); + promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, + "Using the tool results above, call the next required tool or answer the user's original question if all steps are complete."))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + } + + // ── Final answer after all tool round-trips ─────────────────────────── + log("\n--- Final generation ---"); + List finalTokens = generateTokens(state, plan, promptTokens, chatFormat.getStopTokens()); + if (!finalTokens.isEmpty() && chatFormat.getStopTokens().contains(finalTokens.getLast())) { + finalTokens.removeLast(); + } + String finalAnswer = model.tokenizer().decode(finalTokens); + + boolean hitLimit = callsMade.size() >= options.maxRoundTrips() + && chatFormat.extractToolCall(finalAnswer).isPresent(); + + return new ToolCallingResult(finalAnswer, callsMade, toolResults, hitLimit); + + } finally { + if (plan != null) plan.freeTornadoExecutionPlan(); + } + } + + // ── Inference ───────────────────────────────────────────────────────────── + + private List generateTokens(State state, TornadoVMMasterPlan plan, + List prompt, Set stopTokens) { + IntConsumer tokenConsumer = options.verbose() ? this::printToken : null; + + if (options.useGpu()) { + return model.generateTokensGPU( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer, plan); + } else { + return model.generateTokens( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer); + } + } + + private void printToken(int token) { + if (model.tokenizer().shouldDisplayToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + System.out.flush(); + } + } + + // ── Helpers ─────────────────────────────────────────────────────────────── + + private String truncate(String text) { + if (text == null) return ""; + if (text.length() <= options.maxToolResultChars()) return text; + return text.substring(0, options.maxToolResultChars()) + "\n... (truncated)"; + } + + private void log(String fmt, Object... args) { + if (options.verbose()) { + System.out.printf((fmt) + "%n", args); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java new file mode 100644 index 00000000..a07695c5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tools; + +/** + * Framework-agnostic description of a tool available to the model. + * + * @param name unique tool name + * @param description human-readable description used in the model's system prompt + * @param parametersJson JSON Schema object for the tool's parameters, e.g. + * {@code {"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}} + */ +public record ToolDefinition(String name, String description, String parametersJson) { + + /** Convenience factory for a tool with no parameters. */ + public static ToolDefinition noArgs(String name, String description) { + return new ToolDefinition(name, description, "{\"type\":\"object\",\"properties\":{}}"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java new file mode 100644 index 00000000..b8807486 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java @@ -0,0 +1,13 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +/** + * Executes a single tool call and returns the result. + * Implementations are responsible for parsing {@link ToolCallExtract#argumentsJson()} + * and performing the actual action. + */ +@FunctionalInterface +public interface ToolExecutor { + ToolResult execute(ToolCallExtract call); +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java new file mode 100644 index 00000000..e31fa67e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java @@ -0,0 +1,94 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Holds available tool definitions and their executors. + * Registration order is preserved for deterministic JSON output. + */ +public class ToolRegistry { + + private final Map entries = new LinkedHashMap<>(); + + private record Entry(ToolDefinition definition, ToolExecutor executor) {} + + public ToolRegistry register(ToolDefinition definition, ToolExecutor executor) { + entries.put(definition.name(), new Entry(definition, executor)); + return this; + } + + public Optional getDefinition(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.definition()); + } + + public Optional getExecutor(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.executor()); + } + + public List definitions() { + return entries.values().stream().map(Entry::definition).toList(); + } + + public boolean isEmpty() { + return entries.isEmpty(); + } + + /** + * Executes the named tool, returning a failure result for unknown tools or + * executor exceptions. Never throws. + */ + public ToolResult execute(ToolCallExtract call) { + Optional executor = getExecutor(call.name()); + if (executor.isEmpty()) { + return ToolResult.failure(call.name(), "Unknown tool: " + call.name()); + } + try { + return executor.get().execute(call); + } catch (Exception e) { + return ToolResult.failure(call.name(), "Tool execution failed: " + e.getMessage()); + } + } + + /** + * Serialises all registered tools to the flat JSON array expected by + * {@code LlamaChatFormat.toolSystemPromptSuffix()} and + * {@code Qwen3ChatFormat.toolSystemPromptSuffix()}. + * + * Format: {@code [{"name":…,"description":…,"parameters":{…}}]} + */ + public String toToolsJson() { + List defs = definitions(); + if (defs.isEmpty()) return "[]"; + + StringBuilder sb = new StringBuilder("[\n"); + for (int i = 0; i < defs.size(); i++) { + ToolDefinition d = defs.get(i); + sb.append(" {\n"); + sb.append(" \"name\": \"").append(escapeJson(d.name())).append("\",\n"); + sb.append(" \"description\": \"").append(escapeJson(d.description())).append("\",\n"); + sb.append(" \"parameters\": ").append(d.parametersJson()).append("\n"); + sb.append(" }"); + if (i < defs.size() - 1) sb.append(","); + sb.append("\n"); + } + sb.append("]"); + return sb.toString(); + } + + private static String escapeJson(String s) { + if (s == null) return ""; + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java new file mode 100644 index 00000000..6752b207 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java @@ -0,0 +1,24 @@ +package org.beehive.gpullama3.tools; + +/** + * The result of executing a single tool call. + * + * @param toolName name of the tool that was invoked + * @param resultText the tool's output (may be JSON or plain text) + * @param error non-null when execution failed; contains the error message + */ +public record ToolResult(String toolName, String resultText, String error) { + + /** Returns true when the tool execution failed. */ + public boolean isError() { + return error != null; + } + + public static ToolResult success(String toolName, String resultText) { + return new ToolResult(toolName, resultText, null); + } + + public static ToolResult failure(String toolName, String errorMessage) { + return new ToolResult(toolName, null, errorMessage); + } +}