Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import cast

import aiosqlite

from ...items import TResponseInputItem
from ...memory import SessionABC
from ...memory.session import is_session_input_item
from ...memory.session_settings import SessionSettings, resolve_session_limit


Expand Down Expand Up @@ -150,7 +150,8 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
for (message_data,) in rows:
try:
item = json.loads(message_data)
items.append(item)
if is_session_input_item(item):
items.append(item)
except json.JSONDecodeError:
continue

Expand Down Expand Up @@ -220,24 +221,27 @@ async def pop_item(self) -> TResponseInputItem | None:
while result:
message_data = result[0]
try:
return cast(TResponseInputItem, json.loads(message_data))
item = json.loads(message_data)
if is_session_input_item(item):
return item
except (json.JSONDecodeError, TypeError):
cursor = await conn.execute(
f"""
DELETE FROM {self.messages_table}
WHERE id = (
SELECT id FROM {self.messages_table}
WHERE session_id = ?
ORDER BY id DESC
LIMIT 1
)
RETURNING message_data
""",
(self.session_id,),
pass
cursor = await conn.execute(
f"""
DELETE FROM {self.messages_table}
WHERE id = (
SELECT id FROM {self.messages_table}
WHERE session_id = ?
ORDER BY id DESC
LIMIT 1
)
result = await cursor.fetchone()
await cursor.close()
await conn.commit()
RETURNING message_data
""",
(self.session_id,),
)
result = await cursor.fetchone()
await cursor.close()
await conn.commit()

return None

Expand Down
17 changes: 12 additions & 5 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from ...items import TResponseInputItem
from ...logger import logger
from ...memory.session import SessionABC
from ...memory.session import SessionABC, is_session_input_item
from ...memory.session_settings import SessionSettings, resolve_session_limit

# Type alias for consistency levels
Expand Down Expand Up @@ -180,7 +180,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str:

async def _deserialize_item(self, item: str) -> TResponseInputItem:
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
return json.loads(item) # type: ignore[no-any-return]
decoded = json.loads(item)
if not is_session_input_item(decoded):
raise TypeError("Decoded session item is not a response input item")
return decoded

def _decode_messages(self, data: bytes | None, *, strict: bool = False) -> list[Any]:
if not data:
Expand Down Expand Up @@ -284,7 +287,8 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
item = await self._deserialize_item(msg)
else:
item = msg
items.append(item)
if is_session_input_item(item):
items.append(item)
except (json.JSONDecodeError, TypeError):
continue
return items
Expand Down Expand Up @@ -381,8 +385,11 @@ async def pop_item(self) -> TResponseInputItem | None:
raise
try:
if isinstance(last_item, str):
return await self._deserialize_item(last_item)
return last_item # type: ignore[no-any-return]
item = await self._deserialize_item(last_item)
else:
item = last_item
if is_session_input_item(item):
return item
except (json.JSONDecodeError, TypeError):
continue

Expand Down
7 changes: 5 additions & 2 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
)

from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session import SessionABC, is_session_input_item
from ...memory.session_settings import SessionSettings, resolve_session_limit

# Identifies this library in the MongoDB handshake for server-side telemetry.
Expand Down Expand Up @@ -241,7 +241,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str:

async def _deserialize_item(self, raw: str) -> TResponseInputItem:
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
return json.loads(raw) # type: ignore[no-any-return]
decoded = json.loads(raw)
if not is_session_input_item(decoded):
raise TypeError("Decoded session item is not a response input item")
return decoded

# ------------------------------------------------------------------
# Session protocol implementation
Expand Down
11 changes: 7 additions & 4 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)

from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session import SessionABC, is_session_input_item
from ...memory.session_settings import SessionSettings, resolve_session_limit


Expand Down Expand Up @@ -126,7 +126,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str:

async def _deserialize_item(self, item: str) -> TResponseInputItem:
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
return json.loads(item) # type: ignore[no-any-return] # json.loads returns Any but we know the structure
decoded = json.loads(item)
if not is_session_input_item(decoded):
raise TypeError("Decoded session item is not a response input item")
return decoded

async def _get_next_id(self) -> int:
"""Get the next message ID using Redis INCR for atomic increment."""
Expand Down Expand Up @@ -178,7 +181,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
msg_str = raw_msg # Already a string
item = await self._deserialize_item(msg_str)
items.append(item)
except (json.JSONDecodeError, UnicodeDecodeError):
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
# Skip corrupted messages
continue

Expand Down Expand Up @@ -242,7 +245,7 @@ async def pop_item(self) -> TResponseInputItem | None:
else:
msg_str = raw_msg # Already a string
return await self._deserialize_item(msg_str)
except (json.JSONDecodeError, UnicodeDecodeError):
except (json.JSONDecodeError, TypeError, UnicodeDecodeError):
# Drop corrupted messages and keep looking for a valid item.
continue

Expand Down
9 changes: 6 additions & 3 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session import SessionABC, is_session_input_item
from ...memory.session_settings import SessionSettings, resolve_session_limit


Expand Down Expand Up @@ -249,7 +249,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str:

async def _deserialize_item(self, item: str) -> TResponseInputItem:
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
return json.loads(item) # type: ignore[no-any-return]
decoded = json.loads(item)
if not is_session_input_item(decoded):
raise TypeError("Decoded session item is not a response input item")
return decoded

# ------------------------------------------------------------------
# Session protocol implementation
Expand Down Expand Up @@ -321,7 +324,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
for raw in rows:
try:
items.append(await self._deserialize_item(raw))
except json.JSONDecodeError:
except (json.JSONDecodeError, TypeError):
# Skip corrupted rows
continue
return items
Expand Down
5 changes: 5 additions & 0 deletions src/agents/memory/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from .session_settings import SessionSettings


def is_session_input_item(value: object) -> TypeGuard[TResponseInputItem]:
"""Return whether a decoded session payload is shaped like a response input item."""
return isinstance(value, dict)


@runtime_checkable
class Session(Protocol):
"""Protocol for session implementations.
Expand Down
37 changes: 20 additions & 17 deletions src/agents/memory/sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import ClassVar

from ..items import TResponseInputItem
from .session import SessionABC
from .session import SessionABC, is_session_input_item
from .session_settings import SessionSettings, resolve_session_limit


Expand Down Expand Up @@ -245,7 +245,8 @@ def _get_items_sync():
for (message_data,) in rows:
try:
item = json.loads(message_data)
items.append(item)
if is_session_input_item(item):
items.append(item)
except (json.JSONDecodeError, TypeError):
# Skip invalid JSON entries
continue
Expand Down Expand Up @@ -301,24 +302,26 @@ def _pop_item_sync():
message_data = result[0]
try:
item = json.loads(message_data)
return item
if is_session_input_item(item):
return item
except (json.JSONDecodeError, TypeError):
# Drop corrupted JSON entries and keep looking for a valid item.
cursor = conn.execute(
f"""
DELETE FROM {self.messages_table}
WHERE id = (
SELECT id FROM {self.messages_table}
WHERE session_id = ?
ORDER BY id DESC
LIMIT 1
)
RETURNING message_data
""",
(self.session_id,),
pass
cursor = conn.execute(
f"""
DELETE FROM {self.messages_table}
WHERE id = (
SELECT id FROM {self.messages_table}
WHERE session_id = ?
ORDER BY id DESC
LIMIT 1
)
result = cursor.fetchone()
conn.commit()
RETURNING message_data
""",
(self.session_id,),
)
result = cursor.fetchone()
conn.commit()

return None

Expand Down
50 changes: 50 additions & 0 deletions tests/extensions/memory/test_async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,56 @@ async def test_async_sqlite_session_pop_item_returns_none_after_dropping_only_co
await session.close()


async def test_async_sqlite_session_get_items_skips_json_values_that_are_not_input_items():
"""get_items skips JSON values that decode but are not response input items."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "async_get_non_item_json.db"
session = AsyncSQLiteSession("async_get_non_item_json", db_path)

valid_item: TResponseInputItem = {"role": "user", "content": "valid"}
await session.add_items([valid_item])

conn = await session._get_connection()
await conn.executemany(
f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)",
[
(session.session_id, json.dumps("not an input item")),
(session.session_id, json.dumps(["also", "not", "an", "input", "item"])),
(session.session_id, json.dumps(123)),
],
)
await conn.commit()

assert await session.get_items() == [valid_item]

await session.close()


async def test_async_sqlite_session_pop_item_skips_json_values_that_are_not_input_items():
"""pop_item drops JSON values that decode but are not response input items."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "async_pop_non_item_json.db"
session = AsyncSQLiteSession("async_pop_non_item_json", db_path)

valid_item: TResponseInputItem = {"role": "user", "content": "valid"}
await session.add_items([valid_item])

conn = await session._get_connection()
await conn.executemany(
f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)",
[
(session.session_id, json.dumps("not an input item")),
(session.session_id, json.dumps(["also", "not", "an", "input", "item"])),
],
)
await conn.commit()

assert await session.pop_item() == valid_item
assert await session.get_items() == []

await session.close()


async def test_async_sqlite_session_get_items_limit():
"""Test AsyncSQLiteSession get_items limit handling."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
19 changes: 19 additions & 0 deletions tests/extensions/memory/test_dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,25 @@ async def test_already_deserialized_messages(fake_dapr_client: FakeDaprClient):
await session.close()


async def test_non_item_json_values_are_skipped(fake_dapr_client: FakeDaprClient):
"""JSON-valid values that are not response input items are skipped."""
session = await _create_test_session(fake_dapr_client, "non_item_json_test")

messages_list = [
{"role": "user", "content": "valid"},
"not an input item",
["also", "not", "an", "item"],
json.dumps(123),
]
fake_dapr_client._state[session._messages_key] = json.dumps(messages_list).encode("utf-8")

assert await session.get_items() == [{"role": "user", "content": "valid"}]
assert await session.pop_item() == {"role": "user", "content": "valid"}
assert await session.get_items() == []

await session.close()


async def test_context_manager(fake_dapr_client: FakeDaprClient):
"""Test that DaprSession works as an async context manager."""
# Test that the context manager enters and exits properly
Expand Down
Loading
Loading