Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile(
});
}

/**
* This method is equivalent to a while loop, where the condition is checked before each iteration.
* If the condition returns {@code false} on the first check, the body is never executed.
*
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
Comment on lines +250 to +251
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Javadoc @param descriptions are swapped relative to the method signature. The signature is (BooleanSupplier whileCheck, AsyncRunnable loopBodyRunnable), but the Javadoc lists loopBodyRunnable first and whileCheck second, which is misleading for IDE hover/docs.

Suggested change
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
* @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true
* @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop

Copilot uses AI. Check for mistakes.
* @return the composition of this and the looping branch
* @see AsyncCallbackLoop
*/
default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) {
return thenRun(finalCallback -> {
LoopState loopState = new LoopState();
new AsyncCallbackLoop(loopState, iterationCallback -> {

if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) {
return;
}
loopBodyRunnable.finish((result, t) -> {
if (t != null) {
iterationCallback.completeExceptionally(t);
return;
}
iterationCallback.complete(iterationCallback);
});

}).run(finalCallback);
});
}

/**
* This method is equivalent to a do-while loop, where the loop body is executed first and
* then the condition is checked to determine whether the loop should continue.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal.async;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.lang.Nullable;

/**
* A trampoline that converts recursive callback invocations into an iterative loop,
* preventing stack overflow in async loops.
*
* <p>When async loop iterations complete synchronously on the same thread, callback
* recursion occurs: each iteration's {@code callback.onResult()} immediately triggers
* the next iteration, causing unbounded stack growth. For example, a 1000-iteration
* loop would create > 1000 stack frames and cause {@code StackOverflowError}.</p>
*
* <p>The trampoline intercepts this recursion: instead of executing the next iteration
* immediately (which would deepen the stack), it enqueues the continuation and returns, allowing
* the stack to unwind. A flat loop at the top then processes enqueued continuation iteratively,
* maintaining constant stack depth regardless of iteration count.</p>
*
* <p>Since async chains are sequential, at most one task is pending at any time.
* The trampoline uses a single slot rather than a queue.</p>
*
* The first call on a thread becomes the "trampoline owner" and runs the drain loop.
* Subsequent (re-entrant) calls on the same thread enqueue their continuation and return immediately.</p>
*
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
@NotThreadSafe
public final class AsyncTrampoline {

private static final ThreadLocal<ContinuationHolder> TRAMPOLINE = new ThreadLocal<>();

private AsyncTrampoline() {}

/**
* Execute continuation through the trampoline. If no trampoline is active, become the owner
* and drain all enqueued continuations. If a trampoline is already active, enqueue and return.
*/
public static void run(final Runnable continuation) {
ContinuationHolder continuationHolder = TRAMPOLINE.get();
if (continuationHolder != null) {
continuationHolder.enqueue(continuation);
} else {
continuationHolder = new ContinuationHolder();
TRAMPOLINE.set(continuationHolder);
try {
continuation.run();
while (continuationHolder.continuation != null) {
Runnable continuationToRun = continuationHolder.continuation;
continuationHolder.continuation = null;
continuationToRun.run();
}
} finally {
TRAMPOLINE.remove();
}
}
}

/**
* A single-slot container for continuation.
* At most one continuation is pending at any time in a sequential async chain.
*/
@NotThreadSafe
private static final class ContinuationHolder {
@Nullable
private Runnable continuation;

void enqueue(final Runnable continuation) {
if (this.continuation != null) {
throw new AssertionError("Trampoline slot already occupied");
}
this.continuation = continuation;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.mongodb.internal.async.function;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.internal.async.AsyncTrampoline;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -62,9 +63,11 @@ public void run(final SingleResultCallback<Void> callback) {
@NotThreadSafe
private class LoopingCallback implements SingleResultCallback<Void> {
private final SingleResultCallback<Void> wrapped;
private final Runnable nextIteration;

LoopingCallback(final SingleResultCallback<Void> callback) {
wrapped = callback;
nextIteration = () -> body.run(this);
Copy link
Member Author

@vbabanin vbabanin Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nextIteration is reused to avoid creation of extra objects via LambdaMetafactory as we have the capturing lambda.

bounce.work = task is a write to a heap object's field, which can be considered an automatic escape in the JIT's analysis. Even if the Bounce object is short-lived, the JIT sees "object written to another object's field" and should give up.

The AsyncCallbackLoop JMH GC profiling (OpenJDK 17.0.10 LTS, 64-bit Server VM, mixed mode with compressed oops).

Metric Runnable (this) Lambda
Alloc rate 0.039 MB/sec 96.924 MB/sec
Alloc per op 64 B/op 160,048 B/op
GC count ~ 0 10
GC time 0 ms 9 ms

For Lambda case:
Per iteration: 1 lambda * 16 bytes = 16 B

  • Per op (10,000 iterations): 10,000 * 16 = 160,000 B
  • Plus one-time objects ~ 48 B

}

@Override
Expand All @@ -80,7 +83,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) {
return;
}
if (continueLooping) {
body.run(this);
AsyncTrampoline.run(nextIteration);
} else {
wrapped.onResult(result, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
import static org.junit.jupiter.api.Assertions.assertEquals;

abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
Expand Down Expand Up @@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() {
});
}

@Test
void testWhile() {
// last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
assertBehavesSameVariations(10,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(callback);
});
}

@Test
void testWhileWithThenRun() {
// while: last iteration: 3 < 3 = 1
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7
// 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10
// trailing sync: 1(exception) + 1(success) = 2
// 6(while exception) + 4(while success) * 2(trailing sync) = 14
assertBehavesSameVariations(14,
() -> {
int counter = 0;
while (counter < 3 && plainTest(counter)) {
counter++;
sync(counter);
}
sync(counter + 1);
},
(callback) -> {
MutableValue<Integer> counter = new MutableValue<>(0);
beginAsync().thenRun(c -> {
beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> {
counter.set(counter.get() + 1);
async(counter.get(), c2);
}).finish(c);
}).thenRun(c -> {
async(counter.get() + 1, c);
}).finish(callback);
});
}

@Test
void testNestedWhileLoops() {
// inner while: 4 success + 6 exception = 10
// last inner iteration: 3 < 3 = 1
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56
// 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232
assertBehavesSameVariations(232,
() -> {
int outer = 0;
while (outer < 3 && plainTest(outer)) {
int inner = 0;
while (inner < 3 && plainTest(inner)) {
sync(outer + inner);
inner++;
}
outer++;
}
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(
() -> inner.get() < 3 && plainTest(inner.get()),
c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}).finish(callback);
});
}

@Test
void testWhileLoopStackConstant() {
int depthWith100 = maxStackDepthForIterations(100);
int depthWith10000 = maxStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)");
}
Comment on lines +823 to +826
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These stack-depth assertions require exact equality between two runs (assertEquals(depthWith100, depthWith10000)). Measuring stack depth via Thread.currentThread().getStackTrace() can vary slightly due to JVM/JIT/runtime differences, making this potentially flaky even if depth is effectively constant. Consider asserting that the depth does not grow with iteration count (e.g., depthWith10000 <= depthWith100 + smallDelta, or compare ratios / a fixed upper bound) rather than strict equality.

Copilot uses AI. Check for mistakes.

private int maxStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}).finish((v, t) -> {});

assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testRetryLoop() {
assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1,
Expand Down Expand Up @@ -768,6 +883,65 @@ void testDoWhileLoop() {
});
}

@Test
void testNestedDoWhileLoops() {
// inner do-while: 3 success + 5 exception = 8
// last outer iteration: 3 < 3 = 1
// 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration)) = 35
// 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration)) = 116
assertBehavesSameVariations(116,
() -> {
int outer = 0;
do {
int inner = 0;
do {
sync(outer + inner);
inner++;
} while (inner < 3 && plainTest(inner));
outer++;
} while (outer < 3 && plainTest(outer));
},
(callback) -> {
MutableValue<Integer> outer = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
MutableValue<Integer> inner = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c2 -> {
beginAsync().thenRun(c3 -> {
async(outer.get() + inner.get(), c3);
}).thenRun(c3 -> {
inner.set(inner.get() + 1);
c3.complete(c3);
}).finish(c2);
}, () -> inner.get() < 3 && plainTest(inner.get())
).thenRun(c2 -> {
outer.set(outer.get() + 1);
c2.complete(c2);
}).finish(c);
}, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback);
});
}

@Test
void testDoWhileLoopStackConstant() {
int depthWith100 = maxDoWhileStackDepthForIterations(100);
int depthWith10000 = maxDoWhileStackDepthForIterations(10_000);
assertEquals(depthWith100, depthWith10000,
"Stack depth should be constant regardless of iteration count");
Comment on lines +929 to +930
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern as the while-loop stack-depth test: strict equality on stack trace depth (assertEquals(depthWith100, depthWith10000)) can be brittle across JVMs/runs. Prefer an assertion that verifies non-growth with iterations (bounded delta / upper bound) to avoid flakes while still validating the trampoline behavior.

Suggested change
assertEquals(depthWith100, depthWith10000,
"Stack depth should be constant regardless of iteration count");
org.junit.jupiter.api.Assertions.assertTrue(
depthWith10000 <= depthWith100,
"Stack depth with more iterations should not exceed that with fewer iterations");

Copilot uses AI. Check for mistakes.
}

private int maxDoWhileStackDepthForIterations(final int iterations) {
MutableValue<Integer> counter = new MutableValue<>(0);
MutableValue<Integer> maxDepth = new MutableValue<>(0);
beginAsync().thenRunDoWhileLoop(c -> {
maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length));
counter.set(counter.get() + 1);
c.complete(c);
}, () -> counter.get() < iterations).finish((v, t) -> {});
assertEquals(iterations, counter.get());
return maxDepth.get();
}

@Test
void testFinallyWithPlainInsideTry() {
// (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import static java.lang.String.format;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -272,14 +273,16 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
}

assertTrue(wasCalledFuture.isDone(), "callback should have been called");
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());
assertEquals(expectedException == null, actualException.get() == null,
"both or neither should have produced an exception");
format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s",
expectedException,
actualException));
Comment on lines +277 to +279
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion message formats actualException (the AtomicReference) instead of the stored throwable (actualException.get()), so failures will print something like AtomicReference@... rather than the actual exception. Pass actualException.get() to format (and consider also including expectedException.getClass()/actualException.get().getClass() if helpful).

Suggested change
format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s",
expectedException,
actualException));
format("both or neither should have produced an exception. "
+ "Expected exception: %s (%s), actual exception: %s (%s)",
expectedException,
expectedException == null ? null : expectedException.getClass(),
actualException.get(),
actualException.get() == null ? null : actualException.get().getClass()));

Copilot uses AI. Check for mistakes.
if (expectedException != null) {
assertEquals(expectedException.getMessage(), actualException.get().getMessage());
assertEquals(expectedException.getClass(), actualException.get().getClass());
}
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
assertEquals(expectedValue, actualValue.get());

listener.clear();
}
Expand Down
Loading