Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ async def test_send_event_and_stream(self, client: AsyncAgentex, agent_id: str):
task = task_response.result
assert task is not None

await asyncio.sleep(1) # wait for state to be initialized
# Check initial state
states = await client.states.list(agent_id=agent_id, task_id=task.id)
assert len(states) == 1
Expand Down
16 changes: 14 additions & 2 deletions src/agentex/lib/adk/_modules/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ruff: noqa: I001
# Import order matters - AsyncTracer must come after client import to avoid circular imports
from __future__ import annotations
from datetime import timedelta
from datetime import datetime, timedelta

from temporalio.common import RetryPolicy

Expand All @@ -22,7 +22,7 @@
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.types.task_message import TaskMessage, TaskMessageContent
from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.temporal import in_temporal_workflow
from agentex.lib.utils.temporal import in_temporal_workflow, workflow_now_if_in_workflow

logger = make_logger(__name__)

Expand Down Expand Up @@ -66,6 +66,7 @@ async def create(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
created_at: datetime | None = None,
) -> TaskMessage:
"""
Create a new message for a task.
Expand All @@ -82,12 +83,17 @@ async def create(
Returns:
TaskMessageEntity: The created message.
"""
# Default created_at to workflow.now() so two awaited adk.messages.create
# calls from the same workflow are guaranteed monotonic at the server.
if created_at is None:
created_at = workflow_now_if_in_workflow()
params = CreateMessageParams(
trace_id=trace_id,
parent_span_id=parent_span_id,
task_id=task_id,
content=content,
emit_updates=emit_updates,
created_at=created_at,
)
if in_temporal_workflow():
return await ActivityHelpers.execute_activity(
Expand All @@ -103,6 +109,7 @@ async def create(
task_id=task_id,
content=content,
emit_updates=emit_updates,
created_at=created_at,
)

async def update(
Expand Down Expand Up @@ -163,6 +170,7 @@ async def create_batch(
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
created_at: datetime | None = None,
) -> list[TaskMessage]:
"""
Create a batch of messages for a task.
Expand All @@ -177,12 +185,15 @@ async def create_batch(
Returns:
List[TaskMessageEntity]: The created messages.
"""
if created_at is None:
created_at = workflow_now_if_in_workflow()
params = CreateMessagesBatchParams(
task_id=task_id,
contents=contents,
emit_updates=emit_updates,
trace_id=trace_id,
parent_span_id=parent_span_id,
created_at=created_at,
)
if in_temporal_workflow():
return await ActivityHelpers.execute_activity(
Expand All @@ -198,6 +209,7 @@ async def create_batch(
task_id=task_id,
contents=contents,
emit_updates=emit_updates,
created_at=created_at,
)

async def update_batch(
Expand Down
3 changes: 3 additions & 0 deletions src/agentex/lib/adk/_modules/streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: I001
# Import order matters - AsyncTracer must come after client import to avoid circular imports
from __future__ import annotations
from datetime import datetime
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex # noqa: F401
Expand Down Expand Up @@ -52,6 +53,7 @@ def streaming_task_message_context(
task_id: str,
initial_content: TaskMessageContent,
streaming_mode: StreamingMode = "coalesced",
created_at: datetime | None = None,
) -> StreamingTaskMessageContext:
"""
Create a streaming context for managing TaskMessage lifecycle.
Expand Down Expand Up @@ -83,4 +85,5 @@ def streaming_task_message_context(
task_id=task_id,
initial_content=initial_content,
streaming_mode=streaming_mode,
created_at=created_at,
)
8 changes: 4 additions & 4 deletions src/agentex/lib/adk/providers/_modules/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from temporalio.common import RetryPolicy

from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.temporal import in_temporal_workflow
from agentex.lib.utils.temporal import in_temporal_workflow, workflow_now_if_in_workflow
from agentex.types.task_message import TaskMessage
from agentex.lib.types.llm_messages import LLMConfig, Completion
from agentex.lib.core.tracing.tracer import AsyncTracer
Expand Down Expand Up @@ -88,9 +88,7 @@ async def chat_completion(
Completion: An OpenAI compatible Completion object
"""
if in_temporal_workflow():
params = ChatCompletionParams(
trace_id=trace_id, parent_span_id=parent_span_id, llm_config=llm_config
)
params = ChatCompletionParams(trace_id=trace_id, parent_span_id=parent_span_id, llm_config=llm_config)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION,
request=params,
Expand Down Expand Up @@ -138,6 +136,7 @@ async def chat_completion_auto_send(
parent_span_id=parent_span_id,
task_id=task_id,
llm_config=llm_config,
created_at=workflow_now_if_in_workflow(),
)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION_AUTO_SEND,
Expand Down Expand Up @@ -222,6 +221,7 @@ async def chat_completion_stream_auto_send(
parent_span_id=parent_span_id,
task_id=task_id,
llm_config=llm_config,
created_at=workflow_now_if_in_workflow(),
)
return await ActivityHelpers.execute_activity(
activity_name=LiteLLMActivityName.CHAT_COMPLETION_STREAM_AUTO_SEND,
Expand Down
6 changes: 4 additions & 2 deletions src/agentex/lib/adk/providers/_modules/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import deprecated

from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.temporal import in_temporal_workflow
from agentex.lib.utils.temporal import in_temporal_workflow, workflow_now_if_in_workflow
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.lib.types.agent_results import (
SerializableRunResult,
Expand Down Expand Up @@ -265,6 +265,7 @@ async def run_agent_auto_send(
output_guardrails=output_guardrails, # type: ignore[arg-type]
max_turns=max_turns,
previous_response_id=previous_response_id,
created_at=workflow_now_if_in_workflow(),
)
return await ActivityHelpers.execute_activity(
activity_name=OpenAIActivityName.RUN_AGENT_AUTO_SEND,
Expand Down Expand Up @@ -479,6 +480,7 @@ async def run_agent_streamed_auto_send(
input_guardrails=input_guardrails,
output_guardrails=output_guardrails,
max_turns=max_turns,
created_at=workflow_now_if_in_workflow(),
)
return await ActivityHelpers.execute_activity(
activity_name=OpenAIActivityName.RUN_AGENT_STREAMED_AUTO_SEND,
Expand Down Expand Up @@ -509,4 +511,4 @@ async def run_agent_streamed_auto_send(
output_guardrails=output_guardrails,
max_turns=max_turns,
previous_response_id=previous_response_id,
)
)
11 changes: 7 additions & 4 deletions src/agentex/lib/core/services/adk/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import asyncio
from typing import Any, Coroutine
from datetime import datetime

from agentex import AsyncAgentex
from agentex._types import omit
from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.temporal import heartbeat_if_in_workflow
from agentex.types.task_message import TaskMessage, TaskMessageContent
Expand Down Expand Up @@ -32,6 +34,7 @@ async def create_message(
emit_updates: bool = True,
trace_id: str | None = None,
parent_span_id: str | None = None,
created_at: datetime | None = None,
) -> TaskMessage:
trace = self._tracer.trace(trace_id)
async with trace.span(
Expand All @@ -43,6 +46,7 @@ async def create_message(
task_message = await self._agentex_client.messages.create(
task_id=task_id,
content=content.model_dump(),
created_at=created_at if created_at is not None else omit,
)
if emit_updates:
await self._emit_updates([task_message])
Expand Down Expand Up @@ -85,6 +89,7 @@ async def create_messages_batch(
emit_updates: bool = True,
trace_id: str | None = None,
parent_span_id: str | None = None,
created_at: datetime | None = None,
) -> list[TaskMessage]:
trace = self._tracer.trace(trace_id)
async with trace.span(
Expand All @@ -96,6 +101,7 @@ async def create_messages_batch(
task_messages = await self._agentex_client.messages.batch.create(
task_id=task_id,
contents=[content.model_dump() for content in contents],
created_at=created_at if created_at is not None else omit,
)
if emit_updates:
await self._emit_updates(task_messages)
Expand All @@ -119,10 +125,7 @@ async def update_messages_batch(
heartbeat_if_in_workflow("update messages batch")
task_messages = await self._agentex_client.messages.batch.update(
task_id=task_id,
updates={
message_id: content.model_dump()
for message_id, content in updates.items()
},
updates={message_id: content.model_dump() for message_id, content in updates.items()},
)
if span:
span.output = [task_message.model_dump() for task_message in task_messages]
Expand Down
31 changes: 10 additions & 21 deletions src/agentex/lib/core/services/adk/providers/litellm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from collections.abc import AsyncGenerator

from agentex import AsyncAgentex
Expand Down Expand Up @@ -63,6 +64,7 @@ async def chat_completion_auto_send(
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
created_at: datetime | None = None,
) -> TaskMessage | None:
"""
Chat completion with automatic TaskMessage creation. This does not stream the completion. To stream use chat_completion_stream_auto_send.
Expand Down Expand Up @@ -98,13 +100,10 @@ async def chat_completion_auto_send(
content="",
format="markdown",
),
created_at=created_at,
) as streaming_context:
completion = await self.llm_gateway.acompletion(**llm_config.model_dump())
if (
completion.choices
and len(completion.choices) > 0
and completion.choices[0].message
):
if completion.choices and len(completion.choices) > 0 and completion.choices[0].message:
final_content = TextContent(
author="agent",
content=completion.choices[0].message.content or "",
Expand Down Expand Up @@ -159,9 +158,7 @@ async def chat_completion_stream(
) as span:
# Direct streaming outside temporal - yield each chunk as it comes
chunks: list[Completion] = []
async for chunk in self.llm_gateway.acompletion_stream(
**llm_config.model_dump()
):
async for chunk in self.llm_gateway.acompletion_stream(**llm_config.model_dump()):
chunks.append(chunk)
yield chunk
if span:
Expand All @@ -173,6 +170,7 @@ async def chat_completion_stream_auto_send(
llm_config: LLMConfig,
trace_id: str | None = None,
parent_span_id: str | None = None,
created_at: datetime | None = None,
) -> TaskMessage | None:
"""
Stream chat completion with automatic TaskMessage creation and streaming.
Expand Down Expand Up @@ -206,18 +204,13 @@ async def chat_completion_stream_auto_send(
content="",
format="markdown",
),
created_at=created_at,
) as streaming_context:
# Get the streaming response
chunks = []
async for response in self.llm_gateway.acompletion_stream(
**llm_config.model_dump()
):
async for response in self.llm_gateway.acompletion_stream(**llm_config.model_dump()):
heartbeat_if_in_workflow("chat completion streaming")
if (
response.choices
and len(response.choices) > 0
and response.choices[0].delta
):
if response.choices and len(response.choices) > 0 and response.choices[0].delta:
delta = response.choices[0].delta.content
if delta:
# Stream the chunk via the context manager
Expand All @@ -235,11 +228,7 @@ async def chat_completion_stream_auto_send(

# Update the final message content
complete_message = concat_completion_chunks(chunks)
if (
complete_message
and complete_message.choices
and complete_message.choices[0].message
):
if complete_message and complete_message.choices and complete_message.choices[0].message:
final_content = TextContent(
author="agent",
content=complete_message.choices[0].message.content or "",
Expand Down
Loading
Loading