Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@ public McpClientTransport build(Object connectionParams) {
.orElse(""))))
.build();
} else if (connectionParams instanceof StreamableHttpServerParameters streamableParams) {
return HttpClientStreamableHttpTransport.builder(streamableParams.url())
.connectTimeout(streamableParams.timeout())
.jsonMapper(jsonMapper)
.asyncHttpRequestCustomizer(
(builder, method, uri, body, context) -> {
streamableParams.headers().forEach((key, value) -> builder.header(key, value));
return Mono.just(builder);
})
.build();
HttpClientStreamableHttpTransport.Builder transportBuilder =
HttpClientStreamableHttpTransport.builder(streamableParams.url())
.connectTimeout(streamableParams.timeout())
.jsonMapper(jsonMapper)
.asyncHttpRequestCustomizer(
(builder, method, uri, body, context) -> {
streamableParams.headers().forEach((key, value) -> builder.header(key, value));
return Mono.just(builder);
});
if (streamableParams.endpoint() != null) {
transportBuilder.endpoint(streamableParams.endpoint());
}
return transportBuilder.build();
} else {
throw new IllegalArgumentException(
"DefaultMcpTransportBuilder supports only ServerParameters, SseServerParameters, or"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
/** Server parameters for Streamable HTTP client transport. */
public class StreamableHttpServerParameters {
private final String url;
private final String endpoint;
private final Map<String, String> headers;
private final Duration timeout;
private final Duration readTimeout;
Expand All @@ -35,19 +36,23 @@ public class StreamableHttpServerParameters {
* Server parameters for Streamable HTTP client transport.
*
* @param url The base URL for the MCP Streamable HTTP server.
* @param endpoint The endpoint path on the server (e.g. {@code /mcp/stream}). When {@code null},
* the MCP library default ({@code /mcp}) is used.
* @param headers Optional headers to include in requests.
* @param timeout Timeout for HTTP operations (default: 30 seconds).
* @param readTimeout Timeout for reading data from the streamed http events(default: 5 minutes).
* @param terminateOnClose Whether to terminate the session on close (default: true).
*/
public StreamableHttpServerParameters(
String url,
@Nullable String endpoint,
Map<String, String> headers,
@Nullable Duration timeout,
@Nullable Duration readTimeout,
@Nullable Boolean terminateOnClose) {
Assert.hasText(url, "url must not be empty");
this.url = url;
this.endpoint = endpoint;
this.headers = headers == null ? Collections.emptyMap() : headers;
this.timeout = timeout == null ? Duration.ofSeconds(30) : timeout;
this.readTimeout = readTimeout == null ? Duration.ofMinutes(5) : readTimeout;
Expand All @@ -58,6 +63,11 @@ public String url() {
return url;
}

@Nullable
public String endpoint() {
return endpoint;
}

public Map<String, String> headers() {
return headers;
}
Expand All @@ -81,6 +91,7 @@ public static Builder builder() {
/** Builder for {@link StreamableHttpServerParameters}. */
public static class Builder {
private String url;
private String endpoint;
private Map<String, String> headers = Collections.emptyMap();
private Duration timeout = Duration.ofSeconds(30);
private Duration readTimeout = Duration.ofMinutes(5);
Expand All @@ -95,6 +106,12 @@ public Builder url(String url) {
return this;
}

@CanIgnoreReturnValue
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

@CanIgnoreReturnValue
public Builder headers(Map<String, String> headers) {
this.headers = headers;
Expand All @@ -121,7 +138,7 @@ public Builder terminateOnClose(boolean terminateOnClose) {

public StreamableHttpServerParameters build() {
return new StreamableHttpServerParameters(
url, headers, timeout, readTimeout, terminateOnClose);
url, endpoint, headers, timeout, readTimeout, terminateOnClose);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.tools.mcp;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.client.transport.StdioClientTransport;
import io.modelcontextprotocol.spec.McpClientTransport;
import java.lang.reflect.Field;
import java.net.URI;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for {@link DefaultMcpTransportBuilder}. */
@RunWith(JUnit4.class)
public final class DefaultMcpTransportBuilderTest {

private final DefaultMcpTransportBuilder builder = new DefaultMcpTransportBuilder();

// -------------------------------------------------------------------------
// Helper: read private fields via reflection
// -------------------------------------------------------------------------

private static String getEndpointField(HttpClientStreamableHttpTransport transport)
throws Exception {
Field field = HttpClientStreamableHttpTransport.class.getDeclaredField("endpoint");
field.setAccessible(true);
return (String) field.get(transport);
}

private static URI getBaseUriField(HttpClientStreamableHttpTransport transport) throws Exception {
Field field = HttpClientStreamableHttpTransport.class.getDeclaredField("baseUri");
field.setAccessible(true);
return (URI) field.get(transport);
}

// -------------------------------------------------------------------------
// StreamableHttp transport tests
// -------------------------------------------------------------------------

@Test
public void build_withStreamableHttpParamsWithoutEndpoint_usesLibraryDefaultEndpoint()
throws Exception {
// When the user does NOT set .endpoint(), the library default "/mcp" must be preserved.
StreamableHttpServerParameters params =
StreamableHttpServerParameters.builder()
.url("http://localhost:8080")
// No .endpoint() call → endpoint() returns null
.build();

McpClientTransport transport = builder.build(params);

assertThat(transport).isInstanceOf(HttpClientStreamableHttpTransport.class);
HttpClientStreamableHttpTransport streamableTransport =
(HttpClientStreamableHttpTransport) transport;

// baseUri stores exactly the url passed by the user
assertThat(getBaseUriField(streamableTransport))
.isEqualTo(URI.create("http://localhost:8080"));
// endpoint is the library's hard-coded default because the user did not override it
assertThat(getEndpointField(streamableTransport)).isEqualTo("/mcp");
}

@Test
public void build_withStreamableHttpParamsWithCustomEndpoint_setsEndpointOnTransport()
throws Exception {
// When the user sets .endpoint("/mcp/stream"), the transport's endpoint field must reflect it.
// This is the core of the bug fix: the library's Utils.resolveUri(baseUri, endpoint) will now
// compute URI.create("http://localhost:8080").resolve("/mcp/stream")
// = "http://localhost:8080/mcp/stream" ✅
// instead of the broken pre-fix behaviour:
// URI.create("http://localhost:8080/mcp/stream").resolve("/mcp")
// = "http://localhost:8080/mcp" ❌
StreamableHttpServerParameters params =
StreamableHttpServerParameters.builder()
.url("http://localhost:8080")
.endpoint("/mcp/stream")
.build();

McpClientTransport transport = builder.build(params);

assertThat(transport).isInstanceOf(HttpClientStreamableHttpTransport.class);
HttpClientStreamableHttpTransport streamableTransport =
(HttpClientStreamableHttpTransport) transport;

assertThat(getBaseUriField(streamableTransport))
.isEqualTo(URI.create("http://localhost:8080"));
// endpoint was explicitly overridden — must NOT be the default "/mcp"
assertThat(getEndpointField(streamableTransport)).isEqualTo("/mcp/stream");
}

@Test
public void build_withStreamableHttpParams_defaultEndpointProducesCorrectFinalUri()
throws Exception {
// Confirm that the final resolved URI for default case is http://host/mcp (not broken).
StreamableHttpServerParameters params =
StreamableHttpServerParameters.builder().url("http://localhost:8080").build();

McpClientTransport transport = builder.build(params);
HttpClientStreamableHttpTransport streamableTransport =
(HttpClientStreamableHttpTransport) transport;

URI base = getBaseUriField(streamableTransport);
String endpoint = getEndpointField(streamableTransport);
URI resolved = base.resolve(endpoint);

assertThat(resolved).isEqualTo(URI.create("http://localhost:8080/mcp"));
}

@Test
public void build_withStreamableHttpParams_customEndpointProducesCorrectFinalUri()
throws Exception {
// Confirm that the final resolved URI for the custom endpoint case is correct.
StreamableHttpServerParameters params =
StreamableHttpServerParameters.builder()
.url("http://localhost:8080")
.endpoint("/mcp/stream")
.build();

McpClientTransport transport = builder.build(params);
HttpClientStreamableHttpTransport streamableTransport =
(HttpClientStreamableHttpTransport) transport;

URI base = getBaseUriField(streamableTransport);
String endpoint = getEndpointField(streamableTransport);
URI resolved = base.resolve(endpoint);

assertThat(resolved).isEqualTo(URI.create("http://localhost:8080/mcp/stream"));
}

// -------------------------------------------------------------------------
// SSE transport tests — ensure existing behaviour is unchanged
// -------------------------------------------------------------------------

@Test
public void build_withSseParams_returnsSseTransport() {
SseServerParameters params =
SseServerParameters.builder().url("http://localhost:8080").build();

McpClientTransport transport = builder.build(params);

assertThat(transport).isInstanceOf(HttpClientSseClientTransport.class);
}

@Test
public void build_withSseParamsWithCustomSseEndpoint_returnsSseTransport() {
SseServerParameters params =
SseServerParameters.builder()
.url("http://localhost:8080")
.sseEndpoint("events")
.build();

McpClientTransport transport = builder.build(params);

assertThat(transport).isInstanceOf(HttpClientSseClientTransport.class);
}

// -------------------------------------------------------------------------
// Stdio transport tests — ensure existing behaviour is unchanged
// -------------------------------------------------------------------------

@Test
public void build_withStdioParams_returnsStdioTransport() {
ServerParameters params = ServerParameters.builder("echo").args("hello").build();

McpClientTransport transport = builder.build(params);

assertThat(transport).isInstanceOf(StdioClientTransport.class);
}

// -------------------------------------------------------------------------
// Unknown param type test
// -------------------------------------------------------------------------

@Test
public void build_withUnknownParamType_throwsIllegalArgumentException() {
Object unknownParams = new Object();

assertThrows(IllegalArgumentException.class, () -> builder.build(unknownParams));
}
}
Loading