Introduce tool calling support#116
Conversation
# Conflicts: # LlamaTornadoCli.java
… extract ToolCallingDemo
…l calls, user message integration, and enhanced response parsing.
There was a problem hiding this comment.
Pull request overview
Introduces a tool/function-calling subsystem to GPULlama3.java. A new tools package adds a registry, executor interface, options, session orchestrator, and a standalone ToolCallingApp entry point. The chat format abstraction is extended with tool-aware hooks, implemented for Llama 3.1/3.2 and Qwen3, with shared parsing utilities for <tool_call>, <|python_tag|>, and raw-JSON variants. The llamaTornado/llama-tornado runners gain a --tool-demo flag to launch the new entry point.
Changes:
- New
org.beehive.gpullama3.toolspackage:ToolRegistry,ToolDefinition,ToolExecutor,ToolResult,ToolCallingOptions,ToolCallingSession,ToolCallingResult, plus aToolCallingAppCLI with a demolistDirectorytool. ChatFormatgains tool-calling defaults;LlamaChatFormatandQwen3ChatFormatimplement system/user prompt injection, tool-call/result re-encoding, response extraction, and tool-aware stop tokens, backed by a newToolCallParserUtils.- Both runner scripts add
--tool-demo; the Java runner swaps the main class, while the Python runner adds a--main-classoverride.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| src/main/java/org/beehive/gpullama3/tools/ToolResult.java | Record with success/failure factories for tool execution outcomes. |
| src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java | Ordered registry, execution dispatch, JSON serialization of tool defs. |
| src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java | Functional interface for tool implementations. |
| src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java | Record describing a tool (name, description, parameter schema JSON). |
| src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java | Orchestrates the prompt→tool→re-prompt→answer loop across formats. |
| src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java | Aggregated session outcome (final answer + call/result history). |
| src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java | Tuning record (token limit, round-trips, truncation, verbose, GPU). |
| src/main/java/org/beehive/gpullama3/ToolCallingApp.java | CLI entry point wiring Options to a ToolCallingSession plus demo tool. |
| src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java | Shared string parsers for Llama/Qwen3 tool-call payloads. |
| src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java | Record holding parsed tool name + raw arguments JSON. |
| src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java | Adds default tool-calling hooks to the chat-format SPI. |
| src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java | Llama 3.1/3.2 implementation of tool prompts and (de)serialization. |
| src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java | Qwen3 implementation of tool prompts, <tool_call> encoding, parsing. |
| llamaTornado | Adds --tool-demo flag that swaps the launched main class. |
| llama-tornado | Adds --tool-demo and a configurable --main-class argument. |
Comments suppressed due to low confidence (2)
llama-tornado:537
- The new default for
--main-classisorg.beehive.gpullama3.cli.LlamaTornadoCli, but that class does not exist in the codebase — the only main class isorg.beehive.gpullama3.LlamaApp. Changing the default silently breaks every invocation ofllama-tornadothat does not explicitly pass--main-class. The default should remainorg.beehive.gpullama3.LlamaApp.
advanced_group.add_argument(
"--main-class",
default="org.beehive.gpullama3.cli.LlamaTornadoCli",
help="Java main class to run (default: LlamaTornadoCli)",
)
src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java:164
- The feedback content sent back to the model uses the truncated tool result (
truncate(result.resultText())), but the truncation is silently dropped if the result is short and the model receives a misleading marker ("\n... (truncated)") only when the limit is exceeded. More importantly,callsMade.size() >= options.maxRoundTrips()is compared against the total number of calls across all rounds, not the number of rounds elapsed — if a single round emits multiple batch tool calls,callsMade.size()will exceedmaxRoundTripseven when only one round occurred, causinghitLimitto be reported incorrectly. Consider tracking round count separately from total calls for thereachedMaxRoundTripsflag.
boolean hitLimit = callsMade.size() >= options.maxRoundTrips()
&& chatFormat.extractToolCall(finalAnswer).isPresent();
return new ToolCallingResult(finalAnswer, callsMade, toolResults, hitLimit);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp")); | ||
| var mainClass = cfg.toolDemo() | ||
| ? "org.beehive.gpullama3.ToolCallingDemo" |
| llama_args.append("--tool-demo") | ||
|
|
| public static ToolCallingOptions defaults() { | ||
| return new ToolCallingOptions(1024, 1, 2000, true, false); | ||
| } | ||
|
|
||
| public static ToolCallingOptions from(org.beehive.gpullama3.Options options) { | ||
| return new ToolCallingOptions( | ||
| options.maxTokens(), | ||
| 1, | ||
| 2000, | ||
| true, | ||
| options.useTornadovm()); | ||
| } |
| private List<Integer> generateTokens(State state, TornadoVMMasterPlan plan, | ||
| List<Integer> prompt, Set<Integer> stopTokens) { | ||
| IntConsumer tokenConsumer = options.verbose() ? this::printToken : null; | ||
|
|
||
| if (options.useGpu()) { | ||
| return model.generateTokensGPU( | ||
| state, 0, prompt, stopTokens, options.maxTokens(), sampler, | ||
| false, tokenConsumer, plan); | ||
| } else { | ||
| return model.generateTokens( | ||
| state, 0, prompt, stopTokens, options.maxTokens(), sampler, | ||
| false, tokenConsumer); | ||
| } | ||
| } |
| log("[DEBUG] response tokens: %d (stop token stripped: %s)", | ||
| responseTokens.size(), | ||
| !responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast()) ? "yes" : "no — stop token was removed before decoding"); |
| public String toToolsJson() { | ||
| List<ToolDefinition> defs = definitions(); | ||
| if (defs.isEmpty()) return "[]"; | ||
|
|
||
| StringBuilder sb = new StringBuilder("[\n"); | ||
| for (int i = 0; i < defs.size(); i++) { | ||
| ToolDefinition d = defs.get(i); | ||
| sb.append(" {\n"); | ||
| sb.append(" \"name\": \"").append(escapeJson(d.name())).append("\",\n"); | ||
| sb.append(" \"description\": \"").append(escapeJson(d.description())).append("\",\n"); | ||
| sb.append(" \"parameters\": ").append(d.parametersJson()).append("\n"); | ||
| sb.append(" }"); | ||
| if (i < defs.size() - 1) sb.append(","); | ||
| sb.append("\n"); | ||
| } | ||
| sb.append("]"); | ||
| return sb.toString(); | ||
| } |
| /** | ||
| * Extracts the string value for {@code "key": "<value>"} from a JSON object. | ||
| * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) | ||
| * inside the value, so multi-line code strings with embedded {@code "} are returned intact. | ||
| */ | ||
| public static String extractStringValue(String json, String key) { | ||
| String marker = "\"" + key + "\""; | ||
| int markerIdx = json.indexOf(marker); | ||
| if (markerIdx == -1) return null; | ||
| int colonIdx = json.indexOf(':', markerIdx + marker.length()); | ||
| if (colonIdx == -1) return null; | ||
| int quoteStart = json.indexOf('"', colonIdx + 1); | ||
| if (quoteStart == -1) return null; | ||
| // Scan for the closing quote, honouring backslash escapes | ||
| int i = quoteStart + 1; | ||
| while (i < json.length()) { | ||
| char c = json.charAt(i); | ||
| if (c == '\\') { | ||
| i += 2; // skip escape sequence (e.g. \", \\, \n) | ||
| } else if (c == '"') { | ||
| break; | ||
| } else { | ||
| i++; | ||
| } | ||
| } | ||
| if (i >= json.length()) return null; | ||
| return json.substring(quoteStart + 1, i); | ||
| } | ||
|
|
||
| /** | ||
| * Extracts the JSON object value for {@code "key": {…}} using brace-counting. | ||
| * Handles nested objects and tolerates whitespace around {@code :}. | ||
| */ | ||
| public static String extractNestedObject(String json, String key) { | ||
| String marker = "\"" + key + "\""; | ||
| int markerIdx = json.indexOf(marker); | ||
| if (markerIdx == -1) return null; | ||
| int colonIdx = json.indexOf(':', markerIdx + marker.length()); | ||
| if (colonIdx == -1) return null; | ||
| int braceStart = json.indexOf('{', colonIdx + 1); | ||
| if (braceStart == -1) return null; | ||
| int depth = 0; | ||
| for (int i = braceStart; i < json.length(); i++) { | ||
| char c = json.charAt(i); | ||
| if (c == '{') depth++; | ||
| else if (c == '}') { | ||
| if (--depth == 0) return json.substring(braceStart, i + 1); | ||
| } | ||
| } | ||
| return null; // unbalanced | ||
| } |
| @Override | ||
| public List<Integer> encodeToolResultTurn(String toolCallId, String toolName, String result) { | ||
| List<Integer> tokens = new ArrayList<>(); | ||
| tokens.add(imStart); | ||
| tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); | ||
| tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); | ||
| if (imEnd != -1) { | ||
| tokens.add(imEnd); | ||
| } | ||
| return tokens; | ||
| } |
No description provided.