From ff959b7788c20d26a6f5a0281e122f6aabdc0b88 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Thu, 16 Apr 2026 10:58:53 -0700 Subject: [PATCH 1/2] Update codegen fixture files --- ...test-bearer-auth-client-builder-class.java | 6 - .../builder/test-client-builder-class.java | 6 - ...-client-builder-endpoints-auth-params.java | 6 - ...lient-builder-internal-defaults-class.java | 6 - ...-composed-sync-default-client-builder.java | 6 - ...env-bearer-token-client-builder-class.java | 6 - .../test-h2-service-client-builder-class.java | 6 - ...ulti-auth-sigv4a-client-builder-class.java | 6 - ...test-no-auth-ops-client-builder-class.java | 6 - ...-no-auth-service-client-builder-class.java | 6 - .../test-query-client-builder-class.java | 6 - .../test-aws-json-async-client-class.java | 216 +++++++++++++---- ...ry-compatible-json-async-client-class.java | 73 +++++- ...ery-compatible-json-sync-client-class.java | 62 +++++ .../poet/client/test-batchmanager-async.java | 64 ++++- .../client/test-cbor-async-client-class.java | 219 ++++++++++++++---- .../poet/client/test-cbor-client-class.java | 146 ++++++++++-- ...tom-context-params-async-client-class.java | 2 +- ...stom-context-params-sync-client-class.java | 103 ++++++-- .../poet/client/test-custompackage-async.java | 73 +++++- .../poet/client/test-custompackage-sync.java | 71 +++++- .../test-customservicemetadata-async.java | 73 +++++- .../test-customservicemetadata-sync.java | 64 +++++ .../client/test-endpoint-discovery-async.java | 109 ++++++++- .../client/test-endpoint-discovery-sync.java | 116 ++++++++-- .../client/test-json-async-client-class.java | 216 ++++++++++++++--- .../poet/client/test-json-client-class.java | 2 +- .../client/test-query-async-client-class.java | 217 ++++++++++++++--- .../poet/client/test-query-client-class.java | 67 +++++- .../client/test-rpcv2-async-client-class.java | 209 ++++++++++++++--- .../codegen/poet/client/test-rpcv2-sync.java | 113 ++++++++- ...gned-payload-trait-async-client-class.java | 182 ++++++++++++--- ...igned-payload-trait-sync-client-class.java | 168 ++++++++++++-- .../client/test-xml-async-client-class.java | 148 ++++++++++-- .../poet/client/test-xml-client-class.java | 119 +++++++++- ...esolver-utils-with-endpointsbasedauth.java | 90 +++---- ...t-resolver-utils-with-multiauthsigv4a.java | 24 +- .../poet/rules/endpoint-resolver-utils.java | 93 ++++---- 38 files changed, 2577 insertions(+), 528 deletions(-) diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java index ee8f9a73d3e5..1b7d802857d9 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-bearer-auth-client-builder-class.java @@ -32,10 +32,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -74,9 +71,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java index a0bdac67d04d..cdba2ca0888c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-class.java @@ -41,11 +41,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonClientContextParams; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -86,9 +83,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java index 360d3664eaad..0fe341a73bca 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-endpoints-auth-params.java @@ -40,11 +40,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; -import software.amazon.awssdk.services.query.auth.scheme.internal.QueryAuthSchemeInterceptor; import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; -import software.amazon.awssdk.services.query.endpoints.internal.QueryRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.query.endpoints.internal.QueryResolveEndpointInterceptor; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -83,9 +80,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new QueryAuthSchemeInterceptor()); - endpointInterceptors.add(new QueryResolveEndpointInterceptor()); - endpointInterceptors.add(new QueryRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/query/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java index 9b143b9ccd69..f8caa8885f3a 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-client-builder-internal-defaults-class.java @@ -29,10 +29,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -75,9 +72,6 @@ protected final SdkClientConfiguration mergeInternalDefaults(SdkClientConfigurat @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java index 117e19038881..168c15a64e31 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-composed-sync-default-client-builder.java @@ -37,11 +37,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonClientContextParams; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -81,9 +78,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java index 48ecf08535fa..615f6a76d1a5 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-env-bearer-token-client-builder-class.java @@ -36,10 +36,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; -import software.amazon.awssdk.services.json.auth.scheme.internal.JsonAuthSchemeInterceptor; import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; -import software.amazon.awssdk.services.json.endpoints.internal.JsonRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.json.endpoints.internal.JsonResolveEndpointInterceptor; import software.amazon.awssdk.services.json.internal.EnvironmentTokenSystemSettings; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -91,9 +88,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new JsonAuthSchemeInterceptor()); - endpointInterceptors.add(new JsonResolveEndpointInterceptor()); - endpointInterceptors.add(new JsonRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/json/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java index 76c1cd2fc7eb..2e5198dac4b6 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-service-client-builder-class.java @@ -32,10 +32,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.h2.auth.scheme.H2AuthSchemeProvider; -import software.amazon.awssdk.services.h2.auth.scheme.internal.H2AuthSchemeInterceptor; import software.amazon.awssdk.services.h2.endpoints.H2EndpointProvider; -import software.amazon.awssdk.services.h2.endpoints.internal.H2RequestSetEndpointInterceptor; -import software.amazon.awssdk.services.h2.endpoints.internal.H2ResolveEndpointInterceptor; import software.amazon.awssdk.services.h2.internal.H2ServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -71,9 +68,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new H2AuthSchemeInterceptor()); - endpointInterceptors.add(new H2ResolveEndpointInterceptor()); - endpointInterceptors.add(new H2RequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/h2/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java index 75faf2cad7a8..fcc5f72b4e25 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-multi-auth-sigv4a-client-builder-class.java @@ -31,10 +31,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -70,9 +67,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java index 72d4f526bfb3..5da93144892e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-ops-client-builder-class.java @@ -33,10 +33,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -76,9 +73,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java index 0be9c031d828..05f0b72afa2b 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-no-auth-service-client-builder-class.java @@ -27,10 +27,7 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; -import software.amazon.awssdk.services.database.auth.scheme.internal.DatabaseAuthSchemeInterceptor; import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.database.endpoints.internal.DatabaseResolveEndpointInterceptor; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; @@ -66,9 +63,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new DatabaseAuthSchemeInterceptor()); - endpointInterceptors.add(new DatabaseResolveEndpointInterceptor()); - endpointInterceptors.add(new DatabaseRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/database/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java index 19b8d5abbae1..db0dae50266e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java @@ -38,11 +38,8 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; -import software.amazon.awssdk.services.query.auth.scheme.internal.QueryAuthSchemeInterceptor; import software.amazon.awssdk.services.query.endpoints.QueryClientContextParams; import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; -import software.amazon.awssdk.services.query.endpoints.internal.QueryRequestSetEndpointInterceptor; -import software.amazon.awssdk.services.query.endpoints.internal.QueryResolveEndpointInterceptor; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Validate; @@ -81,9 +78,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new QueryAuthSchemeInterceptor()); - endpointInterceptors.add(new QueryResolveEndpointInterceptor()); - endpointInterceptors.add(new QueryRequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); List interceptors = interceptorFactory .getInterceptors("software/amazon/awssdk/services/query/execution.interceptors"); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java index a7d94b064cd8..0c9eb9d0326b 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-json-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; @@ -15,8 +16,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -29,6 +34,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -40,7 +46,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -48,6 +56,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -57,6 +67,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -118,6 +133,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -216,10 +232,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -294,10 +316,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -400,13 +428,20 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withInitialRequestEvent(true).withResponseHandler(voidResponseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - asyncResponseTransformer); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withInitialRequestEvent(true) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { try { @@ -492,11 +527,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withInitialRequestEvent(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withInitialRequestEvent(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -596,10 +638,16 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(voidResponseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -682,10 +730,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -763,6 +817,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -833,10 +890,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -914,6 +977,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -986,10 +1052,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1060,10 +1132,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1143,10 +1221,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1238,8 +1321,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1330,10 +1418,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1387,6 +1481,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(JsonAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java index 96f890a789a1..04eb7325e084 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-async-client-class.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeParams; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointParams; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.internal.QueryToJsonCompatibleEndpointResolverUtils; import software.amazon.awssdk.services.querytojsoncompatible.internal.QueryToJsonCompatibleServiceClientConfigurationBuilder; import software.amazon.awssdk.services.querytojsoncompatible.internal.ServiceVersionInfo; import software.amazon.awssdk.services.querytojsoncompatible.model.APostOperationRequest; @@ -46,6 +62,7 @@ import software.amazon.awssdk.services.querytojsoncompatible.transform.APostOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryToJsonCompatibleAsyncClient}. @@ -133,10 +150,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -180,6 +203,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryToJsonCompatibleAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + QueryToJsonCompatibleAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider"); + QueryToJsonCompatibleAuthSchemeParams.Builder paramsBuilder = QueryToJsonCompatibleAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryToJsonCompatibleEndpointProvider provider = (QueryToJsonCompatibleEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryToJsonCompatibleEndpointParams endpointParams = QueryToJsonCompatibleEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryToJsonCompatibleEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryToJsonCompatibleEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryToJsonCompatibleEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java index d4fb640000aa..605351f3612d 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-aws-query-compatible-json-sync-client-class.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeParams; +import software.amazon.awssdk.services.querytojsoncompatible.auth.scheme.QueryToJsonCompatibleAuthSchemeProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointParams; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.QueryToJsonCompatibleEndpointProvider; +import software.amazon.awssdk.services.querytojsoncompatible.endpoints.internal.QueryToJsonCompatibleEndpointResolverUtils; import software.amazon.awssdk.services.querytojsoncompatible.internal.QueryToJsonCompatibleServiceClientConfigurationBuilder; import software.amazon.awssdk.services.querytojsoncompatible.internal.ServiceVersionInfo; import software.amazon.awssdk.services.querytojsoncompatible.model.APostOperationRequest; @@ -42,6 +57,7 @@ import software.amazon.awssdk.services.querytojsoncompatible.transform.APostOperationRequestMarshaller; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryToJsonCompatibleClient}. @@ -129,6 +145,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -155,6 +173,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryToJsonCompatibleAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + QueryToJsonCompatibleAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryToJsonCompatibleAuthSchemeProvider"); + QueryToJsonCompatibleAuthSchemeParams.Builder paramsBuilder = QueryToJsonCompatibleAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryToJsonCompatibleEndpointProvider provider = (QueryToJsonCompatibleEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryToJsonCompatibleEndpointParams endpointParams = QueryToJsonCompatibleEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryToJsonCompatibleEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryToJsonCompatibleEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryToJsonCompatibleEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java index 1db4ed5a574f..5ede080c6e5e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-batchmanager-async.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Function; @@ -13,7 +14,11 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -21,14 +26,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -38,7 +49,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.batchmanagertest.auth.scheme.BatchManagerTestAuthSchemeParams; +import software.amazon.awssdk.services.batchmanagertest.auth.scheme.BatchManagerTestAuthSchemeProvider; import software.amazon.awssdk.services.batchmanagertest.batchmanager.BatchManagerTestAsyncBatchManager; +import software.amazon.awssdk.services.batchmanagertest.endpoints.BatchManagerTestEndpointParams; +import software.amazon.awssdk.services.batchmanagertest.endpoints.BatchManagerTestEndpointProvider; +import software.amazon.awssdk.services.batchmanagertest.endpoints.internal.BatchManagerTestEndpointResolverUtils; import software.amazon.awssdk.services.batchmanagertest.internal.BatchManagerTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.batchmanagertest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.batchmanagertest.model.BatchManagerTestException; @@ -46,6 +62,7 @@ import software.amazon.awssdk.services.batchmanagertest.model.SendRequestResponse; import software.amazon.awssdk.services.batchmanagertest.transform.SendRequestRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link BatchManagerTestAsyncClient}. @@ -129,7 +146,8 @@ public CompletableFuture sendRequest(SendRequestRequest sen .withMarshaller(new SendRequestRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(sendRequestRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "SendRequest", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SendRequest")).withInput(sendRequestRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -177,6 +195,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + BatchManagerTestAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(BatchManagerTestAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of BatchManagerTestAuthSchemeProvider"); + BatchManagerTestAuthSchemeParams.Builder paramsBuilder = BatchManagerTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + BatchManagerTestEndpointProvider provider = (BatchManagerTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + BatchManagerTestEndpointParams endpointParams = BatchManagerTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = BatchManagerTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = BatchManagerTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + BatchManagerTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java index 073ba89daf3e..4f1b4180cdcf 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import java.util.function.Function; @@ -15,8 +16,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -29,6 +34,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -40,7 +46,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -48,6 +56,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -58,6 +68,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -119,6 +134,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -146,8 +162,7 @@ final class DefaultJsonAsyncClient implements JsonAsyncClient { protected DefaultJsonAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, - "Json_Service" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Json_Service" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsCborProtocolFactory.builder()).build(); this.jsonProtocolFactory = init(AwsJsonProtocolFactory.builder()).build(); this.executor = clientConfiguration.option(SdkAdvancedAsyncClientOption.FUTURE_COMPLETION_EXECUTOR); @@ -221,10 +236,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -299,10 +320,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -405,13 +432,20 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withInitialRequestEvent(true).withResponseHandler(voidResponseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - asyncResponseTransformer); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withInitialRequestEvent(true) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { try { @@ -497,11 +531,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withInitialRequestEvent(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withInitialRequestEvent(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -601,10 +642,16 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(voidResponseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(voidResponseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), asyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -687,10 +734,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -768,6 +821,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -838,10 +894,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -919,6 +981,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -991,10 +1056,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1065,10 +1136,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1148,10 +1225,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1243,8 +1325,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1335,10 +1422,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1392,6 +1485,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(JsonAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java index a69befcfad26..a7e8c5ce9faa 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-cbor-client-class.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,6 +28,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; @@ -30,6 +37,8 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -79,6 +93,7 @@ import software.amazon.awssdk.services.json.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonClient}. @@ -169,6 +184,8 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -234,10 +251,16 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -302,10 +325,16 @@ public GetWithoutRequiredMembersResponse getWithoutRequiredMembers( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(getWithoutRequiredMembersRequest) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(getWithoutRequiredMembersRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -373,6 +402,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -435,10 +467,16 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -506,6 +544,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -568,10 +609,16 @@ public PaginatedOperationWithResultKeyResponse paginatedOperationWithResultKey( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithResultKeyRequest) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -632,10 +679,16 @@ public PaginatedOperationWithoutResultKeyResponse paginatedOperationWithoutResul return clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(paginatedOperationWithoutResultKeyRequest) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(paginatedOperationWithoutResultKeyRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -713,6 +766,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -803,6 +859,9 @@ public ReturnT streamingInputOutputOperation( .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOutputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withResponseTransformer(responseTransformer) .withRequestBody(requestBody) .withMarshaller( @@ -877,10 +936,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -915,6 +981,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(JsonAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java index 4e6c868852fe..168c55537fdc 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java @@ -17,11 +17,11 @@ import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; -import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java index 7b8faf909d0c..a8fa677a4052 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-sync-client-class.java @@ -4,11 +4,16 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -16,6 +21,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -23,8 +29,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -34,7 +44,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeParams; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeProvider; import software.amazon.awssdk.services.foobar.endpoints.FooBarClientContextParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointParams; +import software.amazon.awssdk.services.foobar.endpoints.FooBarEndpointProvider; +import software.amazon.awssdk.services.foobar.endpoints.internal.FooBarEndpointResolverUtils; import software.amazon.awssdk.services.foobar.internal.FooBarServiceClientConfigurationBuilder; import software.amazon.awssdk.services.foobar.internal.ServiceVersionInfo; import software.amazon.awssdk.services.foobar.model.FooBarException; @@ -56,7 +71,7 @@ final class DefaultFooBarClient implements FooBarClient { private static final Logger log = Logger.loggerFor(DefaultFooBarClient.class); private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() - .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); private final SyncClientHandler clientHandler; @@ -67,7 +82,7 @@ final class DefaultFooBarClient implements FooBarClient { protected DefaultFooBarClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsSyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -91,39 +106,41 @@ protected DefaultFooBarClient(SdkClientConfiguration clientConfiguration) { */ @Override public GetDatabaseVersionResponse getDatabaseVersion(GetDatabaseVersionRequest getDatabaseVersionRequest) - throws AwsServiceException, SdkClientException, FooBarException { + throws AwsServiceException, SdkClientException, FooBarException { JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) - .isPayloadJson(true).build(); + .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( - operationMetadata, GetDatabaseVersionResponse::builder); + operationMetadata, GetDatabaseVersionResponse::builder); Function> exceptionMetadataMapper = errorCode -> { if (errorCode == null) { return Optional.empty(); } switch (errorCode) { - default: - return Optional.empty(); + default: + return Optional.empty(); } }; HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, - operationMetadata, exceptionMetadataMapper); + operationMetadata, exceptionMetadataMapper); SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(getDatabaseVersionRequest, - this.clientConfiguration); + this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, getDatabaseVersionRequest - .overrideConfiguration().orElse(null)); + .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector - .create("ApiCall"); + .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "Foo Bar"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "GetDatabaseVersion"); return clientHandler.execute(new ClientExecutionParams() - .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(getDatabaseVersionRequest) - .withMetricCollector(apiCallMetricCollector) - .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory))); + .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration).withInput(getDatabaseVersionRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetDatabaseVersion", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetDatabaseVersion")) + .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -135,7 +152,7 @@ public final String serviceName() { } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, - RequestOverrideConfiguration requestOverrideConfiguration) { + RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); @@ -149,8 +166,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + FooBarAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(FooBarAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of FooBarAuthSchemeProvider"); + FooBarAuthSchemeParams.Builder paramsBuilder = FooBarAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + FooBarEndpointProvider provider = (FooBarEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + FooBarEndpointParams endpointParams = FooBarEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = FooBarEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = FooBarEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + FooBarEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, - JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { + JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); } @@ -192,16 +251,16 @@ private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, newContextParams = (newContextParams != null) ? newContextParams : AttributeMap.empty(); originalContextParams = originalContextParams != null ? originalContextParams : AttributeMap.empty(); Validate.validState( - Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), - newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), - "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); + Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), + newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), + "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); updateRetryStrategyClientConfiguration(configuration); return configuration.build(); } private > T init(T builder) { return builder.clientConfiguration(clientConfiguration).defaultServiceExceptionSupplier(FooBarException::builder) - .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); + .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); } @Override diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java index bc69607bf5d9..949b5fa5e92a 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-async.java @@ -2,6 +2,11 @@ import static software.amazon.awssdk.utils.FunctionalUtils.runAndLogError; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeParams; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeProvider; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointParams; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointProvider; +import foo.bar.helloworld.endpoints.internal.ProtocolRestJsonWithCustomPackageEndpointResolverUtils; import foo.bar.helloworld.internal.ProtocolRestJsonWithCustomPackageServiceClientConfigurationBuilder; import foo.bar.helloworld.internal.ServiceVersionInfo; import foo.bar.helloworld.model.OneOperationRequest; @@ -12,13 +17,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -26,14 +36,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -44,6 +60,7 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomPackageAsyncClient}. @@ -66,8 +83,11 @@ final class DefaultProtocolRestJsonWithCustomPackageAsyncClient implements Proto protected DefaultProtocolRestJsonWithCustomPackageAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -124,7 +144,8 @@ public CompletableFuture oneOperation(OneOperationRequest .withMarshaller(new OneOperationRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(oneOperationRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")).withInput(oneOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -168,6 +189,52 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomPackageAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider"); + ProtocolRestJsonWithCustomPackageAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomPackageAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomPackageEndpointProvider provider = (ProtocolRestJsonWithCustomPackageEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomPackageEndpointParams endpointParams = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomPackageEndpointResolverUtils.hostPrefix(operationName, + request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomPackageEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java index fe80546d6f1a..62fa8305f07c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custompackage-sync.java @@ -1,5 +1,10 @@ package foo.bar.helloworld; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeParams; +import foo.bar.helloworld.auth.scheme.ProtocolRestJsonWithCustomPackageAuthSchemeProvider; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointParams; +import foo.bar.helloworld.endpoints.ProtocolRestJsonWithCustomPackageEndpointProvider; +import foo.bar.helloworld.endpoints.internal.ProtocolRestJsonWithCustomPackageEndpointResolverUtils; import foo.bar.helloworld.internal.ProtocolRestJsonWithCustomPackageServiceClientConfigurationBuilder; import foo.bar.helloworld.internal.ServiceVersionInfo; import foo.bar.helloworld.model.OneOperationRequest; @@ -9,11 +14,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -21,6 +31,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -28,8 +39,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -40,6 +55,7 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomPackageClient}. @@ -62,8 +78,11 @@ final class DefaultProtocolRestJsonWithCustomPackageClient implements ProtocolRe protected DefaultProtocolRestJsonWithCustomPackageClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsSyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomPackage" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -116,6 +135,8 @@ public OneOperationResponse oneOperation(OneOperationRequest oneOperationRequest .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(oneOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")) .withMarshaller(new OneOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -142,6 +163,52 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomPackageAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + ProtocolRestJsonWithCustomPackageAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomPackageAuthSchemeProvider"); + ProtocolRestJsonWithCustomPackageAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomPackageAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomPackageEndpointProvider provider = (ProtocolRestJsonWithCustomPackageEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomPackageEndpointParams endpointParams = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomPackageEndpointResolverUtils.hostPrefix(operationName, + request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomPackageEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomPackageEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java index 96f626d27f13..0c905d1b77b4 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-async.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.internal.ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ProtocolRestJsonWithCustomContentTypeServiceClientConfigurationBuilder; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ServiceVersionInfo; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.OneOperationRequest; @@ -44,6 +60,7 @@ import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.ProtocolRestJsonWithCustomContentTypeException; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.transform.OneOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomContentTypeAsyncClient}. @@ -66,8 +83,11 @@ final class DefaultProtocolRestJsonWithCustomContentTypeAsyncClient implements P protected DefaultProtocolRestJsonWithCustomContentTypeAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); - this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "AmazonProtocolRestJsonWithCustomContentType" + "#" + ServiceVersionInfo.VERSION).build(); + this.clientConfiguration = clientConfiguration + .toBuilder() + .option(SdkClientOption.SDK_CLIENT, this) + .option(SdkClientOption.API_METADATA, + "AmazonProtocolRestJsonWithCustomContentType" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -124,7 +144,8 @@ public CompletableFuture oneOperation(OneOperationRequest .withMarshaller(new OneOperationRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(oneOperationRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")).withInput(oneOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -168,6 +189,52 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider"); + ProtocolRestJsonWithCustomContentTypeAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomContentTypeAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomContentTypeEndpointProvider provider = (ProtocolRestJsonWithCustomContentTypeEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomContentTypeEndpointParams endpointParams = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.hostPrefix( + operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java index 3852e59b5710..dc7df15efd3d 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-customservicemetadata-sync.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.auth.scheme.ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointParams; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.ProtocolRestJsonWithCustomContentTypeEndpointProvider; +import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.endpoints.internal.ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ProtocolRestJsonWithCustomContentTypeServiceClientConfigurationBuilder; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.internal.ServiceVersionInfo; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.OneOperationRequest; @@ -40,6 +55,7 @@ import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.model.ProtocolRestJsonWithCustomContentTypeException; import software.amazon.awssdk.services.protocolrestjsonwithcustomcontenttype.transform.OneOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link ProtocolRestJsonWithCustomContentTypeClient}. @@ -119,6 +135,8 @@ public OneOperationResponse oneOperation(OneOperationRequest oneOperationRequest .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(oneOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OneOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OneOperation")) .withMarshaller(new OneOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -145,6 +163,52 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of ProtocolRestJsonWithCustomContentTypeAuthSchemeProvider"); + ProtocolRestJsonWithCustomContentTypeAuthSchemeParams.Builder paramsBuilder = ProtocolRestJsonWithCustomContentTypeAuthSchemeParams + .builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + ProtocolRestJsonWithCustomContentTypeEndpointProvider provider = (ProtocolRestJsonWithCustomContentTypeEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + ProtocolRestJsonWithCustomContentTypeEndpointParams endpointParams = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.hostPrefix( + operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils + .authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + ProtocolRestJsonWithCustomContentTypeEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java index b0426321ea91..c123a0786011 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; @@ -16,6 +17,9 @@ import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -23,6 +27,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -30,9 +35,14 @@ import software.amazon.awssdk.core.client.handler.ClientExecutionParams; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -43,6 +53,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeParams; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointParams; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.internal.EndpointDiscoveryTestEndpointResolverUtils; import software.amazon.awssdk.services.endpointdiscoverytest.internal.EndpointDiscoveryTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.endpointdiscoverytest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.endpointdiscoverytest.model.DescribeEndpointsRequest; @@ -59,6 +74,7 @@ import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryOptionalRequestMarshaller; import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryRequiredRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link EndpointDiscoveryTestAsyncClient}. @@ -140,10 +156,16 @@ public CompletableFuture describeEndpoints(DescribeEn CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("DescribeEndpoints").withProtocolMetadata(protocolMetadata) + .withOperationName("DescribeEndpoints") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new DescribeEndpointsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "DescribeEndpoints", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DescribeEndpoints")) .withInput(describeEndpointsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -235,10 +257,17 @@ public CompletableFuture testDiscovery CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryIdentifiersRequired").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryIdentifiersRequired") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryIdentifiersRequired", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryIdentifiersRequired")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryIdentifiersRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -320,10 +349,16 @@ public CompletableFuture testDiscoveryOptional( CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryOptional").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryOptional") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryOptional", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryOptional")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryOptionalRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -413,10 +448,16 @@ public CompletableFuture testDiscoveryRequired( CompletableFuture executeFuture = endpointFuture .thenCompose(cachedEndpoint -> clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryRequired").withProtocolMetadata(protocolMetadata) + .withOperationName("TestDiscoveryRequired") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryRequired")) .discoveredEndpoint(cachedEndpoint).withInput(testDiscoveryRequiredRequest))); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -460,6 +501,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + EndpointDiscoveryTestAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + EndpointDiscoveryTestAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider"); + EndpointDiscoveryTestAuthSchemeParams.Builder paramsBuilder = EndpointDiscoveryTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + EndpointDiscoveryTestEndpointProvider provider = (EndpointDiscoveryTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + EndpointDiscoveryTestEndpointParams endpointParams = EndpointDiscoveryTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = EndpointDiscoveryTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = EndpointDiscoveryTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + EndpointDiscoveryTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java index 731f8a467062..f494d60f1364 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-sync.java @@ -5,6 +5,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; @@ -12,6 +13,9 @@ import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -19,6 +23,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -28,8 +33,12 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -40,6 +49,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeParams; +import software.amazon.awssdk.services.endpointdiscoverytest.auth.scheme.EndpointDiscoveryTestAuthSchemeProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointParams; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.EndpointDiscoveryTestEndpointProvider; +import software.amazon.awssdk.services.endpointdiscoverytest.endpoints.internal.EndpointDiscoveryTestEndpointResolverUtils; import software.amazon.awssdk.services.endpointdiscoverytest.internal.EndpointDiscoveryTestServiceClientConfigurationBuilder; import software.amazon.awssdk.services.endpointdiscoverytest.internal.ServiceVersionInfo; import software.amazon.awssdk.services.endpointdiscoverytest.model.DescribeEndpointsRequest; @@ -57,6 +71,7 @@ import software.amazon.awssdk.services.endpointdiscoverytest.transform.TestDiscoveryRequiredRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link EndpointDiscoveryTestClient}. @@ -138,6 +153,8 @@ public DescribeEndpointsResponse describeEndpoints(DescribeEndpointsRequest desc .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(describeEndpointsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DescribeEndpoints", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DescribeEndpoints")) .withMarshaller(new DescribeEndpointsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -211,10 +228,17 @@ public TestDiscoveryIdentifiersRequiredResponse testDiscoveryIdentifiersRequired return clientHandler .execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryIdentifiersRequired").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryIdentifiersRequiredRequest).withMetricCollector(apiCallMetricCollector) + .withOperationName("TestDiscoveryIdentifiersRequired") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryIdentifiersRequiredRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryIdentifiersRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryIdentifiersRequired")) .withMarshaller(new TestDiscoveryIdentifiersRequiredRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -277,12 +301,20 @@ public TestDiscoveryOptionalResponse testDiscoveryOptional(TestDiscoveryOptional apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "AwsEndpointDiscoveryTest"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "TestDiscoveryOptional"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryOptional").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryOptionalRequest).withMetricCollector(apiCallMetricCollector) - .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("TestDiscoveryOptional") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryOptionalRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryOptional", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryOptional")) + .withMarshaller(new TestDiscoveryOptionalRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -352,12 +384,20 @@ public TestDiscoveryRequiredResponse testDiscoveryRequired(TestDiscoveryRequired apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "AwsEndpointDiscoveryTest"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "TestDiscoveryRequired"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("TestDiscoveryRequired").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .discoveredEndpoint(cachedEndpoint).withRequestConfiguration(clientConfiguration) - .withInput(testDiscoveryRequiredRequest).withMetricCollector(apiCallMetricCollector) - .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("TestDiscoveryRequired") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .discoveredEndpoint(cachedEndpoint) + .withRequestConfiguration(clientConfiguration) + .withInput(testDiscoveryRequiredRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "TestDiscoveryRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "TestDiscoveryRequired")) + .withMarshaller(new TestDiscoveryRequiredRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -383,6 +423,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + EndpointDiscoveryTestAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + EndpointDiscoveryTestAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of EndpointDiscoveryTestAuthSchemeProvider"); + EndpointDiscoveryTestAuthSchemeParams.Builder paramsBuilder = EndpointDiscoveryTestAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + EndpointDiscoveryTestEndpointProvider provider = (EndpointDiscoveryTestEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + EndpointDiscoveryTestEndpointParams endpointParams = EndpointDiscoveryTestEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = EndpointDiscoveryTestEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = EndpointDiscoveryTestEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + EndpointDiscoveryTestEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java index 4fcb45c2f919..db9913412e2a 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-async-client-class.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; @@ -16,8 +17,12 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.client.handler.AwsClientHandlerUtils; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionJsonMarshaller; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; @@ -33,6 +38,7 @@ import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -44,7 +50,9 @@ import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.AttachHttpMetadataResponseHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -53,6 +61,8 @@ import software.amazon.awssdk.core.protocol.VoidSdkResponse; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -62,7 +72,12 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeParams; +import software.amazon.awssdk.services.json.auth.scheme.JsonAuthSchemeProvider; import software.amazon.awssdk.services.json.batchmanager.JsonAsyncBatchManager; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointParams; +import software.amazon.awssdk.services.json.endpoints.JsonEndpointProvider; +import software.amazon.awssdk.services.json.endpoints.internal.JsonEndpointResolverUtils; import software.amazon.awssdk.services.json.internal.JsonServiceClientConfigurationBuilder; import software.amazon.awssdk.services.json.internal.ServiceVersionInfo; import software.amazon.awssdk.services.json.model.APostOperationRequest; @@ -129,6 +144,7 @@ import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.HostnameValidator; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link JsonAsyncClient}. @@ -227,10 +243,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -301,10 +323,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -371,10 +399,16 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("BearerAuthOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .credentialType(CredentialType.TOKEN).withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -484,11 +518,18 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withFullDuplex(true) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withFullDuplex(true) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) .withInput(eventStreamOperationRequest), restAsyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -572,11 +613,18 @@ public CompletableFuture eventStreamO CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EventStreamOperationWithOnlyInput").withProtocolMetadata(protocolMetadata) + .withOperationName("EventStreamOperationWithOnlyInput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyInputRequestMarshaller(protocolFactory)) - .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationWithOnlyInputRequest)); + .withAsyncRequestBody(AsyncRequestBody.fromPublisher(adapted)) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyInput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyInput")) + .withInput(eventStreamOperationWithOnlyInputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -686,8 +734,14 @@ public CompletableFuture eventStreamOperationWithOnlyOutput( .withOperationName("EventStreamOperationWithOnlyOutput") .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationWithOnlyOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperationWithOnlyOutput", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperationWithOnlyOutput")) .withInput(eventStreamOperationWithOnlyOutputRequest), restAsyncResponseTransformer); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -770,6 +824,9 @@ public CompletableFuture getOperationWithCheck .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -845,10 +902,16 @@ public CompletableFuture getWithoutRequiredMe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GetWithoutRequiredMembers").withProtocolMetadata(protocolMetadata) + .withOperationName("GetWithoutRequiredMembers") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GetWithoutRequiredMembersRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetWithoutRequiredMembers", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetWithoutRequiredMembers")) .withInput(getWithoutRequiredMembersRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -923,6 +986,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { @@ -998,6 +1064,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -1067,10 +1136,16 @@ public CompletableFuture paginatedOpera CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithResultKey")) .withInput(paginatedOperationWithResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1138,10 +1213,16 @@ public CompletableFuture paginatedOp CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("PaginatedOperationWithoutResultKey").withProtocolMetadata(protocolMetadata) + .withOperationName("PaginatedOperationWithoutResultKey") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new PaginatedOperationWithoutResultKeyRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PaginatedOperationWithoutResultKey", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PaginatedOperationWithoutResultKey")) .withInput(paginatedOperationWithoutResultKeyRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1241,6 +1322,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .withAsyncRequestBody(requestBody) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, @@ -1344,10 +1428,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -1436,8 +1525,13 @@ public CompletableFuture streamingInputOutputOperation( .delegateMarshaller( new StreamingInputOutputOperationRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOutputOperation")) .withAsyncRequestBody(requestBody).withAsyncResponseTransformer(asyncResponseTransformer) .withInput(streamingInputOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1525,10 +1619,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); AsyncResponseTransformer finalAsyncResponseTransformer = asyncResponseTransformer; @@ -1587,6 +1687,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + JsonAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(JsonAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of JsonAuthSchemeProvider"); + JsonAuthSchemeParams.Builder paramsBuilder = JsonAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + JsonEndpointProvider provider = (JsonEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + JsonEndpointParams endpointParams = JsonEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = JsonEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = JsonEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + JsonEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java index cb125ce31ad6..b737870efece 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-json-client-class.java @@ -11,11 +11,11 @@ import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; -import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm; import software.amazon.awssdk.core.CredentialType; diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java index b43795a31a9b..7da95a5d40cb 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-async-client-class.java @@ -4,14 +4,20 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -22,6 +28,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -30,7 +37,9 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -38,12 +47,19 @@ import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; import software.amazon.awssdk.protocols.core.ExceptionMetadata; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; +import software.amazon.awssdk.services.query.endpoints.internal.QueryEndpointResolverUtils; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -99,6 +115,7 @@ import software.amazon.awssdk.services.query.waiters.QueryAsyncWaiter; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryAsyncClient}. @@ -173,10 +190,16 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .hostPrefixExpression(resolvedHostExpression).withInput(aPostOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -233,10 +256,16 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -289,11 +318,18 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("BearerAuthOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(bearerAuthOperationRequest)); + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .credentialType(CredentialType.TOKEN) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) + .withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -352,6 +388,9 @@ public CompletableFuture getOperationWithCheck .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -417,6 +456,9 @@ public CompletableFuture operationWithChe .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleteFuture = null; @@ -470,10 +512,16 @@ public CompletableFuture operationWithContext CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithContextParam")) .withInput(operationWithContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -527,10 +575,16 @@ public CompletableFuture operationWithCustomM CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithCustomMember").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithCustomMember") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomMember")) .withInput(operationWithCustomMemberRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -588,8 +642,14 @@ public CompletableFuture o .withOperationName("OperationWithCustomizedOperationContextParam") .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomizedOperationContextParam")) .withInput(operationWithCustomizedOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -644,10 +704,16 @@ public CompletableFuture operatio CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithMapOperationContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithMapOperationContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithMapOperationContextParam")) .withInput(operationWithMapOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -700,10 +766,16 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -758,10 +830,16 @@ public CompletableFuture operationWi CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithOperationContextParam").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithOperationContextParam") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithOperationContextParam")) .withInput(operationWithOperationContextParamRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -822,6 +900,9 @@ public CompletableFuture operationWithR .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -877,10 +958,16 @@ public CompletableFuture operationWith CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithStaticContextParams").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithStaticContextParams") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithStaticContextParams")) .withInput(operationWithStaticContextParamsRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -966,6 +1053,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -1052,10 +1142,15 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1117,10 +1212,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleteFuture = null; @@ -1183,6 +1284,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(QueryAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryAuthSchemeProvider"); + QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryEndpointProvider provider = (QueryEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryEndpointParams endpointParams = QueryEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java index 766210e8c6f4..661ca78263d5 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java @@ -2,11 +2,16 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -17,6 +22,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -24,6 +30,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -33,6 +40,7 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -42,6 +50,9 @@ import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointParams; +import software.amazon.awssdk.services.query.endpoints.QueryEndpointProvider; +import software.amazon.awssdk.services.query.endpoints.internal.QueryEndpointResolverUtils; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -96,6 +107,7 @@ import software.amazon.awssdk.services.query.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.services.query.waiters.QueryWaiter; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link QueryClient}. @@ -168,6 +180,7 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -224,6 +237,7 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -270,6 +284,7 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest).withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -322,6 +337,7 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -379,6 +395,7 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -433,6 +450,7 @@ public OperationWithContextParamResponse operationWithContextParam( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithContextParam")) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -486,6 +504,7 @@ public OperationWithCustomMemberResponse operationWithCustomMember( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomMember")) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -540,6 +559,7 @@ public OperationWithCustomizedOperationContextParamResponse operationWithCustomi .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithCustomizedOperationContextParam")) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -593,6 +613,7 @@ public OperationWithMapOperationContextParamResponse operationWithMapOperationCo .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithMapOperationContextParam")) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -645,6 +666,7 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -698,6 +720,7 @@ public OperationWithOperationContextParamResponse operationWithOperationContextP .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithOperationContextParam")) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -750,6 +773,7 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -804,6 +828,7 @@ public OperationWithStaticContextParamsResponse operationWithStaticContextParams .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithStaticContextParams")) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -882,6 +907,7 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -962,6 +988,7 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -1025,6 +1052,7 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques .withMetricCollector(apiCallMetricCollector) .withAuthSchemeOptionsResolver( r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { @@ -1067,11 +1095,44 @@ private static List resolveMetricPublishers(SdkClientConfigurat private List resolveAuthSchemeOptions(SdkRequest request, String operationName, SdkClientConfiguration clientConfiguration) { - QueryAuthSchemeProvider authSchemeProvider = (QueryAuthSchemeProvider) clientConfiguration - .option(SdkClientOption.AUTH_SCHEME_PROVIDER); + QueryAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(QueryAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of QueryAuthSchemeProvider"); QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); - return authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + QueryEndpointProvider provider = (QueryEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + QueryEndpointParams endpointParams = QueryEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = QueryEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = QueryEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + QueryEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } } private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java index bd3083b4602b..f91dfa34ce19 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-async-client-class.java @@ -6,13 +6,18 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,14 +25,20 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -37,6 +48,11 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.protocols.rpcv2.SmithyRpcV2CborProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.internal.SmithyRpcV2ProtocolEndpointResolverUtils; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.ServiceVersionInfo; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.SmithyRpcV2ProtocolServiceClientConfigurationBuilder; import software.amazon.awssdk.services.smithyrpcv2protocol.model.ComplexErrorException; @@ -83,6 +99,7 @@ import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SimpleScalarPropertiesRequestMarshaller; import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SparseNullsOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link SmithyRpcV2ProtocolAsyncClient}. @@ -169,10 +186,16 @@ public CompletableFuture emptyInputOutput(EmptyInputOu CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("EmptyInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("EmptyInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new EmptyInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EmptyInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EmptyInputOutput")) .withInput(emptyInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -246,7 +269,8 @@ public CompletableFuture float16(Float16Request float16Request) .withProtocolMetadata(protocolMetadata).withMarshaller(new Float16RequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(float16Request)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "Float16", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "Float16")).withInput(float16Request)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -317,10 +341,16 @@ public CompletableFuture fractionalSeconds(Fractional CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("FractionalSeconds").withProtocolMetadata(protocolMetadata) + .withOperationName("FractionalSeconds") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new FractionalSecondsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "FractionalSeconds", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "FractionalSeconds")) .withInput(fractionalSecondsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -394,10 +424,16 @@ public CompletableFuture greetingWithErrors(Greeting CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("GreetingWithErrors").withProtocolMetadata(protocolMetadata) + .withOperationName("GreetingWithErrors") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new GreetingWithErrorsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GreetingWithErrors", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GreetingWithErrors")) .withInput(greetingWithErrorsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -468,10 +504,15 @@ public CompletableFuture noInputOutput(NoInputOutputReque CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("NoInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("NoInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new NoInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "NoInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "NoInputOutput")) .withInput(noInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -545,10 +586,16 @@ public CompletableFuture operationWithDefaults( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithDefaults").withProtocolMetadata(protocolMetadata) + .withOperationName("OperationWithDefaults") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithDefaults", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithDefaults")) .withInput(operationWithDefaultsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -621,10 +668,16 @@ public CompletableFuture optionalInputOutput( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OptionalInputOutput").withProtocolMetadata(protocolMetadata) + .withOperationName("OptionalInputOutput") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OptionalInputOutputRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OptionalInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OptionalInputOutput")) .withInput(optionalInputOutputRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -696,10 +749,16 @@ public CompletableFuture recursiveShapes(RecursiveShape CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RecursiveShapes").withProtocolMetadata(protocolMetadata) + .withOperationName("RecursiveShapes") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RecursiveShapesRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RecursiveShapes", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RecursiveShapes")) .withInput(recursiveShapesRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -772,10 +831,16 @@ public CompletableFuture rpcV2CborDenseMaps(RpcV2Cbo CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborDenseMaps").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborDenseMaps") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborDenseMapsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborDenseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborDenseMaps")) .withInput(rpcV2CborDenseMapsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -847,10 +912,16 @@ public CompletableFuture rpcV2CborLists(RpcV2CborListsRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborLists").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborLists") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborListsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborLists", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborLists")) .withInput(rpcV2CborListsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -924,10 +995,16 @@ public CompletableFuture rpcV2CborSparseMaps( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("RpcV2CborSparseMaps").withProtocolMetadata(protocolMetadata) + .withOperationName("RpcV2CborSparseMaps") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new RpcV2CborSparseMapsRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "RpcV2CborSparseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborSparseMaps")) .withInput(rpcV2CborSparseMapsRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1000,10 +1077,16 @@ public CompletableFuture simpleScalarProperties( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("SimpleScalarProperties").withProtocolMetadata(protocolMetadata) + .withOperationName("SimpleScalarProperties") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SimpleScalarPropertiesRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SimpleScalarProperties", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SimpleScalarProperties")) .withInput(simpleScalarPropertiesRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1076,10 +1159,16 @@ public CompletableFuture sparseNullsOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("SparseNullsOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("SparseNullsOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SparseNullsOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SparseNullsOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SparseNullsOperation")) .withInput(sparseNullsOperationRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1123,6 +1212,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + SmithyRpcV2ProtocolAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + SmithyRpcV2ProtocolAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider"); + SmithyRpcV2ProtocolAuthSchemeParams.Builder paramsBuilder = SmithyRpcV2ProtocolAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + SmithyRpcV2ProtocolEndpointProvider provider = (SmithyRpcV2ProtocolEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + SmithyRpcV2ProtocolEndpointParams endpointParams = SmithyRpcV2ProtocolEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = SmithyRpcV2ProtocolEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = SmithyRpcV2ProtocolEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + SmithyRpcV2ProtocolEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); @@ -1170,4 +1303,4 @@ private HttpResponseHandler createErrorResponseHandler(Base public void close() { clientHandler.close(); } -} \ No newline at end of file +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java index fcda9aa09cb8..fa3c0f0eaeca 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-rpcv2-sync.java @@ -3,11 +3,16 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +20,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,8 +28,12 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -33,6 +43,11 @@ import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.protocols.rpcv2.SmithyRpcV2CborProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.auth.scheme.SmithyRpcV2ProtocolAuthSchemeProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointParams; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.SmithyRpcV2ProtocolEndpointProvider; +import software.amazon.awssdk.services.smithyrpcv2protocol.endpoints.internal.SmithyRpcV2ProtocolEndpointResolverUtils; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.ServiceVersionInfo; import software.amazon.awssdk.services.smithyrpcv2protocol.internal.SmithyRpcV2ProtocolServiceClientConfigurationBuilder; import software.amazon.awssdk.services.smithyrpcv2protocol.model.ComplexErrorException; @@ -79,6 +94,7 @@ import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SimpleScalarPropertiesRequestMarshaller; import software.amazon.awssdk.services.smithyrpcv2protocol.transform.SparseNullsOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link SmithyRpcV2ProtocolClient}. @@ -165,6 +181,8 @@ public EmptyInputOutputResponse emptyInputOutput(EmptyInputOutputRequest emptyIn .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(emptyInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "EmptyInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EmptyInputOutput")) .withMarshaller(new EmptyInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -228,6 +246,8 @@ public Float16Response float16(Float16Request float16Request) throws AwsServiceE .withOperationName("Float16").withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(float16Request).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "Float16", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "Float16")) .withMarshaller(new Float16RequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -293,6 +313,8 @@ public FractionalSecondsResponse fractionalSeconds(FractionalSecondsRequest frac .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(fractionalSecondsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "FractionalSeconds", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "FractionalSeconds")) .withMarshaller(new FractionalSecondsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -361,6 +383,8 @@ public GreetingWithErrorsResponse greetingWithErrors(GreetingWithErrorsRequest g .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(greetingWithErrorsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GreetingWithErrors", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GreetingWithErrors")) .withMarshaller(new GreetingWithErrorsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -425,6 +449,8 @@ public NoInputOutputResponse noInputOutput(NoInputOutputRequest noInputOutputReq .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(noInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "NoInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "NoInputOutput")) .withMarshaller(new NoInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -486,12 +512,19 @@ public OperationWithDefaultsResponse operationWithDefaults(OperationWithDefaults apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "SmithyRpcV2Protocol"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "OperationWithDefaults"); - return clientHandler.execute(new ClientExecutionParams() - .withOperationName("OperationWithDefaults").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithDefaultsRequest) - .withMetricCollector(apiCallMetricCollector) - .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory))); + return clientHandler + .execute(new ClientExecutionParams() + .withOperationName("OperationWithDefaults") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithDefaultsRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithDefaults", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithDefaults")) + .withMarshaller(new OperationWithDefaultsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -556,6 +589,8 @@ public OptionalInputOutputResponse optionalInputOutput(OptionalInputOutputReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(optionalInputOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "OptionalInputOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OptionalInputOutput")) .withMarshaller(new OptionalInputOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -621,6 +656,8 @@ public RecursiveShapesResponse recursiveShapes(RecursiveShapesRequest recursiveS .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(recursiveShapesRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RecursiveShapes", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RecursiveShapes")) .withMarshaller(new RecursiveShapesRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -687,6 +724,8 @@ public RpcV2CborDenseMapsResponse rpcV2CborDenseMaps(RpcV2CborDenseMapsRequest r .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborDenseMapsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborDenseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborDenseMaps")) .withMarshaller(new RpcV2CborDenseMapsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -752,6 +791,8 @@ public RpcV2CborListsResponse rpcV2CborLists(RpcV2CborListsRequest rpcV2CborList .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborListsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborLists", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborLists")) .withMarshaller(new RpcV2CborListsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -818,6 +859,8 @@ public RpcV2CborSparseMapsResponse rpcV2CborSparseMaps(RpcV2CborSparseMapsReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(rpcV2CborSparseMapsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "RpcV2CborSparseMaps", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "RpcV2CborSparseMaps")) .withMarshaller(new RpcV2CborSparseMapsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -880,10 +923,16 @@ public SimpleScalarPropertiesResponse simpleScalarProperties(SimpleScalarPropert return clientHandler .execute(new ClientExecutionParams() - .withOperationName("SimpleScalarProperties").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(simpleScalarPropertiesRequest) + .withOperationName("SimpleScalarProperties") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(simpleScalarPropertiesRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "SimpleScalarProperties", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SimpleScalarProperties")) .withMarshaller(new SimpleScalarPropertiesRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -949,6 +998,8 @@ public SparseNullsOperationResponse sparseNullsOperation(SparseNullsOperationReq .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withInput(sparseNullsOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "SparseNullsOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "SparseNullsOperation")) .withMarshaller(new SparseNullsOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -975,6 +1026,50 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + SmithyRpcV2ProtocolAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf( + SmithyRpcV2ProtocolAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of SmithyRpcV2ProtocolAuthSchemeProvider"); + SmithyRpcV2ProtocolAuthSchemeParams.Builder paramsBuilder = SmithyRpcV2ProtocolAuthSchemeParams.builder().operation( + operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + SmithyRpcV2ProtocolEndpointProvider provider = (SmithyRpcV2ProtocolEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + SmithyRpcV2ProtocolEndpointParams endpointParams = SmithyRpcV2ProtocolEndpointResolverUtils.ruleParams(request, + executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = SmithyRpcV2ProtocolEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = SmithyRpcV2ProtocolEndpointResolverUtils.authSchemeWithEndpointSignerProperties( + endpointAuthSchemes, selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + SmithyRpcV2ProtocolEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java index efa307505463..b0fdf65bb357 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-async-client-class.java @@ -5,14 +5,20 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -20,16 +26,25 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +54,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; +import software.amazon.awssdk.services.database.endpoints.internal.DatabaseEndpointResolverUtils; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.services.database.internal.ServiceVersionInfo; import software.amazon.awssdk.services.database.model.DatabaseException; @@ -76,7 +96,9 @@ import software.amazon.awssdk.services.database.transform.OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; import software.amazon.awssdk.services.database.transform.PutRowRequestMarshaller; import software.amazon.awssdk.services.database.transform.SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link DatabaseAsyncClient}. @@ -163,7 +185,9 @@ public CompletableFuture deleteRow(DeleteRowRequest deleteRow .withProtocolMetadata(protocolMetadata) .withMarshaller(new DeleteRowRequestMarshaller(protocolFactory)).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) - .withMetricCollector(apiCallMetricCollector).withInput(deleteRowRequest)); + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DeleteRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DeleteRow")).withInput(deleteRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -234,7 +258,8 @@ public CompletableFuture getRow(GetRowRequest getRowRequest) { .withProtocolMetadata(protocolMetadata).withMarshaller(new GetRowRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(getRowRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetRow")).withInput(getRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -306,10 +331,16 @@ public CompletableFuture opWithSigv CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4AndSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4AndSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4AndSigv4AUnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4AndSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4AndSigv4aUnSignedPayload")) .withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -380,10 +411,16 @@ public CompletableFuture opWithSigv4SignedPayl CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4SignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4SignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4SignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4SignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4SignedPayload")) .withInput(opWithSigv4SignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -454,10 +491,16 @@ public CompletableFuture opWithSigv4UnSigned CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4UnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4UnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4UnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayload")) .withInput(opWithSigv4UnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -544,8 +587,14 @@ public CompletableFuture opWithS .delegateMarshaller( new OpWithSigv4UnSignedPayloadAndStreamingRequestMarshaller(protocolFactory)) .asyncRequestBody(requestBody).transferEncoding(true).build()) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayloadAndStreaming", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayloadAndStreaming")) .withAsyncRequestBody(requestBody).withInput(opWithSigv4UnSignedPayloadAndStreamingRequest)); CompletableFuture whenCompleted = executeFuture .whenComplete((r, e) -> { @@ -617,10 +666,16 @@ public CompletableFuture opWithSigv4aSignedPa CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aSignedPayload")) .withInput(opWithSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -691,10 +746,16 @@ public CompletableFuture opWithSigv4aUnSign CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opWithSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpWithSigv4AUnSignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aUnSignedPayload")) .withInput(opWithSigv4AUnSignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -767,10 +828,16 @@ public CompletableFuture opsWithSigv CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("opsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("opsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opsWithSigv4andSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opsWithSigv4andSigv4aSignedPayload")) .withInput(opsWithSigv4AndSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -842,7 +909,8 @@ public CompletableFuture putRow(PutRowRequest putRowRequest) { .withProtocolMetadata(protocolMetadata).withMarshaller(new PutRowRequestMarshaller(protocolFactory)) .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(putRowRequest)); + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "PutRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutRow")).withInput(putRowRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -914,10 +982,17 @@ public CompletableFuture secon CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) + .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "secondOpsWithSigv4andSigv4aSignedPayload", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "secondOpsWithSigv4andSigv4aSignedPayload")) .withInput(secondOpsWithSigv4AndSigv4ASignedPayloadRequest)); CompletableFuture whenCompleted = executeFuture .whenComplete((r, e) -> { @@ -961,6 +1036,61 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + DatabaseAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(DatabaseAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of DatabaseAuthSchemeProvider"); + DatabaseAuthSchemeParams.Builder paramsBuilder = DatabaseAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + Set sigv4aRegionSet = clientConfiguration.option(AwsClientOption.AWS_SIGV4A_SIGNING_REGION_SET); + if (!CollectionUtils.isNullOrEmpty(sigv4aRegionSet)) { + paramsBuilder.regionSet(RegionSet.create(sigv4aRegionSet)); + } + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + DatabaseEndpointParams endpointParams = DatabaseEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = DatabaseEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = DatabaseEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + // Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications + if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) + && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) { + AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder(); + RegionSet rs = RegionSet.create(endpointParams.region().id()); + optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, rs); + selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + optionBuilder.build()); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + DatabaseEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java index ff8b3a285415..e7c9da6f9d7e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-unsigned-payload-trait-sync-client-class.java @@ -3,11 +3,17 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import java.util.function.Function; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -15,6 +21,7 @@ import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -22,10 +29,17 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -35,6 +49,11 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; +import software.amazon.awssdk.services.database.endpoints.internal.DatabaseEndpointResolverUtils; import software.amazon.awssdk.services.database.internal.DatabaseServiceClientConfigurationBuilder; import software.amazon.awssdk.services.database.internal.ServiceVersionInfo; import software.amazon.awssdk.services.database.model.DatabaseException; @@ -72,7 +91,9 @@ import software.amazon.awssdk.services.database.transform.OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; import software.amazon.awssdk.services.database.transform.PutRowRequestMarshaller; import software.amazon.awssdk.services.database.transform.SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller; +import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link DatabaseClient}. @@ -155,6 +176,8 @@ public DeleteRowResponse deleteRow(DeleteRowRequest deleteRowRequest) throws Inv .withOperationName("DeleteRow").withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(deleteRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "DeleteRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "DeleteRow")) .withMarshaller(new DeleteRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -216,6 +239,8 @@ public GetRowResponse getRow(GetRowRequest getRowRequest) throws InvalidInputExc .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(getRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "GetRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetRow")) .withMarshaller(new GetRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -277,6 +302,8 @@ public PutRowResponse putRow(PutRowRequest putRowRequest) throws InvalidInputExc .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) .withInput(putRowRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "PutRow", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutRow")) .withMarshaller(new PutRowRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -339,10 +366,16 @@ public OpWithSigv4AndSigv4AUnSignedPayloadResponse opWithSigv4AndSigv4aUnSignedP return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4AndSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest) + .withOperationName("opWithSigv4AndSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4AndSigv4AUnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4AndSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4AndSigv4aUnSignedPayload")) .withMarshaller(new OpWithSigv4AndSigv4AUnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -404,10 +437,16 @@ public OpWithSigv4SignedPayloadResponse opWithSigv4SignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4SignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4SignedPayloadRequest) + .withOperationName("opWithSigv4SignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4SignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4SignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4SignedPayload")) .withMarshaller(new OpWithSigv4SignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -469,10 +508,16 @@ public OpWithSigv4UnSignedPayloadResponse opWithSigv4UnSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4UnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4UnSignedPayloadRequest) + .withOperationName("opWithSigv4UnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4UnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayload")) .withMarshaller(new OpWithSigv4UnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -553,6 +598,10 @@ public OpWithSigv4UnSignedPayloadAndStreamingResponse opWithSigv4UnSignedPayload .withRequestConfiguration(clientConfiguration) .withInput(opWithSigv4UnSignedPayloadAndStreamingRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4UnSignedPayloadAndStreaming", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4UnSignedPayloadAndStreaming")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller @@ -620,10 +669,16 @@ public OpWithSigv4ASignedPayloadResponse opWithSigv4aSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4ASignedPayloadRequest) + .withOperationName("opWithSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aSignedPayload")) .withMarshaller(new OpWithSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -685,10 +740,16 @@ public OpWithSigv4AUnSignedPayloadResponse opWithSigv4aUnSignedPayload( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opWithSigv4aUnSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opWithSigv4AUnSignedPayloadRequest) + .withOperationName("opWithSigv4aUnSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opWithSigv4AUnSignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opWithSigv4aUnSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opWithSigv4aUnSignedPayload")) .withMarshaller(new OpWithSigv4AUnSignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -751,10 +812,16 @@ public OpsWithSigv4AndSigv4ASignedPayloadResponse opsWithSigv4andSigv4aSignedPay return clientHandler .execute(new ClientExecutionParams() - .withOperationName("opsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(opsWithSigv4AndSigv4ASignedPayloadRequest) + .withOperationName("opsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(opsWithSigv4AndSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "opsWithSigv4andSigv4aSignedPayload", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "opsWithSigv4andSigv4aSignedPayload")) .withMarshaller(new OpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -817,11 +884,17 @@ public SecondOpsWithSigv4AndSigv4ASignedPayloadResponse secondOpsWithSigv4andSig return clientHandler .execute(new ClientExecutionParams() - .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withOperationName("secondOpsWithSigv4andSigv4aSignedPayload") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withInput(secondOpsWithSigv4AndSigv4ASignedPayloadRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "secondOpsWithSigv4andSigv4aSignedPayload", + clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "secondOpsWithSigv4andSigv4aSignedPayload")) .withMarshaller(new SecondOpsWithSigv4AndSigv4ASignedPayloadRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -848,6 +921,61 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + DatabaseAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(DatabaseAuthSchemeProvider.class, + clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of DatabaseAuthSchemeProvider"); + DatabaseAuthSchemeParams.Builder paramsBuilder = DatabaseAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + Set sigv4aRegionSet = clientConfiguration.option(AwsClientOption.AWS_SIGV4A_SIGNING_REGION_SET); + if (!CollectionUtils.isNullOrEmpty(sigv4aRegionSet)) { + paramsBuilder.regionSet(RegionSet.create(sigv4aRegionSet)); + } + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + DatabaseEndpointParams endpointParams = DatabaseEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = DatabaseEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = DatabaseEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + // Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications + if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) + && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) { + AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder(); + RegionSet rs = RegionSet.create(endpointParams.region().id()); + optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, rs); + selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + optionBuilder.build()); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + DatabaseEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java index a7d0c7b4a984..a60d2b04556e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-async-client-class.java @@ -4,14 +4,20 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer; @@ -26,6 +32,7 @@ import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkPojoBuilder; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -35,7 +42,9 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.AsyncClientHandler; import software.amazon.awssdk.core.client.handler.ClientExecutionParams; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -43,6 +52,8 @@ import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.runtime.transform.AsyncStreamingRequestMarshaller; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -50,6 +61,11 @@ import software.amazon.awssdk.protocols.xml.AwsXmlProtocolFactory; import software.amazon.awssdk.protocols.xml.XmlOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeParams; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeProvider; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointParams; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointProvider; +import software.amazon.awssdk.services.xml.endpoints.internal.XmlEndpointResolverUtils; import software.amazon.awssdk.services.xml.internal.ServiceVersionInfo; import software.amazon.awssdk.services.xml.internal.XmlServiceClientConfigurationBuilder; import software.amazon.awssdk.services.xml.model.APostOperationRequest; @@ -91,6 +107,7 @@ import software.amazon.awssdk.services.xml.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link XmlAsyncClient}. @@ -164,11 +181,17 @@ public CompletableFuture aPostOperation(APostOperationRe CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("APostOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).hostPrefixExpression(resolvedHostExpression) - .withMetricCollector(apiCallMetricCollector).withInput(aPostOperationRequest)); + .withCombinedResponseHandler(responseHandler) + .hostPrefixExpression(resolvedHostExpression) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) + .withInput(aPostOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -224,10 +247,15 @@ public CompletableFuture aPostOperationWithOut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withRequestConfiguration(clientConfiguration) + .withOperationName("APostOperationWithOutput") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withInput(aPostOperationWithOutputRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -280,11 +308,17 @@ public CompletableFuture bearerAuthOperation( CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("BearerAuthOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("BearerAuthOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).credentialType(CredentialType.TOKEN) - .withMetricCollector(apiCallMetricCollector).withInput(bearerAuthOperationRequest)); + .withCombinedResponseHandler(responseHandler) + .credentialType(CredentialType.TOKEN) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) + .withInput(bearerAuthOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -353,12 +387,17 @@ public CompletableFuture eventStreamOperation(EventStreamOperationRequest CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("EventStreamOperation").withRequestConfiguration(clientConfiguration) + .withOperationName("EventStreamOperation") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new EventStreamOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withMetricCollector(apiCallMetricCollector).withInput(eventStreamOperationRequest), - restAsyncResponseTransformer); + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "EventStreamOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "EventStreamOperation")) + .withInput(eventStreamOperationRequest), restAsyncResponseTransformer); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { if (e != null) { @@ -423,6 +462,9 @@ public CompletableFuture getOperationWithCheck .withMarshaller(new GetOperationWithChecksumRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -487,6 +529,9 @@ public CompletableFuture operationWithChe .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()).withInput(operationWithChecksumRequiredRequest)); CompletableFuture whenCompleteFuture = null; @@ -540,10 +585,15 @@ public CompletableFuture operationWithNoneAut CompletableFuture executeFuture = clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withRequestConfiguration(clientConfiguration) + .withOperationName("OperationWithNoneAuthType") + .withRequestConfiguration(clientConfiguration) .withProtocolMetadata(protocolMetadata) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory)) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withInput(operationWithNoneAuthTypeRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { @@ -603,6 +653,9 @@ public CompletableFuture operationWithR .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory)) .withCombinedResponseHandler(responseHandler) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withInput(operationWithRequestCompressionRequest)); @@ -691,6 +744,9 @@ public CompletableFuture putOperationWithChecksum( .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -777,9 +833,13 @@ public CompletableFuture streamingInputOperatio .withMarshaller( AsyncStreamingRequestMarshaller.builder() .delegateMarshaller(new StreamingInputOperationRequestMarshaller(protocolFactory)) - .asyncRequestBody(requestBody).build()).withCombinedResponseHandler(responseHandler) - .withMetricCollector(apiCallMetricCollector).withAsyncRequestBody(requestBody) - .withInput(streamingInputOperationRequest)); + .asyncRequestBody(requestBody).build()) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) + .withAsyncRequestBody(requestBody).withInput(streamingInputOperationRequest)); CompletableFuture whenCompleteFuture = null; whenCompleteFuture = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -842,10 +902,16 @@ public CompletableFuture streamingOutputOperation( CompletableFuture executeFuture = clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) .withAsyncResponseTransformer(asyncResponseTransformer).withInput(streamingOutputOperationRequest), asyncResponseTransformer); CompletableFuture whenCompleteFuture = null; @@ -903,6 +969,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + XmlAuthSchemeProvider authSchemeProvider = Validate + .isInstanceOf(XmlAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of XmlAuthSchemeProvider"); + XmlAuthSchemeParams.Builder paramsBuilder = XmlAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + XmlEndpointProvider provider = (XmlEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + XmlEndpointParams endpointParams = XmlEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = XmlEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = XmlEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + XmlEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java index 627684e47ff2..9706ee65f074 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-xml-client-class.java @@ -2,10 +2,16 @@ import java.util.Collections; import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; import software.amazon.awssdk.awscore.internal.AwsServiceProtocol; @@ -16,6 +22,7 @@ import software.amazon.awssdk.core.Response; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -23,6 +30,7 @@ import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; @@ -32,6 +40,8 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -39,6 +49,11 @@ import software.amazon.awssdk.protocols.xml.AwsXmlProtocolFactory; import software.amazon.awssdk.protocols.xml.XmlOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeParams; +import software.amazon.awssdk.services.xml.auth.scheme.XmlAuthSchemeProvider; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointParams; +import software.amazon.awssdk.services.xml.endpoints.XmlEndpointProvider; +import software.amazon.awssdk.services.xml.endpoints.internal.XmlEndpointResolverUtils; import software.amazon.awssdk.services.xml.internal.ServiceVersionInfo; import software.amazon.awssdk.services.xml.internal.XmlServiceClientConfigurationBuilder; import software.amazon.awssdk.services.xml.model.APostOperationRequest; @@ -74,6 +89,7 @@ import software.amazon.awssdk.services.xml.transform.StreamingInputOperationRequestMarshaller; import software.amazon.awssdk.services.xml.transform.StreamingOutputOperationRequestMarshaller; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; /** * Internal implementation of {@link XmlClient}. @@ -142,7 +158,10 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withOperationName("APostOperation").withProtocolMetadata(protocolMetadata) .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) - .withInput(aPostOperationRequest).withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); + .withInput(aPostOperationRequest) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperation")) + .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); } @@ -188,9 +207,15 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "APostOperationWithOutput")) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -235,6 +260,8 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "BearerAuthOperation")) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -283,6 +310,9 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "GetOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -336,6 +366,9 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithChecksumRequired")) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -380,9 +413,15 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withCombinedResponseHandler(responseHandler).withMetricCollector(apiCallMetricCollector) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withCombinedResponseHandler(responseHandler) + .withMetricCollector(apiCallMetricCollector) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithNoneAuthType")) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -431,6 +470,9 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "OperationWithRequestCompression")) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -509,6 +551,9 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "PutOperationWithChecksum")) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -585,6 +630,9 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withMetricCollector(apiCallMetricCollector) .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingInputOperation")) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -639,10 +687,17 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withEndpointResolver((r, a) -> resolveEndpoint(r, a, "StreamingOutputOperation")) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -669,6 +724,48 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + XmlAuthSchemeProvider authSchemeProvider = Validate + .isInstanceOf(XmlAuthSchemeProvider.class, clientConfiguration.option(SdkClientOption.AUTH_SCHEME_PROVIDER), + "Expected an instance of XmlAuthSchemeProvider"); + XmlAuthSchemeParams.Builder paramsBuilder = XmlAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + List options = authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + return options; + } + + private Endpoint resolveEndpoint(SdkRequest request, ExecutionAttributes executionAttributes, String operationName) { + XmlEndpointProvider provider = (XmlEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + XmlEndpointParams endpointParams = XmlEndpointResolverUtils.ruleParams(request, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = XmlEndpointResolverUtils.hostPrefix(operationName, request); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = XmlEndpointResolverUtils.authSchemeWithEndpointSignerProperties(endpointAuthSchemes, + selectedAuthScheme); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + XmlEndpointResolverUtils.setMetricValues(endpoint, executionAttributes); + return endpoint; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } + throw SdkClientException.create("Endpoint resolution failed: " + cause.getMessage(), cause); + } + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java index 4c4d67aebfbb..ad3b2e674750 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-endpointsbasedauth.java @@ -7,10 +7,10 @@ import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoints.AccountIdEndpointMode; import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; -import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.internal.useragent.BusinessMetricsUtils; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; @@ -41,8 +41,7 @@ public final class QueryEndpointResolverUtils { private QueryEndpointResolverUtils() { } - public static QueryEndpointParams ruleParams(SdkRequest request, - ExecutionAttributes executionAttributes) { + public static QueryEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { QueryEndpointParams.Builder builder = QueryEndpointParams.builder(); builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); builder.useDualStackEndpoint(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(executionAttributes)); @@ -56,31 +55,31 @@ public static QueryEndpointParams ruleParams(SdkRequest request, return builder.build(); } - private static void setContextParams(QueryEndpointParams.Builder params, String operationName, - SdkRequest request) { + private static void setContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { switch (operationName) { - case "OperationWithContextParam":setContextParams(params, (OperationWithContextParamRequest) request); - break; - default:break; + case "OperationWithContextParam": + setContextParams(params, (OperationWithContextParamRequest) request); + break; + default: + break; } } - private static void setContextParams(QueryEndpointParams.Builder params, - OperationWithContextParamRequest request) { + private static void setContextParams(QueryEndpointParams.Builder params, OperationWithContextParamRequest request) { params.operationContextParam(request.stringMember()); } - private static void setStaticContextParams(QueryEndpointParams.Builder params, - String operationName) { + private static void setStaticContextParams(QueryEndpointParams.Builder params, String operationName) { switch (operationName) { - case "OperationWithStaticContextParams":operationWithStaticContextParamsStaticContextParams(params); - break; - default:break; + case "OperationWithStaticContextParams": + operationWithStaticContextParamsStaticContextParams(params); + break; + default: + break; } } - private static void operationWithStaticContextParamsStaticContextParams( - QueryEndpointParams.Builder params) { + private static void operationWithStaticContextParamsStaticContextParams(QueryEndpointParams.Builder params) { params.staticStringParam("hello"); } @@ -109,7 +108,9 @@ public static SelectedAuthScheme authSchemeWithEndpointS if (v4aAuthScheme.isDisableDoubleEncodingSet()) { option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); } - if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme + .authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) + && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); } @@ -118,37 +119,41 @@ public static SelectedAuthScheme authSchemeWithEndpointS } return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); } - throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); } return selectedAuthScheme; } - private static void setClientContextParams(QueryEndpointParams.Builder params, - ExecutionAttributes executionAttributes) { + private static void setClientContextParams(QueryEndpointParams.Builder params, ExecutionAttributes executionAttributes) { AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS); - Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent(params::booleanContextParam); - Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent(params::stringContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent( + params::booleanContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent( + params::stringContextParam); } - private static void setOperationContextParams(QueryEndpointParams.Builder params, - String operationName, SdkRequest request) { + private static void setOperationContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { switch (operationName) { - case "OperationWithMapOperationContextParam":setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); - break; - case "OperationWithOperationContextParam":setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); - break; - default:break; + case "OperationWithMapOperationContextParam": + setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); + break; + case "OperationWithOperationContextParam": + setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); + break; + default: + break; } } private static void setOperationContextParams(QueryEndpointParams.Builder params, - OperationWithMapOperationContextParamRequest request) { + OperationWithMapOperationContextParamRequest request) { JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); params.arnList(input.field("RequestMap").keys()); } private static void setOperationContextParams(QueryEndpointParams.Builder params, - OperationWithOperationContextParamRequest request) { + OperationWithOperationContextParamRequest request) { JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString")); } @@ -158,21 +163,22 @@ public static Optional hostPrefix(String operationName, SdkRequest reque case "APostOperation": { return Optional.of("foo-"); } - default:return Optional.empty(); + default: + return Optional.empty(); } } - private static String resolveAndRecordAccountIdFromIdentity( - ExecutionAttributes executionAttributes) { - String accountId = accountIdFromIdentity(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + private static String resolveAndRecordAccountIdFromIdentity(ExecutionAttributes executionAttributes) { + String accountId = accountIdFromIdentity(executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); if (accountId != null) { - executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric( + BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); } return accountId; } - private static String accountIdFromIdentity( - SelectedAuthScheme selectedAuthScheme) { + private static String accountIdFromIdentity(SelectedAuthScheme selectedAuthScheme) { T identity = CompletableFutureUtils.joinLikeSync(selectedAuthScheme.identity()); String accountId = null; if (identity instanceof AwsCredentialsIdentity) { @@ -183,13 +189,15 @@ private static String accountIdFromIdentity( private static String recordAccountIdEndpointMode(ExecutionAttributes executionAttributes) { AccountIdEndpointMode mode = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_AUTH_ACCOUNT_ID_ENDPOINT_MODE); - BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent(m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); + BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent( + m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); return mode.name().toLowerCase(); } public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { - executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent(metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); } } } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java index b37a93855a11..f4a83932a826 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils-with-multiauthsigv4a.java @@ -6,10 +6,10 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; -import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; @@ -30,8 +30,7 @@ public final class DatabaseEndpointResolverUtils { private DatabaseEndpointResolverUtils() { } - public static DatabaseEndpointParams ruleParams(SdkRequest request, - ExecutionAttributes executionAttributes) { + public static DatabaseEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { DatabaseEndpointParams.Builder builder = DatabaseEndpointParams.builder(); builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); builder.endpoint(AwsEndpointProviderUtils.endpointBuiltIn(executionAttributes)); @@ -41,12 +40,10 @@ public static DatabaseEndpointParams ruleParams(SdkRequest request, return builder.build(); } - private static void setContextParams(DatabaseEndpointParams.Builder params, String operationName, - SdkRequest request) { + private static void setContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { } - private static void setStaticContextParams(DatabaseEndpointParams.Builder params, - String operationName) { + private static void setStaticContextParams(DatabaseEndpointParams.Builder params, String operationName) { } public static SelectedAuthScheme authSchemeWithEndpointSignerProperties( @@ -74,7 +71,9 @@ public static SelectedAuthScheme authSchemeWithEndpointS if (v4aAuthScheme.isDisableDoubleEncodingSet()) { option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); } - if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme + .authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) + && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); } @@ -83,13 +82,13 @@ public static SelectedAuthScheme authSchemeWithEndpointS } return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); } - throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); } return selectedAuthScheme; } - private static void setOperationContextParams(DatabaseEndpointParams.Builder params, - String operationName, SdkRequest request) { + private static void setOperationContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { } public static Optional hostPrefix(String operationName, SdkRequest request) { @@ -98,7 +97,8 @@ public static Optional hostPrefix(String operationName, SdkRequest reque public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { - executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent(metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); } } } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java index 97ac44776a59..01c7ba43831f 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolver-utils.java @@ -7,10 +7,10 @@ import software.amazon.awssdk.awscore.AwsExecutionAttribute; import software.amazon.awssdk.awscore.endpoints.AccountIdEndpointMode; import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; -import software.amazon.awssdk.awscore.internal.endpoints.AwsEndpointProviderUtils; import software.amazon.awssdk.awscore.internal.useragent.BusinessMetricsUtils; import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; @@ -41,8 +41,7 @@ public final class QueryEndpointResolverUtils { private QueryEndpointResolverUtils() { } - public static QueryEndpointParams ruleParams(SdkRequest request, - ExecutionAttributes executionAttributes) { + public static QueryEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { QueryEndpointParams.Builder builder = QueryEndpointParams.builder(); builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); builder.useDualStackEndpoint(AwsEndpointProviderUtils.dualStackEnabledBuiltIn(executionAttributes)); @@ -56,31 +55,31 @@ public static QueryEndpointParams ruleParams(SdkRequest request, return builder.build(); } - private static void setContextParams(QueryEndpointParams.Builder params, String operationName, - SdkRequest request) { + private static void setContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { switch (operationName) { - case "OperationWithContextParam":setContextParams(params, (OperationWithContextParamRequest) request); - break; - default:break; + case "OperationWithContextParam": + setContextParams(params, (OperationWithContextParamRequest) request); + break; + default: + break; } } - private static void setContextParams(QueryEndpointParams.Builder params, - OperationWithContextParamRequest request) { + private static void setContextParams(QueryEndpointParams.Builder params, OperationWithContextParamRequest request) { params.operationContextParam(request.stringMember()); } - private static void setStaticContextParams(QueryEndpointParams.Builder params, - String operationName) { + private static void setStaticContextParams(QueryEndpointParams.Builder params, String operationName) { switch (operationName) { - case "OperationWithStaticContextParams":operationWithStaticContextParamsStaticContextParams(params); - break; - default:break; + case "OperationWithStaticContextParams": + operationWithStaticContextParamsStaticContextParams(params); + break; + default: + break; } } - private static void operationWithStaticContextParamsStaticContextParams( - QueryEndpointParams.Builder params) { + private static void operationWithStaticContextParamsStaticContextParams(QueryEndpointParams.Builder params) { params.staticStringParam("hello"); } @@ -118,45 +117,50 @@ public static SelectedAuthScheme authSchemeWithEndpointS } return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); } - throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); } return selectedAuthScheme; } - private static void setClientContextParams(QueryEndpointParams.Builder params, - ExecutionAttributes executionAttributes) { + private static void setClientContextParams(QueryEndpointParams.Builder params, ExecutionAttributes executionAttributes) { AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS); - Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent(params::booleanContextParam); - Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent(params::stringContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.BOOLEAN_CONTEXT_PARAM)).ifPresent( + params::booleanContextParam); + Optional.ofNullable(clientContextParams.get(QueryClientContextParams.STRING_CONTEXT_PARAM)).ifPresent( + params::stringContextParam); } - private static void setOperationContextParams(QueryEndpointParams.Builder params, - String operationName, SdkRequest request) { + private static void setOperationContextParams(QueryEndpointParams.Builder params, String operationName, SdkRequest request) { switch (operationName) { - case "OperationWithCustomizedOperationContextParam":setOperationContextParams(params, (OperationWithCustomizedOperationContextParamRequest) request); - break; - case "OperationWithMapOperationContextParam":setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); - break; - case "OperationWithOperationContextParam":setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); - break; - default:break; + case "OperationWithCustomizedOperationContextParam": + setOperationContextParams(params, (OperationWithCustomizedOperationContextParamRequest) request); + break; + case "OperationWithMapOperationContextParam": + setOperationContextParams(params, (OperationWithMapOperationContextParamRequest) request); + break; + case "OperationWithOperationContextParam": + setOperationContextParams(params, (OperationWithOperationContextParamRequest) request); + break; + default: + break; } } private static void setOperationContextParams(QueryEndpointParams.Builder params, - OperationWithCustomizedOperationContextParamRequest request) { + OperationWithCustomizedOperationContextParamRequest request) { JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString").stringValues()); } private static void setOperationContextParams(QueryEndpointParams.Builder params, - OperationWithMapOperationContextParamRequest request) { + OperationWithMapOperationContextParamRequest request) { JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); params.arnList(input.field("RequestMap").keys().stringValues()); } private static void setOperationContextParams(QueryEndpointParams.Builder params, - OperationWithOperationContextParamRequest request) { + OperationWithOperationContextParamRequest request) { JmesPathRuntime.Value input = new JmesPathRuntime.Value(request); params.customEndpointArray(input.field("ListMember").field("StringList").wildcard().field("LeafString").stringValues()); } @@ -166,21 +170,22 @@ public static Optional hostPrefix(String operationName, SdkRequest reque case "APostOperation": { return Optional.of("foo-"); } - default:return Optional.empty(); + default: + return Optional.empty(); } } - private static String resolveAndRecordAccountIdFromIdentity( - ExecutionAttributes executionAttributes) { - String accountId = accountIdFromIdentity(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); + private static String resolveAndRecordAccountIdFromIdentity(ExecutionAttributes executionAttributes) { + String accountId = accountIdFromIdentity(executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)); if (accountId != null) { - executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric( + BusinessMetricFeatureId.RESOLVED_ACCOUNT_ID.value()); } return accountId; } - private static String accountIdFromIdentity( - SelectedAuthScheme selectedAuthScheme) { + private static String accountIdFromIdentity(SelectedAuthScheme selectedAuthScheme) { T identity = CompletableFutureUtils.joinLikeSync(selectedAuthScheme.identity()); String accountId = null; if (identity instanceof AwsCredentialsIdentity) { @@ -191,13 +196,15 @@ private static String accountIdFromIdentity( private static String recordAccountIdEndpointMode(ExecutionAttributes executionAttributes) { AccountIdEndpointMode mode = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_AUTH_ACCOUNT_ID_ENDPOINT_MODE); - BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent(m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); + BusinessMetricsUtils.resolveAccountIdEndpointModeMetric(mode).ifPresent( + m -> executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).addMetric(m)); return mode.name().toLowerCase(); } public static void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { - executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent(metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); } } } From 059c60c5c48a17a523d60ee0787308edebc19dd4 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Mon, 20 Apr 2026 05:20:53 -0700 Subject: [PATCH 2/2] Fix NPE in EndpointResolutionStage when port is null --- .../internal/http/pipeline/stages/EndpointResolutionStage.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java index 40c329f01a84..b01dc444fa77 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/EndpointResolutionStage.java @@ -104,10 +104,11 @@ private static boolean interceptorModifiedEndpoint(SdkHttpFullRequest.Builder re return false; } String requestHost = request.host(); + Integer requestPort = request.port(); return requestHost != null && (!requestHost.equals(preModifyUri.getHost()) || !String.valueOf(request.protocol()).equals(preModifyUri.getScheme()) - || request.port() != preModifyUri.getPort()); + || (requestPort != null && requestPort != preModifyUri.getPort())); } /**