Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS;
import static org.a2aproject.sdk.compat03.transport.jsonrpc.context.JSONRPCContextKeys_v0_3.HEADERS_KEY;
import static org.a2aproject.sdk.compat03.transport.jsonrpc.context.JSONRPCContextKeys_v0_3.METHOD_NAME_KEY;

Expand All @@ -24,11 +23,6 @@
import io.quarkus.security.ForbiddenException;
import io.quarkus.security.UnauthorizedException;
import io.smallrye.mutiny.Multi;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
Expand Down Expand Up @@ -66,20 +60,19 @@
import org.a2aproject.sdk.server.ServerCallContext;
import org.a2aproject.sdk.server.auth.UnauthenticatedUser;
import org.a2aproject.sdk.server.auth.User;
import org.a2aproject.sdk.server.common.quarkus.SseResponseWriter;
import org.a2aproject.sdk.server.common.quarkus.VertxSecurityHelper;
import org.a2aproject.sdk.server.extensions.A2AExtensions;
import org.a2aproject.sdk.spec.AgentCard;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
public class A2AServerRoutes_v0_3 {

@Inject
JSONRPCHandler_v0_3 jsonRpcHandler;

// Hook so testing can wait until the MultiSseSupport is subscribed.
// Hook so testing can wait until the SSE subscriber is attached.
// Without this we get intermittent failures
private static volatile @Nullable Runnable streamingMultiSseSupportSubscribedRunnable;

Expand Down Expand Up @@ -206,7 +199,7 @@ public void invokeJSONRPCHandler(String body, RoutingContext rc) {
AtomicLong eventIdCounter = new AtomicLong(0);
Multi<String> sseEvents = streamingResponse
.map(response -> formatSseEvent(response, eventIdCounter.getAndIncrement()));
MultiSseSupport.writeSseStrings(sseEvents, rc, context);
SseResponseWriter.writeSseStrings(sseEvents, rc, context, streamingMultiSseSupportSubscribedRunnable);
} else {
rc.response()
.setStatusCode(200)
Expand Down Expand Up @@ -335,81 +328,4 @@ public String getUsername() {
}
}

private static class MultiSseSupport {
private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class);

private MultiSseSupport() {
// Avoid direct instantiation.
}

public static void writeSseStrings(Multi<String> sseStrings, RoutingContext rc, ServerCallContext context) {
HttpServerResponse response = rc.response();

sseStrings.subscribe().withSubscriber(new Flow.Subscriber<String>() {
Flow.@Nullable Subscription upstream;

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.upstream = subscription;
this.upstream.request(1);

response.closeHandler(v -> {
logger.info("SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop");
context.invokeEventConsumerCancelCallback();
subscription.cancel();
});

// Notify tests that we are subscribed
Runnable runnable = streamingMultiSseSupportSubscribedRunnable;
if (runnable != null) {
runnable.run();
}
}

@Override
public void onNext(String sseEvent) {
if (response.bytesWritten() == 0) {
MultiMap headers = response.headers();
if (headers.get(CONTENT_TYPE) == null) {
headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
}
headers.set("Cache-Control", "no-cache");
headers.set("X-Accel-Buffering", "no");
response.setChunked(true);
response.setWriteQueueMaxSize(1);
response.write(": SSE stream started\n\n");
}

response.write(Buffer.buffer(sseEvent), new Handler<AsyncResult<Void>>() {
@Override
public void handle(AsyncResult<Void> ar) {
if (ar.failed()) {
java.util.Objects.requireNonNull(upstream).cancel();
rc.fail(ar.cause());
} else {
java.util.Objects.requireNonNull(upstream).request(1);
}
}
});
}

@Override
public void onError(Throwable throwable) {
java.util.Objects.requireNonNull(upstream).cancel();
rc.fail(throwable);
}

@Override
public void onComplete() {
if (response.bytesWritten() == 0) {
MultiMap headers = response.headers();
if (headers.get(CONTENT_TYPE) == null) {
headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
}
}
response.end();
}
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static jakarta.ws.rs.core.MediaType.SERVER_SENT_EVENTS;
import static org.a2aproject.sdk.compat03.transport.rest.context.RestContextKeys_v0_3.HEADERS_KEY;
import static org.a2aproject.sdk.compat03.transport.rest.context.RestContextKeys_v0_3.METHOD_NAME_KEY;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

Expand All @@ -24,11 +22,7 @@
import io.quarkus.security.ForbiddenException;
import io.quarkus.security.UnauthorizedException;
import io.smallrye.mutiny.Multi;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
Expand All @@ -53,20 +47,19 @@
import org.a2aproject.sdk.server.ServerCallContext;
import org.a2aproject.sdk.server.auth.UnauthenticatedUser;
import org.a2aproject.sdk.server.auth.User;
import org.a2aproject.sdk.server.common.quarkus.SseResponseWriter;
import org.a2aproject.sdk.server.common.quarkus.VertxSecurityHelper;
import org.a2aproject.sdk.server.extensions.A2AExtensions;
import org.a2aproject.sdk.spec.AgentCard;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
public class A2AServerRoutes_v0_3 {

@Inject
RestHandler_v0_3 jsonRestHandler;

// Hook so testing can wait until the MultiSseSupport is subscribed.
// Hook so testing can wait until the SSE subscriber is attached.
// Without this we get intermittent failures
private static volatile @Nullable Runnable streamingMultiSseSupportSubscribedRunnable;

Expand Down Expand Up @@ -202,7 +195,7 @@ public void sendMessageStreaming(String body, RoutingContext rc) {
AtomicLong eventIdCounter = new AtomicLong(0);
Multi<String> sseEvents = Multi.createFrom().publisher(streamingResponse.getPublisher())
.map(json -> formatSseEvent(json, eventIdCounter.getAndIncrement()));
MultiSseSupport.writeSseStrings(sseEvents, rc, context);
SseResponseWriter.writeSseStrings(sseEvents, rc, context, streamingMultiSseSupportSubscribedRunnable);
}
}
}
Expand Down Expand Up @@ -302,7 +295,7 @@ public void resubscribeTask(RoutingContext rc) {
AtomicLong eventIdCounter = new AtomicLong(0);
Multi<String> sseEvents = Multi.createFrom().publisher(streamingResponse.getPublisher())
.map(json -> formatSseEvent(json, eventIdCounter.getAndIncrement()));
MultiSseSupport.writeSseStrings(sseEvents, rc, context);
SseResponseWriter.writeSseStrings(sseEvents, rc, context, streamingMultiSseSupportSubscribedRunnable);
}
}
}
Expand Down Expand Up @@ -463,81 +456,4 @@ private static boolean hasNonDefaultV10AgentCard() {
return false;
}

private static class MultiSseSupport {
private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class);

private MultiSseSupport() {
// Avoid direct instantiation.
}

public static void writeSseStrings(Multi<String> sseStrings, RoutingContext rc, ServerCallContext context) {
HttpServerResponse response = rc.response();

sseStrings.subscribe().withSubscriber(new Flow.Subscriber<String>() {
Flow.@Nullable Subscription upstream;

@Override
public void onSubscribe(Flow.Subscription subscription) {
this.upstream = subscription;
this.upstream.request(1);

response.closeHandler(v -> {
logger.debug("REST SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop");
context.invokeEventConsumerCancelCallback();
subscription.cancel();
});

// Notify tests that we are subscribed
Runnable runnable = streamingMultiSseSupportSubscribedRunnable;
if (runnable != null) {
runnable.run();
}
}

@Override
public void onNext(String sseEvent) {
if (response.bytesWritten() == 0) {
MultiMap headers = response.headers();
if (headers.get(CONTENT_TYPE) == null) {
headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
}
headers.set("Cache-Control", "no-cache");
headers.set("X-Accel-Buffering", "no");
response.setChunked(true);
response.setWriteQueueMaxSize(1);
response.write(": SSE stream started\n\n");
}

response.write(Buffer.buffer(sseEvent), new Handler<AsyncResult<Void>>() {
@Override
public void handle(AsyncResult<Void> ar) {
if (ar.failed()) {
java.util.Objects.requireNonNull(upstream).cancel();
rc.fail(ar.cause());
} else {
java.util.Objects.requireNonNull(upstream).request(1);
}
}
});
}

@Override
public void onError(Throwable throwable) {
java.util.Objects.requireNonNull(upstream).cancel();
rc.fail(throwable);
}

@Override
public void onComplete() {
if (response.bytesWritten() == 0) {
MultiMap headers = response.headers();
if (headers.get(CONTENT_TYPE) == null) {
headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS);
}
}
response.end();
}
});
}
}
}
Loading
Loading