diff --git a/robosystems_client/clients/operation_client.py b/robosystems_client/clients/operation_client.py index 449d5c4..acdffe3 100644 --- a/robosystems_client/clients/operation_client.py +++ b/robosystems_client/clients/operation_client.py @@ -71,9 +71,14 @@ class OperationClient: def __init__(self, config: Dict[str, Any]): self.config = config self.base_url = config["base_url"] - self.headers = config.get("headers", {}) + self.headers = dict(config.get("headers") or {}) # Get token from config if passed by parent self.token = config.get("token") + # Propagate the API key into SSE headers — SSE requests bypass the + # AuthenticatedClient used by the generated REST methods, so the token + # must be injected explicitly or the /stream endpoint returns 401. + if self.token and "X-API-Key" not in self.headers: + self.headers["X-API-Key"] = self.token self.active_operations: Dict[str, SSEClient] = {} # Thread safety for operations tracking import threading @@ -146,6 +151,14 @@ def on_operation_cancelled(): result.completed_at = datetime.now() completed = True + def on_connection_error(err): + nonlocal completed, error + result.status = OperationStatus.FAILED + result.error = str(err) + result.completed_at = datetime.now() + error = err if isinstance(err, Exception) else Exception(str(err)) + completed = True + # Register event handlers sse_client.on(EventType.OPERATION_STARTED.value, on_operation_started) sse_client.on(EventType.OPERATION_PROGRESS.value, on_operation_progress) @@ -153,6 +166,10 @@ def on_operation_cancelled(): sse_client.on(EventType.OPERATION_COMPLETED.value, on_operation_completed) sse_client.on(EventType.OPERATION_ERROR.value, on_operation_error) sse_client.on(EventType.OPERATION_CANCELLED.value, on_operation_cancelled) + # Surface transport-level errors (bad status, dropped connection, + # max retries exceeded) so the wait loop terminates instead of hanging. + sse_client.on("error", on_connection_error) + sse_client.on("max_retries_exceeded", on_connection_error) # Connect and monitor try: @@ -271,8 +288,10 @@ class AsyncOperationClient: def __init__(self, config: Dict[str, Any]): self.config = config self.base_url = config["base_url"] - self.headers = config.get("headers", {}) + self.headers = dict(config.get("headers") or {}) self.token = config.get("token") + if self.token and "X-API-Key" not in self.headers: + self.headers["X-API-Key"] = self.token self.active_operations: Dict[str, AsyncSSEClient] = {} async def monitor_operation( @@ -335,6 +354,14 @@ def on_operation_cancelled(): result.completed_at = datetime.now() completed = True + def on_connection_error(err): + nonlocal completed, error + result.status = OperationStatus.FAILED + result.error = str(err) + result.completed_at = datetime.now() + error = err if isinstance(err, Exception) else Exception(str(err)) + completed = True + # Register event handlers sse_client.on(EventType.OPERATION_STARTED.value, on_operation_started) sse_client.on(EventType.OPERATION_PROGRESS.value, on_operation_progress) @@ -342,6 +369,10 @@ def on_operation_cancelled(): sse_client.on(EventType.OPERATION_COMPLETED.value, on_operation_completed) sse_client.on(EventType.OPERATION_ERROR.value, on_operation_error) sse_client.on(EventType.OPERATION_CANCELLED.value, on_operation_cancelled) + # Surface transport-level errors (bad status, dropped connection, + # max retries exceeded) so the wait loop terminates instead of hanging. + sse_client.on("error", on_connection_error) + sse_client.on("max_retries_exceeded", on_connection_error) # Connect and monitor try: diff --git a/robosystems_client/clients/sse_client.py b/robosystems_client/clients/sse_client.py index 1befd84..15e1666 100644 --- a/robosystems_client/clients/sse_client.py +++ b/robosystems_client/clients/sse_client.py @@ -109,6 +109,19 @@ def connect(self, operation_id: str, from_sequence: int = 0) -> None: ) self._response = self._context_manager.__enter__() + if self._response.status_code != 200: + status = self._response.status_code + body = self._response.read().decode("utf-8", errors="replace")[:500] + self._context_manager.__exit__(None, None, None) + self._context_manager = None + self._response = None + self.closed = True + self.emit( + "error", + RuntimeError(f"SSE connection failed: HTTP {status} {body}".strip()), + ) + return + self.reconnect_attempts = 0 self.emit("connected", None) @@ -336,6 +349,19 @@ async def connect(self, operation_id: str, from_sequence: int = 0) -> None: ) self._response = await self._context_manager.__aenter__() + if self._response.status_code != 200: + status = self._response.status_code + body = (await self._response.aread()).decode("utf-8", errors="replace")[:500] + await self._context_manager.__aexit__(None, None, None) + self._context_manager = None + self._response = None + self.closed = True + self.emit( + "error", + RuntimeError(f"SSE connection failed: HTTP {status} {body}".strip()), + ) + return + self.reconnect_attempts = 0 self.emit("connected", None) diff --git a/robosystems_client/models/create_event_handler_request.py b/robosystems_client/models/create_event_handler_request.py index 66e876c..29b3bf8 100644 --- a/robosystems_client/models/create_event_handler_request.py +++ b/robosystems_client/models/create_event_handler_request.py @@ -46,7 +46,7 @@ class CreateEventHandlerRequest: match_agent_type (None | str | Unset): match_resource_type (None | str | Unset): match_metadata_expression (CreateEventHandlerRequestMatchMetadataExpressionType0 | None | Unset): JSONPath-style - equality map, e.g. {"metadata.category": "payroll"} + equality map against event.metadata, e.g. {"category": "payroll"} or {"metadata.category": "payroll"} priority (int | Unset): Default: 0. is_active (bool | Unset): Default: True. origin (CreateEventHandlerRequestOrigin | Unset): Default: CreateEventHandlerRequestOrigin.TENANT. diff --git a/tests/test_sse_client.py b/tests/test_sse_client.py index 5122857..b97074a 100644 --- a/tests/test_sse_client.py +++ b/tests/test_sse_client.py @@ -382,6 +382,7 @@ def test_connect_sets_up_stream(self, mock_httpx, sse_config): mock_http_client = MagicMock() mock_context = MagicMock() mock_response = MagicMock() + mock_response.status_code = 200 mock_response.iter_lines.return_value = iter( [ "event: operation_completed", @@ -402,6 +403,40 @@ def test_connect_sets_up_stream(self, mock_httpx, sse_config): assert len(connected_events) == 1 assert client.reconnect_attempts == 0 + @patch("robosystems_client.clients.sse_client.httpx") + def test_connect_non_200_emits_error_without_retry(self, mock_httpx, sse_config): + """A non-200 status must emit an error and not enter the event loop. + + Regression test: before this check, a 401 would cause connect() to fall + through to iter_lines() on an empty body, return silently, and leave any + caller spinning in a wait-for-completion loop forever. + """ + mock_http_client = MagicMock() + mock_context = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.read.return_value = b'{"detail": "Not authenticated"}' + mock_context.__enter__ = Mock(return_value=mock_response) + mock_http_client.stream.return_value = mock_context + mock_httpx.Client.return_value = mock_http_client + + client = SSEClient(sse_config) + errors: list = [] + connected: list = [] + client.on("error", lambda e: errors.append(e)) + client.on("connected", lambda d: connected.append(d)) + + client.connect("op-123") + + # Never entered the event loop / never emitted connected. + assert connected == [] + # Error event surfaced with status and body. + assert len(errors) == 1 + assert "HTTP 401" in str(errors[0]) + # iter_lines was never called — we bailed before the read loop. + mock_response.iter_lines.assert_not_called() + assert client.closed is True + @patch("robosystems_client.clients.sse_client.httpx") def test_connect_error_triggers_retry(self, mock_httpx, sse_config): """Test that connection error triggers _handle_error."""