Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions core/src/main/java/io/grpc/internal/DelayedClientCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public final void start(Listener<RespT> listener, final Metadata headers) {
savedError = error;
savedPassThrough = passThrough;
if (!savedPassThrough) {
listener = delayedListener = new DelayedListener<>(listener);
listener = delayedListener = new DelayedListener<>(this, listener);
startHeaders = headers;
}
}
Expand Down Expand Up @@ -445,15 +445,33 @@ public void runInContext() {
}

private static final class DelayedListener<RespT> extends Listener<RespT> {
private final DelayedClientCall<?, RespT> call;
private final Listener<RespT> realListener;
private volatile boolean passThrough;
private volatile Status exceptionStatus;
@GuardedBy("this")
private List<Runnable> pendingCallbacks = new ArrayList<>();

public DelayedListener(Listener<RespT> listener) {
public DelayedListener(DelayedClientCall<?, RespT> call, Listener<RespT> listener) {
this.call = call;
this.realListener = listener;
}

/**
* Cancels call and schedules onClose() notification. May only be called from within a
* DelayedListener callback dispatch (either queued drain or passThrough). Both phases
* deliver callbacks serially on the transport's callExecutor, so the write to
* {@code exceptionStatus} is serialized with, and thus visible to, subsequent listener
* callbacks on that executor.
*/
private void exceptionThrown(Throwable t, String description) {
// onClose() must be delivered exactly once and last. Other callbacks may already be queued
// ahead of realCall's eventual onClose, so we can't call onClose() here. We set the status
// and overwrite the onClose() details when it arrives.
exceptionStatus = Status.CANCELLED.withCause(t).withDescription(description);
call.cancel(description, t);
}

private void delayOrExecute(Runnable runnable) {
synchronized (this) {
if (!passThrough) {
Expand All @@ -467,55 +485,104 @@ private void delayOrExecute(Runnable runnable) {
@Override
public void onHeaders(final Metadata headers) {
if (passThrough) {
realListener.onHeaders(headers);
deliverHeaders(headers);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.onHeaders(headers);
deliverHeaders(headers);
}
});
}
}

private void deliverHeaders(Metadata headers) {
if (exceptionStatus != null) {
return;
}
try {
realListener.onHeaders(headers);
} catch (Throwable t) {
exceptionThrown(t, "Failed to read headers");
}
}

@Override
public void onMessage(final RespT message) {
if (passThrough) {
realListener.onMessage(message);
deliverMessage(message);
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.onMessage(message);
deliverMessage(message);
}
});
}
}

private void deliverMessage(RespT message) {
if (exceptionStatus != null) {
return;
}
try {
realListener.onMessage(message);
} catch (Throwable t) {
exceptionThrown(t, "Failed to read message.");
}
}

@Override
public void onClose(final Status status, final Metadata trailers) {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.onClose(status, trailers);
Status effectiveStatus = status;
Metadata effectiveTrailers = trailers;
if (exceptionStatus != null) {
// Ideally exceptionStatus == status, as exceptionStatus was passed to cancel().
// However the cancel is racy and this onClose may have already been queued when the
// cancellation occurred. Since other callbacks throw away data if exceptionStatus !=
// null, it is semantically essential that we _not_ use a status provided by the
// server.
effectiveStatus = exceptionStatus;
// Replace trailers to prevent mixing sources of status and trailers.
effectiveTrailers = new Metadata();
}
try {
realListener.onClose(effectiveStatus, effectiveTrailers);
} catch (RuntimeException ex) {
logger.log(Level.WARNING, "Exception thrown by onClose() in ClientCall", ex);
}
}
});
}

@Override
public void onReady() {
if (passThrough) {
realListener.onReady();
deliverOnReady();
} else {
delayOrExecute(new Runnable() {
@Override
public void run() {
realListener.onReady();
deliverOnReady();
}
});
}
}

private void deliverOnReady() {
if (exceptionStatus != null) {
return;
}
try {
realListener.onReady();
} catch (Throwable t) {
exceptionThrown(t, "Failed to call onReady.");
}
}

void drainPendingCallbacks() {
assert !passThrough;
List<Runnable> toRun = new ArrayList<>();
Expand All @@ -535,7 +602,6 @@ void drainPendingCallbacks() {
}
for (Runnable runnable : toRun) {
// Avoid calling listener while lock is held to prevent deadlocks.
// TODO(ejona): exception handling
runnable.run();
}
toRun.clear();
Expand Down
226 changes: 226 additions & 0 deletions core/src/test/java/io/grpc/internal/DelayedClientCallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,232 @@ public void delayedCallsRunUnderContext() throws Exception {
assertThat(contextKey.get(readyContext.get())).isEqualTo(goldenValue);
}

@Test
public void listenerThrowsInPendingCallback_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onMessage(Integer msg) {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
// Deliver onMessage while the wrapping DelayedListener is still buffering, by firing
// it from within realCall.start() — drainPendingCalls has not yet flipped the listener
// to pass-through. The queued onMessage is then drained and throws; the fix must catch
// the throwable and cancel the real call rather than let it escape.
Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall<String, Integer>(
mockRealCall) {
@Override
public void start(Listener<Integer> listener, Metadata metadata) {
super.start(listener, metadata);
listener.onMessage(42);
}
});
assertThat(r).isNotNull();
r.run(); // Must not propagate `boom`.
verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom));
}

@Test
public void listenerThrowsInPendingOnHeaders_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onHeaders(Metadata headers) {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall<String, Integer>(
mockRealCall) {
@Override
public void start(Listener<Integer> listener, Metadata metadata) {
super.start(listener, metadata);
listener.onHeaders(new Metadata());
}
});
assertThat(r).isNotNull();
r.run();
verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom));
}

@Test
public void listenerThrowsInPendingOnReady_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onReady() {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall<String, Integer>(
mockRealCall) {
@Override
public void start(Listener<Integer> listener, Metadata metadata) {
super.start(listener, metadata);
listener.onReady();
}
});
assertThat(r).isNotNull();
r.run();
verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom));
}

@Test
public void onCloseExceptionCaughtAndLogged() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
final AtomicReference<Status> observed = new AtomicReference<>();
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onClose(Status status, Metadata trailers) {
observed.set(status);
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(new SimpleForwardingClientCall<String, Integer>(
mockRealCall) {
@Override
public void start(Listener<Integer> listener, Metadata metadata) {
super.start(listener, metadata);
listener.onClose(Status.DATA_LOSS, new Metadata());
}
});
assertThat(r).isNotNull();
r.run(); // Must not propagate `boom`.
assertThat(observed.get().getCode()).isEqualTo(Status.Code.DATA_LOSS);
verify(mockRealCall, never()).cancel(any(), any());
}

@Test
public void listenerThrowsInPassThroughOnMessage_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onMessage(Integer msg) {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(mockRealCall);
assertThat(r).isNotNull();
r.run(); // drain completes, listener transitions to passThrough
@SuppressWarnings("unchecked")
ArgumentCaptor<Listener<Integer>> listenerCaptor = ArgumentCaptor.forClass(Listener.class);
verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class));
Listener<Integer> realCallListener = listenerCaptor.getValue();
realCallListener.onMessage(42); // dispatched on passThrough fast path
verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom));
}

@Test
public void listenerThrowsInPassThroughOnHeaders_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onHeaders(Metadata headers) {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(mockRealCall);
assertThat(r).isNotNull();
r.run();
@SuppressWarnings("unchecked")
ArgumentCaptor<Listener<Integer>> listenerCaptor = ArgumentCaptor.forClass(Listener.class);
verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class));
Listener<Integer> realCallListener = listenerCaptor.getValue();
realCallListener.onHeaders(new Metadata());
verify(mockRealCall).cancel(eq("Failed to read headers"), eq(boom));
}

@Test
public void listenerThrowsInPassThroughOnReady_cancelsRealCall() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onReady() {
throw boom;
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(mockRealCall);
assertThat(r).isNotNull();
r.run();
@SuppressWarnings("unchecked")
ArgumentCaptor<Listener<Integer>> listenerCaptor = ArgumentCaptor.forClass(Listener.class);
verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class));
Listener<Integer> realCallListener = listenerCaptor.getValue();
realCallListener.onReady();
verify(mockRealCall).cancel(eq("Failed to call onReady."), eq(boom));
}

@Test
public void listenerThrowsInPassThrough_subsequentCallbacksSwallowedAndOnCloseOverridden() {
DelayedClientCall<String, Integer> delayedClientCall = new DelayedClientCall<>(
callExecutor, fakeClock.getScheduledExecutorService(), null);
final RuntimeException boom = new RuntimeException("boom");
final AtomicReference<Integer> lastMessage = new AtomicReference<>();
final AtomicReference<Status> closeStatus = new AtomicReference<>();
final AtomicReference<Metadata> closeTrailers = new AtomicReference<>();
ClientCall.Listener<Integer> throwingListener = new ClientCall.Listener<Integer>() {
@Override
public void onMessage(Integer msg) {
lastMessage.set(msg);
if (msg == 1) {
throw boom;
}
}

@Override
public void onClose(Status status, Metadata trailers) {
closeStatus.set(status);
closeTrailers.set(trailers);
}
};
delayedClientCall.start(throwingListener, new Metadata());
Runnable r = delayedClientCall.setCall(mockRealCall);
assertThat(r).isNotNull();
r.run();
@SuppressWarnings("unchecked")
ArgumentCaptor<Listener<Integer>> listenerCaptor = ArgumentCaptor.forClass(Listener.class);
verify(mockRealCall).start(listenerCaptor.capture(), any(Metadata.class));
Listener<Integer> realCallListener = listenerCaptor.getValue();

realCallListener.onMessage(1); // throws -> exceptionStatus captured
assertThat(lastMessage.get()).isEqualTo(1);
verify(mockRealCall).cancel(eq("Failed to read message."), eq(boom));

// Later callbacks are swallowed — the listener must not see message 2.
realCallListener.onMessage(2);
assertThat(lastMessage.get()).isEqualTo(1);

// Transport onClose with OK must be overridden by the captured CANCELLED status.
Metadata serverTrailers = new Metadata();
serverTrailers.put(Metadata.Key.of("k", Metadata.ASCII_STRING_MARSHALLER), "v");
realCallListener.onClose(Status.OK, serverTrailers);
assertThat(closeStatus.get().getCode()).isEqualTo(Status.Code.CANCELLED);
assertThat(closeStatus.get().getCause()).isEqualTo(boom);
// Trailers replaced to avoid mixing sources.
assertThat(closeTrailers.get()).isNotSameInstanceAs(serverTrailers);
}

private void callMeMaybe(Runnable r) {
if (r != null) {
r.run();
Expand Down
Loading