Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/agentex/lib/core/clients/temporal/temporal_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from temporalio.client import Client, WorkflowExecutionStatus
from temporalio.common import RetryPolicy as TemporalRetryPolicy, WorkflowIDReusePolicy
from temporalio.service import RPCError, RPCStatusCode
from temporalio.converter import PayloadCodec

from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.model_utils import BaseModel
Expand Down Expand Up @@ -76,9 +77,12 @@


class TemporalClient:
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
def __init__(
self, temporal_client: Client | None = None, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None
):
self._client: Client | None = temporal_client
self._plugins = plugins
self._payload_codec = payload_codec

@property
def client(self) -> Client:
Expand All @@ -88,7 +92,7 @@ def client(self) -> Client:
return self._client

@classmethod
async def create(cls, temporal_address: str, plugins: list[Any] = []):
async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None):
if temporal_address in [
"false",
"False",
Expand All @@ -101,8 +105,8 @@ async def create(cls, temporal_address: str, plugins: list[Any] = []):
]:
_client = None
else:
_client = await get_temporal_client(temporal_address, plugins=plugins)
return cls(_client, plugins)
_client = await get_temporal_client(temporal_address, plugins=plugins, payload_codec=payload_codec)
return cls(_client, plugins, payload_codec)

async def setup(self, temporal_address: str):
self._client = await self._get_temporal_client(temporal_address=temporal_address)
Expand All @@ -120,7 +124,7 @@ async def _get_temporal_client(self, temporal_address: str) -> Client | None:
]:
return None
else:
return await get_temporal_client(temporal_address, plugins=self._plugins)
return await get_temporal_client(temporal_address, plugins=self._plugins, payload_codec=self._payload_codec)

async def start_workflow(
self,
Expand Down
28 changes: 22 additions & 6 deletions src/agentex/lib/core/clients/temporal/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import dataclasses
from typing import Any

from temporalio.client import Client, Plugin as ClientPlugin
from temporalio.worker import Interceptor
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
from temporalio.converter import PayloadCodec
from temporalio.contrib.pydantic import pydantic_data_converter

# class DateTimeJSONEncoder(AdvancedJSONEncoder):
Expand Down Expand Up @@ -79,14 +81,20 @@ def validate_worker_interceptors(interceptors: list[Any]) -> None:
)


async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client:
async def get_temporal_client(
temporal_address: str,
metrics_url: str | None = None,
plugins: list[Any] = [],
payload_codec: PayloadCodec | None = None,
) -> Client:
"""
Create a Temporal client with plugin integration.

Args:
temporal_address: Temporal server address
metrics_url: Optional metrics endpoint URL
plugins: List of Temporal plugins to include
payload_codec: Optional payload codec for encoding/decoding payloads (e.g. encryption, compression)

Returns:
Configured Temporal client
Expand All @@ -98,18 +106,26 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N
# Check if OpenAI plugin is present - it needs to configure its own data converter
# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
has_openai_plugin = any(
isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])
)

# Only set data_converter if OpenAI plugin is not present
has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))

if has_openai_plugin and payload_codec is not None:
raise ValueError(
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
"installs its own data converter and the codec would be silently ignored, "
"leaving payloads unencoded. Remove one or the other."
)

Comment thread
declan-scale marked this conversation as resolved.
connect_kwargs = {
"target_host": temporal_address,
"plugins": plugins,
}

if not has_openai_plugin:
connect_kwargs["data_converter"] = pydantic_data_converter
data_converter = pydantic_data_converter
if payload_codec:
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
connect_kwargs["data_converter"] = data_converter
Comment thread
greptile-apps[bot] marked this conversation as resolved.

if not metrics_url:
client = await Client.connect(**connect_kwargs)
Expand Down
32 changes: 26 additions & 6 deletions src/agentex/lib/core/temporal/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
from temporalio.converter import (
PayloadCodec,
DataConverter,
JSONTypeConverter,
AdvancedJSONEncoder,
Expand Down Expand Up @@ -89,16 +90,27 @@ def _validate_interceptors(interceptors: list) -> None:
)


async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list = []) -> Client:
async def get_temporal_client(
temporal_address: str,
metrics_url: str | None = None,
plugins: list = [],
payload_codec: PayloadCodec | None = None,
) -> Client:
if plugins != []: # We don't need to validate the plugins if they are empty
_validate_plugins(plugins)

# Check if OpenAI plugin is present - it needs to configure its own data converter
# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
has_openai_plugin = any(
isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])
)

has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))

if has_openai_plugin and payload_codec is not None:
raise ValueError(
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
"installs its own data converter and the codec would be silently ignored, "
"leaving payloads unencoded. Remove one or the other."
)

# Build connection kwargs
connect_kwargs = {
Expand All @@ -108,7 +120,10 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N

# Only set data_converter if OpenAI plugin is not present
if not has_openai_plugin:
connect_kwargs["data_converter"] = custom_data_converter
data_converter = custom_data_converter
if payload_codec:
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
connect_kwargs["data_converter"] = data_converter
Comment thread
greptile-apps[bot] marked this conversation as resolved.

if not metrics_url:
client = await Client.connect(**connect_kwargs)
Expand All @@ -129,17 +144,21 @@ def __init__(
plugins: list = [],
interceptors: list = [],
metrics_url: str | None = None,
payload_codec: PayloadCodec | None = None,
):
self.task_queue = task_queue
self.activity_handles = []
self.max_workers = max_workers
self.max_concurrent_activities = max_concurrent_activities
self.health_check_server_running = False
self.healthy = False
self.health_check_port = health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT
self.health_check_port = (
health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT
)
self.plugins = plugins
self.interceptors = interceptors
self.metrics_url = metrics_url
self.payload_codec = payload_codec

@overload
async def run(
Expand Down Expand Up @@ -175,6 +194,7 @@ async def run(
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
plugins=self.plugins,
metrics_url=self.metrics_url,
payload_codec=self.payload_codec,
)

# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
Expand Down
4 changes: 3 additions & 1 deletion src/agentex/lib/sdk/fastacp/fastacp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class FastACP:
Supports three main ACP types:
- "sync": Simple synchronous ACP implementation
- "async": Advanced ACP with sub-types "base" or "temporal" (requires config)
- "agentic": (Deprecated, use "async") Identical to "async"
- "agentic": (Deprecated, use "async") Identical to "async"
"""

@staticmethod
Expand Down Expand Up @@ -63,6 +63,8 @@ def create_async_acp(config: AsyncACPConfig, **kwargs) -> BaseACPServer:
temporal_config["plugins"] = config.plugins # type: ignore[attr-defined]
if hasattr(config, "interceptors"):
temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined]
if hasattr(config, "payload_codec"):
temporal_config["payload_codec"] = config.payload_codec # type: ignore[attr-defined]
return implementation_class.create(**temporal_config)
else:
return implementation_class.create(**kwargs)
Expand Down
17 changes: 14 additions & 3 deletions src/agentex/lib/sdk/fastacp/impl/temporal_acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import asynccontextmanager

from fastapi import FastAPI
from temporalio.converter import PayloadCodec

from agentex.lib.types.acp import (
SendEventParams,
Expand Down Expand Up @@ -31,20 +32,30 @@ def __init__(
temporal_task_service: TemporalTaskService | None = None,
plugins: list[Any] | None = None,
interceptors: list[Any] | None = None,
payload_codec: PayloadCodec | None = None,
):
super().__init__()
self._temporal_task_service = temporal_task_service
self._temporal_address = temporal_address
self._plugins = plugins or []
self._interceptors = interceptors or []
self._payload_codec = payload_codec

@classmethod
@override
def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None) -> "TemporalACP":
def create(
cls,
temporal_address: str,
plugins: list[Any] | None = None,
interceptors: list[Any] | None = None,
payload_codec: PayloadCodec | None = None,
) -> "TemporalACP":
logger.info("Initializing TemporalACP instance")

# Create instance without temporal client initially
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors)
temporal_acp = cls(
temporal_address=temporal_address, plugins=plugins, interceptors=interceptors, payload_codec=payload_codec
)
temporal_acp._setup_handlers()
logger.info("TemporalACP instance initialized now")
return temporal_acp
Expand All @@ -60,7 +71,7 @@ async def lifespan(app: FastAPI):
if self._temporal_task_service is None:
env_vars = EnvironmentVariables.refresh()
temporal_client = await TemporalClient.create(
temporal_address=self._temporal_address, plugins=self._plugins
temporal_address=self._temporal_address, plugins=self._plugins, payload_codec=self._payload_codec
)
self._temporal_task_service = TemporalTaskService(
temporal_client=temporal_client,
Expand Down
11 changes: 10 additions & 1 deletion src/agentex/lib/types/fastacp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class AsyncACPConfig(BaseACPConfig):

type: Literal["temporal", "base"] = Field(..., frozen=True)


AgenticACPConfig = AsyncACPConfig


class TemporalACPConfig(AsyncACPConfig):
"""
Configuration for TemporalACP implementation
Expand All @@ -50,12 +52,18 @@ class TemporalACPConfig(AsyncACPConfig):
temporal_address: The address of the temporal server
plugins: List of Temporal client plugins
interceptors: List of Temporal worker interceptors
payload_codec: Optional ``temporalio.converter.PayloadCodec`` for
encoding/decoding payloads (e.g. encryption, compression). NOTE:
this only configures the ACP (client) side. The worker side must
be configured separately via ``AgentexWorker(payload_codec=...)``
with the SAME codec, or decode will fail at runtime.
"""

type: Literal["temporal"] = Field(default="temporal", frozen=True)
temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True)
plugins: list[Any] = Field(default=[], frozen=True)
interceptors: list[Any] = Field(default=[], frozen=True)
payload_codec: Any = Field(default=None, frozen=True)

@field_validator("plugins")
@classmethod
Expand All @@ -81,4 +89,5 @@ class AsyncBaseACPConfig(AsyncACPConfig):

type: Literal["base"] = Field(default="base", frozen=True)

AgenticBaseACPConfig = AsyncBaseACPConfig

AgenticBaseACPConfig = AsyncBaseACPConfig
Loading
Loading