Skip to content

Commit 46e5c32

Browse files
Copilotedburns
andauthored
Port ProviderConfig model/token overrides and client startup cleanup fix
- Add modelId, wireModel, maxInputTokens, maxOutputTokens to ProviderConfig - Fix client startup cleanup race: properly destroy CLI process on failure - Add unit tests for new ProviderConfig field serialization - Add E2E tests for provider wire model forwarding - Simplify BYOK identity limitations documentation (per reference impl) Co-authored-by: edburns <75821+edburns@users.noreply.github.com>
1 parent 6ff7384 commit 46e5c32

5 files changed

Lines changed: 255 additions & 10 deletions

File tree

src/main/java/com/github/copilot/sdk/CopilotClient.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ private CompletableFuture<Connection> startCore() {
187187
}
188188

189189
private Connection startCoreBody() {
190+
Process process = null;
190191
try {
191192
JsonRpcClient rpc;
192-
Process process = null;
193193

194194
if (optionsHost != null && optionsPort != null) {
195195
// External server (TCP)
@@ -215,6 +215,11 @@ private Connection startCoreBody() {
215215
LOG.info("Copilot client connected");
216216
return connection;
217217
} catch (Exception e) {
218+
// Clean up process if startup failed partway through
219+
if (process != null) {
220+
cleanupCliProcess(process);
221+
}
222+
218223
String stderr = serverManager.getStderrOutput();
219224
if (!stderr.isEmpty()) {
220225
throw new CompletionException(new IOException(
@@ -224,6 +229,20 @@ private Connection startCoreBody() {
224229
}
225230
}
226231

232+
private static void cleanupCliProcess(Process process) {
233+
try {
234+
if (process.isAlive()) {
235+
process.destroyForcibly();
236+
process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS);
237+
}
238+
} catch (InterruptedException ie) {
239+
Thread.currentThread().interrupt();
240+
LOG.log(Level.FINE, "Interrupted while cleaning up CLI process", ie);
241+
} catch (Exception ex) {
242+
LOG.log(Level.FINE, "Error cleaning up CLI process during failed startup", ex);
243+
}
244+
}
245+
227246
private static final int MIN_PROTOCOL_VERSION = 2;
228247
private static final int METHOD_NOT_FOUND_ERROR_CODE = -32601;
229248

src/main/java/com/github/copilot/sdk/json/ProviderConfig.java

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ public class ProviderConfig {
5757
@JsonProperty("headers")
5858
private Map<String, String> headers;
5959

60+
@JsonProperty("modelId")
61+
private String modelId;
62+
63+
@JsonProperty("wireModel")
64+
private String wireModel;
65+
66+
@JsonProperty("maxPromptTokens")
67+
private Integer maxInputTokens;
68+
69+
@JsonProperty("maxOutputTokens")
70+
private Integer maxOutputTokens;
71+
6072
/**
6173
* Gets the provider type.
6274
*
@@ -225,4 +237,116 @@ public ProviderConfig setHeaders(Map<String, String> headers) {
225237
this.headers = headers;
226238
return this;
227239
}
240+
241+
/**
242+
* Gets the well-known model name used by the runtime to look up agent
243+
* configuration (tools, prompts, reasoning behavior) and default token limits.
244+
* <p>
245+
* Also used as the wire model when {@link #getWireModel()} is not set. Falls
246+
* back to {@link SessionConfig#getModel()}.
247+
*
248+
* @return the model ID, or {@code null} if not set
249+
*/
250+
public String getModelId() {
251+
return modelId;
252+
}
253+
254+
/**
255+
* Sets the well-known model name used by the runtime to look up agent
256+
* configuration (tools, prompts, reasoning behavior) and default token limits.
257+
* <p>
258+
* Also used as the wire model when {@link #setWireModel(String)} is not set.
259+
* Falls back to {@link SessionConfig#getModel()}.
260+
*
261+
* @param modelId
262+
* the model ID
263+
* @return this config for method chaining
264+
*/
265+
public ProviderConfig setModelId(String modelId) {
266+
this.modelId = modelId;
267+
return this;
268+
}
269+
270+
/**
271+
* Gets the model name sent to the provider API for inference.
272+
* <p>
273+
* Use this when the provider's model name (e.g. an Azure deployment name or a
274+
* custom fine-tune name) differs from {@link #getModelId()}. Falls back to
275+
* {@link #getModelId()}, then {@link SessionConfig#getModel()}.
276+
*
277+
* @return the wire model name, or {@code null} if not set
278+
*/
279+
public String getWireModel() {
280+
return wireModel;
281+
}
282+
283+
/**
284+
* Sets the model name sent to the provider API for inference.
285+
* <p>
286+
* Use this when the provider's model name (e.g. an Azure deployment name or a
287+
* custom fine-tune name) differs from {@link #getModelId()}. Falls back to
288+
* {@link #getModelId()}, then {@link SessionConfig#getModel()}.
289+
*
290+
* @param wireModel
291+
* the wire model name
292+
* @return this config for method chaining
293+
*/
294+
public ProviderConfig setWireModel(String wireModel) {
295+
this.wireModel = wireModel;
296+
return this;
297+
}
298+
299+
/**
300+
* Gets the override for the resolved model's default max prompt tokens.
301+
* <p>
302+
* The runtime triggers conversation compaction before sending a request when
303+
* the prompt (system message, history, tool definitions, user message) would
304+
* exceed this limit.
305+
*
306+
* @return the max input tokens, or {@code null} if not set
307+
*/
308+
public Integer getMaxInputTokens() {
309+
return maxInputTokens;
310+
}
311+
312+
/**
313+
* Sets the override for the resolved model's default max prompt tokens.
314+
* <p>
315+
* The runtime triggers conversation compaction before sending a request when
316+
* the prompt (system message, history, tool definitions, user message) would
317+
* exceed this limit.
318+
*
319+
* @param maxInputTokens
320+
* the max input tokens
321+
* @return this config for method chaining
322+
*/
323+
public ProviderConfig setMaxInputTokens(Integer maxInputTokens) {
324+
this.maxInputTokens = maxInputTokens;
325+
return this;
326+
}
327+
328+
/**
329+
* Gets the override for the resolved model's default max output tokens.
330+
* <p>
331+
* When hit, the model stops generating and returns a truncated response.
332+
*
333+
* @return the max output tokens, or {@code null} if not set
334+
*/
335+
public Integer getMaxOutputTokens() {
336+
return maxOutputTokens;
337+
}
338+
339+
/**
340+
* Sets the override for the resolved model's default max output tokens.
341+
* <p>
342+
* When hit, the model stops generating and returns a truncated response.
343+
*
344+
* @param maxOutputTokens
345+
* the max output tokens
346+
* @return this config for method chaining
347+
*/
348+
public ProviderConfig setMaxOutputTokens(Integer maxOutputTokens) {
349+
this.maxOutputTokens = maxOutputTokens;
350+
return this;
351+
}
228352
}

src/site/markdown/advanced.md

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -421,17 +421,36 @@ foundry service status
421421

422422
When using BYOK, be aware of these limitations:
423423

424-
#### Identity Limitations
424+
#### Model and Token Limit Overrides
425425

426-
BYOK authentication uses **static credentials only**. The following identity providers are NOT supported:
426+
You can override the model name and token limits used by the provider:
427427

428-
-**Microsoft Entra ID (Azure AD)** - No support for Entra managed identities or service principals
429-
-**Third-party identity providers** - No OIDC, SAML, or other federated identity
430-
-**Managed identities** - Azure Managed Identity is not supported
428+
```java
429+
var session = client.createSession(
430+
new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)
431+
.setProvider(new ProviderConfig()
432+
.setType("openai")
433+
.setBaseUrl("https://api.openai.com/v1")
434+
.setApiKey("sk-...")
435+
.setModelId("gpt-4o") // Runtime model for config lookup
436+
.setWireModel("my-finetune-v3") // Actual model name sent to provider API
437+
.setMaxInputTokens(100_000) // Override max prompt tokens
438+
.setMaxOutputTokens(4096)) // Override max output tokens
439+
).get();
440+
```
431441

432-
You must use an API key or static bearer token that you manage yourself.
442+
| Property | Description |
443+
|---|---|
444+
| `modelId` | Well-known model name for runtime config lookup (tools, prompts, reasoning). Also used as wire model when `wireModel` is not set. Falls back to `SessionConfig.model`. |
445+
| `wireModel` | Model name sent to the provider API. Use when the provider's model name (e.g. Azure deployment name or fine-tune) differs from `modelId`. Falls back to `modelId`, then `SessionConfig.model`. |
446+
| `maxInputTokens` | Override max prompt tokens. The runtime compacts conversation before exceeding this limit. |
447+
| `maxOutputTokens` | Override max output tokens. The model stops generating when this limit is hit. |
433448

434-
**Why not Entra ID?** While Entra ID does issue bearer tokens, these tokens are short-lived (typically 1 hour) and require automatic refresh via the Azure Identity SDK. The `bearerToken` option only accepts a static string—there is no callback mechanism for the SDK to request fresh tokens. For long-running workloads requiring Entra authentication, you would need to implement your own token refresh logic and create new sessions with updated tokens.
449+
#### Identity Limitations
450+
451+
BYOK authentication uses **static credentials only**.
452+
453+
You must use an API key or static bearer token that you manage yourself.
435454

436455
---
437456

src/test/java/com/github/copilot/sdk/ProviderConfigTest.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ void testDefaultsAreNull() {
4646
assertNull(provider.getApiKey());
4747
assertNull(provider.getBearerToken());
4848
assertNull(provider.getAzure());
49+
assertNull(provider.getModelId());
50+
assertNull(provider.getWireModel());
51+
assertNull(provider.getMaxInputTokens());
52+
assertNull(provider.getMaxOutputTokens());
4953
}
5054

5155
@Test
@@ -232,7 +236,8 @@ void testSerializeCustomWireApi() throws Exception {
232236
void testSerializeAllFields() throws Exception {
233237
var provider = new ProviderConfig().setType("azure-openai").setWireApi("completions")
234238
.setBaseUrl("https://my-resource.openai.azure.com").setApiKey("my-api-key")
235-
.setBearerToken("my-bearer-token").setAzure(new AzureOptions().setApiVersion("2024-02-01"));
239+
.setBearerToken("my-bearer-token").setAzure(new AzureOptions().setApiVersion("2024-02-01"))
240+
.setModelId("gpt-4o").setWireModel("my-deployment").setMaxInputTokens(50_000).setMaxOutputTokens(2048);
236241

237242
JsonNode json = MAPPER.valueToTree(provider);
238243

@@ -242,7 +247,11 @@ void testSerializeAllFields() throws Exception {
242247
assertEquals("my-api-key", json.get("apiKey").asText());
243248
assertEquals("my-bearer-token", json.get("bearerToken").asText());
244249
assertEquals("2024-02-01", json.get("azure").get("apiVersion").asText());
245-
assertEquals(6, json.size(), "Expected exactly 6 JSON fields");
250+
assertEquals("gpt-4o", json.get("modelId").asText());
251+
assertEquals("my-deployment", json.get("wireModel").asText());
252+
assertEquals(50_000, json.get("maxPromptTokens").asInt());
253+
assertEquals(2048, json.get("maxOutputTokens").asInt());
254+
assertEquals(10, json.size(), "Expected exactly 10 JSON fields");
246255
}
247256

248257
@Test
@@ -285,6 +294,30 @@ void testRoundTripProviderConfig() throws Exception {
285294
assertEquals(original.getAzure().getApiVersion(), deserialized.getAzure().getApiVersion());
286295
}
287296

297+
@Test
298+
void testSerializeProviderModelAndTokenOverrides() throws Exception {
299+
var provider = new ProviderConfig().setType("openai").setBaseUrl("https://example.com/provider")
300+
.setHeaders(java.util.Map.of("Authorization", "Bearer provider-token")).setModelId("gpt-4o")
301+
.setWireModel("my-finetune-v3").setMaxInputTokens(100_000).setMaxOutputTokens(4096);
302+
303+
JsonNode json = MAPPER.valueToTree(provider);
304+
305+
assertEquals("https://example.com/provider", json.get("baseUrl").asText());
306+
assertEquals("Bearer provider-token", json.get("headers").get("Authorization").asText());
307+
assertEquals("gpt-4o", json.get("modelId").asText());
308+
assertEquals("my-finetune-v3", json.get("wireModel").asText());
309+
assertEquals(100_000, json.get("maxPromptTokens").asInt());
310+
assertEquals(4096, json.get("maxOutputTokens").asInt());
311+
312+
ProviderConfig deserialized = MAPPER.treeToValue(json, ProviderConfig.class);
313+
assertNotNull(deserialized);
314+
assertEquals("https://example.com/provider", deserialized.getBaseUrl());
315+
assertEquals("gpt-4o", deserialized.getModelId());
316+
assertEquals("my-finetune-v3", deserialized.getWireModel());
317+
assertEquals(100_000, deserialized.getMaxInputTokens());
318+
assertEquals(4096, deserialized.getMaxOutputTokens());
319+
}
320+
288321
@Test
289322
void testForwardCompatibilityIgnoresUnknownFields() throws Exception {
290323
String json = """

src/test/java/com/github/copilot/sdk/SessionConfigE2ETest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,54 @@ private static String getSystemMessage(Map<String, Object> exchange) {
127127
}
128128
return null;
129129
}
130+
131+
@SuppressWarnings("unchecked")
132+
private static String getRequestModel(Map<String, Object> exchange) {
133+
Object requestObj = exchange.get("request");
134+
if (!(requestObj instanceof Map<?, ?> request)) {
135+
return null;
136+
}
137+
Object model = request.get("model");
138+
return model != null ? model.toString() : null;
139+
}
140+
141+
@Test
142+
void testShouldForwardProviderWireModel() throws Exception {
143+
ctx.configureForTest("session_config", "should_forward_provider_wire_model");
144+
145+
try (CopilotClient client = ctx.createClient()) {
146+
CopilotSession session = client.createSession(new SessionConfig().setModel("claude-sonnet-4.5")
147+
.setProvider(new com.github.copilot.sdk.json.ProviderConfig().setType("openai")
148+
.setBaseUrl(ctx.getProxyUrl()).setApiKey("test-provider-key")
149+
.setWireModel("test-wire-model").setMaxOutputTokens(1024))
150+
.setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get();
151+
152+
session.sendAndWait(new MessageOptions().setPrompt("What is 1+1?")).get(60, TimeUnit.SECONDS);
153+
154+
List<Map<String, Object>> exchanges = ctx.getExchanges();
155+
assertFalse(exchanges.isEmpty(), "Should have at least one exchange");
156+
assertEquals("test-wire-model", getRequestModel(exchanges.get(0)));
157+
}
158+
}
159+
160+
@Test
161+
void testShouldUseProviderModelIdAsWireModel() throws Exception {
162+
ctx.configureForTest("session_config", "should_use_provider_model_id_as_wire_model");
163+
164+
try (CopilotClient client = ctx.createClient()) {
165+
CopilotSession session = client
166+
.createSession(new SessionConfig()
167+
.setProvider(new com.github.copilot.sdk.json.ProviderConfig().setType("openai")
168+
.setBaseUrl(ctx.getProxyUrl()).setApiKey("test-provider-key")
169+
.setModelId("claude-sonnet-4.5"))
170+
.setOnPermissionRequest(PermissionHandler.APPROVE_ALL))
171+
.get();
172+
173+
session.sendAndWait(new MessageOptions().setPrompt("What is 1+1?")).get(60, TimeUnit.SECONDS);
174+
175+
List<Map<String, Object>> exchanges = ctx.getExchanges();
176+
assertFalse(exchanges.isEmpty(), "Should have at least one exchange");
177+
assertEquals("claude-sonnet-4.5", getRequestModel(exchanges.get(0)));
178+
}
179+
}
130180
}

0 commit comments

Comments
 (0)