Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public abstract class BaseAgent {
public BaseAgent(
String name,
String description,
List<? extends BaseAgent> subAgents,
@Nullable List<? extends BaseAgent> subAgents,
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallback,
@Nullable List<? extends AfterAgentCallback> afterAgentCallback) {
this.name = name;
Expand Down
20 changes: 20 additions & 0 deletions core/src/test/java/com/google/adk/agents/BaseAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,24 @@ public void canonicalCallbacks_returnsListWhenPresent() {
assertThat(agent.canonicalBeforeAgentCallbacks()).containsExactly(bc);
assertThat(agent.canonicalAfterAgentCallbacks()).containsExactly(ac);
}

@Test
public void runLive_invokesRunLiveImpl() {
var runLiveCallback = TestCallback.<Void>returningEmpty();
Content runLiveImplContent = Content.fromParts(Part.fromText("live_output"));
TestBaseAgent agent =
new TestBaseAgent(
TEST_AGENT_NAME,
TEST_AGENT_DESCRIPTION,
/* beforeAgentCallbacks= */ ImmutableList.of(),
/* afterAgentCallbacks= */ ImmutableList.of(),
runLiveCallback.asRunLiveImplSupplier(runLiveImplContent));
InvocationContext invocationContext = TestUtils.createInvocationContext(agent);

List<Event> results = agent.runLive(invocationContext).toList().blockingGet();

assertThat(results).hasSize(1);
assertThat(results.get(0).content()).hasValue(runLiveImplContent);
assertThat(runLiveCallback.wasCalled()).isTrue();
}
}
41 changes: 29 additions & 12 deletions core/src/test/java/com/google/adk/testing/TestCallback.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,63 +102,80 @@ public Supplier<Flowable<Event>> asRunAsyncImplSupplier(String contentText) {
return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText)));
}

/**
* Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable}
* with an event containing the given content.
*/
public Supplier<Flowable<Event>> asRunLiveImplSupplier(Content content) {
return () ->
Flowable.defer(
() -> {
markAsCalled();
return Flowable.just(Event.builder().content(content).build());
});
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public BeforeAgentCallback asBeforeAgentCallback() {
return ctx -> (Maybe<Content>) callMaybe();
return (unusedCtx) -> (Maybe<Content>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public BeforeAgentCallbackSync asBeforeAgentCallbackSync() {
return ctx -> (Optional<Content>) callOptional();
return (unusedCtx) -> (Optional<Content>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public AfterAgentCallback asAfterAgentCallback() {
return ctx -> (Maybe<Content>) callMaybe();
return (unusedCtx) -> (Maybe<Content>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Content.
public AfterAgentCallbackSync asAfterAgentCallbackSync() {
return ctx -> (Optional<Content>) callOptional();
return (unusedCtx) -> (Optional<Content>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public BeforeModelCallback asBeforeModelCallback() {
return (ctx, req) -> (Maybe<LlmResponse>) callMaybe();
return (unusedCtx, unusedReq) -> (Maybe<LlmResponse>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public BeforeModelCallbackSync asBeforeModelCallbackSync() {
return (ctx, req) -> (Optional<LlmResponse>) callOptional();
return (unusedCtx, unusedReq) -> (Optional<LlmResponse>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public AfterModelCallback asAfterModelCallback() {
return (ctx, res) -> (Maybe<LlmResponse>) callMaybe();
return (unusedCtx, unusedRes) -> (Maybe<LlmResponse>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse.
public AfterModelCallbackSync asAfterModelCallbackSync() {
return (ctx, res) -> (Optional<LlmResponse>) callOptional();
return (unusedCtx, unusedRes) -> (Optional<LlmResponse>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public BeforeToolCallback asBeforeToolCallback() {
return (invCtx, tool, toolArgs, toolCtx) -> (Maybe<Map<String, Object>>) callMaybe();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
(Maybe<Map<String, Object>>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public BeforeToolCallbackSync asBeforeToolCallbackSync() {
return (invCtx, tool, toolArgs, toolCtx) -> (Optional<Map<String, Object>>) callOptional();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx) ->
(Optional<Map<String, Object>>) callOptional();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public AfterToolCallback asAfterToolCallback() {
return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe<Map<String, Object>>) callMaybe();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
(Maybe<Map<String, Object>>) callMaybe();
}

@SuppressWarnings("unchecked") // This cast is safe if T is Map<String, Object>.
public AfterToolCallbackSync asAfterToolCallbackSync() {
return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional<Map<String, Object>>) callOptional();
return (unusedCtx, unusedTool, unusedToolArgs, unusedToolCtx, unusedRes) ->
(Optional<Map<String, Object>>) callOptional();
}
}