Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright 2026 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.spanner.spi.v1;

import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.LongUnaryOperator;

/**
* Tracks short-lived endpoint cooldowns after routed {@code RESOURCE_EXHAUSTED} failures.
*
* <p>This allows later requests to try a different replica instead of immediately routing back to
* the same overloaded endpoint.
*/
final class EndpointOverloadCooldownTracker {

@VisibleForTesting static final Duration DEFAULT_INITIAL_COOLDOWN = Duration.ofSeconds(10);
@VisibleForTesting static final Duration DEFAULT_MAX_COOLDOWN = Duration.ofMinutes(1);
@VisibleForTesting static final Duration DEFAULT_RESET_AFTER = Duration.ofMinutes(10);

@VisibleForTesting
static final class CooldownState {
private final int consecutiveFailures;
private final Instant cooldownUntil;
private final Instant lastFailureAt;

private CooldownState(int consecutiveFailures, Instant cooldownUntil, Instant lastFailureAt) {
this.consecutiveFailures = consecutiveFailures;
this.cooldownUntil = cooldownUntil;
this.lastFailureAt = lastFailureAt;
}
}

private final ConcurrentHashMap<String, CooldownState> entries = new ConcurrentHashMap<>();
private final Duration initialCooldown;
private final Duration maxCooldown;
private final Duration resetAfter;
private final Clock clock;
private final LongUnaryOperator randomLong;

EndpointOverloadCooldownTracker() {
this(
DEFAULT_INITIAL_COOLDOWN,
DEFAULT_MAX_COOLDOWN,
DEFAULT_RESET_AFTER,
Clock.systemUTC(),
bound -> ThreadLocalRandom.current().nextLong(bound));
}

@VisibleForTesting
EndpointOverloadCooldownTracker(
Duration initialCooldown,
Duration maxCooldown,
Duration resetAfter,
Clock clock,
LongUnaryOperator randomLong) {
Duration resolvedInitial =
(initialCooldown == null || initialCooldown.isZero() || initialCooldown.isNegative())
? DEFAULT_INITIAL_COOLDOWN
: initialCooldown;
Duration resolvedMax =
(maxCooldown == null || maxCooldown.isZero() || maxCooldown.isNegative())
? DEFAULT_MAX_COOLDOWN
: maxCooldown;
if (resolvedMax.compareTo(resolvedInitial) < 0) {
resolvedMax = resolvedInitial;
}
this.initialCooldown = resolvedInitial;
this.maxCooldown = resolvedMax;
this.resetAfter =
(resetAfter == null || resetAfter.isZero() || resetAfter.isNegative())
? DEFAULT_RESET_AFTER
: resetAfter;
this.clock = clock == null ? Clock.systemUTC() : clock;
this.randomLong =
randomLong == null ? bound -> ThreadLocalRandom.current().nextLong(bound) : randomLong;
}

boolean isCoolingDown(String address) {
if (address == null || address.isEmpty()) {
return false;
}
Instant now = clock.instant();
CooldownState state = entries.get(address);
if (state == null) {
return false;
}
if (state.cooldownUntil.isAfter(now)) {
return true;
}
if (Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
return false;
}
entries.remove(address, state);
CooldownState current = entries.get(address);
return current != null && current.cooldownUntil.isAfter(now);
}

void recordFailure(String address) {
if (address == null || address.isEmpty()) {
return;
}
Instant now = clock.instant();
entries.compute(
address,
(ignored, state) -> {
int consecutiveFailures = 1;
if (state != null
&& Duration.between(state.lastFailureAt, now).compareTo(resetAfter) < 0) {
consecutiveFailures = state.consecutiveFailures + 1;
}
Duration cooldown = cooldownForFailures(consecutiveFailures);
return new CooldownState(consecutiveFailures, now.plus(cooldown), now);
});
}

private Duration cooldownForFailures(int failures) {
Duration cooldown = initialCooldown;
for (int i = 1; i < failures; i++) {
if (cooldown.compareTo(maxCooldown.dividedBy(2)) > 0) {
cooldown = maxCooldown;
break;
}
cooldown = cooldown.multipliedBy(2);
}
long bound = Math.max(1L, cooldown.toMillis() + 1L);
return Duration.ofMillis(randomLong.applyAsLong(bound));
}

@VisibleForTesting
CooldownState getState(String address) {
return entries.get(address);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,15 @@ public GapicSpannerRpc(final SpannerOptions options) {
&& isEnableDirectAccess;
this.readRetrySettings =
options.getSpannerStubSettings().streamingReadSettings().getRetrySettings();
this.readRetryableCodes =
Set<Code> streamingReadRetryableCodes =
options.getSpannerStubSettings().streamingReadSettings().getRetryableCodes();
this.readRetryableCodes =
enableLocationApi
? ImmutableSet.<Code>builder()
.addAll(streamingReadRetryableCodes)
.add(Code.RESOURCE_EXHAUSTED)
.build()
: streamingReadRetryableCodes;
this.executeQueryRetrySettings =
options.getSpannerStubSettings().executeStreamingSqlSettings().getRetrySettings();
this.executeQueryRetryableCodes =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.api.core.InternalApi;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider;
import com.google.cloud.spanner.XGoogSpannerRequestId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.protobuf.ByteString;
Expand Down Expand Up @@ -94,19 +95,28 @@ final class KeyAwareChannel extends ManagedChannel {
// Bounded to prevent unbounded growth if application code does not close read-only transactions.
private final Cache<ByteString, Boolean> readOnlyTxPreferLeader =
CacheBuilder.newBuilder().maximumSize(MAX_TRACKED_READ_ONLY_TRANSACTIONS).build();
// If a routed endpoint returns RESOURCE_EXHAUSTED, the next retry attempt of that same logical
// request should avoid that endpoint once so other requests are unaffected. Bound and age out
// entries in case a caller gives up and never issues a retry.
// If a routed endpoint returns RESOURCE_EXHAUSTED or UNAVAILABLE, the next retry attempt of
// that same logical request should avoid that endpoint once so other requests are unaffected.
// Bound and age out entries in case a caller gives up and never issues a retry.
private final Cache<String, Set<String>> excludedEndpointsForLogicalRequest =
CacheBuilder.newBuilder()
.maximumSize(MAX_TRACKED_EXCLUDED_LOGICAL_REQUESTS)
.expireAfterWrite(EXCLUDED_LOGICAL_REQUEST_TTL_MINUTES, TimeUnit.MINUTES)
.build();
private final EndpointOverloadCooldownTracker endpointOverloadCooldowns;

private KeyAwareChannel(
InstantiatingGrpcChannelProvider channelProvider,
@Nullable ChannelEndpointCacheFactory endpointCacheFactory)
throws IOException {
this(channelProvider, endpointCacheFactory, new EndpointOverloadCooldownTracker());
}

private KeyAwareChannel(
InstantiatingGrpcChannelProvider channelProvider,
@Nullable ChannelEndpointCacheFactory endpointCacheFactory,
EndpointOverloadCooldownTracker endpointOverloadCooldowns)
throws IOException {
if (endpointCacheFactory == null) {
this.endpointCache = new GrpcChannelEndpointCache(channelProvider);
} else {
Expand All @@ -120,6 +130,7 @@ private KeyAwareChannel(
// would interfere with test assertions.
this.lifecycleManager =
(endpointCacheFactory == null) ? new EndpointLifecycleManager(endpointCache) : null;
this.endpointOverloadCooldowns = endpointOverloadCooldowns;
}

static KeyAwareChannel create(
Expand All @@ -129,6 +140,15 @@ static KeyAwareChannel create(
return new KeyAwareChannel(channelProvider, endpointCacheFactory);
}

@VisibleForTesting
static KeyAwareChannel create(
InstantiatingGrpcChannelProvider channelProvider,
@Nullable ChannelEndpointCacheFactory endpointCacheFactory,
EndpointOverloadCooldownTracker endpointOverloadCooldowns)
throws IOException {
return new KeyAwareChannel(channelProvider, endpointCacheFactory, endpointOverloadCooldowns);
}

private static final class ChannelFinderReference extends SoftReference<ChannelFinder> {
final String databaseId;

Expand Down Expand Up @@ -321,36 +341,61 @@ void clearTransactionAndChannelAffinity(ByteString transactionId, @Nullable Long

private void maybeExcludeEndpointOnNextCall(
@Nullable ChannelEndpoint endpoint, @Nullable String logicalRequestKey) {
if (endpoint == null || logicalRequestKey == null) {
if (endpoint == null) {
return;
}
String address = endpoint.getAddress();
if (!defaultEndpointAddress.equals(address)) {
excludedEndpointsForLogicalRequest
.asMap()
.compute(
logicalRequestKey,
(ignored, excludedEndpoints) -> {
Set<String> updated =
excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints;
updated.add(address);
return updated;
});
if (defaultEndpointAddress.equals(address)) {
return;
}
endpointOverloadCooldowns.recordFailure(address);
if (logicalRequestKey == null) {
return;
}
excludedEndpointsForLogicalRequest
.asMap()
.compute(
logicalRequestKey,
(ignored, excludedEndpoints) -> {
Set<String> updated =
excludedEndpoints == null ? ConcurrentHashMap.newKeySet() : excludedEndpoints;
updated.add(address);
return updated;
});
}

private static boolean shouldExcludeEndpointOnRetry(io.grpc.Status.Code statusCode) {
return statusCode == io.grpc.Status.Code.RESOURCE_EXHAUSTED
|| statusCode == io.grpc.Status.Code.UNAVAILABLE;
}

private Predicate<String> consumeExcludedEndpointsForCurrentCall(
@Nullable String logicalRequestKey) {
if (logicalRequestKey == null) {
return address -> false;
Predicate<String> requestScopedExcluded = address -> false;
if (logicalRequestKey != null) {
Set<String> excludedEndpoints =
excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey);
if (excludedEndpoints != null && !excludedEndpoints.isEmpty()) {
excludedEndpoints = new HashSet<>(excludedEndpoints);
requestScopedExcluded = excludedEndpoints::contains;
}
}
Predicate<String> finalRequestScopedExcluded = requestScopedExcluded;
return address ->
finalRequestScopedExcluded.test(address)
|| endpointOverloadCooldowns.isCoolingDown(address);
}

@VisibleForTesting
boolean isCoolingDown(String address) {
return endpointOverloadCooldowns.isCoolingDown(address);
}

@VisibleForTesting
boolean hasExcludedEndpointForLogicalRequest(String logicalRequestKey, String address) {
Set<String> excludedEndpoints =
excludedEndpointsForLogicalRequest.asMap().remove(logicalRequestKey);
if (excludedEndpoints == null || excludedEndpoints.isEmpty()) {
return address -> false;
}
excludedEndpoints = new HashSet<>(excludedEndpoints);
return excludedEndpoints::contains;
excludedEndpointsForLogicalRequest.getIfPresent(logicalRequestKey);
return excludedEndpoints != null && excludedEndpoints.contains(address);
}

private boolean isReadOnlyTransaction(ByteString transactionId) {
Expand Down Expand Up @@ -858,7 +903,7 @@ public void onMessage(ResponseT message) {

@Override
public void onClose(io.grpc.Status status, Metadata trailers) {
if (status.getCode() == io.grpc.Status.Code.RESOURCE_EXHAUSTED) {
if (shouldExcludeEndpointOnRetry(status.getCode())) {
call.parentChannel.maybeExcludeEndpointOnNextCall(
call.selectedEndpoint, call.logicalRequestKey);
}
Expand Down
Loading
Loading