Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/google/adk/events/event_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@ class EventCompaction(BaseModel):
"""The compacted content of the events."""


class RewindAuditReceipt(BaseModel): # type: ignore[misc]
"""Audit receipt metadata emitted for rewind operations."""

model_config = ConfigDict(
extra='forbid',
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)
"""The pydantic model config."""

rewind_before_invocation_id: str
"""The invocation ID that the rewind operation targeted."""

boundary_after_invocation_id: Optional[str] = None
"""The last invocation ID retained before the rewind boundary, if any."""

events_before_rewind: int
"""The number of events present before appending the rewind event."""

events_after_rewind: int
"""The number of pre-existing events retained after rewind filtering."""

history_before_hash: str
"""Canonical hash of the full pre-rewind event history."""

history_after_hash: str
"""Canonical hash of the retained pre-rewind event history."""

receipt_hash: str
"""Tamper-evident hash over the rewind receipt summary."""


class EventActions(BaseModel):
"""Represents the actions attached to an event."""

Expand Down Expand Up @@ -108,3 +140,6 @@ class EventActions(BaseModel):

rewind_before_invocation_id: Optional[str] = None
"""The invocation id to rewind to. This is only set for rewind event."""

rewind_audit_receipt: Optional[RewindAuditReceipt] = None
"""Structured receipt proving rewind boundaries and history digests."""
70 changes: 70 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from __future__ import annotations

import asyncio
import hashlib
import inspect
import json
import logging
from pathlib import Path
import queue
Expand Down Expand Up @@ -47,6 +49,7 @@
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
from .events.event import Event
from .events.event import EventActions
from .events.event_actions import RewindAuditReceipt
from .flows.llm_flows import contents
from .flows.llm_flows.functions import find_matching_function_call
from .memory.base_memory_service import BaseMemoryService
Expand Down Expand Up @@ -594,6 +597,11 @@ async def rewind_async(
artifact_delta = await self._compute_artifact_delta_for_rewind(
session, rewind_event_index
)
rewind_audit_receipt = self._build_rewind_audit_receipt(
session=session,
rewind_event_index=rewind_event_index,
rewind_before_invocation_id=rewind_before_invocation_id,
)

# Create rewind event
rewind_event = Event(
Expand All @@ -603,13 +611,75 @@ async def rewind_async(
rewind_before_invocation_id=rewind_before_invocation_id,
state_delta=state_delta,
artifact_delta=artifact_delta,
rewind_audit_receipt=rewind_audit_receipt,
),
)

logger.info('Rewinding session to invocation: %s', rewind_event)

await self.session_service.append_event(session=session, event=rewind_event)

def _build_rewind_audit_receipt(
self,
*,
session: Session,
rewind_event_index: int,
rewind_before_invocation_id: str,
) -> RewindAuditReceipt:
"""Builds a deterministic audit receipt for a rewind operation."""
events_before = session.events
events_after = session.events[:rewind_event_index]
boundary_after_invocation_id = None
if rewind_event_index > 0:
boundary_after_invocation_id = session.events[
rewind_event_index - 1
].invocation_id

history_before_hash = self._hash_rewind_events(events_before)
history_after_hash = self._hash_rewind_events(events_after)

receipt_payload = {
'rewind_before_invocation_id': rewind_before_invocation_id,
'boundary_after_invocation_id': boundary_after_invocation_id,
'events_before_rewind': len(events_before),
'events_after_rewind': len(events_after),
'history_before_hash': history_before_hash,
'history_after_hash': history_after_hash,
}
receipt_hash = self._hash_rewind_payload(receipt_payload)

return RewindAuditReceipt(
**receipt_payload,
receipt_hash=receipt_hash,
)
Comment on lines 651 to 654
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instantiation of RewindAuditReceipt repeats all the fields that were just defined in receipt_payload. You can simplify this and avoid duplication by unpacking the receipt_payload dictionary when creating the RewindAuditReceipt instance. This makes the code more concise and easier to maintain if the fields change in the future.

    return RewindAuditReceipt(**receipt_payload, receipt_hash=receipt_hash)


def _hash_rewind_events(self, events: List[Event]) -> str:
"""Hashes event summaries for deterministic rewind audit receipts."""
summarized_events = [
{
'event_id': event.id,
'invocation_id': event.invocation_id,
'author': event.author,
'state_delta': event.actions.state_delta,
'artifact_delta': event.actions.artifact_delta,
'rewind_before_invocation_id': (
event.actions.rewind_before_invocation_id
),
}
for event in events
]
return self._hash_rewind_payload({'events': summarized_events})

def _hash_rewind_payload(self, payload: dict[str, Any]) -> str:
"""Returns a canonical SHA-256 digest for rewind audit payloads."""
canonical_json = json.dumps(
payload,
sort_keys=True,
separators=(',', ':'),
ensure_ascii=True,
)
return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest()

async def _compute_state_delta_for_rewind(
self, session: Session, rewind_event_index: int
) -> dict[str, Any]:
Expand Down
46 changes: 46 additions & 0 deletions tests/unittests/runners/test_runner_rewind.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ async def test_rewind_async_with_state_and_artifacts(self):
)
is None
)
rewind_receipt = session.events[-1].actions.rewind_audit_receipt
assert rewind_receipt is not None
assert rewind_receipt.rewind_before_invocation_id == "invocation2"
assert rewind_receipt.boundary_after_invocation_id == "invocation1"
assert rewind_receipt.events_before_rewind == 3
assert rewind_receipt.events_after_rewind == 1
assert rewind_receipt.history_before_hash
assert rewind_receipt.history_after_hash
assert rewind_receipt.receipt_hash

@pytest.mark.asyncio
async def test_rewind_async_not_first_invocation(self):
Expand Down Expand Up @@ -246,3 +255,40 @@ async def test_rewind_async_not_first_invocation(self):
session_id=session_id,
filename="f2",
) == types.Part.from_text(text="f2v0")

@pytest.mark.asyncio
async def test_rewind_receipt_hash_is_deterministic(self):
"""Tests that rewind receipt hashes are stable for the same history."""
runner = self.runner
user_id = "test_user"
session_id = "test_session"
session = await runner.session_service.create_session(
app_name=runner.app_name, user_id=user_id, session_id=session_id
)

for invocation_id in ("invocation1", "invocation2", "invocation3"):
await runner.session_service.append_event(
session=session,
event=Event(
invocation_id=invocation_id,
author="agent",
actions=EventActions(state_delta={invocation_id: invocation_id}),
),
)

first_receipt = runner._build_rewind_audit_receipt(
session=session,
rewind_event_index=1,
rewind_before_invocation_id="invocation2",
)
second_receipt = runner._build_rewind_audit_receipt(
session=session,
rewind_event_index=1,
rewind_before_invocation_id="invocation2",
)

assert (
first_receipt.history_before_hash == second_receipt.history_before_hash
)
assert first_receipt.history_after_hash == second_receipt.history_after_hash
assert first_receipt.receipt_hash == second_receipt.receipt_hash