Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,67 @@ async def get_user_state(
'call get_session on each result to access the merged state.'
)

async def merge_state(
self,
*,
app_name: str,
user_id: str,
session_id: str,
delta: dict[str, Any],
) -> None:
"""Atomically merges a state delta without appending an event.

This is a state-only write path that bypasses the event log and the
whole-session optimistic-concurrency (OCC) check that ``append_event``
performs. It is intended for *commutative* updates to independent state
keys (counters, flags, per-user balances, feature toggles), where coupling
the write to event-log append + session-level OCC would force unrelated
writers to serialize and retry on spurious "stale session" errors.

The ``delta`` uses the same prefix convention as ``create_session`` and
event ``state_delta``; keys are routed to the scope-appropriate storage
row:

* ``app:`` -> app-scoped state (keyed by ``app_name``).
* ``user:`` -> user-scoped state (keyed by ``(app_name, user_id)``).
* no prefix -> session-scoped state (keyed by
``(app_name, user_id, session_id)``).

``temp:`` keys are rejected with ``ValueError`` because temp state is never
persisted.

Guarantees (for backends that implement this method):

* No ``events`` row is written.
* The session's OCC revision marker is **not** advanced, so a concurrently
held in-memory ``Session`` does not become stale and its next
``append_event`` still succeeds.
* App- and user-scoped merges do not require a pre-existing session and are
fully decoupled from any session's revision. ``session_id`` is used only
to route session-scoped keys.

Merge semantics for non-scalar values are backend-dependent; see the
concrete implementations (notably ``DatabaseSessionService.merge_state``)
for the dialect-specific caveats around nested objects and ``None`` values.
For flat scalar values all backends behave identically.

Args:
app_name: The name of the app.
user_id: The ID of the user.
session_id: The ID of the session. Required for routing session-scoped
keys; need not exist when the delta has no session-scoped keys.
delta: The state delta to merge. An empty or ``None`` delta is a no-op.

Raises:
ValueError: If ``delta`` contains ``temp:``-prefixed keys, or if a
session-scoped key is supplied for a session that does not exist.
NotImplementedError: When the concrete ``BaseSessionService``
implementation does not support a server-side state merge.
"""
raise NotImplementedError(
f'{type(self).__name__} does not support merge_state.'
)

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
Expand Down
172 changes: 172 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
from datetime import datetime
from datetime import timezone
import json
import logging
from typing import Any
from typing import AsyncIterator
Expand All @@ -34,6 +35,7 @@
from sqlalchemy import event
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.engine import Connection
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
Expand Down Expand Up @@ -187,6 +189,30 @@ def _merge_state(
return merged_state


def _json_merge_set_clause(dialect_name: str) -> Optional[str]:
"""Returns a 'state = <merge>' SET clause for an atomic server-side JSON merge.

The returned SQL fragment merges a ``:delta`` bind parameter into the
existing ``state`` column in place. Only the ``state`` column is referenced,
so when issued via ``sqlalchemy.text`` no column ``onupdate`` callable (such
as the ``update_time`` revision marker) is triggered.

Returns None for dialects that do not provide a known server-side JSON merge
function; the caller raises ``NotImplementedError`` in that case.
"""
if dialect_name == _POSTGRESQL_DIALECT:
# JSONB concatenation: shallow top-level merge; an explicit JSON null is
# stored (key kept), matching append_event's Python dict union.
return "state = state || CAST(:delta AS JSONB)"
if dialect_name in (_MYSQL_DIALECT, _MARIADB_DIALECT):
# RFC 7396: recursive merge; a JSON null deletes the key.
return "state = JSON_MERGE_PATCH(state, CAST(:delta AS JSON))"
if dialect_name == _SQLITE_DIALECT:
# RFC 7396 semantics, like the native SqliteSessionService.
return "state = json_patch(state, :delta)"
return None


class _SchemaClasses:
"""A helper class to hold schema classes based on version."""

Expand Down Expand Up @@ -737,6 +763,152 @@ async def get_user_state(
return {}
return dict(storage_user_state.state or {})

@override
async def merge_state(
self,
*,
app_name: str,
user_id: str,
session_id: str,
delta: dict[str, Any],
) -> None:
"""Atomically merges a state delta server-side without appending an event.

Each scoped sub-delta is merged into its storage row with a single atomic
SQL statement (PostgreSQL ``state || delta``, MySQL/MariaDB
``JSON_MERGE_PATCH``, SQLite ``json_patch``) issued via ``text`` so that
only the ``state`` column is written. As a result:

* No ``events`` row is created.
* ``sessions.update_time`` (the optimistic-concurrency revision marker) is
never advanced, so a concurrently held ``Session`` does not go stale.
* Independent keys merged concurrently do not lose updates, because the
merge happens inside the database rather than via Python
read-modify-write.

Merge semantics for non-scalar values differ by dialect. On PostgreSQL the
top-level keys are merged shallowly and an explicit ``None`` is stored as
JSON ``null`` (the key is kept), matching ``append_event``'s Python ``dict``
union. On MySQL, MariaDB and SQLite the merge follows RFC 7396: nested
objects are merged recursively and a ``None`` value deletes the key. For
flat scalar values (counters, flags, balances -- the intended use case) all
dialects behave identically. Use ``append_event`` if you need uniform
Python-``dict``-union semantics for nested objects or ``None`` overwrites.

Args:
app_name: The name of the app.
user_id: The ID of the user.
session_id: The ID of the session. Used to route session-scoped (no
prefix) keys; need not exist when ``delta`` has no session-scoped keys.
delta: The state delta to merge, using the ``app:``/``user:``/no-prefix
convention. An empty or ``None`` delta is a no-op.

Raises:
ValueError: If ``delta`` contains ``temp:`` keys, or if a session-scoped
key targets a session that does not exist.
NotImplementedError: If the active SQL dialect provides no known
server-side JSON merge function.
"""
await self.prepare_tables()
if not delta:
return
if any(key.startswith(State.TEMP_PREFIX) for key in delta):
raise ValueError(
"merge_state does not support temp: keys; temp state is never"
" persisted."
)

merge_clause = _json_merge_set_clause(self.db_engine.dialect.name)
if merge_clause is None:
raise ValueError(
"merge_state is not supported for dialect"
f" {self.db_engine.dialect.name!r}: no server-side JSON merge"
" function is available."
)

state_deltas = _session_util.extract_state_delta(delta)
app_delta = state_deltas["app"]
user_delta = state_deltas["user"]
session_delta = state_deltas["session"]

schema = self._get_schema_classes()
app_table = schema.StorageAppState.__tablename__
user_table = schema.StorageUserState.__tablename__
session_table = schema.StorageSession.__tablename__
use_row_level_locking = self._supports_row_level_locking()

async with self._with_session_lock(
app_name=app_name,
user_id=user_id,
session_id=session_id,
):
async with self._rollback_on_exception_session() as sql_session:
if session_delta:
# A session-scoped key requires the session row to exist; never
# auto-create it (that is create_session's responsibility).
session_exists_stmt = (
select(schema.StorageSession.id)
.filter(schema.StorageSession.app_name == app_name)
.filter(schema.StorageSession.user_id == user_id)
.filter(schema.StorageSession.id == session_id)
)
if use_row_level_locking:
session_exists_stmt = session_exists_stmt.with_for_update()
session_exists = await sql_session.execute(session_exists_stmt)
if session_exists.scalar_one_or_none() is None:
raise ValueError(f"Session {session_id} not found.")

if app_delta:
# Ensure the row exists (handles concurrent inserts), then merge
# server-side. App/user rows have their own update_time which is not
# an OCC marker, so merging here is independent of any session.
await _get_or_create_state(
sql_session=sql_session,
state_model=schema.StorageAppState,
primary_key=app_name,
defaults={"app_name": app_name, "state": {}},
)
await sql_session.execute(
text(
f"UPDATE {app_table} SET {merge_clause} WHERE"
" app_name = :app_name"
),
{"app_name": app_name, "delta": json.dumps(app_delta)},
)
if user_delta:
await _get_or_create_state(
sql_session=sql_session,
state_model=schema.StorageUserState,
primary_key=(app_name, user_id),
defaults={"app_name": app_name, "user_id": user_id, "state": {}},
)
await sql_session.execute(
text(
f"UPDATE {user_table} SET {merge_clause} WHERE"
" app_name = :app_name AND user_id = :user_id"
),
{
"app_name": app_name,
"user_id": user_id,
"delta": json.dumps(user_delta),
},
)
if session_delta:
await sql_session.execute(
text(
f"UPDATE {session_table} SET {merge_clause} WHERE"
" app_name = :app_name AND user_id = :user_id"
" AND id = :session_id"
),
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"delta": json.dumps(session_delta),
},
)
await sql_session.commit()

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self.prepare_tables()
Expand Down
40 changes: 40 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,46 @@ async def get_user_state(
) -> dict[str, Any]:
return dict(self.user_state.get(app_name, {}).get(user_id, {}))

@override
async def merge_state(
self,
*,
app_name: str,
user_id: str,
session_id: str,
delta: dict[str, Any],
) -> None:
if not delta:
return
if any(key.startswith(State.TEMP_PREFIX) for key in delta):
raise ValueError(
'merge_state does not support temp: keys; temp state is never'
' persisted.'
)

state_deltas = _session_util.extract_state_delta(delta)
app_state_delta = state_deltas['app']
user_state_delta = state_deltas['user']
session_state_delta = state_deltas['session']

if session_state_delta:
storage_session = (
self.sessions.get(app_name, {}).get(user_id, {}).get(session_id)
)
if storage_session is None:
raise ValueError(f'Session {session_id} not found.')

if app_state_delta:
self.app_state.setdefault(app_name, {}).update(app_state_delta)
if user_state_delta:
self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update(
user_state_delta
)
if session_state_delta:
# Merge into the stored session's state without bumping
# last_update_time, so a concurrently held session does not go stale.
storage_session.state.update(session_state_delta)

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
Loading