From 8e59f3e8a26caac954b045dcf22c3daa16b32b30 Mon Sep 17 00:00:00 2001 From: Jacob Nowjack Date: Thu, 16 Apr 2026 17:01:57 -0600 Subject: [PATCH] Cancel DelayedClientCall when application listener throws Align DelayedClientCall.DelayedListener with ClientCallImpl's existing behavior for listener exceptions. When the application listener throws from onHeaders/onMessage/onReady, catch the Throwable, cancel the call with CANCELLED (cause = the throwable), and swallow subsequent callbacks. When onClose throws, log and continue, matching ClientCallImpl.closeObserver. If onClose arrives from the transport after a prior callback threw, override its status/trailers with the captured CANCELLED so a server-supplied OK can't mask the local failure. Previously, a throw from the application listener escaped to the callExecutor's uncaught-exception handler. The real call was not cancelled and the transport kept delivering callbacks to an already broken listener, different from how the same bug behaves on a normal ClientCallImpl, and a timing-dependent inconsistency depending on whether callbacks arrived before or after setCall + drain completed. Trade-off: listener-callback throws are no longer visible to the executor's UncaughtExceptionHandler (they're attached as Status.cause instead). This matches ClientCallImpl and is the intended behavior. Exception handling for the outer drainPendingCalls loop (realCall.sendMessage/request/halfClose/cancel) remains unaddressed; that TODO is preserved. --- .../io/grpc/internal/DelayedClientCall.java | 86 ++++++- .../grpc/internal/DelayedClientCallTest.java | 226 ++++++++++++++++++ 2 files changed, 302 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index b568bb12c46..36b5a36eb8e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -206,7 +206,7 @@ public final void start(Listener listener, final Metadata headers) { savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { - listener = delayedListener = new DelayedListener<>(listener); + listener = delayedListener = new DelayedListener<>(this, listener); startHeaders = headers; } } @@ -445,15 +445,33 @@ public void runInContext() { } private static final class DelayedListener extends Listener { + private final DelayedClientCall call; private final Listener realListener; private volatile boolean passThrough; + private volatile Status exceptionStatus; @GuardedBy("this") private List pendingCallbacks = new ArrayList<>(); - public DelayedListener(Listener listener) { + public DelayedListener(DelayedClientCall call, Listener listener) { + this.call = call; this.realListener = listener; } + /** + * Cancels call and schedules onClose() notification. May only be called from within a + * DelayedListener callback dispatch (either queued drain or passThrough). Both phases + * deliver callbacks serially on the transport's callExecutor, so the write to + * {@code exceptionStatus} is serialized with, and thus visible to, subsequent listener + * callbacks on that executor. + */ + private void exceptionThrown(Throwable t, String description) { + // onClose() must be delivered exactly once and last. Other callbacks may already be queued + // ahead of realCall's eventual onClose, so we can't call onClose() here. We set the status + // and overwrite the onClose() details when it arrives. + exceptionStatus = Status.CANCELLED.withCause(t).withDescription(description); + call.cancel(description, t); + } + private void delayOrExecute(Runnable runnable) { synchronized (this) { if (!passThrough) { @@ -467,37 +485,75 @@ private void delayOrExecute(Runnable runnable) { @Override public void onHeaders(final Metadata headers) { if (passThrough) { - realListener.onHeaders(headers); + deliverHeaders(headers); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onHeaders(headers); + deliverHeaders(headers); } }); } } + private void deliverHeaders(Metadata headers) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onHeaders(headers); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read headers"); + } + } + @Override public void onMessage(final RespT message) { if (passThrough) { - realListener.onMessage(message); + deliverMessage(message); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onMessage(message); + deliverMessage(message); } }); } } + private void deliverMessage(RespT message) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onMessage(message); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read message."); + } + } + @Override public void onClose(final Status status, final Metadata trailers) { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onClose(status, trailers); + Status effectiveStatus = status; + Metadata effectiveTrailers = trailers; + if (exceptionStatus != null) { + // Ideally exceptionStatus == status, as exceptionStatus was passed to cancel(). + // However the cancel is racy and this onClose may have already been queued when the + // cancellation occurred. Since other callbacks throw away data if exceptionStatus != + // null, it is semantically essential that we _not_ use a status provided by the + // server. + effectiveStatus = exceptionStatus; + // Replace trailers to prevent mixing sources of status and trailers. + effectiveTrailers = new Metadata(); + } + try { + realListener.onClose(effectiveStatus, effectiveTrailers); + } catch (RuntimeException ex) { + logger.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } }); } @@ -505,17 +561,28 @@ public void run() { @Override public void onReady() { if (passThrough) { - realListener.onReady(); + deliverOnReady(); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onReady(); + deliverOnReady(); } }); } } + private void deliverOnReady() { + if (exceptionStatus != null) { + return; + } + try { + realListener.onReady(); + } catch (Throwable t) { + exceptionThrown(t, "Failed to call onReady."); + } + } + void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); @@ -535,7 +602,6 @@ void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling runnable.run(); } toRun.clear(); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index ff131d29975..0d30e947b0c 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -229,6 +229,232 @@ public void delayedCallsRunUnderContext() throws Exception { assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); } + @Test + public void listenerThrowsInPendingCallback_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + // Deliver onMessage while the wrapping DelayedListener is still buffering, by firing + // it from within realCall.start() — drainPendingCalls has not yet flipped the listener + // to pass-through. The queued onMessage is then drained and throws; the fix must catch + // the throwable and cancel the real call rather than let it escape. + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onMessage(42); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onHeaders(new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onReady(); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference observed = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onClose(Status status, Metadata trailers) { + observed.set(status); + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onClose(Status.DATA_LOSS, new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + assertThat(observed.get().getCode()).isEqualTo(Status.Code.DATA_LOSS); + verify(mockRealCall, never()).cancel(any(), any()); + } + + @Test + public void listenerThrowsInPassThroughOnMessage_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); // drain completes, listener transitions to passThrough + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onMessage(42); // dispatched on passThrough fast path + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onHeaders(new Metadata()); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onReady(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThrough_subsequentCallbacksSwallowedAndOnCloseOverridden() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference lastMessage = new AtomicReference<>(); + final AtomicReference closeStatus = new AtomicReference<>(); + final AtomicReference closeTrailers = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + lastMessage.set(msg); + if (msg == 1) { + throw boom; + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + closeStatus.set(status); + closeTrailers.set(trailers); + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + + realCallListener.onMessage(1); // throws -> exceptionStatus captured + assertThat(lastMessage.get()).isEqualTo(1); + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + + // Later callbacks are swallowed — the listener must not see message 2. + realCallListener.onMessage(2); + assertThat(lastMessage.get()).isEqualTo(1); + + // Transport onClose with OK must be overridden by the captured CANCELLED status. + Metadata serverTrailers = new Metadata(); + serverTrailers.put(Metadata.Key.of("k", Metadata.ASCII_STRING_MARSHALLER), "v"); + realCallListener.onClose(Status.OK, serverTrailers); + assertThat(closeStatus.get().getCode()).isEqualTo(Status.Code.CANCELLED); + assertThat(closeStatus.get().getCause()).isEqualTo(boom); + // Trailers replaced to avoid mixing sources. + assertThat(closeTrailers.get()).isNotSameInstanceAs(serverTrailers); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run();