diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 0a7a09864..20d7dfa4f 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -304,15 +304,12 @@ public Flowable runAsync(InvocationContext parentContext) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); + tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); + Tracing.traceAgentInvocation(span, name(), description(), invocationContext); Context spanContext = Context.current().with(span); - InvocationContext invocationContext = createInvocationContext(parentContext); - return Tracing.traceFlowable( spanContext, span, @@ -443,16 +440,41 @@ public Flowable runLive(InvocationContext parentContext) { Tracer tracer = Tracing.getTracer(); return Flowable.defer( () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); Span span = - tracer - .spanBuilder("agent_run [" + name() + "]") - .setParent(Context.current()) - .startSpan(); + tracer.spanBuilder("invoke_agent " + name()).setParent(Context.current()).startSpan(); + Tracing.traceAgentInvocation(span, name(), description(), invocationContext); Context spanContext = Context.current().with(span); - InvocationContext invocationContext = createInvocationContext(parentContext); + return Tracing.traceFlowable( + spanContext, + span, + () -> + callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runLiveImpl(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), + afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); - return Tracing.traceFlowable(spanContext, span, () -> runLiveImpl(invocationContext)); + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + })); }); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index a6fb74d88..5a855445a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -181,7 +181,7 @@ public static Maybe handleFunctionCalls( Span mergedSpan = tracer.spanBuilder("tool_response").setParent(Context.current()).startSpan(); try (Scope scope = mergedSpan.makeCurrent()) { - Tracing.traceToolResponse(invocationContext, mergedEvent.id(), mergedEvent); + Tracing.traceToolResponse(mergedEvent.id(), mergedEvent); } finally { mergedSpan.end(); } @@ -571,7 +571,8 @@ private static Maybe> callTool( .setParent(parentContext) .startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall(args); + Tracing.traceToolCall( + tool.name(), tool.description(), tool.getClass().getSimpleName(), args); return tool.runAsync(args, toolContext) .toMaybe() .doOnError(span::recordException) @@ -620,7 +621,7 @@ private static Event buildResponseEvent( .content(Content.builder().role("user").parts(partFunctionResponse).build()) .actions(toolContext.eventActions()) .build(); - Tracing.traceToolResponse(invocationContext, event.id(), event); + Tracing.traceToolResponse(event.id(), event); return event; } finally { span.end(); diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 23054b674..36c6e3e58 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -26,9 +26,12 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; +import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; @@ -37,7 +40,9 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,8 +57,59 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = + AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); + + private static final AttributeKey GEN_AI_OPERATION_NAME = + AttributeKey.stringKey("gen_ai.operation.name"); + private static final AttributeKey GEN_AI_AGENT_DESCRIPTION = + AttributeKey.stringKey("gen_ai.agent.description"); + private static final AttributeKey GEN_AI_AGENT_NAME = + AttributeKey.stringKey("gen_ai.agent.name"); + private static final AttributeKey GEN_AI_CONVERSATION_ID = + AttributeKey.stringKey("gen_ai.conversation.id"); + private static final AttributeKey GEN_AI_SYSTEM = AttributeKey.stringKey("gen_ai.system"); + private static final AttributeKey GEN_AI_TOOL_CALL_ID = + AttributeKey.stringKey("gen_ai.tool_call.id"); + private static final AttributeKey GEN_AI_TOOL_DESCRIPTION = + AttributeKey.stringKey("gen_ai.tool.description"); + private static final AttributeKey GEN_AI_TOOL_NAME = + AttributeKey.stringKey("gen_ai.tool.name"); + private static final AttributeKey GEN_AI_TOOL_TYPE = + AttributeKey.stringKey("gen_ai.tool.type"); + private static final AttributeKey GEN_AI_REQUEST_MODEL = + AttributeKey.stringKey("gen_ai.request.model"); + private static final AttributeKey GEN_AI_REQUEST_TOP_P = + AttributeKey.doubleKey("gen_ai.request.top_p"); + private static final AttributeKey GEN_AI_REQUEST_MAX_TOKENS = + AttributeKey.longKey("gen_ai.request.max_tokens"); + private static final AttributeKey GEN_AI_USAGE_INPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.input_tokens"); + private static final AttributeKey GEN_AI_USAGE_OUTPUT_TOKENS = + AttributeKey.longKey("gen_ai.usage.output_tokens"); + + private static final AttributeKey ADK_TOOL_CALL_ARGS = + AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"); + private static final AttributeKey ADK_LLM_REQUEST = + AttributeKey.stringKey("gcp.vertex.agent.llm_request"); + private static final AttributeKey ADK_LLM_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.llm_response"); + private static final AttributeKey ADK_INVOCATION_ID = + AttributeKey.stringKey("gcp.vertex.agent.invocation_id"); + private static final AttributeKey ADK_EVENT_ID = + AttributeKey.stringKey("gcp.vertex.agent.event_id"); + private static final AttributeKey ADK_TOOL_RESPONSE = + AttributeKey.stringKey("gcp.vertex.agent.tool_response"); + private static final AttributeKey ADK_SESSION_ID = + AttributeKey.stringKey("gcp.vertex.agent.session_id"); + private static final AttributeKey ADK_DATA = + AttributeKey.stringKey("gcp.vertex.agent.data"); + + private static final TypeReference> MAP_TYPE_REFERENCE = + new TypeReference>() {}; + @SuppressWarnings("NonFinalStaticField") - private static Tracer tracer = GlobalOpenTelemetry.getTracer("com.google.adk"); + private static Tracer tracer = GlobalOpenTelemetry.getTracer("gcp.vertex.agent"); private static final boolean CAPTURE_MESSAGE_CONTENT_IN_SPANS = Boolean.parseBoolean( @@ -66,52 +122,105 @@ public static void setTracerForTesting(Tracer tracer) { Tracing.tracer = tracer; } + /** + * Sets span attributes immediately available on agent invocation according to OTEL semconv + * version 1.37. + * + * @param span Span on which attributes are set. + * @param agentName Agent name from which attributes are gathered. + * @param agentDescription Agent description from which attributes are gathered. + * @param invocationContext InvocationContext from which attributes are gathered. + */ + public static void traceAgentInvocation( + Span span, String agentName, String agentDescription, InvocationContext invocationContext) { + span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); + span.setAttribute(GEN_AI_AGENT_NAME, agentName); + if (invocationContext.session() != null && invocationContext.session().id() != null) { + span.setAttribute(GEN_AI_CONVERSATION_ID, invocationContext.session().id()); + } + } + /** * Traces tool call arguments. * * @param args The arguments to the tool call. */ - public static void traceToolCall(Map args) { + public static void traceToolCall( + String toolName, String toolDescription, String toolType, Map args) { Span span = Span.current(); if (span == null || !span.getSpanContext().isValid()) { log.trace("traceToolCall: No valid span in current context."); return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - try { - span.setAttribute("adk.tool_call_args", JsonBaseModel.getMapper().writeValueAsString(args)); - } catch (JsonProcessingException e) { - log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + span.setAttribute(ADK_TOOL_CALL_ARGS, JsonBaseModel.getMapper().writeValueAsString(args)); + } catch (JsonProcessingException e) { + log.warn("traceToolCall: Failed to serialize tool call args to JSON", e); + } + } else { + span.setAttribute(ADK_TOOL_CALL_ARGS, "{}"); } + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** * Traces tool response event. * - * @param invocationContext The invocation context for the current agent run. * @param eventId The ID of the event. * @param functionResponseEvent The function response event. */ - public static void traceToolResponse( - InvocationContext invocationContext, String eventId, Event functionResponseEvent) { + public static void traceToolResponse(String eventId, Event functionResponseEvent) { Span span = Span.current(); if (span == null || !span.getSpanContext().isValid()) { log.trace("traceToolResponse: No valid span in current context."); return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); - span.setAttribute("adk.tool_response", functionResponseEvent.toJson()); + span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); + span.setAttribute(ADK_EVENT_ID, eventId); - // Setting empty llm request and response (as the AdkDevServer UI expects these) - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); - if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + String toolCallId = ""; + Object toolResponse = ""; + + Optional optionalFunctionResponse = + functionResponseEvent.functionResponses().stream().findFirst(); + + if (optionalFunctionResponse.isPresent()) { + FunctionResponse functionResponse = optionalFunctionResponse.get(); + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + if (!(toolResponse instanceof Map)) { + toolResponse = ImmutableMap.of("result", toolResponse); } + + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + span.setAttribute( + ADK_TOOL_RESPONSE, JsonBaseModel.getMapper().writeValueAsString(toolResponse)); + } catch (JsonProcessingException e) { + log.warn("traceToolResponse: Failed to serialize tool response to JSON", e); + span.setAttribute(ADK_TOOL_RESPONSE, "{\"error\": \"serialization failed\"}"); + } + } else { + span.setAttribute(ADK_TOOL_RESPONSE, "{}"); + } + + // Setting empty llm request and response (as the AdkDevServer UI expects these) + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** @@ -162,32 +271,62 @@ public static void traceCallLlm( return; } - span.setAttribute("gen_ai.system", "com.google.adk"); - llmRequest.model().ifPresent(modelName -> span.setAttribute("gen_ai.request.model", modelName)); - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); - span.setAttribute("adk.event_id", eventId); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); + span.setAttribute(ADK_EVENT_ID, eventId); if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); } else { log.trace( "traceCallLlm: InvocationContext session or session ID is null, cannot set" - + " adk.session_id"); + + " gcp.vertex.agent.session_id"); } if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { try { span.setAttribute( - "adk.llm_request", + ADK_LLM_REQUEST, JsonBaseModel.getMapper().writeValueAsString(buildLlmRequestForTrace(llmRequest))); - span.setAttribute("adk.llm_response", llmResponse.toJson()); + span.setAttribute(ADK_LLM_RESPONSE, llmResponse.toJson()); } catch (JsonProcessingException e) { log.warn("traceCallLlm: Failed to serialize LlmRequest or LlmResponse to JSON", e); } } else { - span.setAttribute("adk.llm_request", "{}"); - span.setAttribute("adk.llm_response", "{}"); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -205,29 +344,27 @@ public static void traceSendData( return; } - span.setAttribute("adk.invocation_id", invocationContext.invocationId()); + span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); if (eventId != null && !eventId.isEmpty()) { - span.setAttribute("adk.event_id", eventId); + span.setAttribute(ADK_EVENT_ID, eventId); } if (invocationContext.session() != null && invocationContext.session().id() != null) { - span.setAttribute("adk.session_id", invocationContext.session().id()); + span.setAttribute(ADK_SESSION_ID, invocationContext.session().id()); } - - try { - List> dataList = new ArrayList<>(); - if (data != null) { - for (Content content : data) { - if (content != null) { - dataList.add( - JsonBaseModel.getMapper() - .convertValue(content, new TypeReference>() {})); - } - } + if (CAPTURE_MESSAGE_CONTENT_IN_SPANS) { + try { + ImmutableList> dataList = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(content -> content != null) + .map(content -> JsonBaseModel.getMapper().convertValue(content, MAP_TYPE_REFERENCE)) + .collect(toImmutableList()); + span.setAttribute(ADK_DATA, JsonBaseModel.toJsonString(dataList)); + } catch (IllegalStateException e) { + log.warn("traceSendData: Failed to serialize data to JSON", e); } - span.setAttribute("adk.data", JsonBaseModel.toJsonString(dataList)); - } catch (IllegalStateException e) { - log.warn("traceSendData: Failed to serialize data to JSON", e); + } else { + span.setAttribute(ADK_DATA, "{}"); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index bf68f905c..5e2fa5792 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -94,6 +94,21 @@ public void rootAgent_returnsRootAgent() { assertThat(subSubAgent.rootAgent()).isEqualTo(agent); assertThat(subAgent.rootAgent()).isEqualTo(agent); assertThat(agent.rootAgent()).isEqualTo(agent); + assertThat(subSubAgent.parentAgent()).isEqualTo(subAgent); + assertThat(subAgent.parentAgent()).isEqualTo(agent); + assertThat(agent.parentAgent()).isNull(); + } + + @Test + public void subAgents_returnsSubAgents() { + TestBaseAgent subAgent1 = + new TestBaseAgent("subAgent1", "subAgent1", null, ImmutableList.of(), null, null); + TestBaseAgent subAgent2 = + new TestBaseAgent("subAgent2", "subAgent2", null, ImmutableList.of(), null, null); + TestBaseAgent agent = + new TestBaseAgent( + "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null); + assertThat(agent.subAgents()).containsExactly(subAgent1, subAgent2).inOrder(); } @Test @@ -354,6 +369,199 @@ public void runLive_invokesRunLiveImpl() { assertThat(runLiveCallback.wasCalled()).isTrue(); } + @Test + public void + runLive_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunLiveImplAndAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback = TestCallback.returning(callbackContent); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(callbackContent); + assertThat(runLiveImpl.wasCalled()).isFalse(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isFalse(); + } + + @Test + public void runLive_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback1 = TestCallback.returning(callbackContent); + var beforeCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of( + beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), + ImmutableList.of(), + TestCallback.returningEmpty().asRunLiveImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + var unused = agent.runLive(invocationContext).toList().blockingGet(); + assertThat(beforeCallback1.wasCalled()).isTrue(); + assertThat(beforeCallback2.wasCalled()).isFalse(); + } + + @Test + public void + runLive_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunLiveImplAndAfterCallbacks() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackReturnsContent_invokesRunLiveImplAndAfterCallbacksAndReturnsAllContent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returning(afterCallbackContent); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + BeforeAgentCallback beforeCallback = + new BeforeAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // State event from before callback + assertThat(results.get(0).content()).isEmpty(); + assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); + // Content event from runLiveImpl + assertThat(results.get(1).content()).hasValue(runLiveImplContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runLive_afterCallbackMutatesStateAndReturnsEmpty_invokesRunLiveImplAndReturnsStateEvent() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + AfterAgentCallback afterCallback = + new AfterAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // Content event from runLiveImpl + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + // State event from after callback + assertThat(results.get(1).content()).isEmpty(); + assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + } + + @Test + public void runLive_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { + var runLiveImpl = TestCallback.returningEmpty(); + Content runLiveImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var afterCallback1 = TestCallback.returning(afterCallbackContent); + var afterCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(), + ImmutableList.of( + afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), + runLiveImpl.asRunLiveImplSupplier(runLiveImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runLive(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runLiveImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runLiveImpl.wasCalled()).isTrue(); + assertThat(afterCallback1.wasCalled()).isTrue(); + assertThat(afterCallback2.wasCalled()).isFalse(); + } + @Test public void constructor_invalidName_throwsIllegalArgumentException() { assertThrows( @@ -379,6 +587,12 @@ public void constructor_duplicateSubAgentNames_throwsIllegalArgumentException() "agent", "description", null, ImmutableList.of(subAgent1, subAgent2), null, null)); } + @Test + @SuppressWarnings("DoNotCall") + public void fromConfig_throwsUnsupportedOperationException() { + assertThrows(UnsupportedOperationException.class, () -> BaseAgent.fromConfig(null, null)); + } + @Test public void close_noSubAgents_completesSuccessfully() { ClosableTestAgent agent = new ClosableTestAgent("agent", "description", ImmutableList.of()); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 438522493..7d493c526 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -16,22 +16,39 @@ package com.google.adk.telemetry; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; - -import io.opentelemetry.api.GlobalOpenTelemetry; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FinishReason; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; -import io.opentelemetry.sdk.OpenTelemetrySdk; -import io.opentelemetry.sdk.testing.exporter.InMemorySpanExporter; -import io.opentelemetry.sdk.trace.SdkTracerProvider; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; -import io.opentelemetry.sdk.trace.export.SimpleSpanProcessor; +import io.reactivex.rxjava3.core.Flowable; import java.util.List; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; /** * Tests for OpenTelemetry context propagation in ADK. @@ -39,31 +56,28 @@ *

Verifies that spans created by ADK properly link to parent contexts when available, enabling * proper distributed tracing across async boundaries. */ -class ContextPropagationTest { +@RunWith(JUnit4.class) +public class ContextPropagationTest { + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); - private InMemorySpanExporter spanExporter; private Tracer tracer; + private Tracer originalTracer; + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + Tracing.setTracerForTesting( + openTelemetryRule.getOpenTelemetry().getTracer("ContextPropagationTest")); + tracer = openTelemetryRule.getOpenTelemetry().getTracer("test"); + } - @BeforeEach - void setup() { - // Reset GlobalOpenTelemetry state - GlobalOpenTelemetry.resetForTest(); - - spanExporter = InMemorySpanExporter.create(); - - SdkTracerProvider tracerProvider = - SdkTracerProvider.builder() - .addSpanProcessor(SimpleSpanProcessor.create(spanExporter)) - .build(); - - OpenTelemetrySdk sdk = - OpenTelemetrySdk.builder().setTracerProvider(tracerProvider).buildAndRegisterGlobal(); - - tracer = sdk.getTracer("test"); + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); } @Test - void testToolCallSpanLinksToParent() { + public void testToolCallSpanLinksToParent() { // Given: Parent span is active Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -82,8 +96,8 @@ void testToolCallSpanLinksToParent() { } // Then: tool_call should be child of parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: parent and tool_call"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: parent and tool_call", 2, spans.size()); SpanData parentSpanData = spans.stream() @@ -99,41 +113,43 @@ void testToolCallSpanLinksToParent() { // Verify parent-child relationship assertEquals( + "Tool call should have same trace ID as parent", parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + toolCallSpanData.getSpanContext().getTraceId()); assertEquals( + "Tool call's parent should be the parent span", parentSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call's parent should be the parent span"); + toolCallSpanData.getParentSpanContext().getSpanId()); } @Test - void testToolCallWithoutParentCreatesRootSpan() { + public void testToolCallWithoutParentCreatesRootSpan() { // Given: No parent span active // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); + try (Scope s = Context.root().makeCurrent()) { + Span toolCallSpan = + tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); + try (Scope scope = toolCallSpan.makeCurrent()) { + // Work + } finally { + toolCallSpan.end(); + } } // Then: Should create root span (backward compatible) - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(1, spans.size(), "Should have exactly 1 span"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have exactly 1 span", 1, spans.size()); SpanData toolCallSpanData = spans.get(0); assertFalse( - toolCallSpanData.getParentSpanContext().isValid(), - "Tool call should be root span when no parent exists"); + "Tool call should be root span when no parent exists", + toolCallSpanData.getParentSpanContext().isValid()); } @Test - void testNestedSpanHierarchy() { + public void testNestedSpanHierarchy() { // Test: parent → invocation → tool_call → tool_response hierarchy Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -168,8 +184,8 @@ void testNestedSpanHierarchy() { } // Verify complete hierarchy - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans in the hierarchy"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 4 spans in the hierarchy", 4, spans.size()); String parentTraceId = spans.stream() @@ -182,9 +198,9 @@ void testNestedSpanHierarchy() { spans.forEach( span -> assertEquals( + "All spans should be in same trace", parentTraceId, - span.getSpanContext().getTraceId(), - "All spans should be in same trace")); + span.getSpanContext().getTraceId())); // Verify parent-child relationships SpanData parentSpanData = findSpanByName(spans, "parent"); @@ -194,25 +210,25 @@ void testNestedSpanHierarchy() { // invocation should be child of parent assertEquals( + "Invocation should be child of parent", parentSpanData.getSpanContext().getSpanId(), - invocationSpanData.getParentSpanContext().getSpanId(), - "Invocation should be child of parent"); + invocationSpanData.getParentSpanContext().getSpanId()); // tool_call should be child of invocation assertEquals( + "Tool call should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - toolCallSpanData.getParentSpanContext().getSpanId(), - "Tool call should be child of invocation"); + toolCallSpanData.getParentSpanContext().getSpanId()); // tool_response should be child of tool_call assertEquals( + "Tool response should be child of tool call", toolCallSpanData.getSpanContext().getSpanId(), - toolResponseSpanData.getParentSpanContext().getSpanId(), - "Tool response should be child of tool call"); + toolResponseSpanData.getParentSpanContext().getSpanId()); } @Test - void testMultipleSpansInParallel() { + public void testMultipleSpansInParallel() { // Test: Multiple tool calls in parallel should all link to same parent Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -234,8 +250,8 @@ void testMultipleSpansInParallel() { } // Verify all tool calls link to same parent - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(4, spans.size(), "Should have 4 spans: 1 parent + 3 tool calls"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 4 spans: 1 parent + 3 tool calls", 4, spans.size()); SpanData parentSpanData = findSpanByName(spans, "parent"); String parentTraceId = parentSpanData.getSpanContext().getTraceId(); @@ -245,59 +261,59 @@ void testMultipleSpansInParallel() { List toolCallSpans = spans.stream().filter(s -> s.getName().startsWith("tool_call")).toList(); - assertEquals(3, toolCallSpans.size(), "Should have 3 tool call spans"); + assertEquals("Should have 3 tool call spans", 3, toolCallSpans.size()); toolCallSpans.forEach( span -> { assertEquals( + "Tool call should have same trace ID as parent", parentTraceId, - span.getSpanContext().getTraceId(), - "Tool call should have same trace ID as parent"); + span.getSpanContext().getTraceId()); assertEquals( + "Tool call should have parent as parent span", parentSpanId, - span.getParentSpanContext().getSpanId(), - "Tool call should have parent as parent span"); + span.getParentSpanContext().getSpanId()); }); } @Test - void testAgentRunSpanLinksToInvocation() { - // Test: agent_run span should link to invocation span + public void testInvokeAgentSpanLinksToInvocation() { + // Test: invoke_agent span should link to invocation span Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span agentRunSpan = - tracer.spanBuilder("agent_run [test-agent]").setParent(Context.current()).startSpan(); + Span invokeAgentSpan = + tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { // Simulate agent work } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } } finally { invocationSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: invocation and agent_run"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: invocation and invoke_agent", 2, spans.size()); SpanData invocationSpanData = findSpanByName(spans, "invocation"); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); + SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); assertEquals( + "Agent run should be child of invocation", invocationSpanData.getSpanContext().getSpanId(), - agentRunSpanData.getParentSpanContext().getSpanId(), - "Agent run should be child of invocation"); + invokeAgentSpanData.getParentSpanContext().getSpanId()); } @Test - void testCallLlmSpanLinksToAgentRun() { + public void testCallLlmSpanLinksToAgentRun() { // Test: call_llm span should link to agent_run span - Span agentRunSpan = tracer.spanBuilder("agent_run [test-agent]").startSpan(); + Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - try (Scope agentScope = agentRunSpan.makeCurrent()) { + try (Scope agentScope = invokeAgentSpan.makeCurrent()) { Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); try (Scope llmScope = callLlmSpan.makeCurrent()) { @@ -306,43 +322,76 @@ void testCallLlmSpanLinksToAgentRun() { callLlmSpan.end(); } } finally { - agentRunSpan.end(); + invokeAgentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans: agent_run and call_llm"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans: invoke_agent and call_llm", 2, spans.size()); - SpanData agentRunSpanData = findSpanByName(spans, "agent_run [test-agent]"); + SpanData invokeAgentSpanData = findSpanByName(spans, "invoke_agent test-agent"); SpanData callLlmSpanData = findSpanByName(spans, "call_llm"); assertEquals( - agentRunSpanData.getSpanContext().getSpanId(), - callLlmSpanData.getParentSpanContext().getSpanId(), - "Call LLM should be child of agent run"); + "Call LLM should be child of agent run", + invokeAgentSpanData.getSpanContext().getSpanId(), + callLlmSpanData.getParentSpanContext().getSpanId()); } @Test - void testSpanCreatedWithinParentScopeIsCorrectlyParented() { + public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { // Test: Simulates creating a span within the scope of a parent Span parentSpan = tracer.spanBuilder("invocation").startSpan(); try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("agent_run").setParent(Context.current()).startSpan(); + Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); agentSpan.end(); } finally { parentSpan.end(); } - List spans = spanExporter.getFinishedSpanItems(); - assertEquals(2, spans.size(), "Should have 2 spans"); + List spans = openTelemetryRule.getSpans(); + assertEquals("Should have 2 spans", 2, spans.size()); SpanData parentSpanData = findSpanByName(spans, "invocation"); - SpanData agentSpanData = findSpanByName(spans, "agent_run"); + SpanData agentSpanData = findSpanByName(spans, "invoke_agent"); + + assertEquals( + "Agent span should be a child of the invocation span", + parentSpanData.getSpanContext().getSpanId(), + agentSpanData.getParentSpanContext().getSpanId()); + } + @Test + public void testTraceFlowable() throws InterruptedException { + Span parentSpan = tracer.spanBuilder("parent").startSpan(); + try (Scope s = parentSpan.makeCurrent()) { + Span flowableSpan = tracer.spanBuilder("flowable").setParent(Context.current()).startSpan(); + Flowable flowable = + Tracing.traceFlowable( + Context.current().with(flowableSpan), + flowableSpan, + () -> + Flowable.just(1, 2, 3) + .map( + i -> { + assertEquals( + flowableSpan.getSpanContext().getSpanId(), + Span.current().getSpanContext().getSpanId()); + return i * 2; + })); + flowable.test().await().assertComplete(); + } finally { + parentSpan.end(); + } + + List spans = openTelemetryRule.getSpans(); + assertEquals(2, spans.size()); + SpanData parentSpanData = findSpanByName(spans, "parent"); + SpanData flowableSpanData = findSpanByName(spans, "flowable"); assertEquals( parentSpanData.getSpanContext().getSpanId(), - agentSpanData.getParentSpanContext().getSpanId(), - "Agent span should be a child of the invocation span"); + flowableSpanData.getParentSpanContext().getSpanId()); + assertTrue(flowableSpanData.hasEnded()); } private SpanData findSpanByName(List spans, String name) { @@ -351,4 +400,164 @@ private SpanData findSpanByName(List spans, String name) { .findFirst() .orElseThrow(() -> new AssertionError("Span not found: " + name)); } + + @Test + public void testTraceAgentInvocation() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceAgentInvocation( + span, + "test-agent", + "test-description", + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build()); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("invoke_agent", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("test-agent", attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))); + assertEquals("test-description", attrs.get(AttributeKey.stringKey("gen_ai.agent.description"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gen_ai.conversation.id"))); + } + + @Test + public void testTraceToolCall() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceToolCall( + "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals( + "{\"arg1\":\"value1\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response"))); + } + + @Test + public void testTraceToolResponse() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Event functionResponseEvent = + Event.builder() + .id("event-1") + .content( + Content.fromParts( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .name("tool-name") + .id("tool-call-id") + .response(ImmutableMap.of("result", "tool-result")) + .build()) + .build())) + .build(); + Tracing.traceToolResponse("event-1", functionResponseEvent); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals( + "{\"result\":\"tool-result\"}", + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); + } + + @Test + public void testTraceCallLlm() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-pro") + .contents(ImmutableList.of(Content.fromParts(Part.fromText("hello")))) + .config(GenerateContentConfig.builder().topP(0.9f).maxOutputTokens(100).build()) + .build(); + LlmResponse llmResponse = + LlmResponse.builder() + .content(Content.builder().parts(Part.fromText("world")).build()) + .finishReason(new FinishReason(FinishReason.Known.STOP)) + .usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(10) + .candidatesTokenCount(20) + .totalTokenCount(30) + .build()) + .build(); + Tracing.traceCallLlm( + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build(), + "event-1", + llmRequest, + llmResponse); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); + assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertEquals(0.9d, attrs.get(AttributeKey.doubleKey("gen_ai.request.top_p")), 0.01); + assertEquals(100L, (long) attrs.get(AttributeKey.longKey("gen_ai.request.max_tokens"))); + assertEquals(10L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.input_tokens"))); + assertEquals(20L, (long) attrs.get(AttributeKey.longKey("gen_ai.usage.output_tokens"))); + assertEquals( + ImmutableList.of("stop"), + attrs.get(AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"))); + assertTrue( + attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_request")).contains("gemini-pro")); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.llm_response")).contains("STOP")); + } + + @Test + public void testTraceSendData() { + Span span = tracer.spanBuilder("test").startSpan(); + try (Scope scope = span.makeCurrent()) { + Tracing.traceSendData( + InvocationContext.builder() + .invocationId("inv-1") + .session(Session.builder("session-1").build()) + .build(), + "event-1", + ImmutableList.of(Content.builder().role("user").parts(Part.fromText("hello")).build())); + } finally { + span.end(); + } + List spans = openTelemetryRule.getSpans(); + assertEquals(1, spans.size()); + SpanData spanData = spans.get(0); + Attributes attrs = spanData.getAttributes(); + assertEquals("inv-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); + assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); + assertEquals("session-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.session_id"))); + assertTrue(attrs.get(AttributeKey.stringKey("gcp.vertex.agent.data")).contains("hello")); + } } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 434d85e6f..6f35f5a3c 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -115,6 +115,13 @@ public Supplier> asRunLiveImplSupplier(Content content) { }); } + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + */ + public Supplier> asRunLiveImplSupplier(String contentText) { + return asRunLiveImplSupplier(Content.fromParts(Part.fromText(contentText))); + } + @SuppressWarnings("unchecked") // This cast is safe if T is Content. public BeforeAgentCallback asBeforeAgentCallback() { return (unusedCtx) -> (Maybe) callMaybe();