diff --git a/src/agentex/lib/core/clients/temporal/temporal_client.py b/src/agentex/lib/core/clients/temporal/temporal_client.py index 76b419b2b..f44648da2 100644 --- a/src/agentex/lib/core/clients/temporal/temporal_client.py +++ b/src/agentex/lib/core/clients/temporal/temporal_client.py @@ -7,6 +7,7 @@ from temporalio.client import Client, WorkflowExecutionStatus from temporalio.common import RetryPolicy as TemporalRetryPolicy, WorkflowIDReusePolicy from temporalio.service import RPCError, RPCStatusCode +from temporalio.converter import PayloadCodec from agentex.lib.utils.logging import make_logger from agentex.lib.utils.model_utils import BaseModel @@ -76,9 +77,12 @@ class TemporalClient: - def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []): + def __init__( + self, temporal_client: Client | None = None, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None + ): self._client: Client | None = temporal_client self._plugins = plugins + self._payload_codec = payload_codec @property def client(self) -> Client: @@ -88,7 +92,7 @@ def client(self) -> Client: return self._client @classmethod - async def create(cls, temporal_address: str, plugins: list[Any] = []): + async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None): if temporal_address in [ "false", "False", @@ -101,8 +105,8 @@ async def create(cls, temporal_address: str, plugins: list[Any] = []): ]: _client = None else: - _client = await get_temporal_client(temporal_address, plugins=plugins) - return cls(_client, plugins) + _client = await get_temporal_client(temporal_address, plugins=plugins, payload_codec=payload_codec) + return cls(_client, plugins, payload_codec) async def setup(self, temporal_address: str): self._client = await self._get_temporal_client(temporal_address=temporal_address) @@ -120,7 +124,7 @@ async def _get_temporal_client(self, temporal_address: str) -> Client | None: ]: return None else: - return await get_temporal_client(temporal_address, plugins=self._plugins) + return await get_temporal_client(temporal_address, plugins=self._plugins, payload_codec=self._payload_codec) async def start_workflow( self, diff --git a/src/agentex/lib/core/clients/temporal/utils.py b/src/agentex/lib/core/clients/temporal/utils.py index 991e7cd1c..8c2241c62 100644 --- a/src/agentex/lib/core/clients/temporal/utils.py +++ b/src/agentex/lib/core/clients/temporal/utils.py @@ -1,10 +1,12 @@ from __future__ import annotations +import dataclasses from typing import Any from temporalio.client import Client, Plugin as ClientPlugin from temporalio.worker import Interceptor from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig +from temporalio.converter import PayloadCodec from temporalio.contrib.pydantic import pydantic_data_converter # class DateTimeJSONEncoder(AdvancedJSONEncoder): @@ -79,7 +81,12 @@ def validate_worker_interceptors(interceptors: list[Any]) -> None: ) -async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client: +async def get_temporal_client( + temporal_address: str, + metrics_url: str | None = None, + plugins: list[Any] = [], + payload_codec: PayloadCodec | None = None, +) -> Client: """ Create a Temporal client with plugin integration. @@ -87,6 +94,7 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N temporal_address: Temporal server address metrics_url: Optional metrics endpoint URL plugins: List of Temporal plugins to include + payload_codec: Optional payload codec for encoding/decoding payloads (e.g. encryption, compression) Returns: Configured Temporal client @@ -98,18 +106,26 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N # Check if OpenAI plugin is present - it needs to configure its own data converter # Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents from temporalio.contrib.openai_agents import OpenAIAgentsPlugin - has_openai_plugin = any( - isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []) - ) - # Only set data_converter if OpenAI plugin is not present + has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])) + + if has_openai_plugin and payload_codec is not None: + raise ValueError( + "payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin " + "installs its own data converter and the codec would be silently ignored, " + "leaving payloads unencoded. Remove one or the other." + ) + connect_kwargs = { "target_host": temporal_address, "plugins": plugins, } if not has_openai_plugin: - connect_kwargs["data_converter"] = pydantic_data_converter + data_converter = pydantic_data_converter + if payload_codec: + data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec) + connect_kwargs["data_converter"] = data_converter if not metrics_url: client = await Client.connect(**connect_kwargs) diff --git a/src/agentex/lib/core/temporal/workers/worker.py b/src/agentex/lib/core/temporal/workers/worker.py index eb284a5a2..2e8591242 100644 --- a/src/agentex/lib/core/temporal/workers/worker.py +++ b/src/agentex/lib/core/temporal/workers/worker.py @@ -18,6 +18,7 @@ ) from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig from temporalio.converter import ( + PayloadCodec, DataConverter, JSONTypeConverter, AdvancedJSONEncoder, @@ -89,16 +90,27 @@ def _validate_interceptors(interceptors: list) -> None: ) -async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list = []) -> Client: +async def get_temporal_client( + temporal_address: str, + metrics_url: str | None = None, + plugins: list = [], + payload_codec: PayloadCodec | None = None, +) -> Client: if plugins != []: # We don't need to validate the plugins if they are empty _validate_plugins(plugins) # Check if OpenAI plugin is present - it needs to configure its own data converter # Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents from temporalio.contrib.openai_agents import OpenAIAgentsPlugin - has_openai_plugin = any( - isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []) - ) + + has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])) + + if has_openai_plugin and payload_codec is not None: + raise ValueError( + "payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin " + "installs its own data converter and the codec would be silently ignored, " + "leaving payloads unencoded. Remove one or the other." + ) # Build connection kwargs connect_kwargs = { @@ -108,7 +120,10 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N # Only set data_converter if OpenAI plugin is not present if not has_openai_plugin: - connect_kwargs["data_converter"] = custom_data_converter + data_converter = custom_data_converter + if payload_codec: + data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec) + connect_kwargs["data_converter"] = data_converter if not metrics_url: client = await Client.connect(**connect_kwargs) @@ -129,6 +144,7 @@ def __init__( plugins: list = [], interceptors: list = [], metrics_url: str | None = None, + payload_codec: PayloadCodec | None = None, ): self.task_queue = task_queue self.activity_handles = [] @@ -136,10 +152,13 @@ def __init__( self.max_concurrent_activities = max_concurrent_activities self.health_check_server_running = False self.healthy = False - self.health_check_port = health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT + self.health_check_port = ( + health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT + ) self.plugins = plugins self.interceptors = interceptors self.metrics_url = metrics_url + self.payload_codec = payload_codec @overload async def run( @@ -175,6 +194,7 @@ async def run( temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"), plugins=self.plugins, metrics_url=self.metrics_url, + payload_codec=self.payload_codec, ) # Enable debug mode if AgentEx debug is enabled (disables deadlock detection) diff --git a/src/agentex/lib/sdk/fastacp/fastacp.py b/src/agentex/lib/sdk/fastacp/fastacp.py index 9e3ae78ec..fbd4f0511 100644 --- a/src/agentex/lib/sdk/fastacp/fastacp.py +++ b/src/agentex/lib/sdk/fastacp/fastacp.py @@ -34,7 +34,7 @@ class FastACP: Supports three main ACP types: - "sync": Simple synchronous ACP implementation - "async": Advanced ACP with sub-types "base" or "temporal" (requires config) - - "agentic": (Deprecated, use "async") Identical to "async" + - "agentic": (Deprecated, use "async") Identical to "async" """ @staticmethod @@ -63,6 +63,8 @@ def create_async_acp(config: AsyncACPConfig, **kwargs) -> BaseACPServer: temporal_config["plugins"] = config.plugins # type: ignore[attr-defined] if hasattr(config, "interceptors"): temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined] + if hasattr(config, "payload_codec"): + temporal_config["payload_codec"] = config.payload_codec # type: ignore[attr-defined] return implementation_class.create(**temporal_config) else: return implementation_class.create(**kwargs) diff --git a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py index 750707c49..f64e16d72 100644 --- a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py +++ b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py @@ -4,6 +4,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from temporalio.converter import PayloadCodec from agentex.lib.types.acp import ( SendEventParams, @@ -31,20 +32,30 @@ def __init__( temporal_task_service: TemporalTaskService | None = None, plugins: list[Any] | None = None, interceptors: list[Any] | None = None, + payload_codec: PayloadCodec | None = None, ): super().__init__() self._temporal_task_service = temporal_task_service self._temporal_address = temporal_address self._plugins = plugins or [] self._interceptors = interceptors or [] + self._payload_codec = payload_codec @classmethod @override - def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None) -> "TemporalACP": + def create( + cls, + temporal_address: str, + plugins: list[Any] | None = None, + interceptors: list[Any] | None = None, + payload_codec: PayloadCodec | None = None, + ) -> "TemporalACP": logger.info("Initializing TemporalACP instance") # Create instance without temporal client initially - temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors) + temporal_acp = cls( + temporal_address=temporal_address, plugins=plugins, interceptors=interceptors, payload_codec=payload_codec + ) temporal_acp._setup_handlers() logger.info("TemporalACP instance initialized now") return temporal_acp @@ -60,7 +71,7 @@ async def lifespan(app: FastAPI): if self._temporal_task_service is None: env_vars = EnvironmentVariables.refresh() temporal_client = await TemporalClient.create( - temporal_address=self._temporal_address, plugins=self._plugins + temporal_address=self._temporal_address, plugins=self._plugins, payload_codec=self._payload_codec ) self._temporal_task_service = TemporalTaskService( temporal_client=temporal_client, diff --git a/src/agentex/lib/types/fastacp.py b/src/agentex/lib/types/fastacp.py index c589a0c99..e11091e93 100644 --- a/src/agentex/lib/types/fastacp.py +++ b/src/agentex/lib/types/fastacp.py @@ -39,8 +39,10 @@ class AsyncACPConfig(BaseACPConfig): type: Literal["temporal", "base"] = Field(..., frozen=True) + AgenticACPConfig = AsyncACPConfig + class TemporalACPConfig(AsyncACPConfig): """ Configuration for TemporalACP implementation @@ -50,12 +52,18 @@ class TemporalACPConfig(AsyncACPConfig): temporal_address: The address of the temporal server plugins: List of Temporal client plugins interceptors: List of Temporal worker interceptors + payload_codec: Optional ``temporalio.converter.PayloadCodec`` for + encoding/decoding payloads (e.g. encryption, compression). NOTE: + this only configures the ACP (client) side. The worker side must + be configured separately via ``AgentexWorker(payload_codec=...)`` + with the SAME codec, or decode will fail at runtime. """ type: Literal["temporal"] = Field(default="temporal", frozen=True) temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True) plugins: list[Any] = Field(default=[], frozen=True) interceptors: list[Any] = Field(default=[], frozen=True) + payload_codec: Any = Field(default=None, frozen=True) @field_validator("plugins") @classmethod @@ -81,4 +89,5 @@ class AsyncBaseACPConfig(AsyncACPConfig): type: Literal["base"] = Field(default="base", frozen=True) -AgenticBaseACPConfig = AsyncBaseACPConfig \ No newline at end of file + +AgenticBaseACPConfig = AsyncBaseACPConfig diff --git a/tests/lib/test_payload_codec.py b/tests/lib/test_payload_codec.py new file mode 100644 index 000000000..bb2b24228 --- /dev/null +++ b/tests/lib/test_payload_codec.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from typing import Any, override +from unittest.mock import AsyncMock, patch + +import pytest +from temporalio.client import Client, Plugin as ClientPlugin +from temporalio.converter import PayloadCodec +from temporalio.contrib.pydantic import pydantic_data_converter + + +class _NoopCodec(PayloadCodec): + @override + async def encode(self, payloads): + return list(payloads) + + @override + async def decode(self, payloads): + return list(payloads) + + +class _FakeOpenAIPlugin(ClientPlugin): + @override + def configure_client(self, config): + return config + + @override + async def connect_service_client(self, config, next): + return await next(config) + + +def _mock_connect(): + return patch.object(Client, "connect", new=AsyncMock(return_value=object())) + + +def _patch_openai_plugin(): + return patch("temporalio.contrib.openai_agents.OpenAIAgentsPlugin", _FakeOpenAIPlugin) + + +class TestTemporalClient: + def test_init_stores_payload_codec(self): + from agentex.lib.core.clients.temporal.temporal_client import TemporalClient + + codec = _NoopCodec() + client = TemporalClient(payload_codec=codec) + assert client._payload_codec is codec + + def test_init_default_payload_codec_is_none(self): + from agentex.lib.core.clients.temporal.temporal_client import TemporalClient + + assert TemporalClient()._payload_codec is None + + async def test_create_with_disabled_address_stores_codec(self): + from agentex.lib.core.clients.temporal.temporal_client import TemporalClient + + codec = _NoopCodec() + client = await TemporalClient.create(temporal_address="false", payload_codec=codec) + assert client._client is None + assert client._payload_codec is codec + + async def test_create_propagates_codec_to_get_temporal_client(self): + import agentex.lib.core.clients.temporal.temporal_client as module + + codec = _NoopCodec() + with patch.object(module, "get_temporal_client", new=AsyncMock(return_value=object())) as mock_get: + await module.TemporalClient.create(temporal_address="localhost:7233", plugins=[], payload_codec=codec) + + mock_get.assert_awaited_once() + assert mock_get.await_args.kwargs["payload_codec"] is codec + + +class TestGetTemporalClientUtils: + async def test_no_codec_uses_pydantic_data_converter_unchanged(self): + from agentex.lib.core.clients.temporal.utils import get_temporal_client + + with _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233") + + kwargs = mock_connect.await_args.kwargs + assert kwargs["data_converter"] is pydantic_data_converter + assert kwargs["data_converter"].payload_codec is None + + async def test_codec_is_attached_to_pydantic_data_converter(self): + from agentex.lib.core.clients.temporal.utils import get_temporal_client + + codec = _NoopCodec() + with _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233", payload_codec=codec) + + data_converter = mock_connect.await_args.kwargs["data_converter"] + assert data_converter.payload_codec is codec + assert data_converter.payload_converter_class is pydantic_data_converter.payload_converter_class + + async def test_codec_with_openai_plugin_raises(self): + from agentex.lib.core.clients.temporal.utils import get_temporal_client + + codec = _NoopCodec() + with _patch_openai_plugin(), _mock_connect() as mock_connect: + with pytest.raises(ValueError, match="payload_codec is not supported alongside OpenAIAgentsPlugin"): + await get_temporal_client( + temporal_address="localhost:7233", + plugins=[_FakeOpenAIPlugin()], + payload_codec=codec, + ) + mock_connect.assert_not_awaited() + + async def test_openai_plugin_without_codec_omits_data_converter(self): + from agentex.lib.core.clients.temporal.utils import get_temporal_client + + with _patch_openai_plugin(), _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233", plugins=[_FakeOpenAIPlugin()]) + + assert "data_converter" not in mock_connect.await_args.kwargs + + +class TestGetTemporalClientWorker: + async def test_no_codec_uses_custom_data_converter_unchanged(self): + from agentex.lib.core.temporal.workers.worker import get_temporal_client, custom_data_converter + + with _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233") + + kwargs = mock_connect.await_args.kwargs + assert kwargs["data_converter"] is custom_data_converter + assert kwargs["data_converter"].payload_codec is None + + async def test_codec_is_attached_to_custom_data_converter(self): + from agentex.lib.core.temporal.workers.worker import get_temporal_client, custom_data_converter + + codec = _NoopCodec() + with _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233", payload_codec=codec) + + data_converter = mock_connect.await_args.kwargs["data_converter"] + assert data_converter.payload_codec is codec + assert data_converter.payload_converter_class is custom_data_converter.payload_converter_class + + async def test_codec_with_openai_plugin_raises(self): + from agentex.lib.core.temporal.workers.worker import get_temporal_client + + codec = _NoopCodec() + with _patch_openai_plugin(), _mock_connect() as mock_connect: + with pytest.raises(ValueError, match="payload_codec is not supported alongside OpenAIAgentsPlugin"): + await get_temporal_client( + temporal_address="localhost:7233", + plugins=[_FakeOpenAIPlugin()], + payload_codec=codec, + ) + mock_connect.assert_not_awaited() + + async def test_openai_plugin_without_codec_omits_data_converter(self): + from agentex.lib.core.temporal.workers.worker import get_temporal_client + + with _patch_openai_plugin(), _mock_connect() as mock_connect: + await get_temporal_client(temporal_address="localhost:7233", plugins=[_FakeOpenAIPlugin()]) + + assert "data_converter" not in mock_connect.await_args.kwargs + + +class TestAgentexWorkerCodec: + def test_worker_stores_payload_codec(self): + from agentex.lib.core.temporal.workers.worker import AgentexWorker + + codec = _NoopCodec() + worker = AgentexWorker(task_queue="test-queue", health_check_port=80, payload_codec=codec) + assert worker.payload_codec is codec + + def test_worker_default_payload_codec_is_none(self): + from agentex.lib.core.temporal.workers.worker import AgentexWorker + + worker = AgentexWorker(task_queue="test-queue", health_check_port=80) + assert worker.payload_codec is None + + +class TestTemporalACPCodec: + def test_create_stores_payload_codec(self): + from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP + + codec = _NoopCodec() + acp = TemporalACP.create(temporal_address="localhost:7233", payload_codec=codec) + assert acp._payload_codec is codec + + def test_create_default_payload_codec_is_none(self): + from agentex.lib.sdk.fastacp.impl.temporal_acp import TemporalACP + + acp = TemporalACP.create(temporal_address="localhost:7233") + assert acp._payload_codec is None + + +class TestFastACPConfigCodec: + def test_config_default_codec_is_none(self): + from agentex.lib.types.fastacp import TemporalACPConfig + + assert TemporalACPConfig().payload_codec is None + + def test_config_accepts_codec(self): + from agentex.lib.types.fastacp import TemporalACPConfig + + codec = _NoopCodec() + assert TemporalACPConfig(payload_codec=codec).payload_codec is codec + + def test_fastacp_forwards_codec_from_config(self): + from agentex.lib.types.fastacp import TemporalACPConfig + from agentex.lib.sdk.fastacp.fastacp import FastACP + + codec = _NoopCodec() + config = TemporalACPConfig(payload_codec=codec) + captured: dict[str, Any] = {} + + def fake_create(**kwargs): + captured.update(kwargs) + return object() + + with patch( + "agentex.lib.sdk.fastacp.impl.temporal_acp.TemporalACP.create", + side_effect=fake_create, + ): + FastACP.create("async", config=config) + + assert captured.get("payload_codec") is codec