diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 4b6719cc709c..471d9a5ddaab 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -47,7 +47,6 @@ protected List createTasks() { tasks.add(generateDefaultParamsImpl()); tasks.add(generateModelBasedProvider()); tasks.add(generatePreferenceProvider()); - tasks.add(generateAuthSchemeInterceptor()); if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { tasks.add(generateEndpointBasedProvider()); tasks.add(generateEndpointAwareAuthSchemeParams()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index 44d8ce154e72..e342a79b4e25 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -370,7 +370,6 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor()); builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 539d5df35a43..fc31f3910476 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -69,10 +69,12 @@ import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; import software.amazon.awssdk.codegen.poet.StaticImport; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils; import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; @@ -100,6 +102,8 @@ public final class AsyncClientClass extends AsyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; private boolean hasScheduledExecutor; public AsyncClientClass(GeneratorTaskParams dependencies) { @@ -110,6 +114,8 @@ public AsyncClientClass(GeneratorTaskParams dependencies) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -165,7 +171,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java index 9d4e15b3dd86..f0f27524cf65 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; @@ -45,6 +46,7 @@ import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; @@ -53,6 +55,7 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -346,4 +349,120 @@ public static MethodSpec updateRetryStrategyClientConfigurationMethod() { static String transformServiceId(String serviceId) { return serviceId.replace(" ", "_"); } + + static MethodSpec resolveAuthSchemeOptionsMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + + builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", + providerInterface, providerInterface, SdkClientOption.class); + + if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + addEndpointBasedAuthSchemeResolution(builder, authSchemeSpecUtils, endpointRulesSpecUtils); + } else { + addSimpleAuthSchemeResolution(builder, authSchemeSpecUtils); + } + + return builder.build(); + } + + private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.nextControlFlow("else"); + builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", + regionSet, awsClientOption); + builder.endControlFlow(); + } + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } + + private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); + ClassName sdkInternalExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", + "SdkInternalExecutionAttribute"); + + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + sdkInternalExecutionAttribute, SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + sdkInternalExecutionAttribute, SdkClientOption.class); + + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, resolverInterceptor); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 5736ddbecaf5..6d9bc4e0ef4f 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -54,12 +54,14 @@ import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; @@ -83,6 +85,8 @@ public class SyncClientClass extends SyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; public SyncClientClass(GeneratorTaskParams taskParams) { super(taskParams.getModel()); @@ -92,6 +96,8 @@ public SyncClientClass(GeneratorTaskParams taskParams) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -133,7 +139,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { type.addMethod(constructor()) .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index b9b69d9bd8a0..ad964239ca02 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -221,7 +221,9 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)\n", opModel.getInput().getVariableName()) - .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -295,6 +297,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withErrorResponseHandler(errorResponseHandler)\n") .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index bab1f29e7281..624a4813655d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -116,6 +116,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -155,6 +157,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(credentialType(opModel, intermediateModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index dbbd33f4276f..92e646312b2a 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -134,7 +134,10 @@ public CodeBlock executionHandler(OperationModel opModel) { discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") - .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -212,7 +215,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper builder.add(hostPrefixExpression(opModel)) .add(credentialType(opModel, model)) - .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java index d3fc4996de98..4ff83066fe22 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java @@ -31,7 +31,6 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.h2.auth.scheme.H2AuthSchemeProvider; -import software.amazon.awssdk.services.h2.auth.scheme.internal.H2AuthSchemeInterceptor; import software.amazon.awssdk.services.h2.endpoints.H2EndpointProvider; import software.amazon.awssdk.services.h2.endpoints.internal.H2RequestSetEndpointInterceptor; import software.amazon.awssdk.services.h2.endpoints.internal.H2ResolveEndpointInterceptor; @@ -70,7 +69,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new H2AuthSchemeInterceptor()); endpointInterceptors.add(new H2ResolveEndpointInterceptor()); endpointInterceptors.add(new H2RequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java index d59409dcbcdd..67473d68951c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java @@ -13,6 +13,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; @@ -29,6 +30,7 @@ import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -38,6 +40,8 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeParams; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeProvider; import software.amazon.awssdk.services.foobar.endpoints.FooBarClientContextParams; import software.amazon.awssdk.services.foobar.internal.FooBarServiceClientConfigurationBuilder; import software.amazon.awssdk.services.foobar.internal.ServiceVersionInfo; @@ -60,7 +64,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { private static final Logger log = LoggerFactory.getLogger(DefaultFooBarAsyncClient.class); private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() - .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); private final AsyncClientHandler clientHandler; @@ -71,7 +75,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -100,38 +104,43 @@ protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { @Override public CompletableFuture getDatabaseVersion(GetDatabaseVersionRequest getDatabaseVersionRequest) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(getDatabaseVersionRequest, - this.clientConfiguration); + this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, getDatabaseVersionRequest - .overrideConfiguration().orElse(null)); + .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector - .create("ApiCall"); + .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "Foo Bar"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "GetDatabaseVersion"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) - .isPayloadJson(true).build(); + .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( - operationMetadata, GetDatabaseVersionResponse::builder); + operationMetadata, GetDatabaseVersionResponse::builder); Function> exceptionMetadataMapper = errorCode -> { if (errorCode == null) { return Optional.empty(); } switch (errorCode) { - default: - return Optional.empty(); + default: + return Optional.empty(); } }; HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, - operationMetadata, exceptionMetadataMapper); + operationMetadata, exceptionMetadataMapper); CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() - .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) - .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(getDatabaseVersionRequest)); + .execute(new ClientExecutionParams() + .withOperationName("GetDatabaseVersion") + .withProtocolMetadata(protocolMetadata) + .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetDatabaseVersion", clientConfiguration)) + .withInput(getDatabaseVersionRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -155,11 +164,11 @@ public final String serviceName() { private > T init(T builder) { return builder.clientConfiguration(clientConfiguration).defaultServiceExceptionSupplier(FooBarException::builder) - .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); + .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, - RequestOverrideConfiguration requestOverrideConfiguration) { + RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); @@ -173,6 +182,15 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + FooBarAuthSchemeProvider authSchemeProvider = (FooBarAuthSchemeProvider) clientConfiguration + .option(SdkClientOption.AUTH_SCHEME_PROVIDER); + FooBarAuthSchemeParams.Builder paramsBuilder = FooBarAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + return authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); @@ -211,15 +229,15 @@ private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, newContextParams = (newContextParams != null) ? newContextParams : AttributeMap.empty(); originalContextParams = originalContextParams != null ? originalContextParams : AttributeMap.empty(); Validate.validState( - Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), - newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), - "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); + Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), + newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), + "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); updateRetryStrategyClientConfiguration(configuration); return configuration.build(); } private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, - JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { + JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java index 3c168823a547..766210e8c6f4 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java @@ -5,6 +5,7 @@ import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; @@ -32,12 +33,15 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; import software.amazon.awssdk.protocols.core.ExceptionMetadata; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -163,6 +167,7 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -210,10 +215,15 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -259,6 +269,7 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -309,6 +320,8 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -364,6 +377,8 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -409,10 +424,15 @@ public OperationWithContextParamResponse operationWithContextParam( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithContextParamRequest) + .withOperationName("OperationWithContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -457,10 +477,15 @@ public OperationWithCustomMemberResponse operationWithCustomMember( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithCustomMember").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithCustomMemberRequest) + .withOperationName("OperationWithCustomMember") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithCustomMemberRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -506,10 +531,15 @@ public OperationWithCustomizedOperationContextParamResponse operationWithCustomi return clientHandler .execute(new ClientExecutionParams() .withOperationName("OperationWithCustomizedOperationContextParam") - .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) .withInput(operationWithCustomizedOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", + clientConfiguration)) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -554,10 +584,15 @@ public OperationWithMapOperationContextParamResponse operationWithMapOperationCo return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithMapOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withOperationName("OperationWithMapOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) - .withInput(operationWithMapOperationContextParamRequest).withMetricCollector(apiCallMetricCollector) + .withInput(operationWithMapOperationContextParamRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -601,10 +636,15 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -649,10 +689,15 @@ public OperationWithOperationContextParamResponse operationWithOperationContextP return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithOperationContextParamRequest) + .withOperationName("OperationWithOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -703,6 +748,8 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -748,10 +795,15 @@ public OperationWithStaticContextParamsResponse operationWithStaticContextParams return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithStaticContextParams").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithStaticContextParamsRequest) + .withOperationName("OperationWithStaticContextParams") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithStaticContextParamsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -828,6 +880,8 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -906,6 +960,8 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -960,10 +1016,16 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1003,6 +1065,15 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryAuthSchemeProvider authSchemeProvider = (QueryAuthSchemeProvider) clientConfiguration + .option(SdkClientOption.AUTH_SCHEME_PROVIDER); + QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + return authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index 32273f16019c..24a9fc09157e 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -33,6 +33,7 @@ import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategy; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategyFactory; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.util.SignerOverrideUtils; import software.amazon.awssdk.core.HttpChecksumConstant; import software.amazon.awssdk.core.RequestOverrideConfiguration; @@ -63,6 +64,7 @@ import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.auth.scheme.NoAuthAuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; import software.amazon.awssdk.metrics.MetricCollector; @@ -144,6 +146,18 @@ private AwsExecutionContextBuilder() { // Auth Scheme resolution related attributes putAuthSchemeResolutionAttributes(executionAttributes, clientConfig, originalRequest); + if (executionParams.authSchemeOptionsResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + executionParams.authSchemeOptionsResolver()); + + List authOptions = executionParams.authSchemeOptionsResolver().resolve(originalRequest); + recordAuthSchemeBusinessMetrics(authOptions, executionAttributes, originalRequest); + } + + // Set the identity provider updater for the pipeline stage to use + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, + AwsIdentityProviderUpdater.INSTANCE); + ExecutionInterceptorChain executionInterceptorChain = new ExecutionInterceptorChain(clientConfig.option(SdkClientOption.EXECUTION_INTERCEPTORS)); @@ -355,7 +369,7 @@ private static EndpointProvider resolveEndpointProvider(SdkRequest request, } private static BusinessMetricCollection - resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, + resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, ClientExecutionParams executionParams) { BusinessMetricCollection businessMetrics = new BusinessMetricCollection(); Optional retryModeMetric = resolveRetryMode(clientConfig.option(RETRY_POLICY), @@ -373,4 +387,66 @@ private static boolean isRpcV2CborProtocol(SdkProtocolMetadata protocolMetadata) return protocolMetadata != null && SMITHY_RPC_V2_CBOR.toString().equals(protocolMetadata.serviceProtocol()); } + + /** + * Records business metrics for auth scheme selection (SigV4a, bearer token). + */ + private static void recordAuthSchemeBusinessMetrics(List authSchemeOptions, + ExecutionAttributes executionAttributes, + SdkRequest request) { + if (authSchemeOptions == null || authSchemeOptions.isEmpty()) { + return; + } + + BusinessMetricCollection businessMetrics = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS); + if (businessMetrics == null) { + return; + } + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + if (authSchemes == null || identityProviders == null) { + return; + } + + for (AuthSchemeOption authOption : authSchemeOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + if (authScheme == null) { + continue; // Auth scheme not enabled, try next option + } + + if (authScheme.identityProvider(identityProviders) == null) { + continue; // Identity provider not configured, try next option + } + + // Check for SigV4a + if ("aws.auth#sigv4a".equals(authOption.schemeId()) && + !SignerOverrideUtils.isSignerOverridden(request, executionAttributes)) { + businessMetrics.addMetric(BusinessMetricFeatureId.SIGV4A_SIGNING.value()); + } + + // Check for bearer token from environment + if ("smithy.api#httpBearerAuth".equals(authOption.schemeId())) { + String tokenFromEnv = executionAttributes.getAttribute(SdkInternalExecutionAttribute.TOKEN_CONFIGURED_FROM_ENV); + if (tokenFromEnv != null && !hasTokenOverride(request)) { + businessMetrics.addMetric(BusinessMetricFeatureId.BEARER_SERVICE_ENV_VARS.value()); + } + } + + return; + } + } + + /** + * Check if the request has a token identity provider override. + */ + private static boolean hasTokenOverride(SdkRequest request) { + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .flatMap(AwsRequestOverrideConfiguration::tokenIdentityProvider) + .isPresent(); + } } diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java new file mode 100644 index 000000000000..1cafd4675a06 --- /dev/null +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java @@ -0,0 +1,50 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.identity; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * AWS implementation of {@link IdentityProviderUpdater} that reads credential overrides + * from {@link AwsRequestOverrideConfiguration}. + */ +@SdkInternalApi +public final class AwsIdentityProviderUpdater implements IdentityProviderUpdater { + + public static final AwsIdentityProviderUpdater INSTANCE = new AwsIdentityProviderUpdater(); + + private AwsIdentityProviderUpdater() { + } + + @Override + public IdentityProviders update(SdkRequest request, IdentityProviders base) { + if (base == null) { + return null; + } + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .map(c -> base.copy(b -> { + c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); + c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); + })) + .orElse(base); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java index 10cb488792e2..8772260e3042 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java @@ -22,9 +22,27 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.utils.Validate; -/** - * A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. - */ + +/// +/// A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. +/// ## The Hierarchy +/// ``` +/// IDENTITY_PROVIDERS (IdentityProviders) +/// └── contains multiple IdentityProvider instances +/// e.g., IdentityProvider for AWS credentials +/// e.g., IdentityProvider for bearer tokens +/// +/// AUTH_SCHEMES (Map>) +/// └── each AuthScheme knows: +/// - which IdentityProvider type it needs +/// - which HttpSigner to use +/// +/// SELECTED_AUTH_SCHEME (SelectedAuthScheme) +/// └── the chosen auth scheme, containing: +/// - identity: CompletableFuture ← the resolved identity! +/// - signer: HttpSigner +/// - authSchemeOption: AuthSchemeOption +/// ``` @SdkProtectedApi public final class SelectedAuthScheme { private final CompletableFuture identity; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index e307f5857ce7..66b4d9e7bf0f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -30,6 +30,7 @@ import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.runtime.transform.Marshaller; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.metrics.MetricCollector; @@ -63,6 +64,7 @@ public final class ClientExecutionParams { private MetricCollector metricCollector; private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; + private AuthSchemeOptionsResolver authSchemeOptionsResolver; public Marshaller getMarshaller() { return marshaller; @@ -261,4 +263,14 @@ public ClientExecutionParams withRequestConfiguration(SdkCl this.requestConfiguration = requestConfiguration; return this; } + + public AuthSchemeOptionsResolver authSchemeOptionsResolver() { + return authSchemeOptionsResolver; + } + + public ClientExecutionParams withAuthSchemeOptionsResolver( + AuthSchemeOptionsResolver authSchemeOptionsResolver) { + this.authSchemeOptionsResolver = authSchemeOptionsResolver; + return this; + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 3fe0f69d3ab8..8c73ce3b70ef 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -29,6 +29,8 @@ import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.endpoints.Endpoint; @@ -166,6 +168,20 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { */ public static final ExecutionAttribute IDENTITY_PROVIDERS = new ExecutionAttribute<>("IdentityProviders"); + /** + * Callback for updating identity providers based on request-level overrides. + * This allows aws-core to provide AWS-specific logic without sdk-core depending on aws-core. + */ + public static final ExecutionAttribute IDENTITY_PROVIDER_UPDATER = + new ExecutionAttribute<>("IdentityProviderUpdater"); + + /** + * Callback to resolve auth scheme options from the (possibly modified) request. + * Called by AuthSchemeResolutionStage after interceptors have run. + */ + public static final ExecutionAttribute AUTH_SCHEME_OPTIONS_RESOLVER = + new ExecutionAttribute<>("AuthSchemeOptionsResolver"); + /** * The selected auth scheme for a request. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index 59bd964d6fad..f470bab80494 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -39,6 +39,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncSigningStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage; @@ -200,6 +201,7 @@ public CompletableFuture execute( .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) + .then(AuthSchemeResolutionStage::new) .then(RequestPipelineBuilder .first(AsyncSigningStage::new) .then(AsyncBeforeTransmissionExecutionInterceptorsStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index dccaad1e2109..bb5f6816f321 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -34,6 +34,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyTransactionIdStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyUserAgentStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; @@ -189,6 +190,7 @@ public OutputT execute(HttpResponseHandler> response .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) // End of mutating request + .then(AuthSchemeResolutionStage::new) .then(RequestPipelineBuilder .first(SigningStage::new) .then(BeforeTransmissionExecutionInterceptorsStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java new file mode 100644 index 000000000000..6348a85f33b9 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java @@ -0,0 +1,179 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.utils.Logger; + +/** + * Shared utility for selecting auth schemes from a list of options. + */ +@SdkInternalApi +public final class AuthSchemeResolver { + + private static final Logger LOG = Logger.loggerFor(AuthSchemeResolver.class); + + private AuthSchemeResolver() { + } + + /** + * Select an auth scheme from the given options. + * + * @param authOptions List of auth scheme options to try in order + * @param authSchemes Map of available auth schemes + * @param identityProviders Identity providers to use for resolving identity + * @param metricCollector Optional metric collector for recording identity fetch duration + * @return The selected auth scheme + * @throws SdkException if no auth scheme could be selected + */ + public static SelectedAuthScheme selectAuthScheme( + List authOptions, + Map> authSchemes, + IdentityProviders identityProviders, + MetricCollector metricCollector) { + + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( + authOption, authScheme, identityProviders, discardedReasons, metricCollector); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", + authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw SdkException.builder() + .message("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) + .build(); + } + + /** + * Merge properties from any pre-existing auth scheme into the selected one. + */ + public static SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + private static SelectedAuthScheme trySelectAuthScheme( + AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons, + MetricCollector metricCollector) { + + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity = resolveIdentity( + identityProvider, identityRequestBuilder.build(), metricCollector); + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private static CompletableFuture resolveIdentity( + IdentityProvider identityProvider, + ResolveIdentityRequest request, + MetricCollector metricCollector) { + + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null || metricCollector == null) { + return identityProvider.resolveIdentity(request); + } + return MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(request), metricCollector, metric); + } + + private static SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java new file mode 100644 index 000000000000..fed726670b99 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -0,0 +1,94 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.util.List; +import java.util.Map; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; +import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.metrics.MetricCollector; + +/** + * Pipeline stage that resolves the auth scheme and identity for signing. + */ +@SdkInternalApi +public final class AuthSchemeResolutionStage implements RequestToRequestPipeline { + + public AuthSchemeResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) throws Exception { + ExecutionAttributes executionAttributes = context.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return request; + } + + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + List authOptions = resolveAuthSchemeOptions(executionAttributes, sdkRequest); + if (authOptions == null || authOptions.isEmpty()) { + return request; + } + + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + identityProviders = updater.update(sdkRequest, identityProviders); + } + + MetricCollector metricCollector = + executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, metricCollector); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + + return request; + } + + private List resolveAuthSchemeOptions(ExecutionAttributes executionAttributes, SdkRequest request) { + AuthSchemeOptionsResolver resolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + + if (resolver == null) { + return null; + } + return resolver.resolve(request); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java new file mode 100644 index 000000000000..7a37d7c3dd6b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import java.util.List; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; + +/** + * Callback interface for resolving auth scheme options from the request. + *

+ * This allows auth scheme resolution to happen after interceptors have modified the request, + * ensuring that any request modifications affecting auth scheme selection are respected. + */ +@FunctionalInterface +@SdkProtectedApi +public interface AuthSchemeOptionsResolver { + /** + * Resolves auth scheme options for the given request. + * + * @param request The request (after interceptors have modified it) + * @return List of auth scheme options in priority order + */ + List resolve(SdkRequest request); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java new file mode 100644 index 000000000000..af54e5ef3661 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * Callback interface for updating identity providers based on request-level overrides. + *

+ * This allows aws-core to provide AWS-specific logic for reading credential overrides + * from {@code AwsRequestOverrideConfiguration} without sdk-core depending on aws-core. + */ +@FunctionalInterface +@SdkProtectedApi +public interface IdentityProviderUpdater { + /** + * Updates identity providers based on request-level overrides. + * + * @param request The request (after interceptors have modified it) + * @param base The base identity providers from client configuration + * @return Updated identity providers, or base if no overrides + */ + IdentityProviders update(SdkRequest request, IdentityProviders base); +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java new file mode 100644 index 000000000000..90c2cdf12008 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java @@ -0,0 +1,189 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolverTest { + + private static final String SCHEME_A = "schemeA"; + private static final String SCHEME_B = "schemeB"; + + @Test + void selectAuthScheme_firstOptionSucceeds_returnsFirstScheme() { + AuthScheme schemeA = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @Test + void selectAuthScheme_firstOptionNoScheme_fallsBackToSecond() { + AuthScheme schemeB = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_firstOptionNoIdentityProvider_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.identityProvider(any())).thenReturn(null); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_signerThrows_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.signer()).thenThrow(new RuntimeException("Signer not available")); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_allOptionsFail_throwsException() { + Map> authSchemes = new HashMap<>(); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + assertThatThrownBy(() -> AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null)) + .isInstanceOf(SdkException.class) + .hasMessageContaining("Failed to determine how to authenticate"); + } + + @Test + void mergeProperties_noExistingScheme_returnsOriginal() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + ExecutionAttributes attributes = new ExecutionAttributes(); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isSameAs(selected); + } + + @Test + @SuppressWarnings("unchecked") + void mergeProperties_withExistingScheme_returnsNewInstance() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + SelectedAuthScheme existing = createSelectedAuthScheme(SCHEME_B); + + ExecutionAttributes attributes = new ExecutionAttributes(); + attributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, existing); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isNotSameAs(selected); + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @SuppressWarnings("unchecked") + private AuthScheme createMockAuthScheme() { + AuthScheme scheme = mock(AuthScheme.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(identityProvider).resolveIdentity(any(ResolveIdentityRequest.class)); + when(scheme.identityProvider(any())).thenReturn(identityProvider); + when(scheme.signer()).thenReturn(mock(HttpSigner.class)); + return scheme; + } + + @SuppressWarnings("unchecked") + private SelectedAuthScheme createSelectedAuthScheme(String schemeId) { + Identity mockIdentity = mock(Identity.class); + HttpSigner mockSigner = mock(HttpSigner.class); + return new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mockIdentity), + mockSigner, + AuthSchemeOption.builder().schemeId(schemeId).build() + ); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java new file mode 100644 index 000000000000..c8307eb7fda1 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java @@ -0,0 +1,231 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolutionStageTest { + + private static final String SCHEME_ID = "test.scheme"; + + private AuthSchemeResolutionStage stage; + private SdkHttpFullRequest httpRequest; + private RequestExecutionContext context; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new AuthSchemeResolutionStage(null); + httpRequest = mock(SdkHttpFullRequest.class); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(sdkRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) + .build(); + } + + @Test + void execute_noAuthSchemes_returnsRequestUnchanged() throws Exception { + // AUTH_SCHEMES is null + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + // AUTH_SCHEME_OPTIONS_RESOLVER is null + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReturnsEmpty_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> Collections.emptyList()); + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + + // Setup interceptor context with a DIFFERENT request than originalRequest + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(modifiedRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) // Different from modifiedRequest + .build(); + + AuthSchemeOptionsResolver resolver = mock(AuthSchemeOptionsResolver.class); + doReturn(createAuthOptions()).when(resolver).resolve(modifiedRequest); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, resolver); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + stage.execute(httpRequest, context); + + // Verify resolver was called with the MODIFIED request, not originalRequest + verify(resolver).resolve(modifiedRequest); + } + + @Test + void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Exception { + // Create mocks first before any stubbing + IdentityProvider identityProvider = createMockIdentityProvider(); + Map> authSchemes = createAuthSchemes(); + IdentityProviders baseProviders = mock(IdentityProviders.class); + IdentityProviders updatedProviders = mock(IdentityProviders.class); + + IdentityProviderUpdater updater = mock(IdentityProviderUpdater.class); + doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders); + + // Setup so that auth scheme uses the updated providers + @SuppressWarnings("unchecked") + AuthScheme scheme = (AuthScheme) authSchemes.get(SCHEME_ID); + doReturn(identityProvider).when(scheme).identityProvider(updatedProviders); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, baseProviders); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, updater); + + stage.execute(httpRequest, context); + + verify(updater).update(sdkRequest, baseProviders); + } + + @Test + void execute_withoutIdentityProviderUpdater_doesNotFail() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + // No IDENTITY_PROVIDER_UPDATER set + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNotNull(); + } + + @Test + void execute_happyPath_setsSelectedAuthScheme() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + SelectedAuthScheme selectedAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(selectedAuthScheme).isNotNull(); + assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(SCHEME_ID); + } + + @SuppressWarnings("unchecked") + private Map> createAuthSchemes() { + IdentityProvider identityProvider = createMockIdentityProvider(); + HttpSigner signer = mock(HttpSigner.class); + + AuthScheme scheme = mock(AuthScheme.class); + doReturn(identityProvider).when(scheme).identityProvider(any()); + doReturn(signer).when(scheme).signer(); + + Map> schemes = new HashMap<>(); + schemes.put(SCHEME_ID, scheme); + return schemes; + } + + @SuppressWarnings("unchecked") + private IdentityProvider createMockIdentityProvider() { + IdentityProvider provider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(provider).resolveIdentity(any(ResolveIdentityRequest.class)); + return provider; + } + + private IdentityProviders createIdentityProviders() { + IdentityProviders providers = mock(IdentityProviders.class); + return providers; + } + + private List createAuthOptions() { + return Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_ID).build() + ); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index 48e09c49a6ba..cbc9497f93f7 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -43,6 +43,7 @@ import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; import software.amazon.awssdk.awscore.internal.AwsExecutionContextBuilder; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.presigner.PresignRequest; import software.amazon.awssdk.awscore.presigner.PresignedRequest; import software.amazon.awssdk.core.ClientType; @@ -61,6 +62,7 @@ import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.signer.Presigner; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.core.sync.RequestBody; @@ -84,9 +86,10 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; @@ -240,7 +243,6 @@ private List initializeInterceptors() { List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); List additionalInterceptors = new ArrayList<>(); - additionalInterceptors.add(new S3AuthSchemeInterceptor()); additionalInterceptors.add(new S3ResolveEndpointInterceptor()); additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); @@ -405,6 +407,9 @@ private T presign(T presignedRequest, addRequestLevelHeadersAndQueryParameters(execCtx); callModifyHttpRequestHooksAndUpdateContext(execCtx); + // Resolve auth scheme after interceptors complete + resolveAndSelectAuthScheme(execCtx, requestToPresign, operationName); + SdkHttpFullRequest httpRequest = getHttpFullRequest(execCtx); SdkHttpFullRequest signedHttpRequest = execCtx.signer() != null @@ -594,6 +599,81 @@ private SdkHttpFullRequest getHttpFullRequest(ExecutionContext execCtx) { .build(); } + /** + * Resolve and select auth scheme for presigning. + */ + private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest request, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return; + } + + List authOptions = resolveAuthSchemeOptions(request, operationName, executionAttributes); + if (authOptions == null || authOptions.isEmpty()) { + return; + } + + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + identityProviders = AwsIdentityProviderUpdater.INSTANCE.update(request, identityProviders); + + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, null); + + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + /** + * Resolve auth scheme options using full endpoint params. + */ + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + ExecutionAttributes executionAttributes) { + S3AuthSchemeProvider authSchemeProvider = (S3AuthSchemeProvider) + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER); + + // Build full endpoint params using the generated interceptor's logic + S3EndpointParams endpointParams = + S3ResolveEndpointInterceptor.ruleParams(request, executionAttributes); + + // Copy all endpoint params to auth scheme params + S3AuthSchemeParams.Builder authParamsBuilder = + S3AuthSchemeParams.builder() + .operation(operationName) + .region(endpointParams.region()) + .bucket(endpointParams.bucket()) + .prefix(endpointParams.prefix()) + .copySource(endpointParams.copySource()) + .key(endpointParams.key()); + + // Set optional endpoint params if present + if (endpointParams.accelerate() != null) { + authParamsBuilder.accelerate(endpointParams.accelerate()); + } + if (endpointParams.disableMultiRegionAccessPoints() != null) { + authParamsBuilder.disableMultiRegionAccessPoints(endpointParams.disableMultiRegionAccessPoints()); + } + if (endpointParams.disableS3ExpressSessionAuth() != null) { + authParamsBuilder.disableS3ExpressSessionAuth(endpointParams.disableS3ExpressSessionAuth()); + } + if (endpointParams.forcePathStyle() != null) { + authParamsBuilder.forcePathStyle(endpointParams.forcePathStyle()); + } + if (endpointParams.useArnRegion() != null) { + authParamsBuilder.useArnRegion(endpointParams.useArnRegion()); + } + if (endpointParams.useFips() != null) { + authParamsBuilder.useFips(endpointParams.useFips()); + } + if (endpointParams.useDualStack() != null) { + authParamsBuilder.useDualStack(endpointParams.useDualStack()); + } + + return authSchemeProvider.resolveAuthScheme(authParamsBuilder.build()); + } + /** * Presign the provided HTTP request using old Signer */ diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java index 13c7b37ab958..1a176f1e119b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java @@ -17,125 +17,48 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.CompletableFuture; +import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.core.ClientEndpointProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.internal.DefaultIdentityProviders; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; import software.amazon.awssdk.services.s3.auth.scheme.internal.DefaultS3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; -import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.s3express.S3ExpressAuthScheme; -import software.amazon.awssdk.services.s3.s3express.S3ExpressSessionCredentials; -import software.amazon.awssdk.utils.AttributeMap; class S3ExpressAuthSchemeProviderTest { private static final String S3EXPRESS_BUCKET = "s3expressformat--use1-az1--x-s3"; @Test - public void s3express_defaultAuthEnabled_returnspressAuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder().build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4-s3express", attribute); - } - - private ExecutionAttributes requiredExecutionAttributes(AttributeMap clientContextParams) { - ExecutionAttributes executionAttributes = new ExecutionAttributes(); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, DefaultS3AuthSchemeProvider.create()); - executionAttributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "PutObject"); - executionAttributes.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, clientContextParams); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemesWithS3Express()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, - DefaultIdentityProviders.builder() - .putIdentityProvider(DefaultCredentialsProvider.create()) - .build()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, - ClientEndpointProvider.forEndpointOverride(URI.create("https://localhost"))); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER, - S3EndpointProvider.defaultProvider()); - return executionAttributes; + public void s3express_defaultAuthEnabled_returnsS3ExpressAuthScheme() { + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4-s3express"); } @Test public void s3express_authDisabled_returnsV4AuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder() - .put(S3ClientContextParams.DISABLE_S3_EXPRESS_SESSION_AUTH, true) - .build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4", attribute); - } - - private void verifyAuthScheme(String expectedAuthSchemeId, SelectedAuthScheme authScheme) { - assertThat(authScheme).isNotNull(); - assertThat(authScheme.authSchemeOption()).isNotNull(); - assertThat(authScheme.authSchemeOption().schemeId()).isEqualTo(expectedAuthSchemeId); - - assertThat(authScheme.identity()).isNotNull(); - assertThat(authScheme.signer()).isNotNull(); - } - - private Map> authSchemesWithS3Express() { - Map> schemes = new HashMap<>(); - AwsV4AuthScheme awsV4AuthScheme = AwsV4AuthScheme.create(); - schemes.put(awsV4AuthScheme.schemeId(), awsV4AuthScheme); - S3ExpressAuthScheme s3ExpressAuthScheme = new S3ExpressAuthScheme() { - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return new IdentityProvider() { - @Override - public Class identityType() { - return null; - } - - @Override - public CompletableFuture resolveIdentity(ResolveIdentityRequest request) { - return CompletableFuture.completedFuture(S3ExpressSessionCredentials.create("a","b","c")); - } - }; - } - - @Override - public HttpSigner signer() { - return DefaultS3ExpressHttpSigner.create(); - } - - @Override - public String schemeId() { - return SCHEME_ID; - } - }; - schemes.put(s3ExpressAuthScheme.schemeId(), s3ExpressAuthScheme); - return schemes; + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .disableS3ExpressSessionAuth(true) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4"); } -} \ No newline at end of file +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java deleted file mode 100644 index e187ab4435d5..000000000000 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeParams; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.internal.ProtocolRestJsonAuthSchemeInterceptor; - -public class AuthSchemeInterceptorTest { - private static final ProtocolRestJsonAuthSchemeInterceptor INTERCEPTOR = new ProtocolRestJsonAuthSchemeInterceptor(); - - private Context.BeforeExecution mockContext; - - @BeforeEach - public void setup() { - mockContext = mock(Context.BeforeExecution.class); - } - - @Test - public void resolveAuthScheme_authSchemeSignerThrows_continuesToNextAuthScheme() { - ProtocolRestJsonAuthSchemeProvider mockAuthSchemeProvider = mock(ProtocolRestJsonAuthSchemeProvider.class); - List authSchemeOptions = Arrays.asList( - AuthSchemeOption.builder().schemeId(TestAuthScheme.SCHEME_ID).build(), - AuthSchemeOption.builder().schemeId(AwsV4AuthScheme.SCHEME_ID).build() - ); - when(mockAuthSchemeProvider.resolveAuthScheme(any(ProtocolRestJsonAuthSchemeParams.class))).thenReturn(authSchemeOptions); - - IdentityProviders mockIdentityProviders = mock(IdentityProviders.class); - when(mockIdentityProviders.identityProvider(any(Class.class))).thenReturn(AnonymousCredentialsProvider.create()); - - Map> authSchemes = new HashMap<>(); - authSchemes.put(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create()); - - TestAuthScheme notProvidedAuthScheme = spy(new TestAuthScheme()); - authSchemes.put(TestAuthScheme.SCHEME_ID, notProvidedAuthScheme); - - ExecutionAttributes attributes = new ExecutionAttributes(); - attributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "GetFoo"); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, mockAuthSchemeProvider); - attributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, mockIdentityProviders); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); - - INTERCEPTOR.beforeExecution(mockContext, attributes); - - SelectedAuthScheme selectedAuthScheme = attributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - - verify(notProvidedAuthScheme).signer(); - assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(AwsV4AuthScheme.SCHEME_ID); - } - - private static class TestAuthScheme implements AuthScheme { - public static final String SCHEME_ID = "codegen-test-scheme"; - - @Override - public String schemeId() { - return SCHEME_ID; - } - - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return providers.identityProvider(AwsCredentialsIdentity.class); - } - - @Override - public HttpSigner signer() { - throw new RuntimeException("Not on classpath"); - } - } -} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java index f20c84a25756..daf73cd838e3 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java @@ -71,9 +71,8 @@ void when_apiCall_setsCredentialsProviderInRequestOverride_overrideCredentialsAr assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } - // Changing the credentials provider in modifyRequest does not work in SRA identity resolution - // Identity is resolved in beforeExecution (and happens before user applied interceptors) and cannot - // be affected by execution interceptors. + // After moving identity resolution to pipeline stage (after interceptors), credentials provider + // set in modifyRequest is now respected (identity resolved after interceptors complete) @Test void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverride_clientCredentialsAreUsed() { ExecutionInterceptor overridingInterceptor = @@ -83,7 +82,7 @@ void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverr assertThatThrownBy(() -> syncClient.allTypes(r -> {})).hasMessageContaining("stop"); - assertSelectedAuthSchemeBeforeTransmissionContains(CLIENT_CREDENTIALS); + assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } @Test