Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions robosystems_client/clients/operation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ class OperationClient:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.base_url = config["base_url"]
self.headers = config.get("headers", {})
self.headers = dict(config.get("headers") or {})
# Get token from config if passed by parent
self.token = config.get("token")
# Propagate the API key into SSE headers — SSE requests bypass the
# AuthenticatedClient used by the generated REST methods, so the token
# must be injected explicitly or the /stream endpoint returns 401.
if self.token and "X-API-Key" not in self.headers:
self.headers["X-API-Key"] = self.token
self.active_operations: Dict[str, SSEClient] = {}
# Thread safety for operations tracking
import threading
Expand Down Expand Up @@ -146,13 +151,25 @@ def on_operation_cancelled():
result.completed_at = datetime.now()
completed = True

def on_connection_error(err):
nonlocal completed, error
result.status = OperationStatus.FAILED
result.error = str(err)
result.completed_at = datetime.now()
error = err if isinstance(err, Exception) else Exception(str(err))
completed = True

# Register event handlers
sse_client.on(EventType.OPERATION_STARTED.value, on_operation_started)
sse_client.on(EventType.OPERATION_PROGRESS.value, on_operation_progress)
sse_client.on(EventType.QUEUE_UPDATE.value, on_queue_update)
sse_client.on(EventType.OPERATION_COMPLETED.value, on_operation_completed)
sse_client.on(EventType.OPERATION_ERROR.value, on_operation_error)
sse_client.on(EventType.OPERATION_CANCELLED.value, on_operation_cancelled)
# Surface transport-level errors (bad status, dropped connection,
# max retries exceeded) so the wait loop terminates instead of hanging.
sse_client.on("error", on_connection_error)
sse_client.on("max_retries_exceeded", on_connection_error)

# Connect and monitor
try:
Expand Down Expand Up @@ -271,8 +288,10 @@ class AsyncOperationClient:
def __init__(self, config: Dict[str, Any]):
self.config = config
self.base_url = config["base_url"]
self.headers = config.get("headers", {})
self.headers = dict(config.get("headers") or {})
self.token = config.get("token")
if self.token and "X-API-Key" not in self.headers:
self.headers["X-API-Key"] = self.token
self.active_operations: Dict[str, AsyncSSEClient] = {}

async def monitor_operation(
Expand Down Expand Up @@ -335,13 +354,25 @@ def on_operation_cancelled():
result.completed_at = datetime.now()
completed = True

def on_connection_error(err):
nonlocal completed, error
result.status = OperationStatus.FAILED
result.error = str(err)
result.completed_at = datetime.now()
error = err if isinstance(err, Exception) else Exception(str(err))
completed = True

# Register event handlers
sse_client.on(EventType.OPERATION_STARTED.value, on_operation_started)
sse_client.on(EventType.OPERATION_PROGRESS.value, on_operation_progress)
sse_client.on(EventType.QUEUE_UPDATE.value, on_queue_update)
sse_client.on(EventType.OPERATION_COMPLETED.value, on_operation_completed)
sse_client.on(EventType.OPERATION_ERROR.value, on_operation_error)
sse_client.on(EventType.OPERATION_CANCELLED.value, on_operation_cancelled)
# Surface transport-level errors (bad status, dropped connection,
# max retries exceeded) so the wait loop terminates instead of hanging.
sse_client.on("error", on_connection_error)
sse_client.on("max_retries_exceeded", on_connection_error)

# Connect and monitor
try:
Expand Down
26 changes: 26 additions & 0 deletions robosystems_client/clients/sse_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ def connect(self, operation_id: str, from_sequence: int = 0) -> None:
)
self._response = self._context_manager.__enter__()

if self._response.status_code != 200:
status = self._response.status_code
body = self._response.read().decode("utf-8", errors="replace")[:500]
self._context_manager.__exit__(None, None, None)
self._context_manager = None
self._response = None
self.closed = True
self.emit(
"error",
RuntimeError(f"SSE connection failed: HTTP {status} {body}".strip()),
)
return

self.reconnect_attempts = 0
self.emit("connected", None)

Expand Down Expand Up @@ -336,6 +349,19 @@ async def connect(self, operation_id: str, from_sequence: int = 0) -> None:
)
self._response = await self._context_manager.__aenter__()

if self._response.status_code != 200:
status = self._response.status_code
body = (await self._response.aread()).decode("utf-8", errors="replace")[:500]
await self._context_manager.__aexit__(None, None, None)
self._context_manager = None
self._response = None
self.closed = True
self.emit(
"error",
RuntimeError(f"SSE connection failed: HTTP {status} {body}".strip()),
)
return

self.reconnect_attempts = 0
self.emit("connected", None)

Expand Down
2 changes: 1 addition & 1 deletion robosystems_client/models/create_event_handler_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CreateEventHandlerRequest:
match_agent_type (None | str | Unset):
match_resource_type (None | str | Unset):
match_metadata_expression (CreateEventHandlerRequestMatchMetadataExpressionType0 | None | Unset): JSONPath-style
equality map, e.g. {"metadata.category": "payroll"}
equality map against event.metadata, e.g. {"category": "payroll"} or {"metadata.category": "payroll"}
priority (int | Unset): Default: 0.
is_active (bool | Unset): Default: True.
origin (CreateEventHandlerRequestOrigin | Unset): Default: CreateEventHandlerRequestOrigin.TENANT.
Expand Down
35 changes: 35 additions & 0 deletions tests/test_sse_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def test_connect_sets_up_stream(self, mock_httpx, sse_config):
mock_http_client = MagicMock()
mock_context = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.iter_lines.return_value = iter(
[
"event: operation_completed",
Expand All @@ -402,6 +403,40 @@ def test_connect_sets_up_stream(self, mock_httpx, sse_config):
assert len(connected_events) == 1
assert client.reconnect_attempts == 0

@patch("robosystems_client.clients.sse_client.httpx")
def test_connect_non_200_emits_error_without_retry(self, mock_httpx, sse_config):
"""A non-200 status must emit an error and not enter the event loop.

Regression test: before this check, a 401 would cause connect() to fall
through to iter_lines() on an empty body, return silently, and leave any
caller spinning in a wait-for-completion loop forever.
"""
mock_http_client = MagicMock()
mock_context = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.read.return_value = b'{"detail": "Not authenticated"}'
mock_context.__enter__ = Mock(return_value=mock_response)
mock_http_client.stream.return_value = mock_context
mock_httpx.Client.return_value = mock_http_client

client = SSEClient(sse_config)
errors: list = []
connected: list = []
client.on("error", lambda e: errors.append(e))
client.on("connected", lambda d: connected.append(d))

client.connect("op-123")

# Never entered the event loop / never emitted connected.
assert connected == []
# Error event surfaced with status and body.
assert len(errors) == 1
assert "HTTP 401" in str(errors[0])
# iter_lines was never called — we bailed before the read loop.
mock_response.iter_lines.assert_not_called()
assert client.closed is True

@patch("robosystems_client.clients.sse_client.httpx")
def test_connect_error_triggers_retry(self, mock_httpx, sse_config):
"""Test that connection error triggers _handle_error."""
Expand Down
Loading