11from __future__ import annotations
22
3+ import asyncio
4+ import contextlib
35import json
4- from typing import Literal
6+ from typing import Awaitable , Callable , Literal
57
68from agentex import AsyncAgentex
79from agentex .lib .utils .logging import make_logger
@@ -39,6 +41,187 @@ def _get_stream_topic(task_id: str) -> str:
3941 return f"task:{ task_id } "
4042
4143
44+ StreamingMode = Literal ["off" , "per_token" , "coalesced" ]
45+ """Controls how a StreamingTaskMessageContext publishes deltas.
46+
47+ - "off": Feed the accumulator (so the persisted message body is correct)
48+ but never publish per-delta events. Consumers see start + done
49+ only. Lowest latency.
50+ - "per_token": Publish every delta immediately. Highest UX fidelity for
51+ token-by-token rendering, highest Redis cost, and re-introduces
52+ head-of-line blocking on the producer's event loop.
53+ - "coalesced": Buffer deltas in a small time/size window and publish them as
54+ merged batches. The first delta flushes immediately for fast
55+ perceived responsiveness; subsequent deltas flush every 50ms or
56+ whenever 128 buffered chars accumulate, whichever comes first.
57+ Order within each (delta type, index) channel is preserved
58+ exactly; only granularity changes.
59+ """
60+
61+
62+ def _delta_char_len (delta : TaskMessageDelta | None ) -> int :
63+ if delta is None :
64+ return 0
65+ if isinstance (delta , TextDelta ):
66+ return len (delta .text_delta or "" )
67+ if isinstance (delta , DataDelta ):
68+ return len (delta .data_delta or "" )
69+ if isinstance (delta , ReasoningSummaryDelta ):
70+ return len (delta .summary_delta or "" )
71+ if isinstance (delta , ReasoningContentDelta ):
72+ return len (delta .content_delta or "" )
73+ if isinstance (delta , ToolRequestDelta ):
74+ return len (delta .arguments_delta or "" )
75+ if isinstance (delta , ToolResponseDelta ):
76+ return len (delta .content_delta or "" )
77+ return 0
78+
79+
80+ def _can_merge (a : TaskMessageDelta , b : TaskMessageDelta ) -> bool :
81+ if type (a ) is not type (b ):
82+ return False
83+ if isinstance (a , ReasoningSummaryDelta ) and isinstance (b , ReasoningSummaryDelta ):
84+ return a .summary_index == b .summary_index
85+ if isinstance (a , ReasoningContentDelta ) and isinstance (b , ReasoningContentDelta ):
86+ return a .content_index == b .content_index
87+ if isinstance (a , ToolRequestDelta ) and isinstance (b , ToolRequestDelta ):
88+ return a .tool_call_id == b .tool_call_id
89+ if isinstance (a , ToolResponseDelta ) and isinstance (b , ToolResponseDelta ):
90+ return a .tool_call_id == b .tool_call_id
91+ return True
92+
93+
94+ def _merge_pair (a : TaskMessageDelta , b : TaskMessageDelta ) -> TaskMessageDelta :
95+ if isinstance (a , TextDelta ) and isinstance (b , TextDelta ):
96+ return TextDelta (type = "text" , text_delta = (a .text_delta or "" ) + (b .text_delta or "" ))
97+ if isinstance (a , DataDelta ) and isinstance (b , DataDelta ):
98+ return DataDelta (type = "data" , data_delta = (a .data_delta or "" ) + (b .data_delta or "" ))
99+ if isinstance (a , ReasoningSummaryDelta ) and isinstance (b , ReasoningSummaryDelta ):
100+ return ReasoningSummaryDelta (
101+ type = "reasoning_summary" ,
102+ summary_index = a .summary_index ,
103+ summary_delta = (a .summary_delta or "" ) + (b .summary_delta or "" ),
104+ )
105+ if isinstance (a , ReasoningContentDelta ) and isinstance (b , ReasoningContentDelta ):
106+ return ReasoningContentDelta (
107+ type = "reasoning_content" ,
108+ content_index = a .content_index ,
109+ content_delta = (a .content_delta or "" ) + (b .content_delta or "" ),
110+ )
111+ if isinstance (a , ToolRequestDelta ) and isinstance (b , ToolRequestDelta ):
112+ return ToolRequestDelta (
113+ type = "tool_request" ,
114+ tool_call_id = a .tool_call_id ,
115+ name = a .name ,
116+ arguments_delta = (a .arguments_delta or "" ) + (b .arguments_delta or "" ),
117+ )
118+ if isinstance (a , ToolResponseDelta ) and isinstance (b , ToolResponseDelta ):
119+ return ToolResponseDelta (
120+ type = "tool_response" ,
121+ tool_call_id = a .tool_call_id ,
122+ name = a .name ,
123+ content_delta = (a .content_delta or "" ) + (b .content_delta or "" ),
124+ )
125+ return b
126+
127+
128+ def _merge_consecutive (updates : list [StreamTaskMessageDelta ]) -> list [StreamTaskMessageDelta ]:
129+ """Merge consecutive same-channel deltas. Order across channels is preserved exactly."""
130+ result : list [StreamTaskMessageDelta ] = []
131+ for u in updates :
132+ if u .delta is None or not result :
133+ result .append (u )
134+ continue
135+ last = result [- 1 ]
136+ if last .delta is not None and _can_merge (last .delta , u .delta ):
137+ result [- 1 ] = StreamTaskMessageDelta (
138+ parent_task_message = last .parent_task_message ,
139+ delta = _merge_pair (last .delta , u .delta ),
140+ type = "delta" ,
141+ )
142+ else :
143+ result .append (u )
144+ return result
145+
146+
147+ class CoalescingBuffer :
148+ """Time-and-size-windowed buffer that merges consecutive same-channel deltas.
149+
150+ Decouples the producer (model event loop) from the publisher (Redis): ``add``
151+ only enqueues and may signal an early flush; the actual publish always runs
152+ on a background ticker, so the producer never awaits on a Redis round-trip.
153+ """
154+
155+ FLUSH_INTERVAL_S = 0.050
156+ MAX_BUFFERED_CHARS = 128
157+
158+ def __init__ (self , on_flush : Callable [[StreamTaskMessageDelta ], Awaitable [object ]]):
159+ self ._on_flush = on_flush
160+ self ._buf : list [StreamTaskMessageDelta ] = []
161+ self ._buf_chars = 0
162+ self ._first_flushed = False
163+ self ._closed = False
164+ self ._lock = asyncio .Lock ()
165+ self ._flush_signal = asyncio .Event ()
166+ self ._task : asyncio .Task [None ] | None = None
167+
168+ def start (self ) -> None :
169+ if self ._task is None :
170+ self ._task = asyncio .create_task (self ._run (), name = "coalescing-buffer" )
171+
172+ async def add (self , update : StreamTaskMessageDelta ) -> None :
173+ if self ._closed :
174+ return
175+ async with self ._lock :
176+ self ._buf .append (update )
177+ self ._buf_chars += _delta_char_len (update .delta )
178+ if not self ._first_flushed or self ._buf_chars >= self .MAX_BUFFERED_CHARS :
179+ self ._first_flushed = True
180+ self ._flush_signal .set ()
181+
182+ async def _run (self ) -> None :
183+ try :
184+ while not self ._closed :
185+ try :
186+ await asyncio .wait_for (self ._flush_signal .wait (), timeout = self .FLUSH_INTERVAL_S )
187+ except asyncio .TimeoutError :
188+ pass
189+ async with self ._lock :
190+ self ._flush_signal .clear ()
191+ drained = self ._drain_locked ()
192+ for u in drained :
193+ try :
194+ await self ._on_flush (u )
195+ except Exception as e :
196+ logger .exception (f"CoalescingBuffer flush failed: { e } " )
197+ except asyncio .CancelledError :
198+ pass
199+
200+ async def close (self ) -> None :
201+ self ._closed = True
202+ if self ._task is not None :
203+ self ._flush_signal .set ()
204+ self ._task .cancel ()
205+ with contextlib .suppress (asyncio .CancelledError ):
206+ await self ._task
207+ self ._task = None
208+ async with self ._lock :
209+ drained = self ._drain_locked ()
210+ for u in drained :
211+ try :
212+ await self ._on_flush (u )
213+ except Exception as e :
214+ logger .exception (f"CoalescingBuffer final flush failed: { e } " )
215+
216+ def _drain_locked (self ) -> list [StreamTaskMessageDelta ]:
217+ if not self ._buf :
218+ return []
219+ merged = _merge_consecutive (self ._buf )
220+ self ._buf = []
221+ self ._buf_chars = 0
222+ return merged
223+
224+
42225class DeltaAccumulator :
43226 def __init__ (self ):
44227 self ._accumulated_deltas : list [TaskMessageDelta ] = []
@@ -176,6 +359,7 @@ def __init__(
176359 initial_content : TaskMessageContent ,
177360 agentex_client : AsyncAgentex ,
178361 streaming_service : "StreamingService" ,
362+ streaming_mode : StreamingMode = "coalesced" ,
179363 ):
180364 self .task_id = task_id
181365 self .initial_content = initial_content
@@ -184,6 +368,8 @@ def __init__(
184368 self ._streaming_service = streaming_service
185369 self ._is_closed = False
186370 self ._delta_accumulator = DeltaAccumulator ()
371+ self ._streaming_mode : StreamingMode = streaming_mode
372+ self ._buffer : CoalescingBuffer | None = None
187373
188374 async def __aenter__ (self ) -> "StreamingTaskMessageContext" :
189375 return await self .open ()
@@ -208,6 +394,10 @@ async def open(self) -> "StreamingTaskMessageContext":
208394 )
209395 await self ._streaming_service .stream_update (start_event )
210396
397+ if self ._streaming_mode == "coalesced" :
398+ self ._buffer = CoalescingBuffer (on_flush = self ._streaming_service .stream_update )
399+ self ._buffer .start ()
400+
211401 return self
212402
213403 async def close (self ) -> TaskMessage :
@@ -218,6 +408,12 @@ async def close(self) -> TaskMessage:
218408 if self ._is_closed :
219409 return self .task_message # Already done
220410
411+ # Drain any buffered deltas before announcing DONE so consumers see the
412+ # full sequence in order.
413+ if self ._buffer is not None :
414+ await self ._buffer .close ()
415+ self ._buffer = None
416+
221417 # Send the DONE event
222418 done_event = StreamTaskMessageDone (
223419 parent_task_message = self .task_message ,
@@ -227,8 +423,8 @@ async def close(self) -> TaskMessage:
227423
228424 # Update the task message with the final content
229425 has_deltas = (
230- self ._delta_accumulator ._accumulated_deltas or
231- self ._delta_accumulator ._reasoning_summaries or
426+ self ._delta_accumulator ._accumulated_deltas or
427+ self ._delta_accumulator ._reasoning_summaries or
232428 self ._delta_accumulator ._reasoning_contents
233429 )
234430 if has_deltas :
@@ -248,7 +444,20 @@ async def close(self) -> TaskMessage:
248444 async def stream_update (
249445 self , update : TaskMessageUpdate
250446 ) -> TaskMessageUpdate | None :
251- """Stream an update to the repository."""
447+ """Stream an update to the repository.
448+
449+ Behavior depends on the context's ``streaming_mode``:
450+ - "off": delta updates feed the accumulator (so the persisted message
451+ body is correct) but are never published.
452+ - "per_token": delta updates are published immediately.
453+ - "coalesced": delta updates are queued in a 50ms / 128-char window and
454+ flushed as merged batches on a background ticker; the first delta
455+ flushes immediately for fast perceived responsiveness.
456+
457+ ``StreamTaskMessageDone`` and ``StreamTaskMessageFull`` updates always
458+ publish synchronously regardless of mode so consumers and persistence
459+ stay in sync.
460+ """
252461 if self ._is_closed :
253462 raise ValueError ("Context is already done" )
254463
@@ -258,6 +467,11 @@ async def stream_update(
258467 if isinstance (update , StreamTaskMessageDelta ):
259468 if update .delta is not None :
260469 self ._delta_accumulator .add_delta (update .delta )
470+ if self ._streaming_mode == "off" :
471+ return update
472+ if self ._streaming_mode == "coalesced" and self ._buffer is not None :
473+ await self ._buffer .add (update )
474+ return update
261475
262476 result = await self ._streaming_service .stream_update (update )
263477
@@ -288,12 +502,14 @@ def streaming_task_message_context(
288502 self ,
289503 task_id : str ,
290504 initial_content : TaskMessageContent ,
505+ streaming_mode : StreamingMode = "coalesced" ,
291506 ) -> StreamingTaskMessageContext :
292507 return StreamingTaskMessageContext (
293508 task_id = task_id ,
294509 initial_content = initial_content ,
295510 agentex_client = self ._agentex_client ,
296511 streaming_service = self ,
512+ streaming_mode = streaming_mode ,
297513 )
298514
299515 async def stream_update (
0 commit comments