diff --git a/java-iam/.repo-metadata.json b/java-iam/.repo-metadata.json index fa9ab8c76a75..d35667f774d9 100644 --- a/java-iam/.repo-metadata.json +++ b/java-iam/.repo-metadata.json @@ -10,7 +10,6 @@ "repo": "googleapis/google-cloud-java", "repo_short": "java-iam", "distribution_name": "com.google.cloud:google-iam-policy", - "api_id": "iam.googleapis.com", "library_type": "GAPIC_AUTO", "requires_billing": true, "excluded_dependencies": "google-iam-policy", diff --git a/java-iam/README.md b/java-iam/README.md index a31d56fecfdc..b5f33684f9cf 100644 --- a/java-iam/README.md +++ b/java-iam/README.md @@ -188,7 +188,7 @@ Java is a registered trademark of Oracle and/or its affiliates. [code-of-conduct]: https://github.com/googleapis/google-cloud-java/blob/main/CODE_OF_CONDUCT.md#contributor-code-of-conduct [license]: https://github.com/googleapis/google-cloud-java/blob/main/LICENSE [enable-billing]: https://cloud.google.com/apis/docs/getting-started#enabling_billing -[enable-api]: https://console.cloud.google.com/flows/enableapi?apiid=iam.googleapis.com + [libraries-bom]: https://github.com/GoogleCloudPlatform/cloud-opensource-java/wiki/The-Google-Cloud-Platform-Libraries-BOM [shell_img]: https://gstatic.com/cloudssh/images/open-btn.png diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java new file mode 100644 index 000000000000..0663adf4d9b0 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/EndpointOverloadCooldownTracker.java @@ -0,0 +1,152 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.spi.v1; + +import com.google.common.annotations.VisibleForTesting; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.LongUnaryOperator; + +/** + * Tracks short-lived endpoint cooldowns after routed {@code RESOURCE_EXHAUSTED} failures. + * + *

This allows later requests to try a different replica instead of immediately routing back to + * the same overloaded endpoint. + */ +final class EndpointOverloadCooldownTracker { + + @VisibleForTesting static final Duration DEFAULT_INITIAL_COOLDOWN = Duration.ofSeconds(10); + @VisibleForTesting static final Duration DEFAULT_MAX_COOLDOWN = Duration.ofMinutes(1); + @VisibleForTesting static final Duration DEFAULT_RESET_AFTER = Duration.ofMinutes(10); + + @VisibleForTesting + static final class CooldownState { + private final int consecutiveFailures; + private final Instant cooldownUntil; + private final Instant lastFailureAt; + + private CooldownState(int consecutiveFailures, Instant cooldownUntil, Instant lastFailureAt) { + this.consecutiveFailures = consecutiveFailures; + this.cooldownUntil = cooldownUntil; + this.lastFailureAt = lastFailureAt; + } + } + + private final ConcurrentHashMap entries = new ConcurrentHashMap<>(); + private final Duration initialCooldown; + private final Duration maxCooldown; + private final Duration resetAfter; + private final Clock clock; + private final LongUnaryOperator randomLong; + + EndpointOverloadCooldownTracker() { + this( + DEFAULT_INITIAL_COOLDOWN, + DEFAULT_MAX_COOLDOWN, + DEFAULT_RESET_AFTER, + Clock.systemUTC(), + bound -> ThreadLocalRandom.current().nextLong(bound)); + } + + @VisibleForTesting + EndpointOverloadCooldownTracker( + Duration initialCooldown, + Duration maxCooldown, + Duration resetAfter, + Clock clock, + LongUnaryOperator randomLong) { + Duration resolvedInitial = + (initialCooldown == null || initialCooldown.isZero() || initialCooldown.isNegative()) + ? DEFAULT_INITIAL_COOLDOWN + : initialCooldown; + Duration resolvedMax = + (maxCooldown == null || maxCooldown.isZero() || maxCooldown.isNegative()) + ? DEFAULT_MAX_COOLDOWN + : maxCooldown; + if (resolvedMax.compareTo(resolvedInitial) < 0) { + resolvedMax = resolvedInitial; + } + this.initialCooldown = resolvedInitial; + this.maxCooldown = resolvedMax; + this.resetAfter = + (resetAfter == null || resetAfter.isZero() || resetAfter.isNegative()) + ? DEFAULT_RESET_AFTER + : resetAfter; + this.clock = clock == null ? Clock.systemUTC() : clock; + this.randomLong = + randomLong == null ? bound -> ThreadLocalRandom.current().nextLong(bound) : randomLong; + } + + boolean isCoolingDown(String address) { + if (address == null || address.isEmpty()) { + return false; + } + Instant now = clock.instant(); + CooldownState state = entries.get(address); + if (state == null) { + return false; + } + if (state.cooldownUntil.isAfter(now)) { + return true; + } + if (Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) { + return false; + } + entries.remove(address, state); + CooldownState current = entries.get(address); + return current != null && current.cooldownUntil.isAfter(now); + } + + void recordFailure(String address) { + if (address == null || address.isEmpty()) { + return; + } + Instant now = clock.instant(); + entries.compute( + address, + (ignored, state) -> { + int consecutiveFailures = 1; + if (state != null + && Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) { + consecutiveFailures = state.consecutiveFailures + 1; + } + Duration cooldown = cooldownForFailures(consecutiveFailures); + return new CooldownState(consecutiveFailures, now.plus(cooldown), now); + }); + } + + private Duration cooldownForFailures(int failures) { + Duration cooldown = initialCooldown; + for (int i = 1; i < failures; i++) { + if (cooldown.compareTo(maxCooldown.dividedBy(2)) > 0) { + cooldown = maxCooldown; + break; + } + cooldown = cooldown.multipliedBy(2); + } + long bound = Math.max(1L, cooldown.toMillis() + 1L); + return Duration.ofMillis(randomLong.applyAsLong(bound)); + } + + @VisibleForTesting + CooldownState getState(String address) { + return entries.get(address); + } +} diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 7c1b6be1c1bd..6cc0a485d056 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -432,7 +432,11 @@ public GapicSpannerRpc(final SpannerOptions options) { this.readRetrySettings = options.getSpannerStubSettings().streamingReadSettings().getRetrySettings(); this.readRetryableCodes = - options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes(); + ImmutableSet.builder() + .addAll( + options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes()) + .add(Code.RESOURCE_EXHAUSTED) + .build(); this.executeQueryRetrySettings = options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings(); this.executeQueryRetryableCodes = diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java index d7b32f72bcd6..0d48d07924c6 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyAwareChannel.java @@ -21,6 +21,7 @@ import com.google.api.core.InternalApi; import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider; import com.google.cloud.spanner.XGoogSpannerRequestId; +import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.protobuf.ByteString; @@ -94,19 +95,28 @@ final class KeyAwareChannel extends ManagedChannel { // Bounded to prevent unbounded growth if application code does not close read-only transactions. private final Cache readOnlyTxPreferLeader = CacheBuilder.newBuilder().maximumSize(MAX_TRACKED_READ_ONLY_TRANSACTIONS).build(); - // If a routed endpoint returns RESOURCE_EXHAUSTED, the next retry attempt of that same logical - // request should avoid that endpoint once so other requests are unaffected. Bound and age out - // entries in case a caller gives up and never issues a retry. + // If a routed endpoint returns RESOURCE_EXHAUSTED or UNAVAILABLE, the next retry attempt of + // that same logical request should avoid that endpoint once so other requests are unaffected. + // Bound and age out entries in case a caller gives up and never issues a retry. private final Cache> excludedEndpointsForLogicalRequest = CacheBuilder.newBuilder() .maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS) .expireAfterWrite(EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES, TimeUnit.MINUTES) .build(); + private final EndpointOverloadCooldownTracker endpointOverloadCooldowns; private KeyAwareChannel( InstantiatingGrpcChannelProvider channelProvider, @Nullable ChannelEndpointCacheFactory endpointCacheFactory) throws IOException { + this(channelProvider, endpointCacheFactory, new EndpointOverloadCooldownTracker()); + } + + private KeyAwareChannel( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + EndpointOverloadCooldownTracker endpointOverloadCooldowns) + throws IOException { if (endpointCacheFactory == null) { this.endpointCache = new GrpcChannelEndpointCache(channelProvider); } else { @@ -120,6 +130,7 @@ private KeyAwareChannel( // would interfere with test assertions. this.lifecycleManager = (endpointCacheFactory == null) ? new EndpointLifecycleManager(endpointCache) : null; + this.endpointOverloadCooldowns = endpointOverloadCooldowns; } static KeyAwareChannel create( @@ -129,6 +140,15 @@ static KeyAwareChannel create( return new KeyAwareChannel(channelProvider, endpointCacheFactory); } + @VisibleForTesting + static KeyAwareChannel create( + InstantiatingGrpcChannelProvider channelProvider, + @Nullable ChannelEndpointCacheFactory endpointCacheFactory, + EndpointOverloadCooldownTracker endpointOverloadCooldowns) + throws IOException { + return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointOverloadCooldowns); + } + private static final class ChannelFinderReference extends SoftReference { final String databaseId; @@ -321,36 +341,61 @@ void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long private void maybeExcludeEndpointOnNextCall( @Nullable ChannelEndpoint endpoint, @Nullable String logicalRequestKey) { - if (endpoint == null || logicalRequestKey == null) { + if (endpoint == null) { return; } String address = endpoint.getAddress(); - if (!defaultEndpointAddress.equals(address)) { - excludedEndpointsForLogicalRequest - .asMap() - .compute( - logicalRequestKey, - (ignored, excludedEndpoints) -> { - Set updated = - excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints; - updated.add(address); - return updated; - }); + if (defaultEndpointAddress.equals(address)) { + return; } + endpointOverloadCooldowns.recordFailure(address); + if (logicalRequestKey == null) { + return; + } + excludedEndpointsForLogicalRequest + .asMap() + .compute( + logicalRequestKey, + (ignored, excludedEndpoints) -> { + Set updated = + excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints; + updated.add(address); + return updated; + }); + } + + private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) { + return statusCode == io.grpc.Status.Code.RESOURCE_EXHAUSTED + || statusCode == io.grpc.Status.Code.UNAVAILABLE; } private Predicate consumeExcludedEndpointsForCurrentCall( @Nullable String logicalRequestKey) { - if (logicalRequestKey == null) { - return address -> false; + Predicate requestScopedExcluded = address -> false; + if (logicalRequestKey != null) { + Set excludedEndpoints = + excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey); + if (excludedEndpoints != null && !excludedEndpoints.isEmpty()) { + excludedEndpoints = new HashSet<>(excludedEndpoints); + requestScopedExcluded = excludedEndpoints::contains; + } } + Predicate finalRequestScopedExcluded = requestScopedExcluded; + return address -> + finalRequestScopedExcluded.test(address) + || endpointOverloadCooldowns.isCoolingDown(address); + } + + @VisibleForTesting + boolean isCoolingDown(String address) { + return endpointOverloadCooldowns.isCoolingDown(address); + } + + @VisibleForTesting + boolean hasExcludedEndpointForLogicalRequest(String logicalRequestKey, String address) { Set excludedEndpoints = - excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey); - if (excludedEndpoints == null || excludedEndpoints.isEmpty()) { - return address -> false; - } - excludedEndpoints = new HashSet<>(excludedEndpoints); - return excludedEndpoints::contains; + excludedEndpointsForLogicalRequest.getIfPresent(logicalRequestKey); + return excludedEndpoints != null && excludedEndpoints.contains(address); } private boolean isReadOnlyTransaction(ByteString transactionId) { @@ -858,7 +903,7 @@ public void onMessage(ResponseT message) { @Override public void onClose(io.grpc.Status status, Metadata trailers) { - if (status.getCode() == io.grpc.Status.Code.RESOURCE_EXHAUSTED) { + if (shouldExcludeEndpointOnRetry(status.getCode())) { call.parentChannel.maybeExcludeEndpointOnNextCall( call.selectedEndpoint, call.logicalRequestKey); } diff --git a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java index 59955ccb4bd2..41b8798d9611 100644 --- a/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java +++ b/java-spanner/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/KeyRangeCache.java @@ -28,6 +28,7 @@ import com.google.spanner.v1.Tablet; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -185,11 +186,10 @@ Set getActiveAddresses() { readLock.lock(); try { for (CachedGroup group : groups.values()) { - synchronized (group) { - for (CachedTablet tablet : group.tablets) { - if (!tablet.serverAddress.isEmpty()) { - addresses.add(tablet.serverAddress); - } + GroupSnapshot snapshot = group.snapshot; + for (TabletSnapshot tablet : snapshot.tablets) { + if (!tablet.serverAddress.isEmpty()) { + addresses.add(tablet.serverAddress); } } } @@ -487,34 +487,27 @@ private int compare(ByteString left, ByteString right) { return ByteString.unsignedLexicographicalComparator().compare(left, right); } - /** Represents a single tablet within a group. */ - private class CachedTablet { - long tabletUid = 0; - ByteString incarnation = ByteString.EMPTY; - String serverAddress = ""; - int distance = 0; - boolean skip = false; - Tablet.Role role = Tablet.Role.ROLE_UNSPECIFIED; - String location = ""; - - ChannelEndpoint endpoint = null; - - void update(Tablet tabletIn) { - if (tabletUid > 0 && compare(incarnation, tabletIn.getIncarnation()) > 0) { - return; - } - - tabletUid = tabletIn.getTabletUid(); - incarnation = tabletIn.getIncarnation(); - distance = tabletIn.getDistance(); - skip = tabletIn.getSkip(); - role = tabletIn.getRole(); - location = tabletIn.getLocation(); - - if (!serverAddress.equals(tabletIn.getServerAddress())) { - serverAddress = tabletIn.getServerAddress(); - endpoint = null; - } + private static final GroupSnapshot EMPTY_GROUP_SNAPSHOT = + new GroupSnapshot(ByteString.EMPTY, -1, Collections.emptyList()); + + /** Immutable tablet metadata used by the read path without per-group locking. */ + private static final class TabletSnapshot { + final long tabletUid; + final ByteString incarnation; + final String serverAddress; + final int distance; + final boolean skip; + final Tablet.Role role; + final String location; + + private TabletSnapshot(Tablet tabletIn) { + this.tabletUid = tabletIn.getTabletUid(); + this.incarnation = tabletIn.getIncarnation(); + this.serverAddress = tabletIn.getServerAddress(); + this.distance = tabletIn.getDistance(); + this.skip = tabletIn.getSkip(); + this.role = tabletIn.getRole(); + this.location = tabletIn.getLocation(); } boolean matches(DirectedReadOptions directedReadOptions) { @@ -555,132 +548,6 @@ private boolean matches(DirectedReadOptions.ReplicaSelection selection) { } } - /** - * Evaluates whether this tablet should be skipped for location-aware routing. - * - *

State-aware skip logic: - * - *

- */ - boolean shouldSkip( - RoutingHint.Builder hintBuilder, - Predicate excludedEndpoints, - Set skippedTabletUids) { - // Server-marked skip, no address, or excluded endpoint: always report. - if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) { - addSkippedTablet(hintBuilder, skippedTabletUids); - return true; - } - - // If the cached endpoint's channel has been shut down (e.g. after idle eviction), - // discard the stale reference so we re-lookup from the cache below. - if (endpoint != null && endpoint.getChannel().isShutdown()) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference", - new Object[] {tabletUid, serverAddress}); - endpoint = null; - } - - // Lookup without creating: location-aware routing should not trigger foreground endpoint - // creation. - if (endpoint == null) { - endpoint = endpointCache.getIfPresent(serverAddress); - } - - // No endpoint exists yet - skip silently, request background recreation so the - // endpoint becomes available for future requests. - if (endpoint == null) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: no endpoint present, skipping silently", - new Object[] {tabletUid, serverAddress}); - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - if (lifecycleManager != null) { - lifecycleManager.requestEndpointRecreation(serverAddress); - } - return true; - } - - // READY - usable for location-aware routing. - if (endpoint.isHealthy()) { - return false; - } - - // TRANSIENT_FAILURE - skip and report so server can refresh client cache. - if (endpoint.isTransientFailure()) { - logger.log( - Level.FINE, - "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets", - new Object[] {tabletUid, serverAddress}); - addSkippedTablet(hintBuilder, skippedTabletUids); - return true; - } - - // IDLE, CONNECTING, SHUTDOWN, or unsupported - skip silently. - logger.log( - Level.FINE, - "Tablet {0} at {1}: endpoint not ready, skipping silently", - new Object[] {tabletUid, serverAddress}); - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - return true; - } - - private void addSkippedTablet(RoutingHint.Builder hintBuilder, Set skippedTabletUids) { - if (!skippedTabletUids.add(tabletUid)) { - return; - } - RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder(); - skipped.setTabletUid(tabletUid); - skipped.setIncarnation(incarnation); - } - - private void recordKnownTransientFailure( - RoutingHint.Builder hintBuilder, - Predicate excludedEndpoints, - Set skippedTabletUids) { - if (skip || serverAddress.isEmpty() || excludedEndpoints.test(serverAddress)) { - return; - } - - if (endpoint != null && endpoint.getChannel().isShutdown()) { - endpoint = null; - } - - if (endpoint == null) { - endpoint = endpointCache.getIfPresent(serverAddress); - } - - if (endpoint != null && endpoint.isTransientFailure()) { - addSkippedTablet(hintBuilder, skippedTabletUids); - return; - } - - maybeAddRecentTransientFailureSkip(hintBuilder, skippedTabletUids); - } - - private void maybeAddRecentTransientFailureSkip( - RoutingHint.Builder hintBuilder, Set skippedTabletUids) { - if (lifecycleManager != null - && lifecycleManager.wasRecentlyEvictedTransientFailure(serverAddress)) { - addSkippedTablet(hintBuilder, skippedTabletUids); - } - } - - ChannelEndpoint pick(RoutingHint.Builder hintBuilder) { - hintBuilder.setTabletUid(tabletUid); - // Endpoint must already exist and be READY if shouldSkip returned false. - return endpoint; - } - String debugString() { return tabletUid + ":" @@ -698,19 +565,40 @@ String debugString() { } } + private static final class GroupSnapshot { + final ByteString generation; + final int leaderIndex; + final List tablets; + + private GroupSnapshot(ByteString generation, int leaderIndex, List tablets) { + this.generation = generation; + this.leaderIndex = leaderIndex; + this.tablets = Collections.unmodifiableList(new ArrayList<>(tablets)); + } + + boolean hasLeader() { + return leaderIndex >= 0 && leaderIndex < tablets.size(); + } + + TabletSnapshot leader() { + return tablets.get(leaderIndex); + } + } + /** Represents a paxos group with its tablets. */ private class CachedGroup { final long groupUid; - ByteString generation = ByteString.EMPTY; - List tablets = new ArrayList<>(); - int leaderIndex = -1; + volatile GroupSnapshot snapshot = EMPTY_GROUP_SNAPSHOT; int refs = 1; CachedGroup(long groupUid) { this.groupUid = groupUid; } - synchronized void update(Group groupIn) { + void update(Group groupIn) { + GroupSnapshot current = snapshot; + ByteString generation = current.generation; + int leaderIndex = current.leaderIndex; if (compare(groupIn.getGeneration(), generation) > 0) { generation = groupIn.getGeneration(); if (groupIn.getLeaderIndex() >= 0 && groupIn.getLeaderIndex() < groupIn.getTabletsCount()) { @@ -720,37 +608,11 @@ synchronized void update(Group groupIn) { } } - if (tablets.size() == groupIn.getTabletsCount()) { - boolean mismatch = false; - for (int t = 0; t < groupIn.getTabletsCount(); t++) { - if (tablets.get(t).tabletUid != groupIn.getTablets(t).getTabletUid()) { - mismatch = true; - break; - } - } - if (!mismatch) { - for (int t = 0; t < groupIn.getTabletsCount(); t++) { - tablets.get(t).update(groupIn.getTablets(t)); - } - return; - } - } - - Map tabletsByUid = new HashMap<>(tablets.size()); - for (CachedTablet tablet : tablets) { - tabletsByUid.put(tablet.tabletUid, tablet); - } - List newTablets = new ArrayList<>(groupIn.getTabletsCount()); + List tablets = new ArrayList<>(groupIn.getTabletsCount()); for (int t = 0; t < groupIn.getTabletsCount(); t++) { - Tablet tabletIn = groupIn.getTablets(t); - CachedTablet tablet = tabletsByUid.get(tabletIn.getTabletUid()); - if (tablet == null) { - tablet = new CachedTablet(); - } - tablet.update(tabletIn); - newTablets.add(tablet); + tablets.add(new TabletSnapshot(groupIn.getTablets(t))); } - tablets = newTablets; + snapshot = new GroupSnapshot(generation, leaderIndex, tablets); } ChannelEndpoint fillRoutingHint( @@ -758,59 +620,72 @@ ChannelEndpoint fillRoutingHint( DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints) { + GroupSnapshot snapshot = this.snapshot; Set skippedTabletUids = skippedTabletUids(hintBuilder); boolean hasDirectedReadOptions = directedReadOptions.getReplicasCase() != DirectedReadOptions.ReplicasCase.REPLICAS_NOT_SET; - - // Select a tablet while holding the lock. With state-aware routing, only READY - // endpoints pass shouldSkip(), so the selected tablet always has a cached - // endpoint. No foreground endpoint creation is needed — the lifecycle manager - // creates endpoints in the background. - synchronized (this) { - CachedTablet selected = - selectTabletLocked( - preferLeader, - hasDirectedReadOptions, - hintBuilder, - directedReadOptions, - excludedEndpoints, - skippedTabletUids); - if (selected == null) { - return null; - } - recordKnownTransientFailuresLocked( - selected, directedReadOptions, hintBuilder, excludedEndpoints, skippedTabletUids); - return selected.pick(hintBuilder); - } - } - - private CachedTablet selectTabletLocked( + Map resolvedEndpoints = new HashMap<>(); + + TabletSnapshot selected = + selectTablet( + snapshot, + preferLeader, + hasDirectedReadOptions, + hintBuilder, + directedReadOptions, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints); + if (selected == null) { + return null; + } + recordKnownTransientFailures( + snapshot, + selected, + directedReadOptions, + hintBuilder, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints); + hintBuilder.setTabletUid(selected.tabletUid); + return resolveEndpoint(selected, resolvedEndpoints); + } + + private TabletSnapshot selectTablet( + GroupSnapshot snapshot, boolean preferLeader, boolean hasDirectedReadOptions, RoutingHint.Builder hintBuilder, DirectedReadOptions directedReadOptions, Predicate excludedEndpoints, - Set skippedTabletUids) { + Set skippedTabletUids, + Map resolvedEndpoints) { boolean checkedLeader = false; if (preferLeader && !hasDirectedReadOptions - && hasLeader() - && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { + && snapshot.hasLeader() + && snapshot.leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { checkedLeader = true; - if (!leader().shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { - return leader(); + if (!shouldSkip( + snapshot.leader(), + hintBuilder, + excludedEndpoints, + skippedTabletUids, + resolvedEndpoints)) { + return snapshot.leader(); } } - for (int index = 0; index < tablets.size(); index++) { - if (checkedLeader && index == leaderIndex) { + for (int index = 0; index < snapshot.tablets.size(); index++) { + if (checkedLeader && index == snapshot.leaderIndex) { continue; } - CachedTablet tablet = tablets.get(index); + TabletSnapshot tablet = snapshot.tablets.get(index); if (!tablet.matches(directedReadOptions)) { continue; } - if (tablet.shouldSkip(hintBuilder, excludedEndpoints, skippedTabletUids)) { + if (shouldSkip( + tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints)) { continue; } return tablet; @@ -818,17 +693,20 @@ && leader().distance <= MAX_LOCAL_REPLICA_DISTANCE) { return null; } - private void recordKnownTransientFailuresLocked( - CachedTablet selected, + private void recordKnownTransientFailures( + GroupSnapshot snapshot, + TabletSnapshot selected, DirectedReadOptions directedReadOptions, RoutingHint.Builder hintBuilder, Predicate excludedEndpoints, - Set skippedTabletUids) { - for (CachedTablet tablet : tablets) { + Set skippedTabletUids, + Map resolvedEndpoints) { + for (TabletSnapshot tablet : snapshot.tablets) { if (tablet == selected || !tablet.matches(directedReadOptions)) { continue; } - tablet.recordKnownTransientFailure(hintBuilder, excludedEndpoints, skippedTabletUids); + recordKnownTransientFailure( + tablet, hintBuilder, excludedEndpoints, skippedTabletUids, resolvedEndpoints); } } @@ -840,27 +718,124 @@ private Set skippedTabletUids(RoutingHint.Builder hintBuilder) { return skippedTabletUids; } - boolean hasLeader() { - return leaderIndex >= 0 && leaderIndex < tablets.size(); + private boolean shouldSkip( + TabletSnapshot tablet, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints, + Set skippedTabletUids, + Map resolvedEndpoints) { + if (tablet.skip + || tablet.serverAddress.isEmpty() + || excludedEndpoints.test(tablet.serverAddress)) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return true; + } + + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint == null) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: no endpoint present, skipping silently", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + if (lifecycleManager != null) { + lifecycleManager.requestEndpointRecreation(tablet.serverAddress); + } + return true; + } + if (endpoint.isHealthy()) { + return false; + } + if (endpoint.isTransientFailure()) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: endpoint in TRANSIENT_FAILURE, adding to skipped_tablets", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return true; + } + + logger.log( + Level.FINE, + "Tablet {0} at {1}: endpoint not ready, skipping silently", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + return true; } - CachedTablet leader() { - return tablets.get(leaderIndex); + private void recordKnownTransientFailure( + TabletSnapshot tablet, + RoutingHint.Builder hintBuilder, + Predicate excludedEndpoints, + Set skippedTabletUids, + Map resolvedEndpoints) { + if (tablet.skip + || tablet.serverAddress.isEmpty() + || excludedEndpoints.test(tablet.serverAddress)) { + return; + } + + ChannelEndpoint endpoint = resolveEndpoint(tablet, resolvedEndpoints); + if (endpoint != null && endpoint.isTransientFailure()) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + return; + } + + maybeAddRecentTransientFailureSkip(tablet, hintBuilder, skippedTabletUids); + } + + private ChannelEndpoint resolveEndpoint( + TabletSnapshot tablet, Map resolvedEndpoints) { + if (tablet.serverAddress.isEmpty()) { + return null; + } + if (resolvedEndpoints.containsKey(tablet.serverAddress)) { + return resolvedEndpoints.get(tablet.serverAddress); + } + ChannelEndpoint endpoint = endpointCache.getIfPresent(tablet.serverAddress); + if (endpoint != null && endpoint.getChannel().isShutdown()) { + logger.log( + Level.FINE, + "Tablet {0} at {1}: cached endpoint is shutdown, clearing stale reference", + new Object[] {tablet.tabletUid, tablet.serverAddress}); + endpoint = null; + } + resolvedEndpoints.put(tablet.serverAddress, endpoint); + return endpoint; + } + + private void maybeAddRecentTransientFailureSkip( + TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + if (lifecycleManager != null + && lifecycleManager.wasRecentlyEvictedTransientFailure(tablet.serverAddress)) { + addSkippedTablet(tablet, hintBuilder, skippedTabletUids); + } + } + + private void addSkippedTablet( + TabletSnapshot tablet, RoutingHint.Builder hintBuilder, Set skippedTabletUids) { + if (!skippedTabletUids.add(tablet.tabletUid)) { + return; + } + RoutingHint.SkippedTablet.Builder skipped = hintBuilder.addSkippedTabletUidBuilder(); + skipped.setTabletUid(tablet.tabletUid); + skipped.setIncarnation(tablet.incarnation); } String debugString() { + GroupSnapshot snapshot = this.snapshot; StringBuilder sb = new StringBuilder(); sb.append(groupUid).append(":["); - for (int i = 0; i < tablets.size(); i++) { - sb.append(tablets.get(i).debugString()); - if (hasLeader() && i == leaderIndex) { + for (int i = 0; i < snapshot.tablets.size(); i++) { + sb.append(snapshot.tablets.get(i).debugString()); + if (snapshot.hasLeader() && i == snapshot.leaderIndex) { sb.append(" (leader)"); } - if (i < tablets.size() - 1) { + if (i < snapshot.tablets.size() - 1) { sb.append(", "); } } - sb.append("]@").append(generation.toStringUtf8()); + sb.append("]@").append(snapshot.generation.toStringUtf8()); sb.append("#").append(refs); return sb.toString(); } diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java new file mode 100644 index 000000000000..9b6c2de65397 --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/LocationAwareSharedBackendReplicaHarnessTest.java @@ -0,0 +1,680 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.spi.v1.KeyRecipeCache; +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.ByteString; +import com.google.protobuf.ListValue; +import com.google.protobuf.TextFormat; +import com.google.protobuf.Value; +import com.google.rpc.RetryInfo; +import com.google.spanner.v1.CacheUpdate; +import com.google.spanner.v1.DirectedReadOptions; +import com.google.spanner.v1.DirectedReadOptions.IncludeReplicas; +import com.google.spanner.v1.DirectedReadOptions.ReplicaSelection; +import com.google.spanner.v1.Group; +import com.google.spanner.v1.Range; +import com.google.spanner.v1.ReadRequest; +import com.google.spanner.v1.RecipeList; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.RoutingHint; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.Tablet; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.ProtoUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class LocationAwareSharedBackendReplicaHarnessTest { + + private static final String PROJECT = "fake-project"; + private static final String INSTANCE = "fake-instance"; + private static final String DATABASE = "fake-database"; + private static final String TABLE = "T"; + private static final String REPLICA_LOCATION = "us-east1"; + private static final Statement SEED_QUERY = Statement.of("SELECT 1"); + private static final ByteString RESUME_TOKEN_AFTER_FIRST_ROW = + ByteString.copyFromUtf8("000000001"); + private static final DirectedReadOptions DIRECTED_READ_OPTIONS = + DirectedReadOptions.newBuilder() + .setIncludeReplicas( + IncludeReplicas.newBuilder() + .addReplicaSelections( + ReplicaSelection.newBuilder() + .setLocation(REPLICA_LOCATION) + .setType(ReplicaSelection.Type.READ_ONLY) + .build()) + .build()) + .build(); + + @BeforeClass + public static void enableLocationAwareRouting() { + SpannerOptions.useEnvironment( + new SpannerOptions.SpannerEnvironment() { + @Override + public boolean isEnableLocationApi() { + return true; + } + }); + } + + @AfterClass + public static void restoreEnvironment() { + SpannerOptions.useDefaultEnvironment(); + } + + @Test + public void singleUseReadReroutesOnResourceExhaustedForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, + resourceExhausted("busy-routed-replica")); + + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(resultSet.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + @Test + public void singleUseReadCooldownSkipsReplicaOnNextRequestForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, + resourceExhaustedWithRetryInfo("busy-routed-replica")); + + try (ResultSet firstRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(firstRead.next()); + } + + try (ResultSet secondRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(secondRead.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 2, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + List replicaBRequests = + harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + for (AbstractMessage request : replicaBRequests) { + assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); + } + List replicaBRequestIds = + harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + replicaBRequestIds.get(0)); + assertNotEquals( + XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(), + XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey()); + } + } + + @Test + public void singleUseReadReroutesOnUnavailableForBypassTraffic() throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); + + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(resultSet.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + @Test + public void singleUseReadCooldownSkipsUnavailableReplicaOnNextRequestForBypassTraffic() + throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, singleRowReadResultSet("b")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness + .replicas + .get(0) + .putMethodErrors( + SharedBackendReplicaHarness.METHOD_STREAMING_READ, unavailable("isolated-replica")); + + try (ResultSet firstRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(firstRead.next()); + } + + try (ResultSet secondRead = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + assertTrue(secondRead.next()); + } + + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 2, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + List replicaBRequests = + harness.replicas.get(1).getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + for (AbstractMessage request : replicaBRequests) { + assertTrue(((ReadRequest) request).getResumeToken().isEmpty()); + } + List replicaBRequestIds = + harness.replicas.get(1).getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + replicaBRequestIds.get(0)); + assertNotEquals( + XGoogSpannerRequestId.of(replicaBRequestIds.get(0)).getLogicalRequestKey(), + XGoogSpannerRequestId.of(replicaBRequestIds.get(1)).getLogicalRequestKey()); + } + } + + @Test + public void singleUseReadMidStreamRecvFailureWithoutRetryInfoRetriesForBypassTraffic() + throws Exception { + try (SharedBackendReplicaHarness harness = SharedBackendReplicaHarness.create(2); + Spanner spanner = createSpanner(harness)) { + configureBackend(harness, multiRowReadResultSet("b", "c", "d")); + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of(PROJECT, INSTANCE, DATABASE)); + + seedLocationMetadata(client); + waitForReplicaRoutedRead(client, harness, 0); + harness.clearRequests(); + + harness.backend.setStreamingReadExecutionTime( + SimulatedExecutionTime.ofStreamException(resourceExhausted("busy-routed-replica"), 1L)); + + List rows = new ArrayList<>(); + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + while (resultSet.next()) { + rows.add(resultSet.getString(0)); + } + } + + assertEquals(Arrays.asList("b", "c", "d"), rows); + assertEquals( + 1, + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 1, + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + assertEquals( + 0, + harness + .defaultReplica + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .size()); + + ReadRequest replicaARequest = + (ReadRequest) + harness + .replicas + .get(0) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + ReadRequest replicaBRequest = + (ReadRequest) + harness + .replicas + .get(1) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0); + assertTrue(replicaARequest.getResumeToken().isEmpty()); + assertEquals(RESUME_TOKEN_AFTER_FIRST_ROW, replicaBRequest.getResumeToken()); + assertRetriedOnSameLogicalRequest( + harness + .replicas + .get(0) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0), + harness + .replicas + .get(1) + .getRequestIds(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .get(0)); + } + } + + private static Spanner createSpanner(SharedBackendReplicaHarness harness) { + return SpannerOptions.newBuilder() + .usePlainText() + .setExperimentalHost(harness.defaultAddress) + .setProjectId(PROJECT) + .setCredentials(NoCredentials.getInstance()) + .setChannelEndpointCacheFactory(null) + .build() + .getService(); + } + + private static void configureBackend( + SharedBackendReplicaHarness harness, com.google.spanner.v1.ResultSet readResultSet) + throws TextFormat.ParseException { + Statement readStatement = + StatementResult.createReadStatement( + TABLE, KeySet.singleKey(Key.of("b")), Arrays.asList("k")); + harness.backend.putStatementResult(StatementResult.query(readStatement, readResultSet)); + harness.backend.putStatementResult( + StatementResult.query( + SEED_QUERY, + singleRowReadResultSet("seed").toBuilder() + .setCacheUpdate(cacheUpdate(harness)) + .build())); + } + + private static void seedLocationMetadata(DatabaseClient client) { + try (com.google.cloud.spanner.ResultSet resultSet = + client.singleUse().executeQuery(SEED_QUERY)) { + while (resultSet.next()) { + // Consume the cache update on the first query result. + } + } + } + + private static void waitForReplicaRoutedRead( + DatabaseClient client, SharedBackendReplicaHarness harness, int replicaIndex) + throws InterruptedException { + long deadlineNanos = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + while (System.nanoTime() < deadlineNanos) { + try (ResultSet resultSet = + client + .singleUse() + .read( + TABLE, + KeySet.singleKey(Key.of("b")), + Arrays.asList("k"), + Options.directedRead(DIRECTED_READ_OPTIONS))) { + if (resultSet.next() + && !harness + .replicas + .get(replicaIndex) + .getRequests(SharedBackendReplicaHarness.METHOD_STREAMING_READ) + .isEmpty()) { + return; + } + } + Thread.sleep(50L); + } + throw new AssertionError("Timed out waiting for location-aware read to route to replica"); + } + + private static CacheUpdate cacheUpdate(SharedBackendReplicaHarness harness) + throws TextFormat.ParseException { + RecipeList recipes = readRecipeList(); + RoutingHint routingHint = exactReadRoutingHint(recipes); + ByteString limitKey = routingHint.getLimitKey(); + if (limitKey.isEmpty()) { + limitKey = routingHint.getKey().concat(ByteString.copyFrom(new byte[] {0})); + } + + return CacheUpdate.newBuilder() + .setDatabaseId(12345L) + .setKeyRecipes(recipes) + .addRange( + Range.newBuilder() + .setStartKey(routingHint.getKey()) + .setLimitKey(limitKey) + .setGroupUid(1L) + .setSplitId(1L) + .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1"))) + .addGroup( + Group.newBuilder() + .setGroupUid(1L) + .setGeneration(com.google.protobuf.ByteString.copyFromUtf8("gen1")) + .setLeaderIndex(0) + .addTablets( + Tablet.newBuilder() + .setTabletUid(11L) + .setServerAddress(harness.replicaAddresses.get(0)) + .setLocation(REPLICA_LOCATION) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0)) + .addTablets( + Tablet.newBuilder() + .setTabletUid(12L) + .setServerAddress(harness.replicaAddresses.get(1)) + .setLocation(REPLICA_LOCATION) + .setRole(Tablet.Role.READ_ONLY) + .setDistance(0))) + .build(); + } + + private static RecipeList readRecipeList() throws TextFormat.ParseException { + RecipeList.Builder recipes = RecipeList.newBuilder(); + TextFormat.merge( + "schema_generation: \"1\"\n" + + "recipe {\n" + + " table_name: \"" + + TABLE + + "\"\n" + + " part { tag: 1 }\n" + + " part {\n" + + " order: ASCENDING\n" + + " null_order: NULLS_FIRST\n" + + " type { code: STRING }\n" + + " identifier: \"k\"\n" + + " }\n" + + "}\n", + recipes); + return recipes.build(); + } + + private static RoutingHint exactReadRoutingHint(RecipeList recipes) { + KeyRecipeCache recipeCache = new KeyRecipeCache(); + recipeCache.addRecipes(recipes); + ReadRequest.Builder request = + ReadRequest.newBuilder() + .setSession( + String.format( + "projects/%s/instances/%s/databases/%s/sessions/test-session", + PROJECT, INSTANCE, DATABASE)) + .setTable(TABLE) + .addAllColumns(Arrays.asList("k")) + .setDirectedReadOptions(DIRECTED_READ_OPTIONS); + KeySet.singleKey(Key.of("b")).appendToProto(request.getKeySetBuilder()); + recipeCache.computeKeys(request); + return request.getRoutingHint(); + } + + private static io.grpc.StatusRuntimeException resourceExhaustedWithRetryInfo(String description) { + Metadata trailers = new Metadata(); + trailers.put( + ProtoUtils.keyForProto(RetryInfo.getDefaultInstance()), + RetryInfo.newBuilder() + .setRetryDelay( + com.google.protobuf.Duration.newBuilder() + .setNanos((int) TimeUnit.MILLISECONDS.toNanos(1L)) + .build()) + .build()); + return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(trailers); + } + + private static StatusRuntimeException resourceExhausted(String description) { + return Status.RESOURCE_EXHAUSTED.withDescription(description).asRuntimeException(); + } + + private static StatusRuntimeException unavailable(String description) { + return Status.UNAVAILABLE.withDescription(description).asRuntimeException(); + } + + private static void assertRetriedOnSameLogicalRequest( + String firstRequestId, String secondRequestId) { + XGoogSpannerRequestId first = XGoogSpannerRequestId.of(firstRequestId); + XGoogSpannerRequestId second = XGoogSpannerRequestId.of(secondRequestId); + assertEquals(first.getLogicalRequestKey(), second.getLogicalRequestKey()); + assertEquals(first.getAttempt() + 1, second.getAttempt()); + } + + private static com.google.spanner.v1.ResultSet singleRowReadResultSet(String value) { + return readResultSet(Arrays.asList(value)); + } + + private static com.google.spanner.v1.ResultSet multiRowReadResultSet(String... values) { + return readResultSet(Arrays.asList(values)); + } + + private static com.google.spanner.v1.ResultSet readResultSet(List values) { + com.google.spanner.v1.ResultSet.Builder builder = + com.google.spanner.v1.ResultSet.newBuilder() + .setMetadata( + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + StructType.Field.newBuilder() + .setName("k") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build())); + for (String value : values) { + builder.addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue(value).build()) + .build()); + } + return builder.build(); + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 6f40052d0aed..3ea19ad2422a 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -202,9 +202,15 @@ private static class PartialResultSetsIterator implements Iterator responseObserver, SimulatedExecutionTime executionTime, boolean isMultiplexedSession) @@ -1783,7 +1803,8 @@ private void returnPartialResultSet( new PartialResultSetsIterator( resultSet, isMultiplexedSession && isReadWriteTransaction(transactionId), - transactionId); + transactionId, + resumeToken); long index = 0L; while (iterator.hasNext()) { SimulatedExecutionTime.checkStreamException( diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java new file mode 100644 index 000000000000..891ae0f7d19e --- /dev/null +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SharedBackendReplicaHarness.java @@ -0,0 +1,310 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.Empty; +import com.google.spanner.v1.BatchCreateSessionsRequest; +import com.google.spanner.v1.BatchCreateSessionsResponse; +import com.google.spanner.v1.BeginTransactionRequest; +import com.google.spanner.v1.CommitRequest; +import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.CreateSessionRequest; +import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.GetSessionRequest; +import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ReadRequest; +import com.google.spanner.v1.ResultSet; +import com.google.spanner.v1.RollbackRequest; +import com.google.spanner.v1.Session; +import com.google.spanner.v1.SpannerGrpc; +import com.google.spanner.v1.Transaction; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** Shared-backend replica harness for end-to-end location-aware routing tests. */ +final class SharedBackendReplicaHarness implements Closeable { + + static final String METHOD_BATCH_CREATE_SESSIONS = "BatchCreateSessions"; + static final String METHOD_BEGIN_TRANSACTION = "BeginTransaction"; + static final String METHOD_COMMIT = "Commit"; + static final String METHOD_CREATE_SESSION = "CreateSession"; + static final String METHOD_DELETE_SESSION = "DeleteSession"; + static final String METHOD_EXECUTE_SQL = "ExecuteSql"; + static final String METHOD_EXECUTE_STREAMING_SQL = "ExecuteStreamingSql"; + static final String METHOD_GET_SESSION = "GetSession"; + static final String METHOD_READ = "Read"; + static final String METHOD_ROLLBACK = "Rollback"; + static final String METHOD_STREAMING_READ = "StreamingRead"; + + static final class HookedReplicaSpannerService extends SpannerGrpc.SpannerImplBase { + private final MockSpannerServiceImpl backend; + private final Map> methodErrors = new HashMap<>(); + private final Map> requests = new HashMap<>(); + private final Map> requestIds = new HashMap<>(); + + private HookedReplicaSpannerService(MockSpannerServiceImpl backend) { + this.backend = backend; + } + + synchronized void putMethodErrors(String method, Throwable... errors) { + ArrayDeque queue = new ArrayDeque<>(); + for (Throwable error : errors) { + queue.addLast(error); + } + methodErrors.put(method, queue); + } + + synchronized List getRequests(String method) { + return new ArrayList<>(requests.getOrDefault(method, new ArrayList<>())); + } + + synchronized List getRequestIds(String method) { + return new ArrayList<>(requestIds.getOrDefault(method, new ArrayList<>())); + } + + synchronized void clearRequests() { + requests.clear(); + requestIds.clear(); + } + + private synchronized void recordRequest(String method, AbstractMessage request) { + requests.computeIfAbsent(method, ignored -> new ArrayList<>()).add(request); + } + + private synchronized void recordRequestId(String method, String requestId) { + requestIds.computeIfAbsent(method, ignored -> new ArrayList<>()).add(requestId); + } + + private synchronized Throwable nextError(String method) { + ArrayDeque queue = methodErrors.get(method); + if (queue == null || queue.isEmpty()) { + return null; + } + return queue.removeFirst(); + } + + private boolean maybeFail(String method, StreamObserver responseObserver) { + Throwable error = nextError(method); + if (error == null) { + return false; + } + responseObserver.onError(error); + return true; + } + + @Override + public void batchCreateSessions( + BatchCreateSessionsRequest request, + StreamObserver responseObserver) { + recordRequest(METHOD_BATCH_CREATE_SESSIONS, request); + if (!maybeFail(METHOD_BATCH_CREATE_SESSIONS, responseObserver)) { + backend.batchCreateSessions(request, responseObserver); + } + } + + @Override + public void beginTransaction( + BeginTransactionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_BEGIN_TRANSACTION, request); + if (!maybeFail(METHOD_BEGIN_TRANSACTION, responseObserver)) { + backend.beginTransaction(request, responseObserver); + } + } + + @Override + public void commit(CommitRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_COMMIT, request); + if (!maybeFail(METHOD_COMMIT, responseObserver)) { + backend.commit(request, responseObserver); + } + } + + @Override + public void createSession( + CreateSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_CREATE_SESSION, request); + if (!maybeFail(METHOD_CREATE_SESSION, responseObserver)) { + backend.createSession(request, responseObserver); + } + } + + @Override + public void deleteSession( + DeleteSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_DELETE_SESSION, request); + if (!maybeFail(METHOD_DELETE_SESSION, responseObserver)) { + backend.deleteSession(request, responseObserver); + } + } + + @Override + public void executeSql(ExecuteSqlRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_EXECUTE_SQL, request); + if (!maybeFail(METHOD_EXECUTE_SQL, responseObserver)) { + backend.executeSql(request, responseObserver); + } + } + + @Override + public void executeStreamingSql( + ExecuteSqlRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_EXECUTE_STREAMING_SQL, request); + if (!maybeFail(METHOD_EXECUTE_STREAMING_SQL, responseObserver)) { + backend.executeStreamingSql(request, responseObserver); + } + } + + @Override + public void getSession(GetSessionRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_GET_SESSION, request); + if (!maybeFail(METHOD_GET_SESSION, responseObserver)) { + backend.getSession(request, responseObserver); + } + } + + @Override + public void read(ReadRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_READ, request); + if (!maybeFail(METHOD_READ, responseObserver)) { + backend.read(request, responseObserver); + } + } + + @Override + public void rollback(RollbackRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_ROLLBACK, request); + if (!maybeFail(METHOD_ROLLBACK, responseObserver)) { + backend.rollback(request, responseObserver); + } + } + + @Override + public void streamingRead( + ReadRequest request, StreamObserver responseObserver) { + recordRequest(METHOD_STREAMING_READ, request); + if (!maybeFail(METHOD_STREAMING_READ, responseObserver)) { + backend.streamingRead(request, responseObserver); + } + } + } + + private final List servers; + final MockSpannerServiceImpl backend; + final HookedReplicaSpannerService defaultReplica; + final String defaultAddress; + final List replicas; + final List replicaAddresses; + + private SharedBackendReplicaHarness( + MockSpannerServiceImpl backend, + HookedReplicaSpannerService defaultReplica, + String defaultAddress, + List replicas, + List replicaAddresses, + List servers) { + this.backend = backend; + this.defaultReplica = defaultReplica; + this.defaultAddress = defaultAddress; + this.replicas = replicas; + this.replicaAddresses = replicaAddresses; + this.servers = servers; + } + + static SharedBackendReplicaHarness create(int replicaCount) throws IOException { + MockSpannerServiceImpl backend = new MockSpannerServiceImpl(); + backend.setAbortProbability(0.0D); + List servers = new ArrayList<>(); + HookedReplicaSpannerService defaultReplica = new HookedReplicaSpannerService(backend); + List replicas = new ArrayList<>(); + List replicaAddresses = new ArrayList<>(); + String defaultAddress = startServer(servers, defaultReplica); + for (int i = 0; i < replicaCount; i++) { + HookedReplicaSpannerService replica = new HookedReplicaSpannerService(backend); + replicas.add(replica); + replicaAddresses.add(startServer(servers, replica)); + } + return new SharedBackendReplicaHarness( + backend, defaultReplica, defaultAddress, replicas, replicaAddresses, servers); + } + + private static String startServer(List servers, HookedReplicaSpannerService service) + throws IOException { + InetSocketAddress address = new InetSocketAddress("localhost", 0); + ServerInterceptor interceptor = + new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + service.recordRequestId( + call.getMethodDescriptor().getBareMethodName(), + headers.get(XGoogSpannerRequestId.REQUEST_ID_HEADER_KEY)); + return next.startCall(call, headers); + } + }; + Server server = + NettyServerBuilder.forAddress(address) + .addService(ServerInterceptors.intercept(service, interceptor)) + .build() + .start(); + servers.add(server); + return "localhost:" + server.getPort(); + } + + void clearRequests() { + defaultReplica.clearRequests(); + for (HookedReplicaSpannerService replica : replicas) { + replica.clearRequests(); + } + } + + @Override + public void close() throws IOException { + IOException failure = null; + for (Server server : servers) { + server.shutdown(); + } + for (Server server : servers) { + try { + server.awaitTermination(5L, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (failure == null) { + failure = new IOException("Interrupted while stopping replica harness", e); + } + } + } + if (failure != null) { + throw failure; + } + } +} diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java index 1ad3888b4f9d..1c0a277ca4f4 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/spi/v1/KeyAwareChannelTest.java @@ -56,6 +56,10 @@ import io.grpc.MethodDescriptor; import io.grpc.Status; import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -458,9 +462,11 @@ public void singleUseCommitUsesSameMutationSelectionHeuristicAsBeginTransaction( @Test public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { - TestHarness harness = createHarness(); + TestHarness harness = createHarness(createDeterministicCooldownTracker()); seedCache(harness, createLeaderAndReplicaCacheUpdate()); - CallOptions retryCallOptions = retryCallOptions(1L); + XGoogSpannerRequestId requestId = retryRequestId(1L); + CallOptions retryCallOptions = retryCallOptions(requestId); + String logicalRequestKey = requestId.getLogicalRequestKey(); ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder() @@ -481,6 +487,12 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { harness.endpointCache.latestCallForAddress("server-a:1234"); firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + logicalRequestKey, "server-a:1234")) + .isTrue(); + ClientCall secondCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), retryCallOptions); secondCall.start(new CapturingListener(), new Metadata()); @@ -488,6 +500,11 @@ public void resourceExhaustedRoutedEndpointIsAvoidedOnRetry() throws Exception { assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + logicalRequestKey, "server-a:1234")) + .isFalse(); } @Test @@ -590,11 +607,16 @@ public void resourceExhaustedRoutedEndpointFallsBackToDefaultWhenNoReplicaExists } @Test - public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws Exception { - TestHarness harness = createHarness(); + public void resourceExhaustedCooldownAffectsDifferentLogicalRequestButExclusionDoesNot() + throws Exception { + TestHarness harness = createHarness(createDeterministicCooldownTracker()); seedCache(harness, createLeaderAndReplicaCacheUpdate()); - CallOptions firstLogicalRequest = retryCallOptions(4L); - CallOptions secondLogicalRequest = retryCallOptions(5L); + XGoogSpannerRequestId firstRequestId = retryRequestId(4L); + XGoogSpannerRequestId secondRequestId = retryRequestId(5L); + CallOptions firstLogicalRequest = retryCallOptions(firstRequestId); + CallOptions secondLogicalRequest = retryCallOptions(secondRequestId); + String firstLogicalRequestKey = firstRequestId.getLogicalRequestKey(); + String secondLogicalRequestKey = secondRequestId.getLogicalRequestKey(); ExecuteSqlRequest request = ExecuteSqlRequest.newBuilder() @@ -613,21 +635,47 @@ public void resourceExhaustedSkipDoesNotAffectDifferentLogicalRequest() throws E harness.endpointCache.latestCallForAddress("server-a:1234"); firstDelegate.emitOnClose(Status.RESOURCE_EXHAUSTED, new Metadata()); + assertThat(harness.channel.isCoolingDown("server-a:1234")).isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); + ClientCall unrelatedCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), secondLogicalRequest); unrelatedCall.start(new CapturingListener(), new Metadata()); unrelatedCall.sendMessage(request); - assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2); - assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(0); + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isTrue(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); ClientCall retriedFirstCall = harness.channel.newCall(SpannerGrpc.getExecuteSqlMethod(), firstLogicalRequest); retriedFirstCall.start(new CapturingListener(), new Metadata()); retriedFirstCall.sendMessage(request); - assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(2); - assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-a:1234")).isEqualTo(1); + assertThat(harness.endpointCache.callCountForAddress("server-b:1234")).isEqualTo(2); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + firstLogicalRequestKey, "server-a:1234")) + .isFalse(); + assertThat( + harness.channel.hasExcludedEndpointForLogicalRequest( + secondLogicalRequestKey, "server-a:1234")) + .isFalse(); } @Test @@ -1235,13 +1283,28 @@ private static RecipeList parseRecipeList(String text) throws TextFormat.ParseEx } private static TestHarness createHarness() throws IOException { + return createHarness(new EndpointOverloadCooldownTracker()); + } + + private static TestHarness createHarness(EndpointOverloadCooldownTracker tracker) + throws IOException { FakeEndpointCache endpointCache = new FakeEndpointCache(DEFAULT_ADDRESS); InstantiatingGrpcChannelProvider provider = InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:9999").build(); - KeyAwareChannel channel = KeyAwareChannel.create(provider, baseProvider -> endpointCache); + KeyAwareChannel channel = + KeyAwareChannel.create(provider, baseProvider -> endpointCache, tracker); return new TestHarness(channel, endpointCache, endpointCache.defaultManagedChannel()); } + private static EndpointOverloadCooldownTracker createDeterministicCooldownTracker() { + return new EndpointOverloadCooldownTracker( + Duration.ofMinutes(1), + Duration.ofMinutes(1), + Duration.ofMinutes(10), + Clock.fixed(Instant.ofEpochSecond(100), ZoneOffset.UTC), + bound -> bound - 1L); + } + private static final class TestHarness { private final KeyAwareChannel channel; private final FakeEndpointCache endpointCache; @@ -1483,9 +1546,16 @@ private static ByteString bytes(String value) { return ByteString.copyFromUtf8(value); } + private static XGoogSpannerRequestId retryRequestId(long nthRequest) { + return XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L); + } + private static CallOptions retryCallOptions(long nthRequest) { + return retryCallOptions(retryRequestId(nthRequest)); + } + + private static CallOptions retryCallOptions(XGoogSpannerRequestId requestId) { return CallOptions.DEFAULT.withOption( - XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, - XGoogSpannerRequestId.of(1L, 0L, nthRequest, 0L)); + XGoogSpannerRequestId.REQUEST_ID_CALL_OPTIONS_KEY, requestId); } }