diff --git a/core/src/main/java/io/grpc/internal/DelayedClientCall.java b/core/src/main/java/io/grpc/internal/DelayedClientCall.java index b568bb12c46..36b5a36eb8e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientCall.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientCall.java @@ -206,7 +206,7 @@ public final void start(Listener listener, final Metadata headers) { savedError = error; savedPassThrough = passThrough; if (!savedPassThrough) { - listener = delayedListener = new DelayedListener<>(listener); + listener = delayedListener = new DelayedListener<>(this, listener); startHeaders = headers; } } @@ -445,15 +445,33 @@ public void runInContext() { } private static final class DelayedListener extends Listener { + private final DelayedClientCall call; private final Listener realListener; private volatile boolean passThrough; + private volatile Status exceptionStatus; @GuardedBy("this") private List pendingCallbacks = new ArrayList<>(); - public DelayedListener(Listener listener) { + public DelayedListener(DelayedClientCall call, Listener listener) { + this.call = call; this.realListener = listener; } + /** + * Cancels call and schedules onClose() notification. May only be called from within a + * DelayedListener callback dispatch (either queued drain or passThrough). Both phases + * deliver callbacks serially on the transport's callExecutor, so the write to + * {@code exceptionStatus} is serialized with, and thus visible to, subsequent listener + * callbacks on that executor. + */ + private void exceptionThrown(Throwable t, String description) { + // onClose() must be delivered exactly once and last. Other callbacks may already be queued + // ahead of realCall's eventual onClose, so we can't call onClose() here. We set the status + // and overwrite the onClose() details when it arrives. + exceptionStatus = Status.CANCELLED.withCause(t).withDescription(description); + call.cancel(description, t); + } + private void delayOrExecute(Runnable runnable) { synchronized (this) { if (!passThrough) { @@ -467,37 +485,75 @@ private void delayOrExecute(Runnable runnable) { @Override public void onHeaders(final Metadata headers) { if (passThrough) { - realListener.onHeaders(headers); + deliverHeaders(headers); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onHeaders(headers); + deliverHeaders(headers); } }); } } + private void deliverHeaders(Metadata headers) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onHeaders(headers); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read headers"); + } + } + @Override public void onMessage(final RespT message) { if (passThrough) { - realListener.onMessage(message); + deliverMessage(message); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onMessage(message); + deliverMessage(message); } }); } } + private void deliverMessage(RespT message) { + if (exceptionStatus != null) { + return; + } + try { + realListener.onMessage(message); + } catch (Throwable t) { + exceptionThrown(t, "Failed to read message."); + } + } + @Override public void onClose(final Status status, final Metadata trailers) { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onClose(status, trailers); + Status effectiveStatus = status; + Metadata effectiveTrailers = trailers; + if (exceptionStatus != null) { + // Ideally exceptionStatus == status, as exceptionStatus was passed to cancel(). + // However the cancel is racy and this onClose may have already been queued when the + // cancellation occurred. Since other callbacks throw away data if exceptionStatus != + // null, it is semantically essential that we _not_ use a status provided by the + // server. + effectiveStatus = exceptionStatus; + // Replace trailers to prevent mixing sources of status and trailers. + effectiveTrailers = new Metadata(); + } + try { + realListener.onClose(effectiveStatus, effectiveTrailers); + } catch (RuntimeException ex) { + logger.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex); + } } }); } @@ -505,17 +561,28 @@ public void run() { @Override public void onReady() { if (passThrough) { - realListener.onReady(); + deliverOnReady(); } else { delayOrExecute(new Runnable() { @Override public void run() { - realListener.onReady(); + deliverOnReady(); } }); } } + private void deliverOnReady() { + if (exceptionStatus != null) { + return; + } + try { + realListener.onReady(); + } catch (Throwable t) { + exceptionThrown(t, "Failed to call onReady."); + } + } + void drainPendingCallbacks() { assert !passThrough; List toRun = new ArrayList<>(); @@ -535,7 +602,6 @@ void drainPendingCallbacks() { } for (Runnable runnable : toRun) { // Avoid calling listener while lock is held to prevent deadlocks. - // TODO(ejona): exception handling runnable.run(); } toRun.clear(); diff --git a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java index ff131d29975..0d30e947b0c 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientCallTest.java @@ -229,6 +229,232 @@ public void delayedCallsRunUnderContext() throws Exception { assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue); } + @Test + public void listenerThrowsInPendingCallback_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + // Deliver onMessage while the wrapping DelayedListener is still buffering, by firing + // it from within realCall.start() — drainPendingCalls has not yet flipped the listener + // to pass-through. The queued onMessage is then drained and throws; the fix must catch + // the throwable and cancel the real call rather than let it escape. + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onMessage(42); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onHeaders(new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPendingOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onReady(); + } + }); + assertThat(r).isNotNull(); + r.run(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void onCloseExceptionCaughtAndLogged() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference observed = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onClose(Status status, Metadata trailers) { + observed.set(status); + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall( + mockRealCall) { + @Override + public void start(Listener listener, Metadata metadata) { + super.start(listener, metadata); + listener.onClose(Status.DATA_LOSS, new Metadata()); + } + }); + assertThat(r).isNotNull(); + r.run(); // Must not propagate `boom`. + assertThat(observed.get().getCode()).isEqualTo(Status.Code.DATA_LOSS); + verify(mockRealCall, never()).cancel(any(), any()); + } + + @Test + public void listenerThrowsInPassThroughOnMessage_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); // drain completes, listener transitions to passThrough + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onMessage(42); // dispatched on passThrough fast path + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnHeaders_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onHeaders(Metadata headers) { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onHeaders(new Metadata()); + verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom)); + } + + @Test + public void listenerThrowsInPassThroughOnReady_cancelsRealCall() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onReady() { + throw boom; + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + realCallListener.onReady(); + verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom)); + } + + @Test + public void listenerThrowsInPassThrough_subsequentCallbacksSwallowedAndOnCloseOverridden() { + DelayedClientCall delayedClientCall = new DelayedClientCall<>( + callExecutor, fakeClock.getScheduledExecutorService(), null); + final RuntimeException boom = new RuntimeException("boom"); + final AtomicReference lastMessage = new AtomicReference<>(); + final AtomicReference closeStatus = new AtomicReference<>(); + final AtomicReference closeTrailers = new AtomicReference<>(); + ClientCall.Listener throwingListener = new ClientCall.Listener() { + @Override + public void onMessage(Integer msg) { + lastMessage.set(msg); + if (msg == 1) { + throw boom; + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + closeStatus.set(status); + closeTrailers.set(trailers); + } + }; + delayedClientCall.start(throwingListener, new Metadata()); + Runnable r = delayedClientCall.setCall(mockRealCall); + assertThat(r).isNotNull(); + r.run(); + @SuppressWarnings("unchecked") + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(Listener.class); + verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class)); + Listener realCallListener = listenerCaptor.getValue(); + + realCallListener.onMessage(1); // throws -> exceptionStatus captured + assertThat(lastMessage.get()).isEqualTo(1); + verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom)); + + // Later callbacks are swallowed — the listener must not see message 2. + realCallListener.onMessage(2); + assertThat(lastMessage.get()).isEqualTo(1); + + // Transport onClose with OK must be overridden by the captured CANCELLED status. + Metadata serverTrailers = new Metadata(); + serverTrailers.put(Metadata.Key.of("k", Metadata.ASCII_STRING_MARSHALLER), "v"); + realCallListener.onClose(Status.OK, serverTrailers); + assertThat(closeStatus.get().getCode()).isEqualTo(Status.Code.CANCELLED); + assertThat(closeStatus.get().getCause()).isEqualTo(boom); + // Trailers replaced to avoid mixing sources. + assertThat(closeTrailers.get()).isNotSameInstanceAs(serverTrailers); + } + private void callMeMaybe(Runnable r) { if (r != null) { r.run();