Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,39 @@ Or run it directly from the repository:
jbang https://github.com/github/copilot-sdk/blob/main/java/jbang-example.java
```

## Annotation-based tools and `ToolInvocation` context

When you define tools with `@CopilotTool`, parameters of type `ToolInvocation` are injected as runtime context and are not exposed in the tool schema.
`ToolInvocation` can appear before, between, or after schema-visible parameters.

```java
import com.github.copilot.rpc.ToolInvocation;
import com.github.copilot.tool.CopilotTool;
import com.github.copilot.tool.Param;

class ProgressTools {
@CopilotTool("Reports the current phase and session")
public String reportProgress(
@Param("Current phase") String phase,
ToolInvocation invocation) {
return "phase=" + phase + ", sessionId=" + invocation.getSessionId();
}
}
```

Position examples:

```java
@CopilotTool("Invocation first")
public String report(ToolInvocation invocation, @Param("Phase") String phase) { ... }

@CopilotTool("Invocation only")
public String onlyContext(ToolInvocation invocation) { ... }

@CopilotTool("Invocation middle")
public String report(@Param("Phase") String phase, ToolInvocation invocation, @Param("Limit") int limit) { ... }
```

## Memory

Sessions can opt into persistent memory, allowing the agent to read and write memory across turns. Memory is configured per session and applies to both `createSession` and `resumeSession`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
* When the assistant invokes a tool, this object contains the context including
* the session ID, tool call ID, tool name, and arguments parsed from the
* assistant's request.
* <p>
* In annotation-based tools, methods annotated with
* {@link com.github.copilot.tool.CopilotTool} may declare a
* {@code ToolInvocation} parameter in any position (before, between, or after
* schema-visible parameters). It is always injected as runtime context and is
* never included in the tool's JSON schema.
*
* @see ToolHandler
* @see ToolDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
@CopilotExperimental
public class CopilotToolProcessor extends AbstractProcessor {

private static final String TOOL_INVOCATION_TYPE = "com.github.copilot.rpc.ToolInvocation";

private final SchemaGenerator schemaGenerator = new SchemaGenerator();

@Override
Expand All @@ -67,7 +69,17 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment
}

// Validate @Param conflicts
int toolInvocationParamCount = 0;
for (VariableElement param : method.getParameters()) {
if (isToolInvocationType(param.asType())) {
toolInvocationParamCount++;
if (param.getAnnotation(Param.class) != null) {
processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR,
"@Param is not supported on ToolInvocation parameters because ToolInvocation is injected runtime context and not part of the tool schema",
param);
}
continue;
}
Param paramAnnotation = param.getAnnotation(Param.class);
if (paramAnnotation != null && paramAnnotation.required()
&& !paramAnnotation.defaultValue().isEmpty()) {
Expand All @@ -88,10 +100,16 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment
param);
}
}
if (toolInvocationParamCount > 1) {
processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR,
"@CopilotTool methods may declare at most one ToolInvocation parameter; ToolInvocation is injected runtime context and not part of the tool schema",
method);
}

// Validate single-record wrapper parameter metadata
if (method.getParameters().size() == 1) {
VariableElement singleParam = method.getParameters().get(0);
List<? extends VariableElement> schemaParameters = getSchemaParameters(method.getParameters());
if (schemaParameters.size() == 1) {
VariableElement singleParam = schemaParameters.get(0);
if (isRecord(singleParam.asType())) {
Param paramAnnotation = singleParam.getAnnotation(Param.class);
if (paramAnnotation != null) {
Expand Down Expand Up @@ -262,18 +280,20 @@ private void writeToolDefinition(PrintWriter out, ExecutableElement method) {
}

private String generateSchemaWithParamMetadata(List<? extends VariableElement> parameters) {
if (parameters.isEmpty()) {
List<? extends VariableElement> schemaParameters = getSchemaParameters(parameters);

if (schemaParameters.isEmpty()) {
return "Map.of(\"type\", \"object\", \"properties\", Map.of(), \"required\", List.of())";
}
if (parameters.size() == 1 && isRecord(parameters.get(0).asType())) {
return schemaGenerator.generateSchemaSource(parameters.get(0).asType(), processingEnv.getTypeUtils(),
if (schemaParameters.size() == 1 && isRecord(schemaParameters.get(0).asType())) {
return schemaGenerator.generateSchemaSource(schemaParameters.get(0).asType(), processingEnv.getTypeUtils(),
processingEnv.getElementUtils());
}

List<String> propertyEntries = new ArrayList<>();
List<String> requiredNames = new ArrayList<>();

for (VariableElement param : parameters) {
for (VariableElement param : schemaParameters) {
String paramName = getParamName(param);
TypeMirror paramType = param.asType();
Param paramAnnotation = param.getAnnotation(Param.class);
Expand Down Expand Up @@ -304,6 +324,20 @@ private String generateSchemaWithParamMetadata(List<? extends VariableElement> p
return "Map.of(\"type\", \"object\", \"properties\", " + properties + ", \"required\", " + required + ")";
}

private List<? extends VariableElement> getSchemaParameters(List<? extends VariableElement> parameters) {
List<VariableElement> filtered = new ArrayList<>();
for (VariableElement param : parameters) {
if (!isToolInvocationType(param.asType())) {
filtered.add(param);
}
}
return filtered;
}

private boolean isToolInvocationType(TypeMirror type) {
return TOOL_INVOCATION_TYPE.equals(processingEnv.getTypeUtils().erasure(type).toString());
}

private String buildPropertySchema(String typeSchema, Param paramAnnotation, TypeMirror paramType) {
if (paramAnnotation == null) {
return typeSchema;
Expand All @@ -328,20 +362,21 @@ private String buildPropertySchema(String typeSchema, Param paramAnnotation, Typ

private String generateLambdaBody(ExecutableElement method) {
List<? extends VariableElement> params = method.getParameters();
List<? extends VariableElement> schemaParameters = getSchemaParameters(params);
StringBuilder sb = new StringBuilder();

// Generate argument extraction
if (!params.isEmpty()) {
if (!schemaParameters.isEmpty()) {
// Check if single-record-parameter shortcut applies
if (params.size() == 1 && isRecord(params.get(0).asType())) {
String typeName = getTypeString(params.get(0).asType());
String paramName = params.get(0).getSimpleName().toString();
if (schemaParameters.size() == 1 && isRecord(schemaParameters.get(0).asType())) {
String typeName = getTypeString(schemaParameters.get(0).asType());
String paramName = schemaParameters.get(0).getSimpleName().toString();
sb.append(" ").append(typeName).append(" ").append(paramName)
.append(" = mapper.convertValue(invocation.getArguments(), ").append(typeName)
.append(".class);\n");
} else {
sb.append("Map<String, Object> args = invocation.getArguments();\n");
for (VariableElement param : params) {
for (VariableElement param : schemaParameters) {
String paramName = getParamName(param);
String varName = param.getSimpleName().toString();
TypeMirror paramType = param.asType();
Expand Down Expand Up @@ -404,7 +439,11 @@ private String generateArgList(List<? extends VariableElement> params) {
if (i > 0) {
sb.append(", ");
}
sb.append(params.get(i).getSimpleName().toString());
if (isToolInvocationType(params.get(i).asType())) {
sb.append("invocation");
} else {
sb.append(params.get(i).getSimpleName().toString());
}
}
return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
*--------------------------------------------------------------------------------------------*/

package com.github.copilot.rpc;

public record RecordInvocationArgs(String query, int limit) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import com.github.copilot.rpc.fixtures.ArgCoercionTools;
import com.github.copilot.rpc.fixtures.DateTimeTools;
import com.github.copilot.rpc.fixtures.DefaultValueTools;
import com.github.copilot.rpc.fixtures.InvocationAwareTools;
import com.github.copilot.rpc.fixtures.MultiReturnTools;
import com.github.copilot.rpc.fixtures.OptionalParamTools;
import com.github.copilot.rpc.fixtures.OverrideTools;
import com.github.copilot.rpc.fixtures.SimpleTools;
import com.github.copilot.rpc.fixtures.StaticInvocationTools;
import com.github.copilot.rpc.fixtures.StaticTools;

/**
Expand Down Expand Up @@ -309,6 +311,163 @@ void fromObject_optionalLongAbsent() throws Exception {
assertEquals("100", result);
}

// ── Test 11: ToolInvocation injection ───────────────────────────────────────

@Test
void fromObject_toolInvocationInjection_instanceMethod() throws Exception {
var instance = new InvocationAwareTools();
var tools = ToolDefinition.fromObject(instance);
var tool = findTool(tools, "report_progress");
assertNotNull(tool);

var result = tool.handler().invoke(createInvocation("report_progress", Map.of("phase", "analyzing"))
.setSessionId("session-123").setToolCallId("call-456")).get();
assertEquals("phase=analyzing,sessionId=session-123,toolCallId=call-456,toolName=report_progress", result);
}

@Test
void fromObject_toolInvocationInjection_schemaExcludesToolInvocation() {
var tools = ToolDefinition.fromObject(new InvocationAwareTools());
var tool = findTool(tools, "report_progress");
assertNotNull(tool);

@SuppressWarnings("unchecked")
var schema = (Map<String, Object>) tool.parameters();
@SuppressWarnings("unchecked")
var properties = (Map<String, Object>) schema.get("properties");
@SuppressWarnings("unchecked")
var required = (List<String>) schema.get("required");

assertTrue(properties.containsKey("phase"));
assertFalse(properties.containsKey("invocation"));
assertEquals(List.of("phase"), required);
}

@Test
void fromObject_toolInvocationInjection_asyncMethod() throws Exception {
var instance = new InvocationAwareTools();
var tools = ToolDefinition.fromObject(instance);
var tool = findTool(tools, "report_progress_async");
assertNotNull(tool);

var result = tool.handler().invoke(createInvocation("report_progress_async", Map.of("phase", "planning"))
.setSessionId("session-789").setToolCallId("call-012")).get();
assertEquals("async phase=planning,sessionId=session-789,toolCallId=call-012,toolName=report_progress_async",
result);
}

@Test
void fromClass_toolInvocationInjection_staticMethod() throws Exception {
var tools = ToolDefinition.fromClass(StaticInvocationTools.class);
var tool = findTool(tools, "report_static");
assertNotNull(tool);

var result = tool.handler().invoke(createInvocation("report_static", Map.of("phase", "completed"))
.setSessionId("session-321").setToolCallId("call-654")).get();
assertEquals("phase=completed,sessionId=session-321,toolCallId=call-654,toolName=report_static", result);
}

@Test
void fromObject_toolInvocationInjection_firstParameter() throws Exception {
var tools = ToolDefinition.fromObject(new InvocationAwareTools());
var tool = findTool(tools, "report_progress_first");
assertNotNull(tool);

@SuppressWarnings("unchecked")
var schema = (Map<String, Object>) tool.parameters();
@SuppressWarnings("unchecked")
var properties = (Map<String, Object>) schema.get("properties");
@SuppressWarnings("unchecked")
var required = (List<String>) schema.get("required");

assertTrue(properties.containsKey("phase"));
assertFalse(properties.containsKey("invocation"));
assertEquals(List.of("phase"), required);

var result = tool.handler().invoke(createInvocation("report_progress_first", Map.of("phase", "starting"))
.setSessionId("session-first").setToolCallId("call-first")).get();
assertEquals(
"first phase=starting,sessionId=session-first,toolCallId=call-first,toolName=report_progress_first",
result);
}

@Test
void fromObject_toolInvocationInjection_onlyParameter() throws Exception {
var tools = ToolDefinition.fromObject(new InvocationAwareTools());
var tool = findTool(tools, "only_context");
assertNotNull(tool);

@SuppressWarnings("unchecked")
var schema = (Map<String, Object>) tool.parameters();
@SuppressWarnings("unchecked")
var properties = (Map<String, Object>) schema.get("properties");
@SuppressWarnings("unchecked")
var required = (List<String>) schema.get("required");

assertTrue(properties.isEmpty());
assertTrue(required.isEmpty());

var result = tool.handler().invoke(
createInvocation("only_context", Map.of()).setSessionId("session-only").setToolCallId("call-only"))
.get();
assertEquals("only sessionId=session-only,toolCallId=call-only,toolName=only_context", result);
}

@Test
void fromObject_toolInvocationInjection_middleParameter() throws Exception {
var tools = ToolDefinition.fromObject(new InvocationAwareTools());
var tool = findTool(tools, "report_progress_middle");
assertNotNull(tool);

@SuppressWarnings("unchecked")
var schema = (Map<String, Object>) tool.parameters();
@SuppressWarnings("unchecked")
var properties = (Map<String, Object>) schema.get("properties");
@SuppressWarnings("unchecked")
var required = (List<String>) schema.get("required");

assertTrue(properties.containsKey("phase"));
assertTrue(properties.containsKey("limit"));
assertFalse(properties.containsKey("invocation"));
assertEquals(List.of("phase", "limit"), required);

var result = tool.handler()
.invoke(createInvocation("report_progress_middle", Map.of("phase", "running", "limit", 7))
.setSessionId("session-middle").setToolCallId("call-middle"))
.get();
assertEquals(
"middle phase=running,limit=7,sessionId=session-middle,toolCallId=call-middle,toolName=report_progress_middle",
result);
}

@Test
void fromObject_toolInvocationInjection_singleRecordAndInvocation() throws Exception {
var tools = ToolDefinition.fromObject(new InvocationAwareTools());
var tool = findTool(tools, "report_progress_with_record");
assertNotNull(tool);

@SuppressWarnings("unchecked")
var schema = (Map<String, Object>) tool.parameters();
@SuppressWarnings("unchecked")
var properties = (Map<String, Object>) schema.get("properties");
@SuppressWarnings("unchecked")
var required = (List<String>) schema.get("required");

assertTrue(properties.containsKey("query"));
assertTrue(properties.containsKey("limit"));
assertFalse(properties.containsKey("args"));
assertFalse(properties.containsKey("invocation"));
assertEquals(List.of("query", "limit"), required);

var result = tool.handler()
.invoke(createInvocation("report_progress_with_record", Map.of("query", "logs", "limit", 3))
.setSessionId("session-record").setToolCallId("call-record"))
.get();
assertEquals(
"record query=logs,limit=3,sessionId=session-record,toolCallId=call-record,toolName=report_progress_with_record",
result);
}

// ── Helpers ─────────────────────────────────────────────────────────────────

private static ToolDefinition findTool(List<ToolDefinition> tools, String name) {
Expand Down
Loading
Loading