diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index adc43b19841..9816de95c5e 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -632,6 +632,8 @@ private PickResult( * stream is created at all in some cases. * @since 1.3.0 */ + // TODO(shivaspeaks): Need to deprecate old APIs and create new ones, + // per https://github.com/grpc/grpc-java/issues/12662. public static PickResult withSubchannel( Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) { return new PickResult( @@ -661,6 +663,28 @@ public static PickResult withSubchannel(Subchannel subchannel) { return withSubchannel(subchannel, null); } + /** + * Creates a new {@code PickResult} with the given {@code subchannel}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithSubchannel(Subchannel subchannel) { + return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory, + status, drop, authorityOverride); + } + + /** + * Creates a new {@code PickResult} with the given {@code streamTracerFactory}, + * but retains all other properties from this {@code PickResult}. + * + * @since 1.80.0 + */ + public PickResult copyWithStreamTracerFactory( + @Nullable ClientStreamTracer.Factory streamTracerFactory) { + return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride); + } + /** * A decision to report a connectivity error to the RPC. If the RPC is {@link * CallOptions#withWaitForReady wait-for-ready}, it will stay buffered. Otherwise, it will fail diff --git a/api/src/test/java/io/grpc/LoadBalancerTest.java b/api/src/test/java/io/grpc/LoadBalancerTest.java index 5e9e5cbe816..22fdc220081 100644 --- a/api/src/test/java/io/grpc/LoadBalancerTest.java +++ b/api/src/test/java/io/grpc/LoadBalancerTest.java @@ -64,6 +64,26 @@ public void pickResult_withSubchannelAndTracer() { assertThat(result.isDrop()).isFalse(); } + @Test + public void pickResult_withSubchannelReplacement() { + PickResult result = PickResult.withSubchannel(subchannel, tracerFactory) + .copyWithSubchannel(subchannel2); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + + @Test + public void pickResult_withStreamTracerFactory() { + PickResult result = PickResult.withSubchannel(subchannel) + .copyWithStreamTracerFactory(tracerFactory); + assertThat(result.getSubchannel()).isSameInstanceAs(subchannel); + assertThat(result.getStatus()).isSameInstanceAs(Status.OK); + assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory); + assertThat(result.isDrop()).isFalse(); + } + @Test public void pickResult_withNoResult() { PickResult result = PickResult.withNoResult(); diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java index 8cf1458f5dc..b9f235d0aff 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java @@ -144,6 +144,30 @@ void setHealthCheckedService(@Nullable String service) { public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); } + + @Override + public void updateBalancingState( + io.grpc.ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate().updateBalancingState(newState, new HealthCheckPicker(newPicker)); + } + + private final class HealthCheckPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthCheckPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } } @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/HealthProducerHelper.java b/util/src/main/java/io/grpc/util/HealthProducerHelper.java index b11864765ea..d871911d203 100644 --- a/util/src/main/java/io/grpc/util/HealthProducerHelper.java +++ b/util/src/main/java/io/grpc/util/HealthProducerHelper.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Internal; import io.grpc.LoadBalancer; @@ -84,6 +85,31 @@ protected LoadBalancer.Helper delegate() { return delegate; } + @Override + public void updateBalancingState( + ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new HealthProducerPicker(newPicker)); + } + + private static final class HealthProducerPicker extends LoadBalancer.SubchannelPicker { + private final LoadBalancer.SubchannelPicker delegate; + + HealthProducerPicker(LoadBalancer.SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) { + LoadBalancer.PickResult result = delegate.pickSubchannel(args); + LoadBalancer.Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof HealthProducerSubchannel) { + return result.copyWithSubchannel( + ((HealthProducerSubchannel) subchannel).delegate()); + } + return result; + } + } + // The parent subchannel in the health check producer LB chain. It duplicates subchannel state to // both the state listener and health listener. @VisibleForTesting diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index d72a85012f2..dc61441bccd 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -442,9 +442,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = pickResult.getSubchannel(); if (subchannel != null) { - return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory( - subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY), - pickResult.getStreamTracerFactory())); + EndpointTracker tracker = subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY); + if (subchannel instanceof OutlierDetectionSubchannel) { + subchannel = ((OutlierDetectionSubchannel) subchannel).delegate(); + } + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory(new ResultCountingClientStreamTracerFactory( + tracker, + pickResult.getStreamTracerFactory())); } return pickResult; diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 10436407422..39f5b5fb7d6 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -408,7 +408,7 @@ public void delegatePick() throws Exception { // Make sure that we can pick the single READY subchannel. SubchannelPicker picker = pickerCaptor.getAllValues().get(2); PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); - Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate(); + Subchannel s = pickResult.getSubchannel(); if (s instanceof HealthProducerHelper.HealthProducerSubchannel) { s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate(); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index ec4bec7f25c..8b8ce0f03ce 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -252,42 +252,55 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); final Subchannel subchannel = delegate().createSubchannel(args); - return new ForwardingSubchannel() { - @Override - public void start(SubchannelStateListener listener) { - delegate().start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - // Do nothing if LB has been shutdown - if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { - // Get locality based on the connected address attributes - ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes( - subchannel.getConnectedAddressAttributes()); - ClusterLocality oldClusterLocality = localityAtomicReference - .getAndSet(updatedClusterLocality); - oldClusterLocality.release(); + return new ClusterImplSubchannel(subchannel, localityAtomicReference); + } + + private final class ClusterImplSubchannel extends ForwardingSubchannel { + private final Subchannel delegate; + private final AtomicReference localityAtomicReference; + + private ClusterImplSubchannel( + Subchannel delegate, AtomicReference localityAtomicReference) { + this.delegate = delegate; + this.localityAtomicReference = localityAtomicReference; + } + + @Override + public void start(SubchannelStateListener listener) { + delegate().start( + new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + // Do nothing if LB has been shutdown + if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) { + // Get locality based on the connected address attributes + ClusterLocality updatedClusterLocality = + createClusterLocalityFromAttributes( + delegate.getConnectedAddressAttributes()); + ClusterLocality oldClusterLocality = + localityAtomicReference.getAndSet(updatedClusterLocality); + oldClusterLocality.release(); + } + listener.onSubchannelState(newState); } - listener.onSubchannelState(newState); - } - }); - } + }); + } - @Override - public void shutdown() { - localityAtomicReference.get().release(); - delegate().shutdown(); - } + @Override + public void shutdown() { + localityAtomicReference.get().release(); + delegate().shutdown(); + } - @Override - public void updateAddresses(List addresses) { - delegate().updateAddresses(withAdditionalAttributes(addresses)); - } + @Override + public void updateAddresses(List addresses) { + delegate().updateAddresses(withAdditionalAttributes(addresses)); + } - @Override - protected Subchannel delegate() { - return subchannel; - } - }; + @Override + protected Subchannel delegate() { + return delegate; + } } private List withAdditionalAttributes( @@ -412,6 +425,11 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } PickResult result = delegate.pickSubchannel(args); if (result.getStatus().isOk() && result.getSubchannel() != null) { + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) { + subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate(); + result = result.copyWithSubchannel(subchannel); + } if (enableCircuitBreaking) { if (inFlights.get() >= maxConcurrentRequests) { if (dropStats != null) { @@ -437,8 +455,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { stats, inFlights, result.getStreamTracerFactory()); ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - result = PickResult.withSubchannel(result.getSubchannel(), - orcaTracerFactory); + result = result.copyWithStreamTracerFactory(orcaTracerFactory); } } if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 6cf3189d587..fc6c3f5b119 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -508,12 +508,15 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (subchannel == null) { return pickResult; } + + subchannel = ((WrrSubchannel) subchannel).delegate(); if (!enableOobLoadReport) { - return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - reportListeners.get(pick))); + return pickResult.copyWithSubchannel(subchannel) + .copyWithStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + reportListeners.get(pick))); } else { - return PickResult.withSubchannel(subchannel); + return pickResult.copyWithSubchannel(subchannel); } } diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index 9ac06d362fc..b37b9bc42e3 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -36,12 +36,16 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ClientCall; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; @@ -236,6 +240,30 @@ protected Helper delegate() { return delegate; } + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker)); + } + + @VisibleForTesting + static final class OrcaOobPicker extends SubchannelPicker { + final SubchannelPicker delegate; + + OrcaOobPicker(SubchannelPicker delegate) { + this.delegate = delegate; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + PickResult result = delegate.pickSubchannel(args); + Subchannel subchannel = result.getSubchannel(); + if (subchannel instanceof SubchannelImpl) { + return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate()); + } + return result; + } + } + @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 9fac46eaf09..0b026b06f6f 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -85,6 +85,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; +import io.grpc.xds.orca.OrcaOobUtilAccessor; import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; @@ -171,6 +172,10 @@ public WeightedRoundRobinLoadBalancerTest() { helper = mock(Helper.class, delegatesTo(testHelperInstance)); } + private static WeightedRoundRobinPicker getWrrPicker(SubchannelPicker picker) { + return (WeightedRoundRobinPicker) OrcaOobUtilAccessor.getDelegate(picker); + } + @Before public void setup() { for (int i = 0; i < 3; i++) { @@ -212,9 +217,8 @@ public void pickChildLbTF() throws Exception { .forTransientFailure(Status.UNAVAILABLE)); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - final WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getValue(); - weightedPicker.pickSubchannel(mockArgs); + final SubchannelPicker picker = pickerCaptor.getValue(); + picker.pickSubchannel(mockArgs); } @Test @@ -274,9 +278,9 @@ public void wrrLifeCycle() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); @@ -337,7 +341,7 @@ public void enableOobLoadReportConfig() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -361,8 +365,8 @@ weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).on .setAttributes(affinity).build())); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor2.capture()); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); - pickResult = weightedPicker.pickSubchannel(mockArgs); + SubchannelPicker rawPicker = pickerCaptor2.getAllValues().get(2); + pickResult = rawPicker.pickSubchannel(mockArgs); assertThat(getAddresses(pickResult)).isEqualTo(servers.get(0)); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( @@ -395,7 +399,7 @@ private void pickByWeight(MetricReport r1, MetricReport r2, MetricReport r3, verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); @@ -595,7 +599,7 @@ public void blackoutPeriod() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -655,9 +659,9 @@ public void updateWeightTimer() { eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); + getWrrPicker(pickerCaptor.getAllValues().get(0)); assertThat(weightedPicker.getChildren().size()).isEqualTo(1); - weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + weightedPicker = getWrrPicker(pickerCaptor.getAllValues().get(1)); assertThat(weightedPicker.getChildren().size()).isEqualTo(2); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); @@ -710,7 +714,7 @@ public void weightExpired() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -761,7 +765,7 @@ public void rrFallback() { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); int expectedTasks = isEnabledHappyEyeballs() ? 2 : 1; assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(expectedTasks); Map qpsByChannel = ImmutableMap.of(servers.get(0), 2, @@ -816,7 +820,7 @@ public void unknownWeightIsAvgWeight() { verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); + getWrrPicker(pickerCaptor.getAllValues().get(2)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( @@ -857,7 +861,7 @@ public void pickFromOtherThread() throws Exception { verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = - (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); + getWrrPicker(pickerCaptor.getAllValues().get(1)); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( diff --git a/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java new file mode 100644 index 00000000000..db9168dd08e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilAccessor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2026 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.orca; + +import io.grpc.LoadBalancer; + +/** + * Accessor for white-box testing involving OrcaOobUtil. + */ +public final class OrcaOobUtilAccessor { + private OrcaOobUtilAccessor() { + // Do not instantiate + } + + public static LoadBalancer.SubchannelPicker getDelegate(LoadBalancer.SubchannelPicker picker) { + if (picker instanceof OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) { + return ((OrcaOobUtil.OrcaReportingHelper.OrcaOobPicker) picker).delegate; + } + return picker; + } +}