diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 27a23b1cbe..3ccb1acc09 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,12 +5,12 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast import aiosqlite from ...items import TResponseInputItem from ...memory import SessionABC +from ...memory.session import is_session_input_item from ...memory.session_settings import SessionSettings, resolve_session_limit @@ -150,7 +150,8 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: for (message_data,) in rows: try: item = json.loads(message_data) - items.append(item) + if is_session_input_item(item): + items.append(item) except json.JSONDecodeError: continue @@ -220,24 +221,27 @@ async def pop_item(self) -> TResponseInputItem | None: while result: message_data = result[0] try: - return cast(TResponseInputItem, json.loads(message_data)) + item = json.loads(message_data) + if is_session_input_item(item): + return item except (json.JSONDecodeError, TypeError): - cursor = await conn.execute( - f""" - DELETE FROM {self.messages_table} - WHERE id = ( - SELECT id FROM {self.messages_table} - WHERE session_id = ? - ORDER BY id DESC - LIMIT 1 - ) - RETURNING message_data - """, - (self.session_id,), + pass + cursor = await conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY id DESC + LIMIT 1 ) - result = await cursor.fetchone() - await cursor.close() - await conn.commit() + RETURNING message_data + """, + (self.session_id,), + ) + result = await cursor.fetchone() + await cursor.close() + await conn.commit() return None diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 6ac68f6020..2768c7b3f3 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -44,7 +44,7 @@ from ...items import TResponseInputItem from ...logger import logger -from ...memory.session import SessionABC +from ...memory.session import SessionABC, is_session_input_item from ...memory.session_settings import SessionSettings, resolve_session_limit # Type alias for consistency levels @@ -180,7 +180,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str: async def _deserialize_item(self, item: str) -> TResponseInputItem: """Deserialize a JSON string to an item. Can be overridden by subclasses.""" - return json.loads(item) # type: ignore[no-any-return] + decoded = json.loads(item) + if not is_session_input_item(decoded): + raise TypeError("Decoded session item is not a response input item") + return decoded def _decode_messages(self, data: bytes | None, *, strict: bool = False) -> list[Any]: if not data: @@ -284,7 +287,8 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: item = await self._deserialize_item(msg) else: item = msg - items.append(item) + if is_session_input_item(item): + items.append(item) except (json.JSONDecodeError, TypeError): continue return items @@ -381,8 +385,11 @@ async def pop_item(self) -> TResponseInputItem | None: raise try: if isinstance(last_item, str): - return await self._deserialize_item(last_item) - return last_item # type: ignore[no-any-return] + item = await self._deserialize_item(last_item) + else: + item = last_item + if is_session_input_item(item): + return item except (json.JSONDecodeError, TypeError): continue diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 113acdc6af..f6127aa2b3 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -59,7 +59,7 @@ ) from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, is_session_input_item from ...memory.session_settings import SessionSettings, resolve_session_limit # Identifies this library in the MongoDB handshake for server-side telemetry. @@ -241,7 +241,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str: async def _deserialize_item(self, raw: str) -> TResponseInputItem: """Deserialize a JSON string to an item. Can be overridden by subclasses.""" - return json.loads(raw) # type: ignore[no-any-return] + decoded = json.loads(raw) + if not is_session_input_item(decoded): + raise TypeError("Decoded session item is not a response input item") + return decoded # ------------------------------------------------------------------ # Session protocol implementation diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 11e2dd838b..014f10a08b 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -40,7 +40,7 @@ ) from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, is_session_input_item from ...memory.session_settings import SessionSettings, resolve_session_limit @@ -126,7 +126,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str: async def _deserialize_item(self, item: str) -> TResponseInputItem: """Deserialize a JSON string to an item. Can be overridden by subclasses.""" - return json.loads(item) # type: ignore[no-any-return] # json.loads returns Any but we know the structure + decoded = json.loads(item) + if not is_session_input_item(decoded): + raise TypeError("Decoded session item is not a response input item") + return decoded async def _get_next_id(self) -> int: """Get the next message ID using Redis INCR for atomic increment.""" @@ -178,7 +181,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: msg_str = raw_msg # Already a string item = await self._deserialize_item(msg_str) items.append(item) - except (json.JSONDecodeError, UnicodeDecodeError): + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): # Skip corrupted messages continue @@ -242,7 +245,7 @@ async def pop_item(self) -> TResponseInputItem | None: else: msg_str = raw_msg # Already a string return await self._deserialize_item(msg_str) - except (json.JSONDecodeError, UnicodeDecodeError): + except (json.JSONDecodeError, TypeError, UnicodeDecodeError): # Drop corrupted messages and keep looking for a valid item. continue diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index fd2502e24b..0676b8041e 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -49,7 +49,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, is_session_input_item from ...memory.session_settings import SessionSettings, resolve_session_limit @@ -249,7 +249,10 @@ async def _serialize_item(self, item: TResponseInputItem) -> str: async def _deserialize_item(self, item: str) -> TResponseInputItem: """Deserialize a JSON string to an item. Can be overridden by subclasses.""" - return json.loads(item) # type: ignore[no-any-return] + decoded = json.loads(item) + if not is_session_input_item(decoded): + raise TypeError("Decoded session item is not a response input item") + return decoded # ------------------------------------------------------------------ # Session protocol implementation @@ -321,7 +324,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: for raw in rows: try: items.append(await self._deserialize_item(raw)) - except json.JSONDecodeError: + except (json.JSONDecodeError, TypeError): # Skip corrupted rows continue return items diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 1781b7ac9f..341648613c 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -10,6 +10,11 @@ from .session_settings import SessionSettings +def is_session_input_item(value: object) -> TypeGuard[TResponseInputItem]: + """Return whether a decoded session payload is shaped like a response input item.""" + return isinstance(value, dict) + + @runtime_checkable class Session(Protocol): """Protocol for session implementations. diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 3a69f9883a..26406662e0 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -10,7 +10,7 @@ from typing import ClassVar from ..items import TResponseInputItem -from .session import SessionABC +from .session import SessionABC, is_session_input_item from .session_settings import SessionSettings, resolve_session_limit @@ -245,7 +245,8 @@ def _get_items_sync(): for (message_data,) in rows: try: item = json.loads(message_data) - items.append(item) + if is_session_input_item(item): + items.append(item) except (json.JSONDecodeError, TypeError): # Skip invalid JSON entries continue @@ -301,24 +302,26 @@ def _pop_item_sync(): message_data = result[0] try: item = json.loads(message_data) - return item + if is_session_input_item(item): + return item except (json.JSONDecodeError, TypeError): # Drop corrupted JSON entries and keep looking for a valid item. - cursor = conn.execute( - f""" - DELETE FROM {self.messages_table} - WHERE id = ( - SELECT id FROM {self.messages_table} - WHERE session_id = ? - ORDER BY id DESC - LIMIT 1 - ) - RETURNING message_data - """, - (self.session_id,), + pass + cursor = conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE id = ( + SELECT id FROM {self.messages_table} + WHERE session_id = ? + ORDER BY id DESC + LIMIT 1 ) - result = cursor.fetchone() - conn.commit() + RETURNING message_data + """, + (self.session_id,), + ) + result = cursor.fetchone() + conn.commit() return None diff --git a/tests/extensions/memory/test_async_sqlite_session.py b/tests/extensions/memory/test_async_sqlite_session.py index 7269951829..c72b0e34e1 100644 --- a/tests/extensions/memory/test_async_sqlite_session.py +++ b/tests/extensions/memory/test_async_sqlite_session.py @@ -119,6 +119,56 @@ async def test_async_sqlite_session_pop_item_returns_none_after_dropping_only_co await session.close() +async def test_async_sqlite_session_get_items_skips_json_values_that_are_not_input_items(): + """get_items skips JSON values that decode but are not response input items.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "async_get_non_item_json.db" + session = AsyncSQLiteSession("async_get_non_item_json", db_path) + + valid_item: TResponseInputItem = {"role": "user", "content": "valid"} + await session.add_items([valid_item]) + + conn = await session._get_connection() + await conn.executemany( + f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)", + [ + (session.session_id, json.dumps("not an input item")), + (session.session_id, json.dumps(["also", "not", "an", "input", "item"])), + (session.session_id, json.dumps(123)), + ], + ) + await conn.commit() + + assert await session.get_items() == [valid_item] + + await session.close() + + +async def test_async_sqlite_session_pop_item_skips_json_values_that_are_not_input_items(): + """pop_item drops JSON values that decode but are not response input items.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "async_pop_non_item_json.db" + session = AsyncSQLiteSession("async_pop_non_item_json", db_path) + + valid_item: TResponseInputItem = {"role": "user", "content": "valid"} + await session.add_items([valid_item]) + + conn = await session._get_connection() + await conn.executemany( + f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)", + [ + (session.session_id, json.dumps("not an input item")), + (session.session_id, json.dumps(["also", "not", "an", "input", "item"])), + ], + ) + await conn.commit() + + assert await session.pop_item() == valid_item + assert await session.get_items() == [] + + await session.close() + + async def test_async_sqlite_session_get_items_limit(): """Test AsyncSQLiteSession get_items limit handling.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/extensions/memory/test_dapr_session.py b/tests/extensions/memory/test_dapr_session.py index 9766f35d40..3ae80dcc0d 100644 --- a/tests/extensions/memory/test_dapr_session.py +++ b/tests/extensions/memory/test_dapr_session.py @@ -837,6 +837,25 @@ async def test_already_deserialized_messages(fake_dapr_client: FakeDaprClient): await session.close() +async def test_non_item_json_values_are_skipped(fake_dapr_client: FakeDaprClient): + """JSON-valid values that are not response input items are skipped.""" + session = await _create_test_session(fake_dapr_client, "non_item_json_test") + + messages_list = [ + {"role": "user", "content": "valid"}, + "not an input item", + ["also", "not", "an", "item"], + json.dumps(123), + ] + fake_dapr_client._state[session._messages_key] = json.dumps(messages_list).encode("utf-8") + + assert await session.get_items() == [{"role": "user", "content": "valid"}] + assert await session.pop_item() == {"role": "user", "content": "valid"} + assert await session.get_items() == [] + + await session.close() + + async def test_context_manager(fake_dapr_client: FakeDaprClient): """Test that DaprSession works as an async context manager.""" # Test that the context manager enters and exits properly diff --git a/tests/extensions/memory/test_mongodb_session.py b/tests/extensions/memory/test_mongodb_session.py index cd7954e3ae..bc2d1ee902 100644 --- a/tests/extensions/memory/test_mongodb_session.py +++ b/tests/extensions/memory/test_mongodb_session.py @@ -8,6 +8,7 @@ from __future__ import annotations +import json import sys import types from collections import defaultdict @@ -539,6 +540,31 @@ async def test_non_string_message_data_is_skipped(session: MongoDBSession) -> No assert items[0].get("content") == "valid" +async def test_json_values_that_are_not_input_items_are_skipped(session: MongoDBSession) -> None: + """Documents whose JSON decodes to non-item values are silently skipped.""" + await session.add_items([{"role": "user", "content": "valid"}]) + + for seq, payload in enumerate( + [ + json.dumps("not an input item"), + json.dumps(["also", "not", "an", "item"]), + json.dumps(123), + ], + start=10, + ): + bad_doc = { + "_id": FakeObjectId(), + "session_id": session.session_id, + "seq": seq, + "message_data": payload, + } + session._messages._docs[id(bad_doc["_id"])] = bad_doc + + items = await session.get_items() + assert len(items) == 1 + assert items[0].get("content") == "valid" + + async def test_pop_item_skips_corrupt_most_recent(session: MongoDBSession) -> None: """pop_item must skip a corrupt most-recent document and return the next valid one.""" await session.add_items([{"role": "user", "content": "valid"}]) @@ -551,6 +577,13 @@ async def test_pop_item_skips_corrupt_most_recent(session: MongoDBSession) -> No "message_data": "not valid json {{{", } session._messages._docs[id(bad_doc["_id"])] = bad_doc + non_item_doc = { + "_id": FakeObjectId(), + "session_id": session.session_id, + "seq": 1000, + "message_data": json.dumps("not an input item"), + } + session._messages._docs[id(non_item_doc["_id"])] = non_item_doc popped = await session.pop_item() assert popped is not None diff --git a/tests/extensions/memory/test_redis_session.py b/tests/extensions/memory/test_redis_session.py index b5011cdd4d..dd504b86cf 100644 --- a/tests/extensions/memory/test_redis_session.py +++ b/tests/extensions/memory/test_redis_session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import cast import pytest @@ -725,6 +726,8 @@ async def test_corrupted_data_handling(): # Add invalid JSON directly using the typed Redis client await _safe_rpush(fake_redis, messages_key, "invalid json data") await _safe_rpush(fake_redis, messages_key, "{incomplete json") + await _safe_rpush(fake_redis, messages_key, json.dumps("not an input item")) + await _safe_rpush(fake_redis, messages_key, json.dumps(["also", "not", "an", "item"])) # get_items should skip corrupted data and return valid items items = await session.get_items() @@ -740,7 +743,7 @@ async def test_corrupted_data_handling(): assert items[0].get("content") == "valid message" assert items[1].get("content") == "valid after corruption" - # Test pop_item with corrupted data at the end. + # Test pop_item with corrupted and non-item JSON data at the end. await _safe_rpush(fake_redis, messages_key, "corrupted at end") # The corrupted item should be dropped and pop_item should keep looking diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index fe30993699..7cbeaef9bf 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -211,6 +211,36 @@ async def test_pop_item_skips_corrupt_most_recent(): assert await session.get_items() == [] +async def test_json_values_that_are_not_input_items_are_skipped(): + """Rows whose JSON decodes to non-item values are skipped.""" + session = SQLAlchemySession.from_url("non_item_json", url=DB_URL, create_tables=True) + + valid_item: TResponseInputItem = {"role": "user", "content": "valid"} + await session.add_items([valid_item]) + + await session._ensure_tables() + async with session._session_factory() as sess: + async with sess.begin(): + await sess.execute( + insert(session._messages), + [ + { + "session_id": session.session_id, + "message_data": json.dumps("not an input item"), + }, + { + "session_id": session.session_id, + "message_data": json.dumps(["also", "not", "an", "item"]), + }, + {"session_id": session.session_id, "message_data": json.dumps(123)}, + ], + ) + + assert await session.get_items() == [valid_item] + assert await session.pop_item() == valid_item + assert await session.get_items() == [] + + async def test_pop_item_returns_none_after_dropping_only_corrupt_rows(): """pop_item removes corrupt rows and returns None when no valid items remain.""" session = SQLAlchemySession.from_url("pop_only_corrupt", url=DB_URL, create_tables=True) diff --git a/tests/memory/test_session.py b/tests/memory/test_session.py index f9cc324d2e..482f822601 100644 --- a/tests/memory/test_session.py +++ b/tests/memory/test_session.py @@ -1,6 +1,7 @@ """Tests for session memory functionality.""" import asyncio +import json import sqlite3 import tempfile from pathlib import Path @@ -377,6 +378,58 @@ async def test_sqlite_session_pop_item_returns_none_after_dropping_only_corrupt_ session.close() +@pytest.mark.asyncio +async def test_sqlite_session_get_items_skips_json_values_that_are_not_input_items(): + """get_items skips JSON values that decode but are not response input items.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_get_non_item_json.db" + session = SQLiteSession("get_non_item_json", db_path) + + valid_item: TResponseInputItem = {"role": "user", "content": "valid"} + await session.add_items([valid_item]) + + with session._locked_connection() as conn: + conn.executemany( + f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)", + [ + (session.session_id, json.dumps("not an input item")), + (session.session_id, json.dumps(["also", "not", "an", "input", "item"])), + (session.session_id, json.dumps(123)), + ], + ) + conn.commit() + + assert await session.get_items() == [valid_item] + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_pop_item_skips_json_values_that_are_not_input_items(): + """pop_item drops JSON values that decode but are not response input items.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pop_non_item_json.db" + session = SQLiteSession("pop_non_item_json", db_path) + + valid_item: TResponseInputItem = {"role": "user", "content": "valid"} + await session.add_items([valid_item]) + + with session._locked_connection() as conn: + conn.executemany( + f"INSERT INTO {session.messages_table} (session_id, message_data) VALUES (?, ?)", + [ + (session.session_id, json.dumps("not an input item")), + (session.session_id, json.dumps(["also", "not", "an", "input", "item"])), + ], + ) + conn.commit() + + assert await session.pop_item() == valid_item + assert await session.get_items() == [] + + session.close() + + @pytest.mark.asyncio async def test_sqlite_session_get_items_with_limit(): """Test SQLiteSession get_items with limit parameter."""