From f4256f371255a31053a094da41e97602f41d22af Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 15:30:36 +0300 Subject: [PATCH 1/5] [tool][WIP] Introduce tool-calling capabilities by GPULlama3.java # Conflicts: # LlamaTornadoCli.java --- llama-tornado | 15 +- llamaTornado | 15 +- .../beehive/gpullama3/ToolCallingDemo.java | 234 ++++++++++++++++++ .../gpullama3/model/format/ChatFormat.java | 53 ++++ .../model/format/LlamaChatFormat.java | 180 ++++++++++++++ .../model/format/Qwen3ChatFormat.java | 124 ++++++++++ .../model/format/ToolCallExtract.java | 11 + 7 files changed, 628 insertions(+), 4 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingDemo.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java diff --git a/llama-tornado b/llama-tornado index 1d6c3d23..61439dad 100755 --- a/llama-tornado +++ b/llama-tornado @@ -191,7 +191,7 @@ class LlamaRunner: [ "-cp", self._find_llama_jar(), - "org.beehive.gpullama3.LlamaApp", + args.main_class, ] ) cmd.extend(module_config) @@ -246,6 +246,9 @@ class LlamaRunner: elif args.instruct: llama_args.append("--instruct") + if args.tool_demo: + llama_args.append("--tool-demo") + return cmd + llama_args def run(self, args: argparse.Namespace) -> int: @@ -527,6 +530,16 @@ def create_parser() -> argparse.ArgumentParser: # Advanced options advanced_group = parser.add_argument_group("Advanced Options") + advanced_group.add_argument( + "--main-class", + default="org.beehive.gpullama3.cli.LlamaTornadoCli", + help="Java main class to run (default: LlamaTornadoCli)", + ) + advanced_group.add_argument( + "--tool-demo", + action="store_true", + help="Run the tool calling demo (requires a LLaMA 3.1 or Qwen3 model)", + ) advanced_group.add_argument( "--opencl-flags", default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", diff --git a/llamaTornado b/llamaTornado index 6cc303d5..6f5328f5 100755 --- a/llamaTornado +++ b/llamaTornado @@ -17,7 +17,8 @@ record Config( boolean printBytecodes, boolean threads, boolean printKernel, boolean fullDump, boolean verboseInit, boolean showCommand, boolean executeAfterShow, - String openclFlags, int maxWaitEvents, boolean verbose + String openclFlags, int maxWaitEvents, boolean verbose, + boolean toolDemo ) {} Config parseArgs(String[] args) { @@ -50,6 +51,7 @@ Config parseArgs(String[] args) { String openclFlags = "-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only"; int maxWaitEvents = 32000; boolean verbose = false; + boolean toolDemo = false; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -84,6 +86,7 @@ Config parseArgs(String[] args) { case "--opencl-flags" -> openclFlags = args[++i]; case "--max-wait-events" -> maxWaitEvents = Integer.parseInt(args[++i]); case "--verbose", "-v" -> verbose = true; + case "--tool-demo" -> toolDemo = true; default -> { System.err.println("Unknown option: " + args[i]); System.exit(1); @@ -104,7 +107,8 @@ Config parseArgs(String[] args) { return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens, stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump, - verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose); + verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose, + toolDemo); } void printUsage() { @@ -151,6 +155,7 @@ void printUsage() { --show-command Display the full Java command --execute-after-show Execute after showing command --verbose, -v Verbose output + --tool-demo Run the tool calling demo -help Show this help -version Show version @@ -262,7 +267,10 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String } } - cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp")); + var mainClass = cfg.toolDemo() + ? "org.beehive.gpullama3.ToolCallingDemo" + : "org.beehive.gpullama3.LlamaApp"; + cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), mainClass)); // LLaMA arguments cmd.addAll(List.of( @@ -279,6 +287,7 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String if (cfg.systemPrompt() != null) cmd.addAll(List.of("-sp", cfg.systemPrompt())); if (cfg.interactive()) cmd.add("--interactive"); else if (cfg.instruct()) cmd.add("--instruct"); + // --tool-demo is handled by main class selection above, not passed to Java return cmd; } diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java new file mode 100644 index 00000000..1bee90b3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java @@ -0,0 +1,234 @@ +package org.beehive.gpullama3; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntConsumer; + +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +/** + * Standalone demo that exercises tool calling end-to-end directly against the + * GPULlama3.java inference engine — no Quarkus or LangChain4J required. + * + * Usage: + * ./llamaTornado --model /path/to/model.gguf --tool-demo + * ./llamaTornado --model /path/to/model.gguf --tool-demo --gpu --opencl + */ +public class ToolCallingDemo { + + private static final String TOOLS_JSON = """ + [ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files and directories at the given path using ls", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list. Defaults to current directory if omitted." + } + }, + "required": [] + } + } + } + ]"""; + + private static final String USER_PROMPT = + "Use the list_directory tool to list the files in the current directory."; + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(ensurePrompt(args)); + + System.out.println("=== GPULlama3 Tool Calling Demo ==="); + System.out.println("Model : " + options.modelPath()); + System.out.println("GPU : " + options.useTornadovm()); + System.out.println("Prompt: " + USER_PROMPT); + System.out.println(); + + Model model = loadModel(options); + Sampler sampler = Sampler.createSampler(model, options); + ChatFormat chatFormat = model.chatFormat(); + + // ── Build prompt ────────────────────────────────────────────────────── + String toolSuffix = chatFormat.toolSystemPromptSuffix(TOOLS_JSON); + List promptTokens = new ArrayList<>(); + + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + if (model.shouldAddSystemPrompt()) { + // Keep the system message minimal — just the tool instructions. + // A preamble like "You are a helpful assistant" can cause the model + // to answer from knowledge instead of calling the tool. + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.SYSTEM, toolSuffix.stripLeading()))); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, USER_PROMPT))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // ── First inference turn ────────────────────────────────────────────── + Set stopTokens = chatFormat.getToolAwareStopTokens(); + State state = model.createNewState(); + + System.out.println("--- Raw model output ---"); + List responseTokens = generateTokens( + model, options, state, promptTokens, stopTokens, options.maxTokens(), sampler); + System.out.println("\n--- End raw output ---\n"); + + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + String rawResponse = model.tokenizer().decode(responseTokens); + + // ── Detect and execute tool call ────────────────────────────────────── + Optional toolCall = chatFormat.extractToolCall(rawResponse); + if (toolCall.isPresent()) { + ToolCallExtract tc = toolCall.get(); + System.out.println("✓ Tool call detected!"); + System.out.println(" Function : " + tc.name()); + System.out.println(" Arguments: " + tc.argumentsJson()); + + String toolResult = executeTool(tc); + System.out.println(" Result :\n" + toolResult); + + // ── Feed result back and get final answer ───────────────────────── + List continuation = new ArrayList<>(promptTokens); + continuation.addAll(chatFormat.encodeToolCallAssistantTurn(tc)); + continuation.addAll(chatFormat.encodeToolResultTurn(null, tc.name(), toolResult)); + // Ask the model to summarise in plain text — prevents small models from + // looping back into another tool call when they see tool defs in the system prompt. + continuation.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, + "Based on the tool result above, please answer my question in plain text."))); + continuation.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + State state2 = model.createNewState(); + System.out.println("--- Final answer ---"); + generateTokens(model, options, state2, continuation, + chatFormat.getStopTokens(), options.maxTokens(), sampler); + System.out.println(); + + } else { + System.out.println("✗ No tool call detected."); + System.out.println(" The model responded with plain text instead of a tool call."); + System.out.println(" Note: reliable tool calling typically requires a 7B+ model."); + System.out.println("\n Full decoded response:\n" + rawResponse); + } + + LastRunMetrics.printMetrics(); + } + + // ── Tool executor ───────────────────────────────────────────────────────── + + /** Maximum characters of tool output fed back into the prompt. */ + private static final int MAX_TOOL_RESULT_CHARS = 600; + + private static String executeTool(ToolCallExtract tc) { + return switch (tc.name()) { + case "list_directory" -> { + String path = extractStringArg(tc.argumentsJson(), "path", "."); + yield truncate(runProcess("ls", "-la", path)); + } + default -> "Unknown tool: " + tc.name(); + }; + } + + private static String runProcess(String... command) { + try { + var process = new ProcessBuilder(command) + .redirectErrorStream(true) + .start(); + String output = new String(process.getInputStream().readAllBytes()); + process.waitFor(); + return output; + } catch (Exception e) { + return "Error executing command: " + e.getMessage(); + } + } + + private static String truncate(String text) { + if (text.length() <= MAX_TOOL_RESULT_CHARS) return text; + return text.substring(0, MAX_TOOL_RESULT_CHARS) + "\n... (truncated)"; + } + + /** + * Extracts a string value from a flat JSON object by key. + * Falls back to {@code defaultValue} if the key is absent or if the value is a + * nested object (which happens when a small model echoes the schema definition + * instead of supplying an actual argument value). + */ + private static String extractStringArg(String json, String key, String defaultValue) { + String marker = "\"" + key + "\":"; + int idx = json.indexOf(marker); + if (idx == -1) return defaultValue; + int pos = idx + marker.length(); + while (pos < json.length() && Character.isWhitespace(json.charAt(pos))) pos++; + if (pos >= json.length()) return defaultValue; + // Nested object instead of a plain string — model echoed the schema, use default + if (json.charAt(pos) == '{') return defaultValue; + int valStart = json.indexOf('"', pos); + if (valStart == -1) return defaultValue; + int valEnd = json.indexOf('"', valStart + 1); + if (valEnd == -1) return defaultValue; + return json.substring(valStart + 1, valEnd); + } + + // ── Inference helpers ───────────────────────────────────────────────────── + + private static List generateTokens( + Model model, Options options, State state, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler) { + + IntConsumer tokenConsumer = token -> { + if (model.tokenizer().shouldDisplayToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + System.out.flush(); + } + }; + + if (options.useTornadovm()) { + TornadoVMMasterPlan plan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); + try { + return model.generateTokensGPU( + state, 0, promptTokens, stopTokens, maxTokens, sampler, + false, tokenConsumer, plan); + } finally { + plan.freeTornadoExecutionPlan(); + } + } else { + return model.generateTokens( + state, 0, promptTokens, stopTokens, maxTokens, sampler, + false, tokenConsumer); + } + } + + private static String[] ensurePrompt(String[] args) { + for (String arg : args) { + if (arg.equals("--prompt") || arg.equals("-p")) return args; + } + String[] extended = Arrays.copyOf(args, args.length + 2); + extended[args.length] = "--prompt"; + extended[args.length + 1] = USER_PROMPT; + return extended; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..486cdd54 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; import java.util.List; +import java.util.Optional; import java.util.Set; public interface ChatFormat { @@ -36,6 +37,58 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns plain text to append to the system message content when tools are available. + * The returned string is concatenated to the system message before encoding, so the + * normal {@link #encodeMessage} path handles tokenization. + * + * @param toolsJson JSON array of tool definitions, e.g. + * {@code [{"type":"function","function":{...}}]} + */ + default String toolSystemPromptSuffix(String toolsJson) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Re-encodes a prior assistant tool-call turn into the conversation token stream. + * Used when replaying multi-turn history that contains a previous tool call. + * + * @param toolCall the tool call to encode (name + raw arguments JSON) + */ + default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Encodes a tool execution result message in the model-native format. + * + * @param toolCallId the ID of the originating tool call (may be ignored by some formats) + * @param toolName the name of the tool that was called + * @param result the result content string + */ + default List encodeToolResultTurn(String toolCallId, String toolName, String result) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Detects and extracts a tool call from fully decoded model response text. + * Returns {@link Optional#empty()} when the response is a plain text answer. + * + * @param responseText the fully decoded response from the model + */ + default Optional extractToolCall(String responseText) { + return Optional.empty(); + } + + /** + * Stop tokens to use when tool calling is enabled. + * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) + * when emitting a tool call instead of a regular response. + */ + default Set getToolAwareStopTokens() { + return getStopTokens(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index c98a72c9..6e5f1683 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class LlamaChatFormat implements ChatFormat { @@ -17,6 +18,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfTurn; protected final int endOfText; protected final int endOfMessage; + protected final int pythonTag; protected final Set stopTokens; public LlamaChatFormat(Tokenizer tokenizer) { @@ -28,6 +30,7 @@ public LlamaChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.pythonTag = specialTokens.getOrDefault("<|python_tag|>", -1); // only in 3.1 this.stopTokens = Set.of(endOfText, endOfTurn); } @@ -71,4 +74,181 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listassistant<|end_header_id|>\n<|python_tag|>JSON<|eom_id|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + if (pythonTag != -1) { + tokens.add(pythonTag); + } + String json = "{\"name\":\"" + toolCall.name() + "\",\"parameters\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList(json)); + if (endOfMessage != -1) { + tokens.add(endOfMessage); + } else { + tokens.add(endOfTurn); + } + return tokens; + } + + /** + * Encodes a tool result using the LLaMA "ipython" role. + * Format: {@code <|start_header_id|>ipython<|end_header_id|>\nresult<|eot_id|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList("ipython")); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + tokens.addAll(tokenizer.encodeAsList(result)); + tokens.add(endOfTurn); + return tokens; + } + + /** + * Detects a tool call in the decoded response text. + * + *

Two formats are recognised:

+ *
    + *
  1. LLaMA 3.1 native: {@code <|python_tag|>{"name":…,"parameters":{…}}}
  2. + *
  3. Fallback: raw JSON object (possibly wrapped in markdown code fences), + * as produced by smaller models that follow the system-prompt instructions but skip + * the special token prefix.
  4. + *
+ */ + @Override + public Optional extractToolCall(String responseText) { + // ── 1. Native LLaMA 3.1 format: <|python_tag|>{...} ────────────────── + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseToolCallJson(json); + } + + // ── 2. Fallback: raw JSON, possibly inside markdown code fences ─────── + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseToolCallJson(stripped); + } + + return Optional.empty(); + } + + /** + * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. + * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. + */ + @Override + public Set getToolAwareStopTokens() { + if (endOfMessage != -1) { + return Set.of(endOfText, endOfTurn, endOfMessage); + } + return stopTokens; + } + + /** + * Strips surrounding markdown code fences (``` or ```json / ```python etc.) if present. + */ + private static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) { + return text; + } + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) { + return text; + } + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) { + body = body.substring(0, body.length() - 3).stripTrailing(); + } + return body.strip(); + } + + /** + * Parses a tool call JSON object into a {@link ToolCallExtract}. + * + *

Accepts both formats produced by LLaMA variants:

+ *
    + *
  • {@code {"name":"fn","parameters":{…}}} — LLaMA 3.1 native
  • + *
  • {@code {"function":"fn","parameters":{…}}} — produced by some fine-tunes
  • + *
  • {@code {"name":"fn","arguments":{…}}} — alternative key
  • + *
+ * Uses brace-counting to extract nested argument objects correctly. + */ + private static Optional parseToolCallJson(String json) { + // ── extract tool name: try "name" then "function" ───────────────────── + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) { + return Optional.empty(); + } + + // ── extract arguments object: try "parameters" then "arguments" ─────── + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) { + argsJson = extractNestedObject(json, "arguments"); + } + if (argsJson == null) { + argsJson = "{}"; // tool call with no arguments + } + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + /** Extracts the string value for {@code "key":""} from a JSON object. */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int quoteStart = json.indexOf('"', markerIdx + marker.length()); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key":{…}} using brace-counting, + * so nested objects are handled correctly regardless of what follows. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int braceStart = json.indexOf('{', markerIdx + marker.length()); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + depth--; + if (depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced JSON + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index b6d2e798..c4e31270 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -129,4 +129,128 @@ public Set getStopTokens() { return stopTokens; } + + // ── Tool calling ────────────────────────────────────────────────────────── + + /** + * Qwen3 tool calling system prompt suffix. + * Appended to the system message; instructs the model to wrap tool calls in + * {@code } XML tags. + */ + @Override + public String toolSystemPromptSuffix(String toolsJson) { + return "\n\n# Tools\n\n" + + "You may call one or more functions to assist with the user query.\n\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + toolsJson + + "\n\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history. + * Format: {@code <|im_start|>assistant\n\nJSON\n<|im_end|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + String json = "{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes a tool result using the Qwen3 "tool" role. + * Format: {@code <|im_start|>tool\nresult<|im_end|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); + tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Detects a tool call enclosed in {@code } tags. + */ + @Override + public Optional extractToolCall(String responseText) { + int start = responseText.indexOf(""); + int end = responseText.lastIndexOf(""); + if (start == -1 || end == -1) { + return Optional.empty(); + } + String json = responseText.substring(start + "".length(), end).strip(); + return parseToolCallJson(json, "arguments"); + } + + /** + * Parses {@code name} and the value of {@code argsKey} out of a tool-call JSON object. + * Uses brace-counting to extract nested argument objects correctly. + * Avoids a JSON-library dependency. + */ + private static Optional parseToolCallJson(String json, String argsKey) { + // extract "name" + String name = extractStringValue(json, "name"); + if (name == null) { + return Optional.empty(); + } + + // extract arguments object using brace-counting + String argsJson = extractNestedObject(json, argsKey); + if (argsJson == null) { + argsJson = "{}"; // tool call with no arguments + } + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + /** Extracts the string value for {@code "key":""} from a JSON object. */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int quoteStart = json.indexOf('"', markerIdx + marker.length()); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key":{…}} using brace-counting, + * so nested objects are handled correctly regardless of what follows. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int braceStart = json.indexOf('{', markerIdx + marker.length()); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + depth--; + if (depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced JSON + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java new file mode 100644 index 00000000..335c2176 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -0,0 +1,11 @@ +package org.beehive.gpullama3.model.format; + +/** + * Represents a single tool call extracted from a model response. + * Contains the raw strings — JSON parsing of arguments is left to the caller. + * + * @param name the tool/function name to invoke + * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + */ +public record ToolCallExtract(String name, String argumentsJson) { +} From 4e7e5f90a971334a02055cca96ed33567e0df980 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Sat, 18 Apr 2026 15:05:11 +0300 Subject: [PATCH 2/5] [tools] Refactor tool-calling architecture: modularize components and extract ToolCallingDemo --- .../beehive/gpullama3/ToolCallingDemo.java | 234 ------------------ .../model/format/LlamaChatFormat.java | 137 ++-------- .../model/format/Qwen3ChatFormat.java | 64 +---- .../model/format/ToolCallParserUtils.java | 143 +++++++++++ .../gpullama3/tools/ToolCallingOptions.java | 31 +++ .../gpullama3/tools/ToolCallingResult.java | 29 +++ .../gpullama3/tools/ToolCallingSession.java | 179 ++++++++++++++ .../gpullama3/tools/ToolDefinition.java | 17 ++ .../beehive/gpullama3/tools/ToolExecutor.java | 13 + .../beehive/gpullama3/tools/ToolRegistry.java | 94 +++++++ .../beehive/gpullama3/tools/ToolResult.java | 24 ++ 11 files changed, 549 insertions(+), 416 deletions(-) delete mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingDemo.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolResult.java diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java deleted file mode 100644 index 1bee90b3..00000000 --- a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java +++ /dev/null @@ -1,234 +0,0 @@ -package org.beehive.gpullama3; - -import org.beehive.gpullama3.auxiliary.LastRunMetrics; -import org.beehive.gpullama3.inference.sampler.Sampler; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.model.format.ToolCallExtract; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.function.IntConsumer; - -import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; - -/** - * Standalone demo that exercises tool calling end-to-end directly against the - * GPULlama3.java inference engine — no Quarkus or LangChain4J required. - * - * Usage: - * ./llamaTornado --model /path/to/model.gguf --tool-demo - * ./llamaTornado --model /path/to/model.gguf --tool-demo --gpu --opencl - */ -public class ToolCallingDemo { - - private static final String TOOLS_JSON = """ - [ - { - "type": "function", - "function": { - "name": "list_directory", - "description": "List files and directories at the given path using ls", - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Directory path to list. Defaults to current directory if omitted." - } - }, - "required": [] - } - } - } - ]"""; - - private static final String USER_PROMPT = - "Use the list_directory tool to list the files in the current directory."; - - public static void main(String[] args) throws IOException { - Options options = Options.parseOptions(ensurePrompt(args)); - - System.out.println("=== GPULlama3 Tool Calling Demo ==="); - System.out.println("Model : " + options.modelPath()); - System.out.println("GPU : " + options.useTornadovm()); - System.out.println("Prompt: " + USER_PROMPT); - System.out.println(); - - Model model = loadModel(options); - Sampler sampler = Sampler.createSampler(model, options); - ChatFormat chatFormat = model.chatFormat(); - - // ── Build prompt ────────────────────────────────────────────────────── - String toolSuffix = chatFormat.toolSystemPromptSuffix(TOOLS_JSON); - List promptTokens = new ArrayList<>(); - - if (model.shouldAddBeginOfText()) { - promptTokens.add(chatFormat.getBeginOfText()); - } - if (model.shouldAddSystemPrompt()) { - // Keep the system message minimal — just the tool instructions. - // A preamble like "You are a helpful assistant" can cause the model - // to answer from knowledge instead of calling the tool. - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.SYSTEM, toolSuffix.stripLeading()))); - } - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, USER_PROMPT))); - promptTokens.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - - // ── First inference turn ────────────────────────────────────────────── - Set stopTokens = chatFormat.getToolAwareStopTokens(); - State state = model.createNewState(); - - System.out.println("--- Raw model output ---"); - List responseTokens = generateTokens( - model, options, state, promptTokens, stopTokens, options.maxTokens(), sampler); - System.out.println("\n--- End raw output ---\n"); - - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { - responseTokens.removeLast(); - } - String rawResponse = model.tokenizer().decode(responseTokens); - - // ── Detect and execute tool call ────────────────────────────────────── - Optional toolCall = chatFormat.extractToolCall(rawResponse); - if (toolCall.isPresent()) { - ToolCallExtract tc = toolCall.get(); - System.out.println("✓ Tool call detected!"); - System.out.println(" Function : " + tc.name()); - System.out.println(" Arguments: " + tc.argumentsJson()); - - String toolResult = executeTool(tc); - System.out.println(" Result :\n" + toolResult); - - // ── Feed result back and get final answer ───────────────────────── - List continuation = new ArrayList<>(promptTokens); - continuation.addAll(chatFormat.encodeToolCallAssistantTurn(tc)); - continuation.addAll(chatFormat.encodeToolResultTurn(null, tc.name(), toolResult)); - // Ask the model to summarise in plain text — prevents small models from - // looping back into another tool call when they see tool defs in the system prompt. - continuation.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, - "Based on the tool result above, please answer my question in plain text."))); - continuation.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - - State state2 = model.createNewState(); - System.out.println("--- Final answer ---"); - generateTokens(model, options, state2, continuation, - chatFormat.getStopTokens(), options.maxTokens(), sampler); - System.out.println(); - - } else { - System.out.println("✗ No tool call detected."); - System.out.println(" The model responded with plain text instead of a tool call."); - System.out.println(" Note: reliable tool calling typically requires a 7B+ model."); - System.out.println("\n Full decoded response:\n" + rawResponse); - } - - LastRunMetrics.printMetrics(); - } - - // ── Tool executor ───────────────────────────────────────────────────────── - - /** Maximum characters of tool output fed back into the prompt. */ - private static final int MAX_TOOL_RESULT_CHARS = 600; - - private static String executeTool(ToolCallExtract tc) { - return switch (tc.name()) { - case "list_directory" -> { - String path = extractStringArg(tc.argumentsJson(), "path", "."); - yield truncate(runProcess("ls", "-la", path)); - } - default -> "Unknown tool: " + tc.name(); - }; - } - - private static String runProcess(String... command) { - try { - var process = new ProcessBuilder(command) - .redirectErrorStream(true) - .start(); - String output = new String(process.getInputStream().readAllBytes()); - process.waitFor(); - return output; - } catch (Exception e) { - return "Error executing command: " + e.getMessage(); - } - } - - private static String truncate(String text) { - if (text.length() <= MAX_TOOL_RESULT_CHARS) return text; - return text.substring(0, MAX_TOOL_RESULT_CHARS) + "\n... (truncated)"; - } - - /** - * Extracts a string value from a flat JSON object by key. - * Falls back to {@code defaultValue} if the key is absent or if the value is a - * nested object (which happens when a small model echoes the schema definition - * instead of supplying an actual argument value). - */ - private static String extractStringArg(String json, String key, String defaultValue) { - String marker = "\"" + key + "\":"; - int idx = json.indexOf(marker); - if (idx == -1) return defaultValue; - int pos = idx + marker.length(); - while (pos < json.length() && Character.isWhitespace(json.charAt(pos))) pos++; - if (pos >= json.length()) return defaultValue; - // Nested object instead of a plain string — model echoed the schema, use default - if (json.charAt(pos) == '{') return defaultValue; - int valStart = json.indexOf('"', pos); - if (valStart == -1) return defaultValue; - int valEnd = json.indexOf('"', valStart + 1); - if (valEnd == -1) return defaultValue; - return json.substring(valStart + 1, valEnd); - } - - // ── Inference helpers ───────────────────────────────────────────────────── - - private static List generateTokens( - Model model, Options options, State state, - List promptTokens, Set stopTokens, - int maxTokens, Sampler sampler) { - - IntConsumer tokenConsumer = token -> { - if (model.tokenizer().shouldDisplayToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - System.out.flush(); - } - }; - - if (options.useTornadovm()) { - TornadoVMMasterPlan plan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); - try { - return model.generateTokensGPU( - state, 0, promptTokens, stopTokens, maxTokens, sampler, - false, tokenConsumer, plan); - } finally { - plan.freeTornadoExecutionPlan(); - } - } else { - return model.generateTokens( - state, 0, promptTokens, stopTokens, maxTokens, sampler, - false, tokenConsumer); - } - } - - private static String[] ensurePrompt(String[] args) { - for (String arg : args) { - if (arg.equals("--prompt") || arg.equals("-p")) return args; - } - String[] extended = Arrays.copyOf(args, args.length + 2); - extended[args.length] = "--prompt"; - extended[args.length + 1] = USER_PROMPT; - return extended; - } -} diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 6e5f1683..289b3004 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -83,11 +83,17 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, List XML tags:\n\n" + + "\n" + toolsJson + "\n\n\n" + + "IMPORTANT: the \"name\" field in your tool call MUST be exactly one of the function names " + + "listed inside above — not a path, not a word from the user's message.\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; } /** @@ -97,16 +103,9 @@ public String toolSystemPromptSuffix(String toolsJson) { @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); - if (pythonTag != -1) { - tokens.add(pythonTag); - } - String json = "{\"name\":\"" + toolCall.name() + "\",\"parameters\":" + toolCall.argumentsJson() + "}"; + String json = "\n{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}\n"; tokens.addAll(tokenizer.encodeAsList(json)); - if (endOfMessage != -1) { - tokens.add(endOfMessage); - } else { - tokens.add(endOfTurn); - } + tokens.add(endOfTurn); return tokens; } @@ -128,31 +127,13 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St /** * Detects a tool call in the decoded response text. - * - *

Two formats are recognised:

- *
    - *
  1. LLaMA 3.1 native: {@code <|python_tag|>{"name":…,"parameters":{…}}}
  2. - *
  3. Fallback: raw JSON object (possibly wrapped in markdown code fences), - * as produced by smaller models that follow the system-prompt instructions but skip - * the special token prefix.
  4. - *
+ * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), + * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback + * for smaller models. Delegates to {@link ToolCallParserUtils#parseLlamaResponse}. */ @Override public Optional extractToolCall(String responseText) { - // ── 1. Native LLaMA 3.1 format: <|python_tag|>{...} ────────────────── - int idx = responseText.indexOf("<|python_tag|>"); - if (idx != -1) { - String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); - return parseToolCallJson(json); - } - - // ── 2. Fallback: raw JSON, possibly inside markdown code fences ─────── - String stripped = stripMarkdownFences(responseText.strip()); - if (stripped.startsWith("{")) { - return parseToolCallJson(stripped); - } - - return Optional.empty(); + return ToolCallParserUtils.parseLlamaResponse(responseText); } /** @@ -167,88 +148,4 @@ public Set getToolAwareStopTokens() { return stopTokens; } - /** - * Strips surrounding markdown code fences (``` or ```json / ```python etc.) if present. - */ - private static String stripMarkdownFences(String text) { - if (!text.startsWith("```")) { - return text; - } - int firstNewline = text.indexOf('\n'); - if (firstNewline == -1) { - return text; - } - String body = text.substring(firstNewline + 1); - if (body.endsWith("```")) { - body = body.substring(0, body.length() - 3).stripTrailing(); - } - return body.strip(); - } - - /** - * Parses a tool call JSON object into a {@link ToolCallExtract}. - * - *

Accepts both formats produced by LLaMA variants:

- *
    - *
  • {@code {"name":"fn","parameters":{…}}} — LLaMA 3.1 native
  • - *
  • {@code {"function":"fn","parameters":{…}}} — produced by some fine-tunes
  • - *
  • {@code {"name":"fn","arguments":{…}}} — alternative key
  • - *
- * Uses brace-counting to extract nested argument objects correctly. - */ - private static Optional parseToolCallJson(String json) { - // ── extract tool name: try "name" then "function" ───────────────────── - String name = extractStringValue(json, "name"); - if (name == null) { - name = extractStringValue(json, "function"); - } - if (name == null) { - return Optional.empty(); - } - - // ── extract arguments object: try "parameters" then "arguments" ─────── - String argsJson = extractNestedObject(json, "parameters"); - if (argsJson == null) { - argsJson = extractNestedObject(json, "arguments"); - } - if (argsJson == null) { - argsJson = "{}"; // tool call with no arguments - } - - return Optional.of(new ToolCallExtract(name, argsJson)); - } - - /** Extracts the string value for {@code "key":""} from a JSON object. */ - private static String extractStringValue(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int quoteStart = json.indexOf('"', markerIdx + marker.length()); - if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); - } - - /** - * Extracts the JSON object value for {@code "key":{…}} using brace-counting, - * so nested objects are handled correctly regardless of what follows. - */ - private static String extractNestedObject(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int braceStart = json.indexOf('{', markerIdx + marker.length()); - if (braceStart == -1) return null; - int depth = 0; - for (int i = braceStart; i < json.length(); i++) { - char c = json.charAt(i); - if (c == '{') depth++; - else if (c == '}') { - depth--; - if (depth == 0) return json.substring(braceStart, i + 1); - } - } - return null; // unbalanced JSON - } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index c4e31270..de671a59 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -187,70 +187,10 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St /** * Detects a tool call enclosed in {@code } tags. + * Delegates to {@link ToolCallParserUtils#parseQwen3Response}. */ @Override public Optional extractToolCall(String responseText) { - int start = responseText.indexOf(""); - int end = responseText.lastIndexOf(""); - if (start == -1 || end == -1) { - return Optional.empty(); - } - String json = responseText.substring(start + "".length(), end).strip(); - return parseToolCallJson(json, "arguments"); - } - - /** - * Parses {@code name} and the value of {@code argsKey} out of a tool-call JSON object. - * Uses brace-counting to extract nested argument objects correctly. - * Avoids a JSON-library dependency. - */ - private static Optional parseToolCallJson(String json, String argsKey) { - // extract "name" - String name = extractStringValue(json, "name"); - if (name == null) { - return Optional.empty(); - } - - // extract arguments object using brace-counting - String argsJson = extractNestedObject(json, argsKey); - if (argsJson == null) { - argsJson = "{}"; // tool call with no arguments - } - - return Optional.of(new ToolCallExtract(name, argsJson)); - } - - /** Extracts the string value for {@code "key":""} from a JSON object. */ - private static String extractStringValue(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int quoteStart = json.indexOf('"', markerIdx + marker.length()); - if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); - } - - /** - * Extracts the JSON object value for {@code "key":{…}} using brace-counting, - * so nested objects are handled correctly regardless of what follows. - */ - private static String extractNestedObject(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int braceStart = json.indexOf('{', markerIdx + marker.length()); - if (braceStart == -1) return null; - int depth = 0; - for (int i = braceStart; i < json.length(); i++) { - char c = json.charAt(i); - if (c == '{') depth++; - else if (c == '}') { - depth--; - if (depth == 0) return json.substring(braceStart, i + 1); - } - } - return null; // unbalanced JSON + return ToolCallParserUtils.parseQwen3Response(responseText); } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java new file mode 100644 index 00000000..73a327a4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -0,0 +1,143 @@ +package org.beehive.gpullama3.model.format; + +import java.util.Optional; + +/** + * Pure-string tool-call extraction for Llama and Qwen3 response formats. + * + * All methods are stateless and do not require any model or tokenizer instance, + * making them directly unit-testable. + */ +public final class ToolCallParserUtils { + + private ToolCallParserUtils() {} + + // ── Llama ───────────────────────────────────────────────────────────────── + + /** + * Extracts a tool call from a LLaMA 3.1 or 3.2 model response. + * + * Recognised formats: + * 1. {@code <|python_tag|>{"name":…,"parameters":{…}}} — LLaMA 3.1 native, also accepted by 3.2 + * 2. Raw JSON with {@code "arguments"} key instead of {@code "parameters"} — LLaMA 3.2 instruction format + * 3. Raw JSON object optionally inside markdown code fences — fallback for models that + * follow system-prompt instructions but omit the special-token prefix + * + * Both {@code "parameters"} and {@code "arguments"} are tried so a single implementation + * handles the 3.1 and 3.2 variants transparently. + */ + public static Optional parseLlamaResponse(String responseText) { + // 1. Native LLaMA 3.1 format: <|python_tag|>{...} + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseLlamaJson(json); + } + + // 2. LLaMA 3.2 format: ... + int tcStart = responseText.indexOf(""); + int tcEnd = responseText.lastIndexOf(""); + if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { + String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); + return parseLlamaJson(json); + } + + // 3. Fallback: raw JSON, possibly inside markdown code fences + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseLlamaJson(stripped); + } + + return Optional.empty(); + } + + /** + * Parses a LLaMA-style tool call JSON object. + * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, + * and {@code {"name":…,"arguments":{…}}}. + */ + private static Optional parseLlamaJson(String json) { + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Qwen3 ───────────────────────────────────────────────────────────────── + + /** + * Extracts a tool call enclosed in {@code } tags + * as produced by Qwen3 models. + */ + public static Optional parseQwen3Response(String responseText) { + int start = responseText.indexOf(""); + int end = responseText.lastIndexOf(""); + if (start == -1 || end == -1 || end <= start) return Optional.empty(); + + String json = responseText.substring(start + "".length(), end).strip(); + + String name = extractStringValue(json, "name"); + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Shared helpers ──────────────────────────────────────────────────────── + + /** Strips surrounding markdown code fences (```…```) if present. */ + public static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) return text; + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) return text; + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) body = body.substring(0, body.length() - 3).stripTrailing(); + return body.strip(); + } + + /** Extracts the string value for {@code "key": ""} from a JSON object. Tolerates whitespace around {@code :}. */ + public static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int quoteStart = json.indexOf('"', colonIdx + 1); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key": {…}} using brace-counting. + * Handles nested objects and tolerates whitespace around {@code :}. + */ + public static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int braceStart = json.indexOf('{', colonIdx + 1); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + if (--depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java new file mode 100644 index 00000000..50ad5c9d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java @@ -0,0 +1,31 @@ +package org.beehive.gpullama3.tools; + +/** + * Tuning parameters for a {@link ToolCallingSession}. + * + * @param maxTokens max tokens per inference call + * @param maxRoundTrips max tool → result → re-inference cycles (default 1) + * @param maxToolResultChars tool output is truncated to this length before feeding back + * @param verbose print step-by-step output to stdout + * @param useGpu use TornadoVM GPU path for inference + */ +public record ToolCallingOptions( + int maxTokens, + int maxRoundTrips, + int maxToolResultChars, + boolean verbose, + boolean useGpu) { + + public static ToolCallingOptions defaults() { + return new ToolCallingOptions(1024, 1, 2000, true, false); + } + + public static ToolCallingOptions from(org.beehive.gpullama3.Options options) { + return new ToolCallingOptions( + options.maxTokens(), + 1, + 2000, + true, + options.useTornadovm()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java new file mode 100644 index 00000000..9d1c539f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java @@ -0,0 +1,29 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.List; + +/** + * The outcome of a complete tool-calling session (prompt → [tool round-trips] → answer). + * + * @param finalAnswer the model's final plain-text answer + * @param callsMade tool calls that were extracted and executed (may be empty) + * @param results corresponding tool results (same order as callsMade) + * @param reachedMaxRoundTrips true when the session stopped because maxRoundTrips was hit + */ +public record ToolCallingResult( + String finalAnswer, + List callsMade, + List results, + boolean reachedMaxRoundTrips) { + + public boolean hadToolCalls() { + return !callsMade.isEmpty(); + } + + /** Returns a result representing a plain-text (no-tool) response. */ + public static ToolCallingResult plainText(String answer) { + return new ToolCallingResult(answer, List.of(), List.of(), false); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java new file mode 100644 index 00000000..b56be36c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java @@ -0,0 +1,179 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Framework-agnostic orchestrator for the tool-calling loop: + *
+ *   prompt → first generation → extract tool call? → execute → feed result → final answer
+ * 
+ * + * Supports Llama 3.1, Llama 3.2, and Qwen3 via the {@link ChatFormat} abstraction. + * The session is single-use; create a new instance per request. + */ +public class ToolCallingSession { + + private final Model model; + private final Sampler sampler; + private final ToolRegistry registry; + private final ToolCallingOptions options; + + public ToolCallingSession(Model model, Sampler sampler, ToolRegistry registry, ToolCallingOptions options) { + this.model = model; + this.sampler = sampler; + this.registry = registry; + this.options = options; + } + + /** Run with no custom system prompt (the tool definitions become the system message). */ + public ToolCallingResult run(String userPrompt) { + return run(null, userPrompt); + } + + /** + * Run with an optional system prompt prefix. Tool definitions are appended to it. + * + * @param systemPrompt base system prompt, or {@code null} for tools-only + * @param userPrompt the user's request + */ + public ToolCallingResult run(String systemPrompt, String userPrompt) { + ChatFormat chatFormat = model.chatFormat(); + String toolsJson = registry.toToolsJson(); + String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); + + String effectiveSystem = systemPrompt == null + ? toolSuffix.stripLeading() + : systemPrompt + toolSuffix; + + // ── Build initial prompt tokens ─────────────────────────────────────── + List promptTokens = new ArrayList<>(); + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + if (model.shouldAddSystemPrompt()) { + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, userPrompt))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + Set toolStopTokens = chatFormat.getToolAwareStopTokens(); + + List callsMade = new ArrayList<>(); + List toolResults = new ArrayList<>(); + + State state = model.createNewState(); + TornadoVMMasterPlan plan = options.useGpu() + ? TornadoVMMasterPlan.initializeTornadoVMPlan(state, model) + : null; + + try { + // ── Tool round-trip loop ────────────────────────────────────────────── + for (int round = 0; round < options.maxRoundTrips(); round++) { + log("\n--- First generation (round %d) ---", round + 1); + List responseTokens = generateTokens(state, plan, promptTokens, toolStopTokens); + + // strip trailing stop token + if (!responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + String rawResponse = model.tokenizer().decode(responseTokens); + + Optional maybeCall = chatFormat.extractToolCall(rawResponse); + if (maybeCall.isEmpty()) { + log("\n--- No tool call detected; returning plain text response ---"); + return new ToolCallingResult(rawResponse, callsMade, toolResults, false); + } + + ToolCallExtract call = maybeCall.get(); + callsMade.add(call); + log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); + + ToolResult result = registry.execute(call); + toolResults.add(result); + log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); + + // ── Build continuation tokens ───────────────────────────────────── + String feedbackContent = result.isError() + ? "Tool '" + call.name() + "' failed: " + result.error() + : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); + + promptTokens = new ArrayList<>(promptTokens); + promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); + promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, + "Using only the tool result above, answer the user's question in plain text. Do not repeat the raw output."))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + } + + // ── Final answer after all tool round-trips ─────────────────────────── + log("\n--- Final generation ---"); + List finalTokens = generateTokens(state, plan, promptTokens, chatFormat.getStopTokens()); + if (!finalTokens.isEmpty() && chatFormat.getStopTokens().contains(finalTokens.getLast())) { + finalTokens.removeLast(); + } + String finalAnswer = model.tokenizer().decode(finalTokens); + + boolean hitLimit = callsMade.size() >= options.maxRoundTrips() + && chatFormat.extractToolCall(finalAnswer).isPresent(); + + return new ToolCallingResult(finalAnswer, callsMade, toolResults, hitLimit); + + } finally { + if (plan != null) plan.freeTornadoExecutionPlan(); + } + } + + // ── Inference ───────────────────────────────────────────────────────────── + + private List generateTokens(State state, TornadoVMMasterPlan plan, + List prompt, Set stopTokens) { + IntConsumer tokenConsumer = options.verbose() ? this::printToken : null; + + if (options.useGpu()) { + return model.generateTokensGPU( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer, plan); + } else { + return model.generateTokens( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer); + } + } + + private void printToken(int token) { + if (model.tokenizer().shouldDisplayToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + System.out.flush(); + } + } + + // ── Helpers ─────────────────────────────────────────────────────────────── + + private String truncate(String text) { + if (text == null) return ""; + if (text.length() <= options.maxToolResultChars()) return text; + return text.substring(0, options.maxToolResultChars()) + "\n... (truncated)"; + } + + private void log(String fmt, Object... args) { + if (options.verbose()) { + System.out.printf((fmt) + "%n", args); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java new file mode 100644 index 00000000..a07695c5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tools; + +/** + * Framework-agnostic description of a tool available to the model. + * + * @param name unique tool name + * @param description human-readable description used in the model's system prompt + * @param parametersJson JSON Schema object for the tool's parameters, e.g. + * {@code {"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}} + */ +public record ToolDefinition(String name, String description, String parametersJson) { + + /** Convenience factory for a tool with no parameters. */ + public static ToolDefinition noArgs(String name, String description) { + return new ToolDefinition(name, description, "{\"type\":\"object\",\"properties\":{}}"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java new file mode 100644 index 00000000..b8807486 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java @@ -0,0 +1,13 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +/** + * Executes a single tool call and returns the result. + * Implementations are responsible for parsing {@link ToolCallExtract#argumentsJson()} + * and performing the actual action. + */ +@FunctionalInterface +public interface ToolExecutor { + ToolResult execute(ToolCallExtract call); +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java new file mode 100644 index 00000000..e31fa67e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java @@ -0,0 +1,94 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Holds available tool definitions and their executors. + * Registration order is preserved for deterministic JSON output. + */ +public class ToolRegistry { + + private final Map entries = new LinkedHashMap<>(); + + private record Entry(ToolDefinition definition, ToolExecutor executor) {} + + public ToolRegistry register(ToolDefinition definition, ToolExecutor executor) { + entries.put(definition.name(), new Entry(definition, executor)); + return this; + } + + public Optional getDefinition(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.definition()); + } + + public Optional getExecutor(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.executor()); + } + + public List definitions() { + return entries.values().stream().map(Entry::definition).toList(); + } + + public boolean isEmpty() { + return entries.isEmpty(); + } + + /** + * Executes the named tool, returning a failure result for unknown tools or + * executor exceptions. Never throws. + */ + public ToolResult execute(ToolCallExtract call) { + Optional executor = getExecutor(call.name()); + if (executor.isEmpty()) { + return ToolResult.failure(call.name(), "Unknown tool: " + call.name()); + } + try { + return executor.get().execute(call); + } catch (Exception e) { + return ToolResult.failure(call.name(), "Tool execution failed: " + e.getMessage()); + } + } + + /** + * Serialises all registered tools to the flat JSON array expected by + * {@code LlamaChatFormat.toolSystemPromptSuffix()} and + * {@code Qwen3ChatFormat.toolSystemPromptSuffix()}. + * + * Format: {@code [{"name":…,"description":…,"parameters":{…}}]} + */ + public String toToolsJson() { + List defs = definitions(); + if (defs.isEmpty()) return "[]"; + + StringBuilder sb = new StringBuilder("[\n"); + for (int i = 0; i < defs.size(); i++) { + ToolDefinition d = defs.get(i); + sb.append(" {\n"); + sb.append(" \"name\": \"").append(escapeJson(d.name())).append("\",\n"); + sb.append(" \"description\": \"").append(escapeJson(d.description())).append("\",\n"); + sb.append(" \"parameters\": ").append(d.parametersJson()).append("\n"); + sb.append(" }"); + if (i < defs.size() - 1) sb.append(","); + sb.append("\n"); + } + sb.append("]"); + return sb.toString(); + } + + private static String escapeJson(String s) { + if (s == null) return ""; + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java new file mode 100644 index 00000000..6752b207 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java @@ -0,0 +1,24 @@ +package org.beehive.gpullama3.tools; + +/** + * The result of executing a single tool call. + * + * @param toolName name of the tool that was invoked + * @param resultText the tool's output (may be JSON or plain text) + * @param error non-null when execution failed; contains the error message + */ +public record ToolResult(String toolName, String resultText, String error) { + + /** Returns true when the tool execution failed. */ + public boolean isError() { + return error != null; + } + + public static ToolResult success(String toolName, String resultText) { + return new ToolResult(toolName, resultText, null); + } + + public static ToolResult failure(String toolName, String errorMessage) { + return new ToolResult(toolName, null, errorMessage); + } +} From bf184413281d1eeabfcdcc6d75424314324642f7 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Sat, 18 Apr 2026 15:05:33 +0300 Subject: [PATCH 3/5] [introduce] Add standalone ToolCallingApp for tool-calling functionality in GPULlama3 --- .../org/beehive/gpullama3/ToolCallingApp.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingApp.java diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingApp.java b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java new file mode 100644 index 00000000..7729d695 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java @@ -0,0 +1,102 @@ +package org.beehive.gpullama3; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.model.format.ToolCallParserUtils; +import org.beehive.gpullama3.tools.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +/** + * Standalone tool-calling entry point for GPULlama3.java. + * + * Uses the same command-line flags as {@link LlamaApp} plus the tool-calling loop + * provided by {@link ToolCallingSession}. A {@code listDirectory} tool is registered + * as a built-in demo; extend {@link #buildRegistry()} to add more tools. + * + *
+ * java @$TORNADOVM_HOME/tornado-argfile \
+ *     --add-modules jdk.incubator.vector --enable-preview \
+ *     -cp gpu-llama3.jar org.beehive.gpullama3.ToolCallingApp \
+ *     --model /path/to/model.gguf \
+ *     --prompt "Show me what is inside /tmp" \
+ *     --use-tornadovm true
+ * 
+ */ +public class ToolCallingApp { + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(args); + Model model = loadModel(options); + Sampler sampler = createSampler(model, options); + + ToolRegistry registry = buildRegistry(); + ToolCallingOptions tcOptions = ToolCallingOptions.from(options); + ToolCallingSession session = new ToolCallingSession(model, sampler, registry, tcOptions); + + ToolCallingResult result = session.run(options.systemPrompt(), options.prompt()); + + if (!tcOptions.verbose()) { + // verbose=false means ToolCallingSession did not stream tokens — print the answer now + System.out.println(result.finalAnswer()); + } + } + + // ── Tool registry ───────────────────────────────────────────────────────── + + private static ToolRegistry buildRegistry() { + ToolRegistry registry = new ToolRegistry(); + registry.register(listDirectoryDefinition(), ToolCallingApp::listDirectory); + return registry; + } + + private static ToolDefinition listDirectoryDefinition() { + return new ToolDefinition( + "listDirectory", + "Lists the contents of a directory on the local filesystem. " + + "Returns file names and metadata. " + + "Use this when the user asks what is inside a directory or folder.", + """ + {"type":"object","properties":{"path":{"type":"string",\ + "description":"Absolute path of the directory to list, e.g. /tmp or /home/orion"}},\ + "required":["path"]}"""); + } + + private static ToolResult listDirectory(ToolCallExtract call) { + String path = ToolCallParserUtils.extractStringValue(call.argumentsJson(), "path"); + if (path == null || path.isBlank()) { + return ToolResult.failure("listDirectory", + "Could not extract 'path' from arguments: " + call.argumentsJson()); + } + Path dir = Path.of(path); + if (!dir.isAbsolute()) + return ToolResult.failure("listDirectory", "Path must be absolute. Got: " + path); + if (!Files.exists(dir)) + return ToolResult.failure("listDirectory", "Path does not exist: " + path); + if (!Files.isDirectory(dir)) + return ToolResult.failure("listDirectory", "Not a directory: " + path); + try { + StringBuilder sb = new StringBuilder("Contents of ").append(path).append(":\n"); + try (var stream = Files.list(dir)) { + stream.sorted().forEach(entry -> { + boolean isDir = Files.isDirectory(entry); + long size = 0; + try { if (!isDir) size = Files.size(entry); } catch (IOException ignored) {} + sb.append(isDir ? "dir: " : "file: ") + .append(entry.getFileName()) + .append(isDir ? "" : ", " + size + " bytes") + .append("\n"); + }); + } + return ToolResult.success("listDirectory", sb.toString()); + } catch (IOException e) { + return ToolResult.failure("listDirectory", e.getMessage()); + } + } +} From a35e87ba4dd9d0fecb5778c0e19891b57aa860ab Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 21 Apr 2026 12:28:05 +0300 Subject: [PATCH 4/5] [tools][wip] Add support for Llama 3.2 tool-call injection: batch tool calls, user message integration, and enhanced response parsing. --- .../gpullama3/model/format/ChatFormat.java | 58 ++++++++++++- .../model/format/LlamaChatFormat.java | 65 ++++++++++---- .../model/format/Qwen3ChatFormat.java | 5 ++ .../model/format/ToolCallParserUtils.java | 84 +++++++++++++++++-- .../gpullama3/tools/ToolCallingSession.java | 73 +++++++++++----- 5 files changed, 236 insertions(+), 49 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 486cdd54..c1ceb267 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -39,16 +39,54 @@ default ChatTokens chatTokens() { /** * Returns plain text to append to the system message content when tools are available. - * The returned string is concatenated to the system message before encoding, so the - * normal {@link #encodeMessage} path handles tokenization. + * Used by formats that inject tool definitions into the system message. * - * @param toolsJson JSON array of tool definitions, e.g. - * {@code [{"type":"function","function":{...}}]} + *

Formats that inject tools into the user message instead should override + * {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and + * {@link #toolFirstUserMessagePrefix(String)} rather than this method. + * + * @param toolsJson JSON array of tool definitions */ default String toolSystemPromptSuffix(String toolsJson) { throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); } + /** + * Returns {@code true} when this format injects tool definitions into the + * first user message instead of the system message. + * + *

When this returns {@code true}, callers should: + *

    + *
  1. Prepend {@link #toolSystemMessagePrefix()} to the system message content.
  2. + *
  3. Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.
  4. + *
+ * When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to + * the system message as before. + */ + default boolean injectsToolsInUserMessage() { + return false; + } + + /** + * Returns text to prepend to the system message content when tools are active + * and {@link #injectsToolsInUserMessage()} is {@code true}. + * Default: empty string (no prefix). + */ + default String toolSystemMessagePrefix() { + return ""; + } + + /** + * Returns the preamble to prepend to the first user message when + * {@link #injectsToolsInUserMessage()} is {@code true}. + * The preamble should include the tool definitions and usage instructions. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolFirstUserMessagePrefix(String toolsJson) { + return ""; + } + /** * Re-encodes a prior assistant tool-call turn into the conversation token stream. * Used when replaying multi-turn history that contains a previous tool call. @@ -80,6 +118,18 @@ default Optional extractToolCall(String responseText) { return Optional.empty(); } + /** + * Extracts ALL tool calls from a response. Models may emit multiple + * {@code } blocks in a single turn (batch tool calls). + * The default delegates to {@link #extractToolCall} for formats that + * do not support batch calls. + * + * @param responseText the fully decoded response from the model + */ + default List extractAllToolCalls(String responseText) { + return extractToolCall(responseText).map(List::of).orElse(List.of()); + } + /** * Stop tokens to use when tool calling is enabled. * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 289b3004..6f2e6ffe 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -78,32 +78,58 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message + * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). + * The system message receives only an environment prefix; the tools and usage instructions + * go in the user turn. */ @Override - public String toolSystemPromptSuffix(String toolsJson) { - return "\n\n# Tools\n\n" - + "You may call one or more functions to assist with the user query.\n\n" - + "You are provided with function signatures within XML tags:\n\n" - + "\n" + toolsJson + "\n\n\n" - + "IMPORTANT: the \"name\" field in your tool call MUST be exactly one of the function names " - + "listed inside above — not a path, not a word from the user's message.\n\n" - + "For each function call, return a json object with function name and arguments " - + "within XML tags:\n\n" - + "\n" - + "{\"name\": , \"arguments\": }\n" - + ""; + public boolean injectsToolsInUserMessage() { + return true; } /** - * Re-encodes a prior assistant tool-call turn for multi-turn history. - * Format: {@code <|start_header_id|>assistant<|end_header_id|>\n<|python_tag|>JSON<|eom_id|>} + * System-message prefix that signals tool availability to Llama 3.2. + * Matches the template's {@code "Environment: ipython\n"} line. + */ + @Override + public String toolSystemMessagePrefix() { + return "Environment: ipython\n\n"; + } + + /** + * Prepends tool definitions and usage instructions to the first user message, + * matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}). + * + *

Format mirrors: + *

+     * Given the following functions, please respond with a JSON for a function call
+     * with its proper arguments that best answers the given prompt.
+     *
+     * Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
+     * Do not use variables.
+     *
+     * {toolsJson}
+     *
+     * 
+ */ + @Override + public String toolFirstUserMessagePrefix(String toolsJson) { + return "Given the following functions, please respond with a JSON for a function call " + + "with its proper arguments that best answers the given prompt.\n\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of " + + "argument name and its value}. Do not use variables.\n\n" + + toolsJson + "\n\n"; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history using the + * Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}. */ @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); - String json = "\n{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}\n"; + String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; tokens.addAll(tokenizer.encodeAsList(json)); tokens.add(endOfTurn); return tokens; @@ -136,6 +162,11 @@ public Optional extractToolCall(String responseText) { return ToolCallParserUtils.parseLlamaResponse(responseText); } + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } + /** * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index de671a59..35930ff1 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -193,4 +193,9 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St public Optional extractToolCall(String responseText) { return ToolCallParserUtils.parseQwen3Response(responseText); } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index 73a327a4..146148fc 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -1,5 +1,7 @@ package org.beehive.gpullama3.model.format; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; /** @@ -41,6 +43,11 @@ public static Optional parseLlamaResponse(String responseText) String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); return parseLlamaJson(json); } + // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag + if (tcStart != -1 && tcEnd == -1) { + String json = responseText.substring(tcStart + "".length()).strip(); + return parseLlamaJson(json); + } // 3. Fallback: raw JSON, possibly inside markdown code fences String stripped = stripMarkdownFences(responseText.strip()); @@ -72,6 +79,54 @@ private static Optional parseLlamaJson(String json) { // ── Qwen3 ───────────────────────────────────────────────────────────────── + /** + * Extracts ALL tool calls from a response that may contain multiple + * {@code } blocks (Llama 3.2 and Qwen3 batch calls). + * + * Falls back to the raw-JSON single-call path if no tags are found. + * Returns an empty list when the response contains no tool calls. + */ + public static List parseAllToolCalls(String responseText) { + List calls = new java.util.ArrayList<>(); + + // <|python_tag|> (Llama 3.1) — single call by definition + int pythonIdx = responseText.indexOf("<|python_tag|>"); + if (pythonIdx != -1) { + parseLlamaJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + .ifPresent(calls::add); + return calls; + } + + // Scan for all blocks + int searchFrom = 0; + while (true) { + int start = responseText.indexOf("", searchFrom); + if (start == -1) break; + int end = responseText.indexOf("", start); + String json; + if (end != -1) { + json = responseText.substring(start + "".length(), end).strip(); + searchFrom = end + "".length(); + } else { + // Unclosed tag — model stopped before writing the closing tag + json = responseText.substring(start + "".length()).strip(); + searchFrom = responseText.length(); + } + parseLlamaJson(json).ifPresent(calls::add); + if (end == -1) break; + } + + // Raw JSON fallback (no tags at all) + if (calls.isEmpty()) { + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + parseLlamaJson(stripped).ifPresent(calls::add); + } + } + + return calls; + } + /** * Extracts a tool call enclosed in {@code } tags * as produced by Qwen3 models. @@ -79,9 +134,11 @@ private static Optional parseLlamaJson(String json) { public static Optional parseQwen3Response(String responseText) { int start = responseText.indexOf(""); int end = responseText.lastIndexOf(""); - if (start == -1 || end == -1 || end <= start) return Optional.empty(); + if (start == -1) return Optional.empty(); - String json = responseText.substring(start + "".length(), end).strip(); + String json = (end != -1 && end > start) + ? responseText.substring(start + "".length(), end).strip() + : responseText.substring(start + "".length()).strip(); String name = extractStringValue(json, "name"); if (name == null) return Optional.empty(); @@ -104,7 +161,11 @@ public static String stripMarkdownFences(String text) { return body.strip(); } - /** Extracts the string value for {@code "key": ""} from a JSON object. Tolerates whitespace around {@code :}. */ + /** + * Extracts the string value for {@code "key": ""} from a JSON object. + * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) + * inside the value, so multi-line code strings with embedded {@code "} are returned intact. + */ public static String extractStringValue(String json, String key) { String marker = "\"" + key + "\""; int markerIdx = json.indexOf(marker); @@ -113,9 +174,20 @@ public static String extractStringValue(String json, String key) { if (colonIdx == -1) return null; int quoteStart = json.indexOf('"', colonIdx + 1); if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); + // Scan for the closing quote, honouring backslash escapes + int i = quoteStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; // skip escape sequence (e.g. \", \\, \n) + } else if (c == '"') { + break; + } else { + i++; + } + } + if (i >= json.length()) return null; + return json.substring(quoteStart + 1, i); } /** diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java index b56be36c..565294db 100644 --- a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java @@ -50,23 +50,45 @@ public ToolCallingResult run(String userPrompt) { public ToolCallingResult run(String systemPrompt, String userPrompt) { ChatFormat chatFormat = model.chatFormat(); String toolsJson = registry.toToolsJson(); - String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); - String effectiveSystem = systemPrompt == null - ? toolSuffix.stripLeading() - : systemPrompt + toolSuffix; + // Build effective system and user messages according to the model's tool injection strategy. + // Formats like Llama 3.2 inject tool definitions into the first user message + // (injectsToolsInUserMessage() == true); others (Qwen3, Mistral) append to the system message. + String effectiveSystem; + String effectiveUser; + if (chatFormat.injectsToolsInUserMessage()) { + String prefix = chatFormat.toolSystemMessagePrefix(); + effectiveSystem = systemPrompt == null + ? prefix.stripLeading() + : prefix + systemPrompt; + effectiveUser = chatFormat.toolFirstUserMessagePrefix(toolsJson) + userPrompt; + } else { + String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); + effectiveSystem = systemPrompt == null + ? toolSuffix.stripLeading() + : systemPrompt + toolSuffix; + effectiveUser = userPrompt; + } + + log("\n[DEBUG] model: %s", model.getClass().getSimpleName()); + log("[DEBUG] chatFormat: %s", chatFormat.getClass().getSimpleName()); + log("[DEBUG] shouldAddSystemPrompt: %s", model.shouldAddSystemPrompt()); + log("[DEBUG] toolAwareStopTokens: %s", chatFormat.getToolAwareStopTokens()); + log("[DEBUG] ── effective system prompt ──────────────────────────"); + log("%s", effectiveSystem); + log("[DEBUG] ── end system prompt ─────────────────────────────────"); // ── Build initial prompt tokens ─────────────────────────────────────── List promptTokens = new ArrayList<>(); if (model.shouldAddBeginOfText()) { promptTokens.add(chatFormat.getBeginOfText()); } - if (model.shouldAddSystemPrompt()) { + if (model.shouldAddSystemPrompt() && !effectiveSystem.isBlank()) { promptTokens.addAll(chatFormat.encodeMessage( new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); } promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, userPrompt))); + new ChatFormat.Message(ChatFormat.Role.USER, effectiveUser))); promptTokens.addAll(chatFormat.encodeHeader( new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); @@ -92,31 +114,38 @@ public ToolCallingResult run(String systemPrompt, String userPrompt) { } String rawResponse = model.tokenizer().decode(responseTokens); - Optional maybeCall = chatFormat.extractToolCall(rawResponse); - if (maybeCall.isEmpty()) { + log("[DEBUG] raw response (round %d): >>>%s<<<", round + 1, rawResponse); + log("[DEBUG] response tokens: %d (stop token stripped: %s)", + responseTokens.size(), + !responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast()) ? "yes" : "no — stop token was removed before decoding"); + + List batchCalls = chatFormat.extractAllToolCalls(rawResponse); + log("[DEBUG] extractAllToolCalls found: %d call(s)", batchCalls.size()); + if (batchCalls.isEmpty()) { log("\n--- No tool call detected; returning plain text response ---"); return new ToolCallingResult(rawResponse, callsMade, toolResults, false); } - ToolCallExtract call = maybeCall.get(); - callsMade.add(call); - log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); + // ── Execute all tool calls found in this response ───────────────── + promptTokens = new ArrayList<>(promptTokens); + for (ToolCallExtract call : batchCalls) { + callsMade.add(call); + log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); - ToolResult result = registry.execute(call); - toolResults.add(result); - log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); + ToolResult result = registry.execute(call); + toolResults.add(result); + log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); - // ── Build continuation tokens ───────────────────────────────────── - String feedbackContent = result.isError() - ? "Tool '" + call.name() + "' failed: " + result.error() - : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); + String feedbackContent = result.isError() + ? "Tool '" + call.name() + "' failed: " + result.error() + : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); - promptTokens = new ArrayList<>(promptTokens); - promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); - promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); + promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + } promptTokens.addAll(chatFormat.encodeMessage( new ChatFormat.Message(ChatFormat.Role.USER, - "Using only the tool result above, answer the user's question in plain text. Do not repeat the raw output."))); + "Using the tool results above, call the next required tool or answer the user's original question if all steps are complete."))); promptTokens.addAll(chatFormat.encodeHeader( new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); } From 5928ead65c6c2bd5e3f4e577238622776f2bda01 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 28 Apr 2026 14:11:28 +0300 Subject: [PATCH 5/5] [tools][wip] Add fix for tool-calling --- .../beehive/gpullama3/model/format/LlamaChatFormat.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 6f2e6ffe..421a5513 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -129,9 +129,15 @@ public String toolFirstUserMessagePrefix(String toolsJson) { @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + // Preserve the <|python_tag|> prefix used by LLaMA 3.1/3.2 for tool calls so that + // replayed history looks identical to what the model originally generated. + if (pythonTag != -1) { + tokens.add(pythonTag); + } String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; tokens.addAll(tokenizer.encodeAsList(json)); - tokens.add(endOfTurn); + // LLaMA 3.1 ends tool-call turns with <|eom_id|>; fall back to <|eot_id|> for 3.2. + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); return tokens; }