Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/agentex/lib/core/services/adk/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ async def add(self, update: StreamTaskMessageDelta) -> None:
if self._closed:
return
async with self._lock:
# Re-check under the lock: a concurrent close() (e.g. from a racing
# Full) may have drained and shut down the ticker after the check
# above but before we acquired the lock. Appending now would strand
# the delta in a dead buffer, never published.
if self._closed:
return
self._buf.append(update)
self._buf_chars += _delta_char_len(update.delta)
if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS:
Expand Down Expand Up @@ -415,19 +421,28 @@ async def open(self) -> "StreamingTaskMessageContext":

return self

async def _reap_buffer(self) -> None:
"""Drain and stop the coalescing buffer, releasing its background ticker.

Idempotent: a no-op once the buffer has already been reaped.
"""
if self._buffer is not None:
await self._buffer.close()
self._buffer = None

async def close(self) -> TaskMessage:
"""Close the streaming context."""
if not self.task_message:
raise ValueError("Context not properly initialized - no task message")

if self._is_closed:
return self.task_message # Already done
# Reap the buffer (stopping its ticker) before the _is_closed
# short-circuit, so a context already marked done by a Full update can't
# leave the ticker orphaned. Draining here also lets consumers see the
# full delta sequence in order before DONE.
await self._reap_buffer()

# Drain any buffered deltas before announcing DONE so consumers see the
# full sequence in order.
if self._buffer is not None:
await self._buffer.close()
self._buffer = None
if self._is_closed:
return self.task_message # Already done (buffer reaped above)

# Send the DONE event
done_event = StreamTaskMessageDone(
Expand Down Expand Up @@ -486,6 +501,14 @@ async def stream_update(self, update: TaskMessageUpdate) -> TaskMessageUpdate |
await self._buffer.add(update)
return update

# A Full ends the stream and supersedes buffered deltas. Drain and stop
# the buffer BEFORE publishing the Full, so leftover deltas land in order
# (deltas -> Full) instead of trailing the terminal Full as a stale
# duplicate tail. This also stops the ticker, which would otherwise be
# orphaned when __aexit__'s close() short-circuits on _is_closed.
if isinstance(update, StreamTaskMessageFull):
await self._reap_buffer()

result = await self._streaming_service.stream_update(update)
Comment thread
eberki-scale marked this conversation as resolved.

if isinstance(update, StreamTaskMessageDone):
Expand Down
79 changes: 78 additions & 1 deletion tests/lib/core/services/adk/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
ToolResponseDelta,
ReasoningSummaryDelta,
)
from agentex.types.task_message_update import StreamTaskMessageDelta
from agentex.types.task_message_update import (
StreamTaskMessageFull,
StreamTaskMessageDelta,
)
from agentex.lib.core.services.adk.streaming import (
CoalescingBuffer,
StreamingTaskMessageContext,
Expand Down Expand Up @@ -352,6 +355,24 @@ async def on_flush(u: StreamTaskMessageDelta) -> None:
await buf.add(_text(task_message, "after"))
assert flushed == []

@pytest.mark.asyncio
async def test_add_racing_close_is_not_stranded(self, task_message: TaskMessage) -> None:
"""TOCTOU: a delta that passes add()'s pre-lock _closed check but only
acquires the lock after close() set _closed must be dropped, not appended
to a drained, ticker-less buffer where it would never be published."""
buf = CoalescingBuffer(on_flush=AsyncMock())
buf.start()
# Hold the lock so add() parks *after* its pre-lock _closed check.
await buf._lock.acquire()
add_task = asyncio.create_task(buf.add(_text(task_message, "racing")))
await asyncio.sleep(0) # add() passes the _closed check, blocks on the lock
buf._closed = True # close() wins the race
buf._lock.release()
await add_task

assert buf._buf == [], "racing delta was stranded in the closed buffer"
await buf.close() # cleanup


class TestCoalescingBufferCloseDuringFlush:
@pytest.mark.asyncio
Expand Down Expand Up @@ -520,3 +541,59 @@ async def test_open_without_created_at_passes_omit(self) -> None:

kwargs = client.messages.create.call_args.kwargs
assert kwargs["created_at"] is omit


class TestFullMessageClosesBuffer:
"""A StreamTaskMessageFull must stop the buffer ticker and drain its deltas
before the terminal Full. Marking the context done without closing the
buffer leaves close()'s _is_closed short-circuit to orphan the ticker, and
publishing buffered deltas after the Full reads as a stale duplicate tail."""

@pytest.mark.asyncio
async def test_full_message_stops_ticker(self) -> None:
ctx, _svc, tm = await _make_context("coalesced")
# A delta makes the buffer and its ticker live.
await ctx.stream_update(_text(tm, "hello"))
buf = ctx._buffer
assert buf is not None
task = buf._task
assert task is not None and not task.done()

await ctx.stream_update(
StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="final", format="markdown"),
type="full",
)
)

assert ctx._buffer is None, "Full message left the buffer un-closed"
assert task.done(), "coalescing-buffer ticker still running after Full (orphaned)"

@pytest.mark.asyncio
async def test_full_is_terminal_publish_no_trailing_deltas(self) -> None:
# Buffered deltas must publish BEFORE the Full, never after (a trailing
# delta after the terminal Full reads as a stale duplicate tail).
ctx, svc, tm = await _make_context("coalesced")
# Two deltas through the buffer. Regardless of how the coalescing window
# batches them (1 or 2 publishes), the invariant under test is the same:
# every delta publishes before the terminal Full, never after it.
await ctx.stream_update(_text(tm, "alpha"))
await ctx.stream_update(_text(tm, "beta"))

full = StreamTaskMessageFull(
parent_task_message=tm,
content=TextContent(author="agent", content="alphabeta", format="markdown"),
type="full",
)
await ctx.stream_update(full)

published = [c.args[0] for c in svc.stream_update.await_args_list]
assert published, "nothing was published"
assert published[-1] is full, (
f"Full must be the terminal publish; saw trailing "
f"{type(published[-1]).__name__} after it (stale duplicate tail)"
)
assert any(isinstance(u, StreamTaskMessageDelta) for u in published[:-1]), (
"expected the buffered deltas to be published before the Full"
)
Loading