Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/connectrpc/_interceptor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ async def on_start(self, ctx: RequestContext) -> T:
"""
...

async def on_end(self, token: T, ctx: RequestContext) -> None:
async def on_end(
self, token: T, ctx: RequestContext, error: Exception | None
) -> None:
"""Called when the RPC ends."""
return

Expand Down Expand Up @@ -164,10 +166,14 @@ async def intercept_unary(
ctx: RequestContext,
) -> RES:
token = await self._delegate.on_start(ctx)
error: Exception | None = None
try:
return await call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
await self._delegate.on_end(token, ctx)
await self._delegate.on_end(token, ctx, error)

async def intercept_client_stream(
self,
Expand All @@ -176,10 +182,14 @@ async def intercept_client_stream(
ctx: RequestContext,
) -> RES:
token = await self._delegate.on_start(ctx)
error: Exception | None = None
try:
return await call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
await self._delegate.on_end(token, ctx)
await self._delegate.on_end(token, ctx, error)

async def intercept_server_stream(
self,
Expand All @@ -188,11 +198,15 @@ async def intercept_server_stream(
ctx: RequestContext,
) -> AsyncIterator[RES]:
token = await self._delegate.on_start(ctx)
error: Exception | None = None
try:
async for response in call_next(request, ctx):
yield response
except Exception as e:
error = e
raise
finally:
await self._delegate.on_end(token, ctx)
await self._delegate.on_end(token, ctx, error)

async def intercept_bidi_stream(
self,
Expand All @@ -201,11 +215,15 @@ async def intercept_bidi_stream(
ctx: RequestContext,
) -> AsyncIterator[RES]:
token = await self._delegate.on_start(ctx)
error: Exception | None = None
try:
async for response in call_next(request, ctx):
yield response
except Exception as e:
error = e
raise
finally:
await self._delegate.on_end(token, ctx)
await self._delegate.on_end(token, ctx, error)


def resolve_interceptors(
Expand Down
28 changes: 23 additions & 5 deletions src/connectrpc/_interceptor_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def on_start_sync(self, ctx: RequestContext) -> T:
"""
...

def on_end_sync(self, token: T, ctx: RequestContext) -> None:
def on_end_sync(
self, token: T, ctx: RequestContext, error: Exception | None
) -> None:
"""Called when the RPC ends."""
return

Expand Down Expand Up @@ -164,10 +166,14 @@ def intercept_unary_sync(
ctx: RequestContext,
) -> RES:
token = self._delegate.on_start_sync(ctx)
error: Exception | None = None
try:
return call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
self._delegate.on_end_sync(token, ctx)
self._delegate.on_end_sync(token, ctx, error)

def intercept_client_stream_sync(
self,
Expand All @@ -176,10 +182,14 @@ def intercept_client_stream_sync(
ctx: RequestContext,
) -> RES:
token = self._delegate.on_start_sync(ctx)
error: Exception | None = None
try:
return call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
self._delegate.on_end_sync(token, ctx)
self._delegate.on_end_sync(token, ctx, error)

def intercept_server_stream_sync(
self,
Expand All @@ -188,10 +198,14 @@ def intercept_server_stream_sync(
ctx: RequestContext,
) -> Iterator[RES]:
token = self._delegate.on_start_sync(ctx)
error: Exception | None = None
try:
yield from call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
self._delegate.on_end_sync(token, ctx)
self._delegate.on_end_sync(token, ctx, error)

def intercept_bidi_stream_sync(
self,
Expand All @@ -200,10 +214,14 @@ def intercept_bidi_stream_sync(
ctx: RequestContext,
) -> Iterator[RES]:
token = self._delegate.on_start_sync(ctx)
error: Exception | None = None
try:
yield from call_next(request, ctx)
except Exception as e:
error = e
raise
finally:
self._delegate.on_end_sync(token, ctx)
self._delegate.on_end_sync(token, ctx, error)


def resolve_interceptors(
Expand Down
Loading