From 8f2f97a031f4fb3a5cc4ace0900fb15c8da0afc0 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 27 Feb 2026 08:46:36 -0800 Subject: [PATCH] feat: Adding a SessionKey for typeSafety PiperOrigin-RevId: 876274053 --- .../adk/artifacts/BaseArtifactService.java | 39 +++++++ .../java/com/google/adk/runner/Runner.java | 31 ++++++ .../adk/sessions/BaseSessionService.java | 57 +++++++--- .../java/com/google/adk/sessions/Session.java | 22 ++++ .../com/google/adk/sessions/SessionKey.java | 82 ++++++++++++++ .../com/google/adk/runner/RunnerTest.java | 103 ++++++++++++++++++ 6 files changed, 321 insertions(+), 13 deletions(-) create mode 100644 core/src/main/java/com/google/adk/sessions/SessionKey.java diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index 32ef9ff4d..51b2dfb6d 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -16,6 +16,7 @@ package com.google.adk.artifacts; +import com.google.adk.sessions.SessionKey; import com.google.common.collect.ImmutableList; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Completable; @@ -39,6 +40,11 @@ public interface BaseArtifactService { Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact); + default Single saveArtifact(SessionKey sessionKey, String filename, Part artifact) { + return saveArtifact( + sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename, artifact); + } + /** * Saves an artifact and returns it with fileData if available. * @@ -58,18 +64,33 @@ default Single saveAndReloadArtifact( .flatMap(version -> loadArtifact(appName, userId, sessionId, filename, version).toSingle()); } + default Single saveAndReloadArtifact( + SessionKey sessionKey, String filename, Part artifact) { + return saveAndReloadArtifact( + sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename, artifact); + } + /** Loads the latest version of an artifact from the service. */ default Maybe loadArtifact( String appName, String userId, String sessionId, String filename) { return loadArtifact(appName, userId, sessionId, filename, Optional.empty()); } + default Maybe loadArtifact(SessionKey sessionKey, String filename) { + return loadArtifact(sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename); + } + /** Loads a specific version of an artifact from the service. */ default Maybe loadArtifact( String appName, String userId, String sessionId, String filename, int version) { return loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); } + default Maybe loadArtifact(SessionKey sessionKey, String filename, int version) { + return loadArtifact( + sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename, version); + } + /** * @deprecated Use {@link #loadArtifact(String, String, String, String)} or {@link * #loadArtifact(String, String, String, String, int)} instead. @@ -78,6 +99,12 @@ default Maybe loadArtifact( Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version); + default Maybe loadArtifact( + SessionKey sessionKey, String filename, Optional version) { + return loadArtifact( + sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename, version); + } + /** * Lists all the artifact filenames within a session. * @@ -88,6 +115,10 @@ Maybe loadArtifact( */ Single listArtifactKeys(String appName, String userId, String sessionId); + default Single listArtifactKeys(SessionKey sessionKey) { + return listArtifactKeys(sessionKey.appName(), sessionKey.userId(), sessionKey.id()); + } + /** * Deletes an artifact. * @@ -98,6 +129,10 @@ Maybe loadArtifact( */ Completable deleteArtifact(String appName, String userId, String sessionId, String filename); + default Completable deleteArtifact(SessionKey sessionKey, String filename) { + return deleteArtifact(sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename); + } + /** * Lists all the versions (as revision IDs) of an artifact. * @@ -109,4 +144,8 @@ Maybe loadArtifact( */ Single> listVersions( String appName, String userId, String sessionId, String filename); + + default Single> listVersions(SessionKey sessionKey, String filename) { + return listVersions(sessionKey.appName(), sessionKey.userId(), sessionKey.id(), filename); + } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 0ddfdaea1..ea7cb80f6 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -35,6 +35,7 @@ import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.adk.summarizer.LlmEventSummarizer; import com.google.adk.summarizer.SlidingWindowEventCompactor; @@ -383,6 +384,25 @@ public Flowable runAsync( .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); } + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ + public Flowable runAsync( + SessionKey sessionKey, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { + return runAsync(sessionKey.userId(), sessionKey.id(), newMessage, runConfig, stateDelta); + } + + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ + public Flowable runAsync(SessionKey sessionKey, Content newMessage, RunConfig runConfig) { + return runAsync(sessionKey, newMessage, runConfig, /* stateDelta= */ null); + } + + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ + public Flowable runAsync(SessionKey sessionKey, Content newMessage) { + return runAsync(sessionKey, newMessage, RunConfig.builder().build()); + } + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ public Flowable runAsync(String userId, String sessionId, Content newMessage) { return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); @@ -671,6 +691,17 @@ public Flowable runLive( .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); } + /** + * Retrieves the session and runs the agent in live mode. + * + * @return stream of events from the agent. + * @throws IllegalArgumentException if the session is not found. + */ + public Flowable runLive( + SessionKey sessionKey, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { + return runLive(sessionKey.userId(), sessionKey.id(), liveRequestQueue, runConfig); + } + /** * Runs the agent asynchronously with a default user ID. * diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 94e8cd7ba..610efd68e 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -78,6 +78,18 @@ default Single createSession( return createSession(appName, userId, ensureConcurrentMap(state), sessionId); } + /** + * Creates a new session with the specified parameters. + * + * @param sessionKey The session key containing appName, userId and sessionId. + * @param state An optional map representing the initial state of the session. Can be null or + * empty. + */ + default Single createSession( + SessionKey sessionKey, @Nullable Map state) { + return createSession(sessionKey.appName(), sessionKey.userId(), state, sessionKey.id()); + } + /** * Creates a new session with the specified application name and user ID, using a default state * (null) and allowing the service to generate a unique session ID. @@ -94,6 +106,10 @@ default Single createSession(String appName, String userId) { return createSession(appName, userId, null, null); } + default Single createSession(SessionKey sessionKey) { + return createSession(sessionKey.appName(), sessionKey.userId(), null, sessionKey.id()); + } + /** * Retrieves a specific session, optionally filtering the events included. * @@ -110,6 +126,11 @@ default Single createSession(String appName, String userId) { Maybe getSession( String appName, String userId, String sessionId, Optional config); + default Maybe getSession(SessionKey sessionKey, @Nullable GetSessionConfig config) { + return getSession( + sessionKey.appName(), sessionKey.userId(), sessionKey.id(), Optional.ofNullable(config)); + } + /** * Lists sessions associated with a specific application and user. * @@ -123,6 +144,10 @@ Maybe getSession( */ Single listSessions(String appName, String userId); + default Single listSessions(SessionKey sessionKey) { + return listSessions(sessionKey.appName(), sessionKey.userId()); + } + /** * Deletes a specific session. * @@ -134,6 +159,10 @@ Maybe getSession( */ Completable deleteSession(String appName, String userId, String sessionId); + default Completable deleteSession(SessionKey sessionKey) { + return deleteSession(sessionKey.appName(), sessionKey.userId(), sessionKey.id()); + } + /** * Lists the events within a specific session. Supports pagination via the response object. * @@ -147,6 +176,10 @@ Maybe getSession( */ Single listEvents(String appName, String userId, String sessionId); + default Single listEvents(SessionKey sessionKey) { + return listEvents(sessionKey.appName(), sessionKey.userId(), sessionKey.id()); + } + /** * Closes a session. This is currently a placeholder and may involve finalizing session state or * performing cleanup actions in future implementations. The default implementation does nothing. @@ -190,20 +223,18 @@ default Single appendEvent(Session session, Event event) { EventActions actions = event.actions(); if (actions != null) { Map stateDelta = actions.stateDelta(); - if (stateDelta != null && !stateDelta.isEmpty()) { - Map sessionState = session.state(); - if (sessionState != null) { - stateDelta.forEach( - (key, value) -> { - if (!key.startsWith(State.TEMP_PREFIX)) { - if (value == State.REMOVED) { - sessionState.remove(key); - } else { - sessionState.put(key, value); - } + Map sessionState = session.state(); + if (stateDelta != null && !stateDelta.isEmpty() && sessionState != null) { + stateDelta.forEach( + (key, value) -> { + if (!key.startsWith(State.TEMP_PREFIX)) { + if (value == State.REMOVED) { + sessionState.remove(key); + } else { + sessionState.put(key, value); } - }); - } + } + }); } } diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index 877a95220..bba687403 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -49,6 +49,10 @@ public static Builder builder(String id) { return new Builder(id); } + public static Builder builder(SessionKey sessionKey) { + return new Builder(sessionKey); + } + /** Builder for {@link Session}. */ public static final class Builder { private String id; @@ -62,6 +66,12 @@ public Builder(String id) { this.id = id; } + public Builder(SessionKey sessionKey) { + this.id = sessionKey.id(); + this.appName = sessionKey.appName(); + this.userId = sessionKey.userId(); + } + @JsonCreator private Builder() {} @@ -72,6 +82,14 @@ public Builder id(String id) { return this; } + @CanIgnoreReturnValue + public Builder sessionKey(SessionKey sessionKey) { + this.id = sessionKey.id(); + this.appName = sessionKey.appName(); + this.userId = sessionKey.userId(); + return this; + } + @CanIgnoreReturnValue public Builder state(State state) { this.state = state; @@ -130,6 +148,10 @@ public Session build() { } } + public SessionKey sessionKey() { + return new SessionKey(appName, userId, id); + } + @JsonProperty("id") public String id() { return id; diff --git a/core/src/main/java/com/google/adk/sessions/SessionKey.java b/core/src/main/java/com/google/adk/sessions/SessionKey.java new file mode 100644 index 000000000..db26b5a3a --- /dev/null +++ b/core/src/main/java/com/google/adk/sessions/SessionKey.java @@ -0,0 +1,82 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.sessions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.JsonBaseModel; +import java.util.Objects; + +/** Key for a session, composed of appName, userId and session id. */ +public final class SessionKey extends JsonBaseModel { + private final String appName; + private final String userId; + private final String id; + + @JsonCreator + public SessionKey( + @JsonProperty("appName") String appName, + @JsonProperty("userId") String userId, + @JsonProperty("id") String id) { + this.appName = appName; + this.userId = userId; + this.id = id; + } + + @JsonProperty("appName") + public String appName() { + return appName; + } + + @JsonProperty("userId") + public String userId() { + return userId; + } + + @JsonProperty("id") + public String id() { + return id; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SessionKey that = (SessionKey) o; + return Objects.equals(appName, that.appName) + && Objects.equals(userId, that.userId) + && Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(appName, userId, id); + } + + @Override + public String toString() { + return toJson(); + } + + public static SessionKey fromJson(String json) { + return fromJsonString(json, SessionKey.class); + } +} diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 421b79abb..42452c6a0 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -41,6 +41,7 @@ import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.adk.summarizer.EventsCompactionConfig; import com.google.adk.telemetry.Tracing; import com.google.adk.testing.TestLlm; @@ -578,6 +579,14 @@ public void onEventCallback_success() { verify(plugin).onEventCallback(any(), any()); } + @Test + public void runAsync_withSessionKey_success() { + var events = + runner.runAsync(session.sessionKey(), createContent("from user")).toList().blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + } + @Test public void runAsync_withStateDelta_mergesStateIntoSession() { ImmutableMap stateDelta = ImmutableMap.of("key1", "value1", "key2", 42); @@ -605,6 +614,32 @@ public void runAsync_withStateDelta_mergesStateIntoSession() { assertThat(finalSession.state()).containsAtLeastEntriesIn(stateDelta); } + @Test + public void runAsync_withSessionKeyAndStateDelta_mergesStateIntoSession() { + ImmutableMap stateDelta = ImmutableMap.of("key1", "value1", "key2", 42); + + var events = + runner + .runAsync( + session.sessionKey(), + createContent("test message"), + RunConfig.builder().build(), + stateDelta) + .toList() + .blockingGet(); + + // Verify agent runs successfully + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + + // Verify state was merged into session + Session finalSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + assertThat(finalSession.state()).containsAtLeastEntriesIn(stateDelta); + } + @Test public void runAsync_withEmptyStateDelta_doesNotModifySession() { ImmutableMap emptyStateDelta = ImmutableMap.of(); @@ -840,6 +875,20 @@ public void runLive_success() throws Exception { assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); } + @Test + public void runLive_withSessionKey_success() throws Exception { + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + TestSubscriber testSubscriber = + runner.runLive(session.sessionKey(), liveRequestQueue, RunConfig.builder().build()).test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + } + @Test public void runLive_withToolExecution() throws Exception { LlmAgent agentWithTool = @@ -948,6 +997,18 @@ public void runAsync_withoutSessionAndAutoCreateSessionTrue_createsSession() { .isNotNull(); } + @Test + public void runAsync_withoutSessionAndAutoCreateSessionTrue_withSessionKey_createsSession() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); + SessionKey sessionKey = new SessionKey("test", "user", UUID.randomUUID().toString()); + + var events = + runner.runAsync(sessionKey, createContent("from user"), runConfig).toList().blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + assertThat(runner.sessionService().getSession(sessionKey, null).blockingGet()).isNotNull(); + } + @Test public void runAsync_withoutSessionAndAutoCreateSessionFalse_throwsException() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); @@ -959,6 +1020,17 @@ public void runAsync_withoutSessionAndAutoCreateSessionFalse_throwsException() { .assertError(IllegalArgumentException.class); } + @Test + public void runAsync_withoutSessionAndAutoCreateSessionFalse_withSessionKey_throwsException() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); + SessionKey sessionKey = new SessionKey("test", "user", UUID.randomUUID().toString()); + + runner + .runAsync(sessionKey, createContent("from user"), runConfig) + .test() + .assertError(IllegalArgumentException.class); + } + @Test public void runLive_withoutSessionAndAutoCreateSessionTrue_createsSession() throws Exception { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); @@ -982,6 +1054,25 @@ public void runLive_withoutSessionAndAutoCreateSessionTrue_createsSession() thro .isNotNull(); } + @Test + public void runLive_withoutSessionAndAutoCreateSessionTrue_withSessionKey_createsSession() + throws Exception { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(true).build(); + SessionKey sessionKey = new SessionKey("test", "user", UUID.randomUUID().toString()); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + TestSubscriber testSubscriber = + runner.runLive(sessionKey, liveRequestQueue, runConfig).test(); + + liveRequestQueue.content(createContent("from user")); + liveRequestQueue.close(); + + testSubscriber.await(); + testSubscriber.assertComplete(); + assertThat(simplifyEvents(testSubscriber.values())).containsExactly("test agent: from llm"); + assertThat(runner.sessionService().getSession(sessionKey, null).blockingGet()).isNotNull(); + } + @Test public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() { RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); @@ -994,6 +1085,18 @@ public void runLive_withoutSessionAndAutoCreateSessionFalse_throwsException() { .assertError(IllegalArgumentException.class); } + @Test + public void runLive_withoutSessionAndAutoCreateSessionFalse_withSessionKey_throwsException() { + RunConfig runConfig = RunConfig.builder().setAutoCreateSession(false).build(); + SessionKey sessionKey = new SessionKey("test", "user", UUID.randomUUID().toString()); + LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); + + runner + .runLive(sessionKey, liveRequestQueue, runConfig) + .test() + .assertError(IllegalArgumentException.class); + } + @Test public void runAsync_withToolConfirmation() { TestLlm testLlm =