diff --git a/src/google/adk/agents/context.py b/src/google/adk/agents/context.py index 70dfa05f59..8e0e23bb6d 100644 --- a/src/google/adk/agents/context.py +++ b/src/google/adk/agents/context.py @@ -136,14 +136,15 @@ async def load_artifact( async def save_artifact( self, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], custom_metadata: dict[str, Any] | None = None, ) -> int: """Saves an artifact and records it as delta for the current session. Args: filename: The filename of the artifact. - artifact: The artifact to save. + artifact: The artifact to save. Can be a types.Part object or a + dict-shaped (serialized) artifact. custom_metadata: Custom metadata to associate with the artifact. Returns: diff --git a/src/google/adk/artifacts/base_artifact_service.py b/src/google/adk/artifacts/base_artifact_service.py index 1a265f8ad9..48fc7457fd 100644 --- a/src/google/adk/artifacts/base_artifact_service.py +++ b/src/google/adk/artifacts/base_artifact_service.py @@ -63,6 +63,23 @@ class ArtifactVersion(BaseModel): class BaseArtifactService(ABC): """Abstract base class for artifact services.""" + @staticmethod + def _convert_artifact_if_dict( + artifact: types.Part | dict[str, Any], + ) -> types.Part: + """Converts a dict-shaped artifact to types.Part if necessary. + + Args: + artifact: The artifact to convert. Can be a types.Part or dict. + + Returns: + A types.Part object. If input is already a Part, returns as-is. + If input is a dict, converts it to Part via model_validate. + """ + if isinstance(artifact, dict): + return types.Part.model_validate(artifact) + return artifact + @abstractmethod async def save_artifact( self, @@ -70,7 +87,7 @@ async def save_artifact( app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -84,10 +101,11 @@ async def save_artifact( app_name: The app name. user_id: The user ID. filename: The filename of the artifact. - artifact: The artifact to save. If the artifact consists of `file_data`, - the artifact service assumes its content has been uploaded separately, - and this method will associate the `file_data` with the artifact if - necessary. + artifact: The artifact to save. Can be a types.Part object or a + dict-shaped (serialized) artifact that will be converted to types.Part. + If the artifact consists of `file_data`, the artifact service assumes + its content has been uploaded separately, and this method will associate + the `file_data` with the artifact if necessary. session_id: The session ID. If `None`, the artifact is user-scoped. custom_metadata: custom metadata to associate with the artifact. diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index be5adb4818..37457b8b29 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -314,7 +314,7 @@ async def save_artifact( app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: @@ -326,6 +326,9 @@ async def save_artifact( computed scope root; absolute paths or inputs that traverse outside that root (for example ``"../../secret.txt"``) raise ``ValueError``. """ + # Convert dict-shaped artifact to types.Part if necessary + artifact = self._convert_artifact_if_dict(artifact) + return await asyncio.to_thread( self._save_artifact_sync, user_id, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index 4108cfb06b..03692c3605 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -61,10 +61,13 @@ async def save_artifact( app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + # Convert dict-shaped artifact to types.Part if necessary + artifact = self._convert_artifact_if_dict(artifact) + return await asyncio.to_thread( self._save_artifact, app_name, diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 45552b1452..421b66ec42 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -99,10 +99,13 @@ async def save_artifact( app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + # Convert dict-shaped artifact to types.Part if necessary + artifact = self._convert_artifact_if_dict(artifact) + path = self._artifact_path(app_name, user_id, filename, session_id) if path not in self.artifacts: self.artifacts[path] = [] diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index c61f855fc7..758e7640e2 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -232,8 +232,8 @@ class SaveArtifactRequest(common.BaseModel): """Request payload for saving a new artifact.""" filename: str = Field(description="Artifact filename.") - artifact: types.Part = Field( - description="Artifact payload encoded as google.genai.types.Part." + artifact: types.Part | dict[str, Any] = Field( + description="Artifact payload encoded as google.genai.types.Part or as a dict-shaped artifact." ) custom_metadata: Optional[dict[str, Any]] = Field( default=None, diff --git a/src/google/adk/tools/_forwarding_artifact_service.py b/src/google/adk/tools/_forwarding_artifact_service.py index 9667e8d4c3..48fbeb1aa8 100644 --- a/src/google/adk/tools/_forwarding_artifact_service.py +++ b/src/google/adk/tools/_forwarding_artifact_service.py @@ -42,10 +42,12 @@ async def save_artifact( app_name: str, user_id: str, filename: str, - artifact: types.Part, + artifact: types.Part | dict[str, Any], session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: + # Delegate to parent tool context, which will handle conversion in the + # concrete artifact service implementation. return await self.tool_context.save_artifact( filename=filename, artifact=artifact, diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index ec74f8abe3..3b052b4d1e 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -766,3 +766,102 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): filename=str(absolute_in_scope), artifact=part, ) + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_load_dict_shaped_artifact( + service_type, artifact_service_factory +): + """Tests saving and loading dict-shaped artifacts. + + This tests the fix for accepting dict-shaped (serialized) artifacts + in the save_artifact method. Dict-shaped artifacts are commonly used + when artifacts are stored/retrieved from JSON or other serialization formats. + """ + artifact_service = artifact_service_factory(service_type) + # Create a dict-shaped artifact by serializing a real Part instance + part = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + dict_artifact = part.model_dump(exclude_none=True) + + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "dict_file.txt" + + # Save the dict-shaped artifact + version = await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=dict_artifact, + ) + assert version == 0 + + # Load and verify the artifact + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.mime_type == "text/plain" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.FILE, + ], +) +async def test_save_text_dict_shaped_artifact( + service_type, artifact_service_factory +): + """Tests saving and loading dict-shaped artifacts with text content.""" + artifact_service = artifact_service_factory(service_type) + # Create a dict-shaped artifact by serializing a real Part instance + part = types.Part(text="Hello, World!") + dict_artifact = part.model_dump(exclude_none=True) + + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "text_file.txt" + + # Save the dict-shaped artifact + await artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=dict_artifact, + ) + + # Load and verify the artifact + loaded = await artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded is not None + # GCS/File services may return text as inline_data bytes; accept either form. + if loaded.text is not None: + assert loaded.text == "Hello, World!" + else: + assert ( + loaded.inline_data is not None + and loaded.inline_data.data == b"Hello, World!" + ) \ No newline at end of file