From ac409fb8d4237c25bfd3fd8eb765e3cdb129381a Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 12 Feb 2026 15:37:22 +0530 Subject: [PATCH 1/7] unwrap Forwarding subchannel --- api/src/main/java/io/grpc/LoadBalancer.java | 22 +++++ .../test/java/io/grpc/LoadBalancerTest.java | 20 +++++ .../HealthCheckingLoadBalancerFactory.java | 24 +++++ .../io/grpc/util/HealthProducerHelper.java | 26 ++++++ .../util/OutlierDetectionLoadBalancer.java | 11 ++- .../OutlierDetectionLoadBalancerTest.java | 2 +- .../io/grpc/xds/ClusterImplLoadBalancer.java | 87 +++++++++++-------- .../xds/WeightedRoundRobinLoadBalancer.java | 11 ++- .../java/io/grpc/xds/orca/OrcaOobUtil.java | 27 ++++++ .../WeightedRoundRobinLoadBalancerTest.java | 39 ++++++--- 10 files changed, 214 insertions(+), 55 deletions(-) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index adc43b19841..c3d52116471 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -661,6 +661,28 @@ public static PickResult withSubchannel(Subchannel subchannel) { return withSubchannel(subchannel, null); } + /** + * Creates a new {@code PickResult} with the given {@code subchannel}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult withSubchannelReplacement(Subchannel subchannel) { + return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory, + status, drop, authorityOverride); + } + + /** + * Creates a new {@code PickResult} with the given {@code streamTracerFactory}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult withStreamTracerFactory( + @Nullable ClientStreamTracer.Factory streamTracerFactory) { + return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride); + } + /** * A decision to report a connectivity error to the RPC. If the RPC is {@link * CallOptions#withWaitForReady wait-for-ready}, it will stay buffered. Otherwise, it will fail diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index 5e9e5cbe816..2aa0585c18c 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -64,6 +64,26 @@ public void pickResult_withSubchannelAndTracer() { assertThat(result.isDrop()).isFalse(); } + @Test + public void pickResult_withSubchannelReplacement() { + PickResult result = PickResult.withSubchannel(subchannel, tracerFactory) + .withSubchannelReplacement(subchannel2); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + + @Test + public void pickResult_withStreamTracerFactory() { + PickResult result = PickResult.withSubchannel(subchannel) + .withStreamTracerFactory(tracerFactory); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + @Test public void pickResult_withNoResult() { PickResult result = PickResult.withNoResult(); diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java index 8cf1458f5dc..ce6d2e70eae 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java @@ -144,6 +144,30 @@ void setHealthCheckedService(@Nullable String service) { public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); } + + @Override + public void updateBalancingState( + io.grpc.ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new HealthCheckPicker(newPicker)); + } + + private final class HealthCheckPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthCheckPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.withSubchannelReplacement(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } } @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/HealthProducerHelper.java b/util/src/main/java/io/grpc/util/HealthProducerHelper.java index b11864765ea..7913c63d3ad 100644 --- a/util/src/main/java/io/grpc/util/HealthProducerHelper.java +++ b/util/src/main/java/io/grpc/util/HealthProducerHelper.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Internal; import io.grpc.LoadBalancer; @@ -84,6 +85,31 @@ protected LoadBalancer.Helper delegate() { return delegate; } + @Override + public void updateBalancingState( + ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new HealthProducerPicker(newPicker)); + } + + private static final class HealthProducerPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthProducerPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof HealthProducerSubchannel) { + return result.withSubchannelReplacement( + ((HealthProducerSubchannel) subchannel).delegate()); + } + return result; + } + } + // The parent subchannel in the health check producer LB chain. It duplicates subchannel state to // both the state listener and health listener. @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index d72a85012f2..ddf29d2866b 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -442,9 +442,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = pickResult.getSubchannel(); if (subchannel != null) { - return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory( - subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY), - pickResult.getStreamTracerFactory())); + EndpointTracker tracker = subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY); + if (subchannel instanceof OutlierDetectionSubchannel) { + subchannel = ((OutlierDetectionSubchannel) subchannel).delegate(); + } + return pickResult.withSubchannelReplacement(subchannel) + .withStreamTracerFactory(new ResultCountingClientStreamTracerFactory( + tracker, + pickResult.getStreamTracerFactory())); } return pickResult; diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 10436407422..39f5b5fb7d6 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -408,7 +408,7 @@ public void delegatePick() throws Exception { // Make sure that we can pick the single READY subchannel. SubchannelPicker picker = pickerCaptor.getAllValues().get(2); PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); - Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); + Subchannel s = pickResult.getSubchannel(); if (s instanceof HealthProducerHelper.HealthProducerSubchannel) { s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate(); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index ec4bec7f25c..a3c038074df 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -252,42 +252,55 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); final Subchannel subchannel = delegate().createSubchannel(args); - return new ForwardingSubchannel() { - @Override - public void start(SubchannelStateListener listener) { - delegate().start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - // Do nothing if LB has been shutdown - if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { - // Get locality based on the connected address attributes - ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes( - subchannel.getConnectedAddressAttributes()); - ClusterLocality oldClusterLocality = localityAtomicReference - .getAndSet(updatedClusterLocality); - oldClusterLocality.release(); + return new ClusterImplSubchannel(subchannel, localityAtomicReference); + } + + private final class ClusterImplSubchannel extends ForwardingSubchannel { + private final Subchannel delegate; + private final AtomicReference localityAtomicReference; + + private ClusterImplSubchannel( + Subchannel delegate, AtomicReference localityAtomicReference) { + this.delegate = delegate; + this.localityAtomicReference = localityAtomicReference; + } + + @Override + public void start(SubchannelStateListener listener) { + delegate().start( + new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + // Do nothing if LB has been shutdown + if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { + // Get locality based on the connected address attributes + ClusterLocality updatedClusterLocality = + createClusterLocalityFromAttributes( + delegate.getConnectedAddressAttributes()); + ClusterLocality oldClusterLocality = + localityAtomicReference.getAndSet(updatedClusterLocality); + oldClusterLocality.release(); + } + listener.onSubchannelState(newState); } - listener.onSubchannelState(newState); - } - }); - } + }); + } - @Override - public void shutdown() { - localityAtomicReference.get().release(); - delegate().shutdown(); - } + @Override + public void shutdown() { + localityAtomicReference.get().release(); + delegate().shutdown(); + } - @Override - public void updateAddresses(List addresses) { - delegate().updateAddresses(withAdditionalAttributes(addresses)); - } + @Override + public void updateAddresses(List addresses) { + delegate().updateAddresses(withAdditionalAttributes(addresses)); + } - @Override - protected Subchannel delegate() { - return subchannel; - } - }; + @Override + protected Subchannel delegate() { + return delegate; + } } private List withAdditionalAttributes( @@ -411,6 +424,13 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } } PickResult result = delegate.pickSubchannel(args); + Subchannel subchannel = result.getSubchannel(); + if (subchannel != null) { + if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { + subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); + result = result.withSubchannelReplacement(subchannel); + } + } if (result.getStatus().isOk() && result.getSubchannel() != null) { if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { @@ -437,8 +457,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { stats, inFlights, result.getStreamTracerFactory()); ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - result = PickResult.withSubchannel(result.getSubchannel(), - orcaTracerFactory); + result = result.withStreamTracerFactory(orcaTracerFactory); } } if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 6cf3189d587..2c123512a7a 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -508,12 +508,15 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (subchannel == null) { return pickResult; } + + subchannel = ((WrrSubchannel) subchannel).delegate(); if (!enableOobLoadReport) { - return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - reportListeners.get(pick))); + return pickResult.withSubchannelReplacement(subchannel) + .withStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + reportListeners.get(pick))); } else { - return PickResult.withSubchannel(subchannel); + return pickResult.withSubchannelReplacement(subchannel); } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index 9ac06d362fc..f02ae639f2f 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -36,12 +36,16 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; @@ -236,6 +240,29 @@ protected Helper delegate() { return delegate; } + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); + } + + private static final class OrcaOobPicker extends SubchannelPicker { + private final SubchannelPicker delegate; + + OrcaOobPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult result = delegate.pickSubchannel(args); + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.withSubchannelReplacement(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } + @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 9fac46eaf09..72717dc753b 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -171,6 +171,19 @@ public WeightedRoundRobinLoadBalancerTest() { helper = mock(Helper.class, delegatesTo(testHelperInstance)); } + private static WeightedRoundRobinPicker getWrrPicker(SubchannelPicker picker) { + if (picker.getClass().getName().endsWith("OrcaOobPicker")) { + try { + java.lang.reflect.Field f = picker.getClass().getDeclaredField("delegate"); + f.setAccessible(true); + return (WeightedRoundRobinPicker) f.get(picker); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return (WeightedRoundRobinPicker) picker; + } + @Before public void setup() { for (int i = 0; i < 3; i++) { @@ -213,7 +226,7 @@ public void pickChildLbTF() throws Exception { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); final WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getValue(); + getWrrPicker(pickerCaptor.getValue()); weightedPicker.pickSubchannel(mockArgs); } @@ -274,9 +287,9 @@ public void wrrLifeCycle() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); @@ -337,7 +350,7 @@ public void enableOobLoadReportConfig() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -361,7 +374,7 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); + weightedPicker = getWrrPicker(pickerCaptor2.getAllValues().get(2)); pickResult = weightedPicker.pickSubchannel(mockArgs); assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNull(); @@ -395,7 +408,7 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); @@ -595,7 +608,7 @@ public void blackoutPeriod() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -655,9 +668,9 @@ public void updateWeightTimer() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); @@ -710,7 +723,7 @@ public void weightExpired() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -761,7 +774,7 @@ public void rrFallback() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); Map qpsByChannel = ImmutableMap.of(servers.get(0), 2, @@ -816,7 +829,7 @@ public void unknownWeightIsAvgWeight() { verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -857,7 +870,7 @@ public void pickFromOtherThread() throws Exception { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( From 7ed0cc949a77513036aea982845291e373ab793b Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 19 Feb 2026 00:52:20 +0530 Subject: [PATCH 2/7] address comments --- .../io/grpc/xds/ClusterImplLoadBalancer.java | 6 ++-- .../java/io/grpc/xds/orca/OrcaOobUtil.java | 4 +-- .../WeightedRoundRobinLoadBalancerTest.java | 21 ++++------- .../io/grpc/xds/orca/OrcaOobUtilAccessor.java | 35 +++++++++++++++++++ 4 files changed, 45 insertions(+), 21 deletions(-) create mode 100644 xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index a3c038074df..4191e61721d 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -424,14 +424,12 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } } PickResult result = delegate.pickSubchannel(args); - Subchannel subchannel = result.getSubchannel(); - if (subchannel != null) { + if (result.getStatus().isOk() && result.getSubchannel() != null) { + Subchannel subchannel = result.getSubchannel(); if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); result = result.withSubchannelReplacement(subchannel); } - } - if (result.getStatus().isOk() && result.getSubchannel() != null) { if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { if (dropStats != null) { diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index f02ae639f2f..e20b4e15fe6 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -245,8 +245,8 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); } - private static final class OrcaOobPicker extends SubchannelPicker { - private final SubchannelPicker delegate; + static final class OrcaOobPicker extends SubchannelPicker { + final SubchannelPicker delegate; OrcaOobPicker(SubchannelPicker delegate) { this.delegate = delegate; diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 72717dc753b..0b026b06f6f 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -85,6 +85,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.orca.OrcaOobUtilAccessor; import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; @@ -172,16 +173,7 @@ public WeightedRoundRobinLoadBalancerTest() { } private static WeightedRoundRobinPicker getWrrPicker(SubchannelPicker picker) { - if (picker.getClass().getName().endsWith("OrcaOobPicker")) { - try { - java.lang.reflect.Field f = picker.getClass().getDeclaredField("delegate"); - f.setAccessible(true); - return (WeightedRoundRobinPicker) f.get(picker); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - return (WeightedRoundRobinPicker) picker; + return (WeightedRoundRobinPicker) OrcaOobUtilAccessor.getDelegate(picker); } @Before @@ -225,9 +217,8 @@ public void pickChildLbTF() throws Exception { .forTransientFailure(Status.UNAVAILABLE)); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - final WeightedRoundRobinPicker weightedPicker = - getWrrPicker(pickerCaptor.getValue()); - weightedPicker.pickSubchannel(mockArgs); + final SubchannelPicker picker = pickerCaptor.getValue(); + picker.pickSubchannel(mockArgs); } @Test @@ -374,8 +365,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = getWrrPicker(pickerCaptor2.getAllValues().get(2)); - pickResult = weightedPicker.pickSubchannel(mockArgs); + SubchannelPicker rawPicker = pickerCaptor2.getAllValues().get(2); + pickResult = rawPicker.pickSubchannel(mockArgs); assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java new file mode 100644 index 00000000000..8acb14f323b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import io.grpc.LoadBalancer; + +/** + * Accessor for white-box testing involving OrcaOobUtil. + */ +public final class OrcaOobUtilAccessor { + private OrcaOobUtilAccessor() { + // Do not instantiate + } + + public static LoadBalancer.SubchannelPicker getDelegate(LoadBalancer.SubchannelPicker picker) { + if (picker instanceof OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) { + return ((OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) picker).delegate; + } + return picker; + } +} From 9941ef53e473f73b681e5a65f1250ee8be4d8ce6 Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 19 Feb 2026 00:56:27 +0530 Subject: [PATCH 3/7] use 2026 in Copyright --- xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java index 8acb14f323b..db9168dd08e 100644 --- a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 The gRPC Authors + * Copyright 2026 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From bafa9a11349a443e5517c2474d7817f93bff0ed0 Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 19 Feb 2026 13:36:31 +0530 Subject: [PATCH 4/7] rename APIs --- api/src/main/java/io/grpc/LoadBalancer.java | 4 ++-- api/src/test/java/io/grpc/LoadBalancerTest.java | 4 ++-- .../services/HealthCheckingLoadBalancerFactory.java | 2 +- util/src/main/java/io/grpc/util/HealthProducerHelper.java | 2 +- .../java/io/grpc/util/OutlierDetectionLoadBalancer.java | 4 ++-- xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java | 4 ++-- .../java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java | 6 +++--- xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java | 2 +- 8 files changed, 14 insertions(+), 14 deletions(-) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index c3d52116471..0f3e03feb14 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -667,7 +667,7 @@ public static PickResult withSubchannel(Subchannel subchannel) { * * @since 1.80.0 */ - public PickResult withSubchannelReplacement(Subchannel subchannel) { + public PickResult copyWithSubchannel(Subchannel subchannel) { return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory, status, drop, authorityOverride); } @@ -678,7 +678,7 @@ public PickResult withSubchannelReplacement(Subchannel subchannel) { * * @since 1.80.0 */ - public PickResult withStreamTracerFactory( + public PickResult copyWithStreamTracerFactory( @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride); } diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index 2aa0585c18c..22fdc220081 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -67,7 +67,7 @@ public void pickResult_withSubchannelAndTracer() { @Test public void pickResult_withSubchannelReplacement() { PickResult result = PickResult.withSubchannel(subchannel, tracerFactory) - .withSubchannelReplacement(subchannel2); + .copyWithSubchannel(subchannel2); assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2); assertThat(result.getStatus()).isSameInstanceAs(Status.OK); assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); @@ -77,7 +77,7 @@ public void pickResult_withSubchannelReplacement() { @Test public void pickResult_withStreamTracerFactory() { PickResult result = PickResult.withSubchannel(subchannel) - .withStreamTracerFactory(tracerFactory); + .copyWithStreamTracerFactory(tracerFactory); assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); assertThat(result.getStatus()).isSameInstanceAs(Status.OK); assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java index ce6d2e70eae..b9f235d0aff 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java @@ -163,7 +163,7 @@ public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs ar LoadBalancer.PickResult result = delegate.pickSubchannel(args); LoadBalancer.Subchannel subchannel = result.getSubchannel(); if (subchannel instanceof SubchannelImpl) { - return result.withSubchannelReplacement(((SubchannelImpl) subchannel).delegate()); + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); } return result; } diff --git a/util/src/main/java/io/grpc/util/HealthProducerHelper.java b/util/src/main/java/io/grpc/util/HealthProducerHelper.java index 7913c63d3ad..d871911d203 100644 --- a/util/src/main/java/io/grpc/util/HealthProducerHelper.java +++ b/util/src/main/java/io/grpc/util/HealthProducerHelper.java @@ -103,7 +103,7 @@ public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs ar LoadBalancer.PickResult result = delegate.pickSubchannel(args); LoadBalancer.Subchannel subchannel = result.getSubchannel(); if (subchannel instanceof HealthProducerSubchannel) { - return result.withSubchannelReplacement( + return result.copyWithSubchannel( ((HealthProducerSubchannel) subchannel).delegate()); } return result; diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index ddf29d2866b..dc61441bccd 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -446,8 +446,8 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (subchannel instanceof OutlierDetectionSubchannel) { subchannel = ((OutlierDetectionSubchannel) subchannel).delegate(); } - return pickResult.withSubchannelReplacement(subchannel) - .withStreamTracerFactory(new ResultCountingClientStreamTracerFactory( + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory(new ResultCountingClientStreamTracerFactory( tracker, pickResult.getStreamTracerFactory())); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 4191e61721d..8b8ce0f03ce 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -428,7 +428,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = result.getSubchannel(); if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); - result = result.withSubchannelReplacement(subchannel); + result = result.copyWithSubchannel(subchannel); } if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { @@ -455,7 +455,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { stats, inFlights, result.getStreamTracerFactory()); ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - result = result.withStreamTracerFactory(orcaTracerFactory); + result = result.copyWithStreamTracerFactory(orcaTracerFactory); } } if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 2c123512a7a..fc6c3f5b119 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -511,12 +511,12 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { subchannel = ((WrrSubchannel) subchannel).delegate(); if (!enableOobLoadReport) { - return pickResult.withSubchannelReplacement(subchannel) - .withStreamTracerFactory( + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory( OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( reportListeners.get(pick))); } else { - return pickResult.withSubchannelReplacement(subchannel); + return pickResult.copyWithSubchannel(subchannel); } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index e20b4e15fe6..227a9b646d2 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -257,7 +257,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { PickResult result = delegate.pickSubchannel(args); Subchannel subchannel = result.getSubchannel(); if (subchannel instanceof SubchannelImpl) { - return result.withSubchannelReplacement(((SubchannelImpl) subchannel).delegate()); + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); } return result; } From 600c22be5985624cc6f9ab1cdb46661c5e55e1bb Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 19 Feb 2026 14:08:21 +0530 Subject: [PATCH 5/7] add TODO --- api/src/main/java/io/grpc/LoadBalancer.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 0f3e03feb14..f106e6aa2a4 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -632,6 +632,9 @@ private PickResult( * stream is created at all in some cases. * @since 1.3.0 */ + // TODO(shivaspeaks): We need to deprecate old APIs and create new ones. + // Ideally these static methods should start with "of.." instead of "with.." + // to have consistency with other classes. public static PickResult withSubchannel( Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult( From b18ca55cb4900c8d61303b107c3946df2796142e Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Thu, 19 Feb 2026 14:26:52 +0530 Subject: [PATCH 6/7] update TODO --- api/src/main/java/io/grpc/LoadBalancer.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index f106e6aa2a4..9816de95c5e 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -632,9 +632,8 @@ private PickResult( * stream is created at all in some cases. * @since 1.3.0 */ - // TODO(shivaspeaks): We need to deprecate old APIs and create new ones. - // Ideally these static methods should start with "of.." instead of "with.." - // to have consistency with other classes. + // TODO(shivaspeaks): Need to deprecate old APIs and create new ones, + // per https://github.com/grpc/grpc-java/issues/12662. public static PickResult withSubchannel( Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult( From 17b3a3e7fd79c5956238d0a1836d3ed901d83be8 Mon Sep 17 00:00:00 2001 From: MV Shiva Prasad Date: Fri, 20 Feb 2026 11:11:30 +0530 Subject: [PATCH 7/7] add VisibleForTestingAnnotation --- xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java | 1 + 1 file changed, 1 insertion(+) diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index 227a9b646d2..b37b9bc42e3 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -245,6 +245,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); } + @VisibleForTesting static final class OrcaOobPicker extends SubchannelPicker { final SubchannelPicker delegate;