diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 1376df06c..7caeaf86c 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -1,7 +1,10 @@ -from typing import override +from __future__ import annotations + +import asyncio +from typing import Optional, override import scale_gp_beta.lib.tracing as tracing -from scale_gp_beta import SGPClient, AsyncSGPClient +from scale_gp_beta import APIError, SGPClient, AsyncSGPClient from scale_gp_beta.lib.tracing import create_span, flush_queue from scale_gp_beta.lib.tracing.span import Span as SGPSpan @@ -17,6 +20,19 @@ logger = make_logger(__name__) +# Mirrored from scale_gp_beta.lib.tracing.trace_queue_manager defaults so the +# async processor batches and retries the same way the sync (daemon-thread) +# path does. +DEFAULT_MAX_QUEUE_SIZE = 4_000 +DEFAULT_TRIGGER_QUEUE_SIZE = 200 +DEFAULT_TRIGGER_CADENCE = 4.0 +DEFAULT_MAX_BATCH_SIZE = 50 +DEFAULT_RETRIES = 4 +INITIAL_BACKOFF = 0.4 +MAX_BACKOFF = 20.0 +SHUTDOWN_DRAIN_TIMEOUT = 10.0 + + def _get_span_type(span: Span) -> str: """Read span_type from span.data['__span_type__'], defaulting to STANDALONE.""" if isinstance(span.data, dict): @@ -90,6 +106,18 @@ def shutdown(self) -> None: class SGPAsyncTracingProcessor(AsyncTracingProcessor): + """Async tracing processor that buffers spans and flushes them in batches. + + Mirrors the buffer-plus-flush behavior of the SDK's synchronous + `TraceQueueManager`, but uses asyncio primitives so it works inside an + asyncio event loop without blocking it. + + Spans are enqueued on `on_span_start` and `on_span_end`; a background + `asyncio.Task` worker drains the queue into batches and posts them via + `client.spans.upsert_batch`. The worker is lazy-initialized on the + running event loop on first use. + """ + def __init__(self, config: SGPTracingProcessorConfig): self.disabled = config.sgp_api_key == "" or config.sgp_account_id == "" self._spans: dict[str, SGPSpan] = {} @@ -111,6 +139,22 @@ def __init__(self, config: SGPTracingProcessorConfig): ) self.env_vars = EnvironmentVariables.refresh() + # Lazy-initialized on the running loop on first use. Re-created if + # the loop changes (e.g. sync-ACP / per-request loops) so the worker + # is always bound to the loop currently consuming it. + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._queue: Optional[asyncio.Queue[SGPSpan]] = None + self._worker: Optional[asyncio.Task[None]] = None + self._shutdown_event: Optional[asyncio.Event] = None + self._flush_event: Optional[asyncio.Event] = None + + if self.disabled: + # Log once at init rather than on every span event, which would + # flood logs at agent throughput. + logger.warning( + "SGP tracing is disabled (sgp_api_key or sgp_account_id missing); span events will be ignored" + ) + def _add_source_to_span(self, span: Span) -> None: if span.data is None: span.data = {} @@ -123,6 +167,25 @@ def _add_source_to_span(self, span: Span) -> None: if self.env_vars.AGENT_ID is not None: span.data["__agent_id__"] = self.env_vars.AGENT_ID + def _ensure_started(self) -> None: + """Initialize per-loop queue + worker on first use, or after a loop swap. + + Must be called from inside an async method so `get_running_loop()` is + safe. Idempotent on the same loop while the worker is healthy; on a + loop change or worker death, it rebuilds the queue and worker (items + in the previous queue are lost — they were tied to a now-dead loop). + """ + if self.disabled: + return + loop = asyncio.get_running_loop() + if self._loop is loop and self._worker is not None and not self._worker.done(): + return + self._loop = loop + self._queue = asyncio.Queue(maxsize=DEFAULT_MAX_QUEUE_SIZE) + self._shutdown_event = asyncio.Event() + self._flush_event = asyncio.Event() + self._worker = loop.create_task(self._run()) + @override async def on_span_start(self, span: Span) -> None: self._add_source_to_span(span) @@ -137,18 +200,13 @@ async def on_span_start(self, span: Span) -> None: metadata=span.data, ) sgp_span.start_time = span.start_time.isoformat() # type: ignore[union-attr] + self._spans[span.id] = sgp_span if self.disabled: - logger.warning("SGP is disabled, skipping span upsert") return - # TODO(AGX1-198): Batch multiple spans into a single upsert_batch call - # instead of one span per HTTP request. - # https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans - await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params()] - ) - self._spans[span.id] = sgp_span + self._ensure_started() + self._enqueue(sgp_span) @override async def on_span_end(self, span: Span) -> None: @@ -158,20 +216,142 @@ async def on_span_end(self, span: Span) -> None: return self._add_source_to_span(span) - sgp_span.input = span.input # type: ignore[assignment] sgp_span.output = span.output # type: ignore[assignment] sgp_span.metadata = span.data # type: ignore[assignment] sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] if self.disabled: return - await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params()] - ) + + self._ensure_started() + self._enqueue(sgp_span) @override async def shutdown(self) -> None: - await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] - items=[sgp_span.to_request_params() for sgp_span in self._spans.values()] - ) + # Fast path when the processor was never started (disabled, or + # shutdown called before any span event). + if self._worker is None: + self._spans.clear() + return + + # Re-enqueue any spans whose end was never recorded so they aren't + # silently lost. They were already enqueued at start, but on_span_end + # is what mutates output / metadata / end_timestamp; without a second + # enqueue, the server only sees the start payload for them. + for sgp_span in list(self._spans.values()): + self._enqueue(sgp_span) self._spans.clear() + + assert self._shutdown_event is not None + self._shutdown_event.set() + if self._flush_event is not None: + self._flush_event.set() + + try: + await asyncio.wait_for(self._worker, timeout=SHUTDOWN_DRAIN_TIMEOUT) + except asyncio.TimeoutError: + logger.warning(f"Async tracing worker did not exit within {SHUTDOWN_DRAIN_TIMEOUT}s; cancelling") + self._worker.cancel() + + def _enqueue(self, sgp_span: SGPSpan) -> None: + """Push a span onto the queue and signal an early flush if the queue + has crossed `DEFAULT_TRIGGER_QUEUE_SIZE`. Drops the span on overflow.""" + if self._queue is None: + return + try: + self._queue.put_nowait(sgp_span) + except asyncio.QueueFull: + logger.warning(f"Tracing queue full; dropping span {sgp_span.span_id}") + return + if self._flush_event is not None and self._queue.qsize() >= DEFAULT_TRIGGER_QUEUE_SIZE: + self._flush_event.set() + + def _is_shutting_down(self) -> bool: + return self._shutdown_event is not None and self._shutdown_event.is_set() + + async def _wait_for_flush_signal(self) -> None: + """Block until either an early-flush signal arrives or the cadence + timer fires. Returns either way; the caller is responsible for + draining.""" + assert self._flush_event is not None + try: + await asyncio.wait_for(self._flush_event.wait(), timeout=DEFAULT_TRIGGER_CADENCE) + except asyncio.TimeoutError: + pass + self._flush_event.clear() + + async def _safe_drain(self, log_label: str) -> None: + """Run `_drain`, catching unexpected errors so one bad iteration + doesn't kill the worker. CancelledError is always re-raised.""" + try: + await self._drain() + except asyncio.CancelledError: + raise + except Exception: + logger.exception(log_label) + + async def _run(self) -> None: + """Background worker. Sleeps until a flush trigger fires, drains the + queue, and repeats. On shutdown signal, does one final drain so + nothing pending is dropped. The outermost try / except keeps a worker + crash from being silent.""" + try: + while not self._is_shutting_down(): + await self._wait_for_flush_signal() + await self._safe_drain("Tracing worker iteration failed; continuing") + await self._safe_drain("Final tracing drain failed; some spans may be lost") + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Async tracing worker crashed") + + async def _drain(self) -> None: + """Pull spans from the queue and upsert them in batches of up to + `DEFAULT_MAX_BATCH_SIZE`. Stops when the queue is empty. + + A span whose `to_request_params()` raises is dropped (logged); the + rest of the batch still goes out. This matches the SDK's exporter.""" + if self._queue is None or self.sgp_async_client is None: + return + while not self._queue.empty(): + batch: list[dict] = [] + while len(batch) < DEFAULT_MAX_BATCH_SIZE and not self._queue.empty(): + try: + sgp_span = self._queue.get_nowait() + except asyncio.QueueEmpty: + break + try: + batch.append(sgp_span.to_request_params()) + except Exception: + logger.exception("Failed to build span params; dropping span") + if not batch: + continue + await self._upsert_with_retry(batch) + + async def _upsert_with_retry(self, batch: list[dict]) -> None: + """POST a single batch with the SDK's retry policy: 4 attempts with + exponential backoff (`INITIAL_BACKOFF` -> `MAX_BACKOFF` capped). + + - `APIError` triggers retry up to `DEFAULT_RETRIES` attempts. + - Anything else is logged and the batch is dropped (we don't know + whether the server saw the request, and the SDK already wraps + transport-level failures as `APIError`).""" + if self.sgp_async_client is None: + return + backoff = INITIAL_BACKOFF + for attempt in range(DEFAULT_RETRIES): + try: + await self.sgp_async_client.spans.upsert_batch(items=batch) # type: ignore[arg-type] + return + except APIError as exc: + if attempt == DEFAULT_RETRIES - 1: + logger.error(f"Failed to export {len(batch)} spans after {DEFAULT_RETRIES} attempts: {exc.message}") + return + logger.warning(f"Span export failed ({exc.message}); retrying in {backoff:.1f}s") + await asyncio.sleep(backoff) + backoff = min(backoff * 2, MAX_BACKOFF) + except asyncio.CancelledError: + raise + except Exception: + logger.exception(f"Unexpected error exporting {len(batch)} spans; dropping batch") + return diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 818fed375..7a5b6b041 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -1,6 +1,7 @@ from __future__ import annotations import uuid +import asyncio from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock, patch @@ -48,11 +49,9 @@ def _make_processor(): mock_env.refresh.return_value = MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None) mock_create_span = MagicMock(side_effect=lambda **kwargs: _make_mock_sgp_span()) - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.SGPClient"), \ - patch(f"{MODULE}.tracing"), \ - patch(f"{MODULE}.flush_queue"), \ - patch(f"{MODULE}.create_span", mock_create_span): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.SGPClient"), patch( + f"{MODULE}.tracing" + ), patch(f"{MODULE}.flush_queue"), patch(f"{MODULE}.create_span", mock_create_span): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPSyncTracingProcessor, ) @@ -113,9 +112,9 @@ def _make_processor(): mock_async_client = MagicMock() mock_async_client.spans.upsert_batch = AsyncMock() - with patch(f"{MODULE}.EnvironmentVariables", mock_env), \ - patch(f"{MODULE}.create_span", mock_create_span), \ - patch(f"{MODULE}.AsyncSGPClient", return_value=mock_async_client): + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch(f"{MODULE}.create_span", mock_create_span), patch( + f"{MODULE}.AsyncSGPClient", return_value=mock_async_client + ): from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( SGPAsyncTracingProcessor, ) @@ -164,7 +163,9 @@ async def test_span_end_for_unknown_span_is_noop(self): assert len(processor._spans) == 0 async def test_sgp_span_input_updated_on_end(self): - """on_span_end should update sgp_span.input from the incoming span.""" + """on_span_end should mutate the tracked SGP span and enqueue it. + With batched flushing, the upsert happens once on shutdown, with the + final state of the span after both start and end have run.""" processor, _ = self._make_processor() with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): @@ -175,10 +176,12 @@ async def test_sgp_span_input_updated_on_end(self): assert len(processor._spans) == 1 # Simulate modified input at end time - updated_input: dict[str, object] = {"messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hi"}, - ]} + updated_input: dict[str, object] = { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ] + } span.input = updated_input span.output = {"response": "hi"} span.end_time = datetime.now(UTC) @@ -186,5 +189,245 @@ async def test_sgp_span_input_updated_on_end(self): # Span should be removed after end assert len(processor._spans) == 0 - # The end upsert should have been called - assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end + + # No upsert on the hot path; the worker batches and flushes asynchronously. + assert processor.sgp_async_client.spans.upsert_batch.call_count == 0 + + # Shutdown drains the queue and produces a single batched upsert. + await processor.shutdown() + assert processor.sgp_async_client.spans.upsert_batch.call_count == 1 + + +# --------------------------------------------------------------------------- +# Async processor batching tests +# +# Before this change, on_span_start and on_span_end each issued an awaited +# upsert_batch(items=[one]) call on the agent's hot path. The processor now +# buffers events and flushes them in batches from a background asyncio.Task, +# mirroring the SDK's TraceQueueManager. +# --------------------------------------------------------------------------- + + +class TestSGPAsyncTracingProcessorBatching: + @staticmethod + def _make_processor(): + mock_env = MagicMock() + mock_env.refresh.return_value = MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None) + + mock_async_client = MagicMock() + mock_async_client.spans.upsert_batch = AsyncMock() + + with patch(f"{MODULE}.EnvironmentVariables", mock_env), patch( + f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span() + ), patch(f"{MODULE}.AsyncSGPClient", return_value=mock_async_client): + from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( + SGPAsyncTracingProcessor, + ) + + processor = SGPAsyncTracingProcessor(_make_config()) + + processor.sgp_async_client = mock_async_client + return processor, mock_async_client + + async def test_span_event_does_not_trigger_immediate_upsert(self): + """Regression: a single span event must not result in an upsert call + on the hot path. Events must be enqueued and flushed by the worker.""" + processor, client = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + await processor.on_span_start(span) + + assert client.spans.upsert_batch.call_count == 0, "on_span_start should enqueue, not trigger a network call" + + async def test_shutdown_flushes_queued_spans_in_one_batch(self): + """Many span events should be coalesced into a single upsert_batch + call when the buffer fits under MAX_BATCH_SIZE (50).""" + processor, client = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + for _ in range(5): + span = _make_span() + await processor.on_span_start(span) + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + await processor.shutdown() + + assert client.spans.upsert_batch.call_count == 1, ( + f"Expected a single batched upsert, got {client.spans.upsert_batch.call_count}" + ) + items = client.spans.upsert_batch.call_args.kwargs["items"] + # 5 starts + 5 ends = 10 enqueued items, well under MAX_BATCH_SIZE. + assert len(items) == 10, f"Expected 10 items in the batch, got {len(items)}" + + async def test_drain_splits_into_multiple_batches_above_max_batch_size(self): + """Spans beyond MAX_BATCH_SIZE (50) must be split across multiple + upsert_batch calls so a single call never exceeds the cap.""" + processor, client = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + for _ in range(40): + span = _make_span() + await processor.on_span_start(span) + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + # 40 starts + 40 ends = 80 enqueued items. With MAX_BATCH_SIZE=50, + # that's at least 2 upsert calls. + await processor.shutdown() + + assert client.spans.upsert_batch.call_count >= 2, ( + f"Expected ≥2 batched upserts for 80 events, got {client.spans.upsert_batch.call_count}" + ) + for call in client.spans.upsert_batch.call_args_list: + items = call.kwargs["items"] + assert len(items) <= 50, f"Batch of {len(items)} exceeds MAX_BATCH_SIZE=50" + total_items = sum(len(call.kwargs["items"]) for call in client.spans.upsert_batch.call_args_list) + assert total_items == 80, f"Expected 80 items across all batches, got {total_items}" + + async def test_worker_continues_after_unexpected_exception_in_one_batch(self): + """A single upsert raising an unexpected (non-APIError) exception + must drop that batch and let the worker keep flushing subsequent + ones. Regression test for the per-iteration try/except in `_run`.""" + processor, client = self._make_processor() + + # First call raises (unexpected exception → batch dropped), + # subsequent calls succeed. + client.spans.upsert_batch.side_effect = [RuntimeError("boom"), None] + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + # First flush — will raise inside _upsert_with_retry, batch dropped. + span_a = _make_span() + await processor.on_span_start(span_a) + span_a.end_time = datetime.now(UTC) + await processor.on_span_end(span_a) + assert processor._flush_event is not None + processor._flush_event.set() + # Yield so the worker runs the failing flush. + await asyncio.sleep(0) + await asyncio.sleep(0) + + # Worker must still be alive and able to handle a second batch. + span_b = _make_span() + await processor.on_span_start(span_b) + span_b.end_time = datetime.now(UTC) + await processor.on_span_end(span_b) + + await processor.shutdown() + + # First call raised, second succeeded → 2 calls total. + assert client.spans.upsert_batch.call_count == 2, ( + f"Worker should have made a second upsert attempt after the first failed; " + f"got {client.spans.upsert_batch.call_count}" + ) + + +# --------------------------------------------------------------------------- +# Edge-case correctness tests +# --------------------------------------------------------------------------- + + +class TestSGPAsyncTracingProcessorEdgeCases: + async def test_disabled_processor_never_enqueues_or_calls_upsert(self): + """When the config has no api_key / account_id, the processor must + be a no-op: no client constructed, no worker spun up, no upsert + calls. Only span tracking in `_spans` is preserved (matches the + sync processor's contract).""" + env_mock = MagicMock(refresh=MagicMock(return_value=MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None))) + with patch(f"{MODULE}.EnvironmentVariables", env_mock), patch( + f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span() + ), patch(f"{MODULE}.AsyncSGPClient") as mock_client_cls: + from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( + SGPAsyncTracingProcessor, + ) + + disabled_config = SGPTracingProcessorConfig(sgp_api_key="", sgp_account_id="") + processor = SGPAsyncTracingProcessor(disabled_config) + + assert processor.disabled is True + assert processor.sgp_async_client is None, "Disabled processor must not construct a client" + mock_client_cls.assert_not_called() + + span = _make_span() + await processor.on_span_start(span) + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + # No worker, no queue. + assert processor._worker is None + assert processor._queue is None + + # Shutdown is also a no-op. + await processor.shutdown() + + async def test_shutdown_is_safe_when_called_multiple_times(self): + """Shutdown must be idempotent: a second call after the worker has + already exited cleanly should not raise, double-flush, or hang.""" + processor, client = TestSGPAsyncTracingProcessorBatching._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + await processor.on_span_start(span) + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + await processor.shutdown() + first_call_count = client.spans.upsert_batch.call_count + assert first_call_count == 1 + + # Second shutdown: worker is already done; should not raise or + # produce additional upserts since _spans is already cleared and + # the queue has been drained. + await processor.shutdown() + assert client.spans.upsert_batch.call_count == first_call_count, ( + "Calling shutdown twice must not produce extra upserts" + ) + + async def test_shutdown_before_any_event_is_noop(self): + """If shutdown runs before any span event, the worker was never + started; it must early-return without spinning anything up just to + tear it down.""" + env_mock = MagicMock(refresh=MagicMock(return_value=MagicMock(ACP_TYPE=None, AGENT_NAME=None, AGENT_ID=None))) + with patch(f"{MODULE}.EnvironmentVariables", env_mock), patch(f"{MODULE}.AsyncSGPClient") as mock_client_cls: + mock_client_cls.return_value = MagicMock(spans=MagicMock(upsert_batch=AsyncMock())) + + from agentex.lib.core.tracing.processors.sgp_tracing_processor import ( + SGPAsyncTracingProcessor, + ) + + processor = SGPAsyncTracingProcessor(_make_config()) + assert processor._worker is None + + await processor.shutdown() + + assert processor._worker is None, "Shutdown must not spin up a worker just to tear it down" + + async def test_apierror_triggers_retry_then_drops_batch_on_exhaustion(self): + """`APIError` must be retried up to DEFAULT_RETRIES times. After + exhaustion, the batch is dropped and the worker continues.""" + from scale_gp_beta import APIError + + processor, client = TestSGPAsyncTracingProcessorBatching._make_processor() + + # Make every attempt raise APIError so we exhaust the retry budget. + api_error = APIError(message="boom", request=MagicMock(), body=None) + client.spans.upsert_batch.side_effect = api_error + + # Patch sleep so retries don't block the test on real backoff timing. + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()), patch( + "asyncio.sleep", new=AsyncMock() + ): + span = _make_span() + await processor.on_span_start(span) + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + await processor.shutdown() + + # 4 attempts, all failed. Batch dropped. Importantly, no fifth call. + from agentex.lib.core.tracing.processors.sgp_tracing_processor import DEFAULT_RETRIES + + assert client.spans.upsert_batch.call_count == DEFAULT_RETRIES, ( + f"Expected exactly {DEFAULT_RETRIES} attempts before dropping; got {client.spans.upsert_batch.call_count}" + )