From f0c479576e02d3779c9233071c5e1c5c5c549bb8 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Thu, 26 Feb 2026 08:53:01 -0800 Subject: [PATCH 1/6] Move auth scheme identity resolution to pipeline stage from interceptors --- .../tasks/AuthSchemeGeneratorTasks.java | 2 +- .../poet/builder/BaseClientBuilderClass.java | 2 +- .../codegen/poet/client/AsyncClientClass.java | 123 +++++++++- .../codegen/poet/client/SyncClientClass.java | 119 +++++++++- .../poet/client/specs/JsonProtocolSpec.java | 6 +- .../poet/client/specs/QueryProtocolSpec.java | 4 + .../poet/client/specs/XmlProtocolSpec.java | 9 +- .../internal/AwsExecutionContextBuilder.java | 78 ++++++- .../identity/AwsIdentityProviderUpdater.java | 50 ++++ .../awssdk/core/SelectedAuthScheme.java | 24 +- .../client/handler/ClientExecutionParams.java | 12 + .../SdkInternalExecutionAttribute.java | 16 ++ .../internal/http/AmazonAsyncHttpClient.java | 2 + .../internal/http/AmazonSyncHttpClient.java | 2 + .../AsyncAuthSchemeResolutionStage.java | 218 +++++++++++++++++ .../stages/AuthSchemeResolutionStage.java | 219 ++++++++++++++++++ .../spi/identity/IdentityProviderUpdater.java | 39 ++++ .../internal/signing/DefaultS3Presigner.java | 195 +++++++++++++++- .../S3ExpressAuthSchemeProviderTest.java | 139 +++-------- .../services/AuthSchemeInterceptorTest.java | 108 --------- .../IdentityResolutionOverrideTest.java | 7 +- 21 files changed, 1141 insertions(+), 233 deletions(-) create mode 100644 core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java delete mode 100644 test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 4b6719cc709c..7bb0eabb14d3 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -47,7 +47,7 @@ protected List createTasks() { tasks.add(generateDefaultParamsImpl()); tasks.add(generateModelBasedProvider()); tasks.add(generatePreferenceProvider()); - tasks.add(generateAuthSchemeInterceptor()); + // AuthSchemeInterceptor removed - auth scheme resolution now happens in client operation methods if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { tasks.add(generateEndpointBasedProvider()); tasks.add(generateEndpointAwareAuthSchemeParams()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index 44d8ce154e72..9560eac7fbe8 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -370,7 +370,7 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor()); + // AuthSchemeInterceptor removed - auth scheme resolution now happens in client operation methods builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 539d5df35a43..043297ca8024 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -69,11 +69,14 @@ import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; import software.amazon.awssdk.codegen.poet.StaticImport; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils; import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; import software.amazon.awssdk.core.async.SdkPublisher; @@ -84,6 +87,7 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -100,6 +104,8 @@ public final class AsyncClientClass extends AsyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; private boolean hasScheduledExecutor; public AsyncClientClass(GeneratorTaskParams dependencies) { @@ -110,6 +116,8 @@ public AsyncClientClass(GeneratorTaskParams dependencies) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -165,7 +173,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(resolveAuthSchemeOptionsMethod()); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), @@ -582,4 +591,116 @@ private void addScheduledExecutorIfNeeded(Builder classBuilder) { hasScheduledExecutor = true; } } + + private MethodSpec resolveAuthSchemeOptionsMethod() { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + + builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", + providerInterface, providerInterface, SdkClientOption.class); + + if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + // Simple case: operation, region, and optionally regionSet + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + ClassName collectionUtils = ClassName.get("software.amazon.awssdk.utils", "CollectionUtils"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(java.util.Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", collectionUtils); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.nextControlFlow("else"); + builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", + regionSet, awsClientOption); + builder.endControlFlow(); + } + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } else { + // Endpoint-based auth: need to build endpoint params and copy to auth params + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", + "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", + "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", + "SdkExecutionAttribute"); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", + "AwsClientOption"); + + // Build execution attributes with values from client config + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", + sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), + SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), + SdkClientOption.class); + + // Use ruleParams to build endpoint params + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, resolverInterceptor); + + // Build auth scheme params from endpoint params + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + // Set endpoint provider on params if applicable + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), + SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } + + return builder.build(); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 5736ddbecaf5..4a7f52c67b95 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -54,19 +54,23 @@ import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec; import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec; import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; +import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; @@ -83,6 +87,8 @@ public class SyncClientClass extends SyncClientInterface { private final ProtocolSpec protocolSpec; private final ClassName serviceClientConfigurationClassName; private final ServiceClientConfigurationUtils configurationUtils; + private final AuthSchemeSpecUtils authSchemeSpecUtils; + private final EndpointRulesSpecUtils endpointRulesSpecUtils; public SyncClientClass(GeneratorTaskParams taskParams) { super(taskParams.getModel()); @@ -92,6 +98,8 @@ public SyncClientClass(GeneratorTaskParams taskParams) { this.protocolSpec = getProtocolSpecs(poetExtensions, model); this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass(); this.configurationUtils = new ServiceClientConfigurationUtils(model); + this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model); + this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model); } @Override @@ -133,7 +141,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { type.addMethod(constructor()) .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) - .addMethod(resolveMetricPublishersMethod()); + .addMethod(resolveMetricPublishersMethod()) + .addMethod(resolveAuthSchemeOptionsMethod()); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); @@ -445,4 +454,112 @@ protected MethodSpec.Builder waiterOperationBody(MethodSpec.Builder builder) { .addStatement("return $T.builder().client(this).build()", poetExtensions.getSyncWaiterInterface()); } + + private MethodSpec resolveAuthSchemeOptionsMethod() { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + + builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", + providerInterface, providerInterface, SdkClientOption.class); + + if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + // Simple case: operation, region, and optionally regionSet + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + ClassName collectionUtils = ClassName.get("software.amazon.awssdk.utils", "CollectionUtils"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(java.util.Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", collectionUtils); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.nextControlFlow("else"); + builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", + regionSet, awsClientOption); + builder.endControlFlow(); + } + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } else { + // Endpoint-based auth: use ruleParams to build endpoint params, then copy to auth params + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); + + // Build minimal execution attributes needed by ruleParams + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, " + + "clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", + sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), + SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), + SdkClientOption.class); + + // Use ruleParams to build endpoint params (handles context params from request) + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, resolverInterceptor); + + // Build auth scheme params from endpoint params + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + // Set endpoint provider on params if applicable + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), + SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } + + return builder.build(); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index b9b69d9bd8a0..3b1f18672dc8 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -221,7 +221,9 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)\n", opModel.getInput().getVariableName()) - .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -295,6 +297,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withErrorResponseHandler(errorResponseHandler)\n") .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index bab1f29e7281..deb6b7d3b775 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -116,6 +116,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -155,6 +157,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(credentialType(opModel, intermediateModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index dbbd33f4276f..6a524ff89f0a 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -134,7 +134,10 @@ public CodeBlock executionHandler(OperationModel opModel) { discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) .add(".withRequestConfiguration(clientConfiguration)") - .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withInput($L)", opModel.getInput().getVariableName()) + .add(".withMetricCollector(apiCallMetricCollector)") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -212,7 +215,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper builder.add(hostPrefixExpression(opModel)) .add(credentialType(opModel, model)) - .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withMetricCollector(apiCallMetricCollector)\n") + .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", + opModel.getInput().getVariableName(), opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index 32273f16019c..58c27ad5d3b7 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -33,6 +33,7 @@ import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategy; import software.amazon.awssdk.awscore.internal.authcontext.AuthorizationStrategyFactory; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.util.SignerOverrideUtils; import software.amazon.awssdk.core.HttpChecksumConstant; import software.amazon.awssdk.core.RequestOverrideConfiguration; @@ -63,6 +64,7 @@ import software.amazon.awssdk.http.ContentStreamProvider; import software.amazon.awssdk.http.auth.scheme.NoAuthAuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; import software.amazon.awssdk.metrics.MetricCollector; @@ -144,6 +146,18 @@ private AwsExecutionContextBuilder() { // Auth Scheme resolution related attributes putAuthSchemeResolutionAttributes(executionAttributes, clientConfig, originalRequest); + // Set auth scheme options from ClientExecutionParams + if (executionParams.authSchemeOptions() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS, + executionParams.authSchemeOptions()); + + recordAuthSchemeBusinessMetrics(executionParams.authSchemeOptions(), executionAttributes, originalRequest); + } + + // Set the identity provider updater for the pipeline stage to use + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, + AwsIdentityProviderUpdater.INSTANCE); + ExecutionInterceptorChain executionInterceptorChain = new ExecutionInterceptorChain(clientConfig.option(SdkClientOption.EXECUTION_INTERCEPTORS)); @@ -355,7 +369,7 @@ private static EndpointProvider resolveEndpointProvider(SdkRequest request, } private static BusinessMetricCollection - resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, + resolveUserAgentBusinessMetrics(SdkClientConfiguration clientConfig, ClientExecutionParams executionParams) { BusinessMetricCollection businessMetrics = new BusinessMetricCollection(); Optional retryModeMetric = resolveRetryMode(clientConfig.option(RETRY_POLICY), @@ -373,4 +387,66 @@ private static boolean isRpcV2CborProtocol(SdkProtocolMetadata protocolMetadata) return protocolMetadata != null && SMITHY_RPC_V2_CBOR.toString().equals(protocolMetadata.serviceProtocol()); } + + /** + * Records business metrics for auth scheme selection (SigV4a, bearer token). + */ + private static void recordAuthSchemeBusinessMetrics(List authSchemeOptions, + ExecutionAttributes executionAttributes, + SdkRequest request) { + if (authSchemeOptions == null || authSchemeOptions.isEmpty()) { + return; + } + + BusinessMetricCollection businessMetrics = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS); + if (businessMetrics == null) { + return; + } + + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + if (authSchemes == null || identityProviders == null) { + return; + } + + for (AuthSchemeOption authOption : authSchemeOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + if (authScheme == null) { + continue; // Auth scheme not enabled, try next option + } + + if (authScheme.identityProvider(identityProviders) == null) { + continue; // Identity provider not configured, try next option + } + + // Check for SigV4a + if ("aws.auth#sigv4a".equals(authOption.schemeId()) && + !SignerOverrideUtils.isSignerOverridden(request, executionAttributes)) { + businessMetrics.addMetric(BusinessMetricFeatureId.SIGV4A_SIGNING.value()); + } + + // Check for bearer token from environment + if ("smithy.api#httpBearerAuth".equals(authOption.schemeId())) { + String tokenFromEnv = executionAttributes.getAttribute(SdkInternalExecutionAttribute.TOKEN_CONFIGURED_FROM_ENV); + if (tokenFromEnv != null && !hasTokenOverride(request)) { + businessMetrics.addMetric(BusinessMetricFeatureId.BEARER_SERVICE_ENV_VARS.value()); + } + } + + return; + } + } + + /** + * Check if the request has a token identity provider override. + */ + private static boolean hasTokenOverride(SdkRequest request) { + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .flatMap(AwsRequestOverrideConfiguration::tokenIdentityProvider) + .isPresent(); + } } diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java new file mode 100644 index 000000000000..1cafd4675a06 --- /dev/null +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/identity/AwsIdentityProviderUpdater.java @@ -0,0 +1,50 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.awscore.internal.identity; + +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * AWS implementation of {@link IdentityProviderUpdater} that reads credential overrides + * from {@link AwsRequestOverrideConfiguration}. + */ +@SdkInternalApi +public final class AwsIdentityProviderUpdater implements IdentityProviderUpdater { + + public static final AwsIdentityProviderUpdater INSTANCE = new AwsIdentityProviderUpdater(); + + private AwsIdentityProviderUpdater() { + } + + @Override + public IdentityProviders update(SdkRequest request, IdentityProviders base) { + if (base == null) { + return null; + } + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .map(c -> base.copy(b -> { + c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); + c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); + })) + .orElse(base); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java index 10cb488792e2..8772260e3042 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/SelectedAuthScheme.java @@ -22,9 +22,27 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.utils.Validate; -/** - * A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. - */ + +/// +/// A container for the identity resolver, signer and auth option that we selected for use with this service call attempt. +/// ## The Hierarchy +/// ``` +/// IDENTITY_PROVIDERS (IdentityProviders) +/// └── contains multiple IdentityProvider instances +/// e.g., IdentityProvider for AWS credentials +/// e.g., IdentityProvider for bearer tokens +/// +/// AUTH_SCHEMES (Map>) +/// └── each AuthScheme knows: +/// - which IdentityProvider type it needs +/// - which HttpSigner to use +/// +/// SELECTED_AUTH_SCHEME (SelectedAuthScheme) +/// └── the chosen auth scheme, containing: +/// - identity: CompletableFuture ← the resolved identity! +/// - signer: HttpSigner +/// - authSchemeOption: AuthSchemeOption +/// ``` @SdkProtectedApi public final class SelectedAuthScheme { private final CompletableFuture identity; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index e307f5857ce7..ad95335ef5ce 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -16,6 +16,7 @@ package software.amazon.awssdk.core.client.handler; import java.net.URI; +import java.util.List; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.CredentialType; @@ -32,6 +33,7 @@ import software.amazon.awssdk.core.runtime.transform.Marshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; /** @@ -63,6 +65,7 @@ public final class ClientExecutionParams { private MetricCollector metricCollector; private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; + private List authSchemeOptions; public Marshaller getMarshaller() { return marshaller; @@ -261,4 +264,13 @@ public ClientExecutionParams withRequestConfiguration(SdkCl this.requestConfiguration = requestConfiguration; return this; } + + public List authSchemeOptions() { + return authSchemeOptions; + } + + public ClientExecutionParams withAuthSchemeOptions(List authSchemeOptions) { + this.authSchemeOptions = authSchemeOptions; + return this; + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index 3fe0f69d3ab8..f4858b27def1 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -29,12 +29,14 @@ import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; import software.amazon.awssdk.core.useragent.BusinessMetricCollection; import software.amazon.awssdk.endpoints.Endpoint; import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.SdkHttpExecutionAttributes; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeProvider; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; @@ -166,6 +168,20 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { */ public static final ExecutionAttribute IDENTITY_PROVIDERS = new ExecutionAttribute<>("IdentityProviders"); + /** + * Callback for updating identity providers based on request-level overrides. + * This allows aws-core to provide AWS-specific logic without sdk-core depending on aws-core. + */ + public static final ExecutionAttribute IDENTITY_PROVIDER_UPDATER = + new ExecutionAttribute<>("IdentityProviderUpdater"); + + /** + * The resolved auth scheme options for a request. These are resolved by the auth scheme provider + * but identity resolution is deferred to the AuthSchemeResolutionStage. + */ + public static final ExecutionAttribute> AUTH_SCHEME_OPTIONS = + new ExecutionAttribute<>("AuthSchemeOptions"); + /** * The selected auth scheme for a request. */ diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index 59bd964d6fad..cf3722f3587e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -35,6 +35,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncAuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncBeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; @@ -200,6 +201,7 @@ public CompletableFuture execute( .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) + .then(AsyncAuthSchemeResolutionStage::new) .then(RequestPipelineBuilder .first(AsyncSigningStage::new) .then(AsyncBeforeTransmissionExecutionInterceptorsStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java index dccaad1e2109..bb5f6816f321 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonSyncHttpClient.java @@ -34,6 +34,7 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyTransactionIdStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.ApplyUserAgentStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.BeforeUnmarshallingExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; @@ -189,6 +190,7 @@ public OutputT execute(HttpResponseHandler> response .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) // End of mutating request + .then(AuthSchemeResolutionStage::new) .then(RequestPipelineBuilder .first(SigningStage::new) .then(BeforeTransmissionExecutionInterceptorsStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java new file mode 100644 index 000000000000..79e9a6b85e9e --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java @@ -0,0 +1,218 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.utils.Logger; + +/** + * Async pipeline stage that resolves the auth scheme and identity for signing. + */ +@SdkInternalApi +public final class AsyncAuthSchemeResolutionStage implements RequestToRequestPipeline { + + private static final Logger LOG = Logger.loggerFor(AsyncAuthSchemeResolutionStage.class); + + public AsyncAuthSchemeResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes executionAttributes = context.executionAttributes(); + + // Skip if no auth schemes configured (pre-SRA client or authType=None) + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return request; + } + + // Get auth options (set by generated client code via ClientExecutionParams) + List authOptions = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS); + + if (authOptions == null || authOptions.isEmpty()) { + // No auth options means either pre-SRA or no auth required + return request; + } + + // Get base identity providers + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + // Apply request-level overrides via callback (aws-core provides this) + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + // Get the modified request (after interceptors) + identityProviders = updater.update( + context.executionContext().interceptorContext().request(), + identityProviders); + } + + SelectedAuthScheme selectedAuthScheme = + selectAuthScheme(authOptions, authSchemes, identityProviders, executionAttributes); + + // Merge pre-existing properties (preserves any externally set properties) + selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + // Store the final selected auth scheme + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + + return request; + } + + private SelectedAuthScheme selectAuthScheme( + List authOptions, + Map> authSchemes, + IdentityProviders identityProviders, + ExecutionAttributes executionAttributes) { + + MetricCollector metricCollector = + executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( + authOption, authScheme, identityProviders, discardedReasons, metricCollector); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", + authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw SdkException.builder() + .message("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) + .build(); + } + + private SelectedAuthScheme trySelectAuthScheme( + AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons, + MetricCollector metricCollector) { + + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity; + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null) { + identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + } else { + identity = MetricUtils.reportDuration( + () -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + metricCollector, + metric); + } + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } + + private SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + // Skip if no existing scheme + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + // Merge properties from existing scheme + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java new file mode 100644 index 000000000000..7e51da93d549 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -0,0 +1,219 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.HttpClientDependencies; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.utils.Logger; + +/** + * Pipeline stage that resolves the auth scheme and identity for signing. + */ +@SdkInternalApi +public final class AuthSchemeResolutionStage implements RequestToRequestPipeline { + + private static final Logger LOG = Logger.loggerFor(AuthSchemeResolutionStage.class); + + public AuthSchemeResolutionStage(HttpClientDependencies dependencies) { + } + + @Override + public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) + throws Exception { + ExecutionAttributes executionAttributes = context.executionAttributes(); + + // Skip if no auth schemes configured (pre-SRA client or authType=None) + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return request; + } + + // Get auth options (set by generated client code via ClientExecutionParams) + List authOptions = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS); + + if (authOptions == null || authOptions.isEmpty()) { + // No auth options means either pre-SRA or no auth required + return request; + } + + // Get base identity providers + IdentityProviders identityProviders = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + + // Apply request-level overrides via callback (aws-core provides this) + IdentityProviderUpdater updater = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); + if (updater != null) { + // Get the modified request (after interceptors) + identityProviders = updater.update( + context.executionContext().interceptorContext().request(), + identityProviders); + } + + SelectedAuthScheme selectedAuthScheme = + selectAuthScheme(authOptions, authSchemes, identityProviders, executionAttributes); + + + selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + // Store the final selected auth scheme + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + + return request; + } + + + private SelectedAuthScheme selectAuthScheme( + List authOptions, + Map> authSchemes, + IdentityProviders identityProviders, + ExecutionAttributes executionAttributes) { + + MetricCollector metricCollector = + executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( + authOption, authScheme, identityProviders, discardedReasons, metricCollector); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", + authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw SdkException.builder() + .message("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) + .build(); + } + + private SelectedAuthScheme trySelectAuthScheme( + AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons, + MetricCollector metricCollector) { + + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity; + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null) { + identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + } else { + identity = MetricUtils.reportDuration( + () -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + metricCollector, + metric); + } + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } + + private SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + // Skip if no existing scheme + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + // Merge properties from existing scheme + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java new file mode 100644 index 000000000000..af54e5ef3661 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/IdentityProviderUpdater.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.identity.spi.IdentityProviders; + +/** + * Callback interface for updating identity providers based on request-level overrides. + *

+ * This allows aws-core to provide AWS-specific logic for reading credential overrides + * from {@code AwsRequestOverrideConfiguration} without sdk-core depending on aws-core. + */ +@FunctionalInterface +@SdkProtectedApi +public interface IdentityProviderUpdater { + /** + * Updates identity providers based on request-level overrides. + * + * @param request The request (after interceptors have modified it) + * @param base The base identity providers from client configuration + * @return Updated identity providers, or base if no overrides + */ + IdentityProviders update(SdkRequest request, IdentityProviders base); +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index 48e09c49a6ba..ba61788152f4 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -33,6 +33,8 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.Stream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; @@ -80,13 +82,15 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; import software.amazon.awssdk.protocols.xml.AwsS3ProtocolFactory; import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; +import software.amazon.awssdk.services.s3.endpoints.S3EndpointParams; import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; import software.amazon.awssdk.services.s3.endpoints.internal.S3RequestSetEndpointInterceptor; import software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor; @@ -240,7 +244,7 @@ private List initializeInterceptors() { List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); List additionalInterceptors = new ArrayList<>(); - additionalInterceptors.add(new S3AuthSchemeInterceptor()); + // S3AuthSchemeInterceptor removed - auth scheme resolution now done inline additionalInterceptors.add(new S3ResolveEndpointInterceptor()); additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); @@ -405,6 +409,9 @@ private T presign(T presignedRequest, addRequestLevelHeadersAndQueryParameters(execCtx); callModifyHttpRequestHooksAndUpdateContext(execCtx); + // Resolve auth scheme after interceptors complete + resolveAndSelectAuthScheme(execCtx, requestToPresign, operationName); + SdkHttpFullRequest httpRequest = getHttpFullRequest(execCtx); SdkHttpFullRequest signedHttpRequest = execCtx.signer() != null @@ -594,6 +601,190 @@ private SdkHttpFullRequest getHttpFullRequest(ExecutionContext execCtx) { .build(); } + /** + * Resolve and select auth scheme for presigning. + */ + private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest request, String operationName) { + ExecutionAttributes executionAttributes = execCtx.executionAttributes(); + + // Get auth schemes and identity providers + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + if (authSchemes == null) { + return; // No auth schemes configured + } + + // Resolve auth options (same logic as generated client) + List authOptions = resolveAuthSchemeOptions(request, operationName, executionAttributes); + if (authOptions == null || authOptions.isEmpty()) { + return; + } + + // Get identity providers and apply credential overrides + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + identityProviders = applyCredentialOverrides(request, identityProviders); + + // Select auth scheme + SelectedAuthScheme selectedAuthScheme = selectAuthScheme(authOptions, authSchemes, identityProviders); + + // Merge pre-existing properties + selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + + executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + + /** + * Merge properties from any pre-existing auth scheme into the selected one. + */ + private SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); + + // Skip if no existing scheme + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + /** + * Resolve auth scheme options using full endpoint params. + */ + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + ExecutionAttributes executionAttributes) { + S3AuthSchemeProvider authSchemeProvider = (S3AuthSchemeProvider) + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER); + + // Build full endpoint params using the generated interceptor's logic + S3EndpointParams endpointParams = + S3ResolveEndpointInterceptor.ruleParams(request, executionAttributes); + + // Copy all endpoint params to auth scheme params + S3AuthSchemeParams.Builder authParamsBuilder = + S3AuthSchemeParams.builder() + .operation(operationName) + .region(endpointParams.region()) + .bucket(endpointParams.bucket()) + .prefix(endpointParams.prefix()) + .copySource(endpointParams.copySource()) + .key(endpointParams.key()); + + // Set optional endpoint params if present + if (endpointParams.accelerate() != null) { + authParamsBuilder.accelerate(endpointParams.accelerate()); + } + if (endpointParams.disableMultiRegionAccessPoints() != null) { + authParamsBuilder.disableMultiRegionAccessPoints(endpointParams.disableMultiRegionAccessPoints()); + } + if (endpointParams.disableS3ExpressSessionAuth() != null) { + authParamsBuilder.disableS3ExpressSessionAuth(endpointParams.disableS3ExpressSessionAuth()); + } + if (endpointParams.forcePathStyle() != null) { + authParamsBuilder.forcePathStyle(endpointParams.forcePathStyle()); + } + if (endpointParams.useArnRegion() != null) { + authParamsBuilder.useArnRegion(endpointParams.useArnRegion()); + } + if (endpointParams.useFips() != null) { + authParamsBuilder.useFips(endpointParams.useFips()); + } + if (endpointParams.useDualStack() != null) { + authParamsBuilder.useDualStack(endpointParams.useDualStack()); + } + + return authSchemeProvider.resolveAuthScheme(authParamsBuilder.build()); + } + + /** + * Apply credential overrides from request configuration. + */ + private IdentityProviders applyCredentialOverrides(SdkRequest request, IdentityProviders base) { + if (base == null) { + return null; + } + return request.overrideConfiguration() + .filter(c -> c instanceof AwsRequestOverrideConfiguration) + .map(c -> (AwsRequestOverrideConfiguration) c) + .map(c -> base.copy(b -> { + c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); + c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); + })) + .orElse(base); + } + + /** + * Select auth scheme from options + */ + private SelectedAuthScheme selectAuthScheme(List authOptions, + Map> authSchemes, + IdentityProviders identityProviders) { + List> discardedReasons = new ArrayList<>(); + + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = + trySelectAuthScheme(authOption, authScheme, identityProviders, discardedReasons); + + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + log.debug(() -> String.format("%s auth will be used, discarded: '%s'", authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + + throw new IllegalStateException("Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))); + } + + /** + * Try to select a specific auth scheme (copied from AuthSchemeResolutionStage). + */ + private SelectedAuthScheme trySelectAuthScheme(AuthSchemeOption authOption, + AuthScheme authScheme, + IdentityProviders identityProviders, + List> discardedReasons) { + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", + authOption.schemeId())); + return null; + } + + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", + authOption.schemeId(), e.getMessage())); + return null; + } + + ResolveIdentityRequest.Builder identityRequestBuilder = + ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + + CompletableFuture identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + + return new SelectedAuthScheme<>(identity, signer, authOption); + } + /** * Presign the provided HTTP request using old Signer */ diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java index 13c7b37ab958..1a176f1e119b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressAuthSchemeProviderTest.java @@ -17,125 +17,48 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.CompletableFuture; +import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.awscore.AwsExecutionAttribute; -import software.amazon.awssdk.core.ClientEndpointProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.internal.DefaultIdentityProviders; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeParams; +import software.amazon.awssdk.services.s3.auth.scheme.S3AuthSchemeProvider; import software.amazon.awssdk.services.s3.auth.scheme.internal.DefaultS3AuthSchemeProvider; -import software.amazon.awssdk.services.s3.auth.scheme.internal.S3AuthSchemeInterceptor; -import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams; -import software.amazon.awssdk.services.s3.endpoints.S3EndpointProvider; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.s3express.S3ExpressAuthScheme; -import software.amazon.awssdk.services.s3.s3express.S3ExpressSessionCredentials; -import software.amazon.awssdk.utils.AttributeMap; class S3ExpressAuthSchemeProviderTest { private static final String S3EXPRESS_BUCKET = "s3expressformat--use1-az1--x-s3"; @Test - public void s3express_defaultAuthEnabled_returnspressAuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder().build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4-s3express", attribute); - } - - private ExecutionAttributes requiredExecutionAttributes(AttributeMap clientContextParams) { - ExecutionAttributes executionAttributes = new ExecutionAttributes(); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, DefaultS3AuthSchemeProvider.create()); - executionAttributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "PutObject"); - executionAttributes.putAttribute(AwsExecutionAttribute.AWS_REGION, Region.US_EAST_1); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, clientContextParams); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemesWithS3Express()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, - DefaultIdentityProviders.builder() - .putIdentityProvider(DefaultCredentialsProvider.create()) - .build()); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_ENDPOINT_PROVIDER, - ClientEndpointProvider.forEndpointOverride(URI.create("https://localhost"))); - executionAttributes.putAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER, - S3EndpointProvider.defaultProvider()); - return executionAttributes; + public void s3express_defaultAuthEnabled_returnsS3ExpressAuthScheme() { + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4-s3express"); } @Test public void s3express_authDisabled_returnsV4AuthScheme() { - PutObjectRequest request = PutObjectRequest.builder().bucket(S3EXPRESS_BUCKET).key("k").build(); - AttributeMap clientContextParams = AttributeMap.builder() - .put(S3ClientContextParams.DISABLE_S3_EXPRESS_SESSION_AUTH, true) - .build(); - ExecutionAttributes executionAttributes = requiredExecutionAttributes(clientContextParams); - - new S3AuthSchemeInterceptor().beforeExecution(() -> request, executionAttributes); - - SelectedAuthScheme attribute = executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - assertThat(attribute).isNotNull(); - verifyAuthScheme("aws.auth#sigv4", attribute); - } - - private void verifyAuthScheme(String expectedAuthSchemeId, SelectedAuthScheme authScheme) { - assertThat(authScheme).isNotNull(); - assertThat(authScheme.authSchemeOption()).isNotNull(); - assertThat(authScheme.authSchemeOption().schemeId()).isEqualTo(expectedAuthSchemeId); - - assertThat(authScheme.identity()).isNotNull(); - assertThat(authScheme.signer()).isNotNull(); - } - - private Map> authSchemesWithS3Express() { - Map> schemes = new HashMap<>(); - AwsV4AuthScheme awsV4AuthScheme = AwsV4AuthScheme.create(); - schemes.put(awsV4AuthScheme.schemeId(), awsV4AuthScheme); - S3ExpressAuthScheme s3ExpressAuthScheme = new S3ExpressAuthScheme() { - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return new IdentityProvider() { - @Override - public Class identityType() { - return null; - } - - @Override - public CompletableFuture resolveIdentity(ResolveIdentityRequest request) { - return CompletableFuture.completedFuture(S3ExpressSessionCredentials.create("a","b","c")); - } - }; - } - - @Override - public HttpSigner signer() { - return DefaultS3ExpressHttpSigner.create(); - } - - @Override - public String schemeId() { - return SCHEME_ID; - } - }; - schemes.put(s3ExpressAuthScheme.schemeId(), s3ExpressAuthScheme); - return schemes; + S3AuthSchemeProvider provider = DefaultS3AuthSchemeProvider.create(); + S3AuthSchemeParams params = S3AuthSchemeParams.builder() + .operation("PutObject") + .region(Region.US_EAST_1) + .bucket(S3EXPRESS_BUCKET) + .disableS3ExpressSessionAuth(true) + .build(); + + List authOptions = provider.resolveAuthScheme(params); + + assertThat(authOptions).isNotNull(); + assertThat(authOptions.isEmpty()).isFalse(); + assertThat(authOptions.get(0).schemeId()).isEqualTo("aws.auth#sigv4"); } -} \ No newline at end of file +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java deleted file mode 100644 index e187ab4435d5..000000000000 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/AuthSchemeInterceptorTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.services; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider; -import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; -import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; -import software.amazon.awssdk.identity.spi.IdentityProvider; -import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeParams; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.ProtocolRestJsonAuthSchemeProvider; -import software.amazon.awssdk.services.protocolrestjson.auth.scheme.internal.ProtocolRestJsonAuthSchemeInterceptor; - -public class AuthSchemeInterceptorTest { - private static final ProtocolRestJsonAuthSchemeInterceptor INTERCEPTOR = new ProtocolRestJsonAuthSchemeInterceptor(); - - private Context.BeforeExecution mockContext; - - @BeforeEach - public void setup() { - mockContext = mock(Context.BeforeExecution.class); - } - - @Test - public void resolveAuthScheme_authSchemeSignerThrows_continuesToNextAuthScheme() { - ProtocolRestJsonAuthSchemeProvider mockAuthSchemeProvider = mock(ProtocolRestJsonAuthSchemeProvider.class); - List authSchemeOptions = Arrays.asList( - AuthSchemeOption.builder().schemeId(TestAuthScheme.SCHEME_ID).build(), - AuthSchemeOption.builder().schemeId(AwsV4AuthScheme.SCHEME_ID).build() - ); - when(mockAuthSchemeProvider.resolveAuthScheme(any(ProtocolRestJsonAuthSchemeParams.class))).thenReturn(authSchemeOptions); - - IdentityProviders mockIdentityProviders = mock(IdentityProviders.class); - when(mockIdentityProviders.identityProvider(any(Class.class))).thenReturn(AnonymousCredentialsProvider.create()); - - Map> authSchemes = new HashMap<>(); - authSchemes.put(AwsV4AuthScheme.SCHEME_ID, AwsV4AuthScheme.create()); - - TestAuthScheme notProvidedAuthScheme = spy(new TestAuthScheme()); - authSchemes.put(TestAuthScheme.SCHEME_ID, notProvidedAuthScheme); - - ExecutionAttributes attributes = new ExecutionAttributes(); - attributes.putAttribute(SdkExecutionAttribute.OPERATION_NAME, "GetFoo"); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER, mockAuthSchemeProvider); - attributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, mockIdentityProviders); - attributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); - - INTERCEPTOR.beforeExecution(mockContext, attributes); - - SelectedAuthScheme selectedAuthScheme = attributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - - verify(notProvidedAuthScheme).signer(); - assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(AwsV4AuthScheme.SCHEME_ID); - } - - private static class TestAuthScheme implements AuthScheme { - public static final String SCHEME_ID = "codegen-test-scheme"; - - @Override - public String schemeId() { - return SCHEME_ID; - } - - @Override - public IdentityProvider identityProvider(IdentityProviders providers) { - return providers.identityProvider(AwsCredentialsIdentity.class); - } - - @Override - public HttpSigner signer() { - throw new RuntimeException("Not on classpath"); - } - } -} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java index f20c84a25756..daf73cd838e3 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/IdentityResolutionOverrideTest.java @@ -71,9 +71,8 @@ void when_apiCall_setsCredentialsProviderInRequestOverride_overrideCredentialsAr assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } - // Changing the credentials provider in modifyRequest does not work in SRA identity resolution - // Identity is resolved in beforeExecution (and happens before user applied interceptors) and cannot - // be affected by execution interceptors. + // After moving identity resolution to pipeline stage (after interceptors), credentials provider + // set in modifyRequest is now respected (identity resolved after interceptors complete) @Test void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverride_clientCredentialsAreUsed() { ExecutionInterceptor overridingInterceptor = @@ -83,7 +82,7 @@ void when_executionInterceptorModifyRequest_setsCredentialProviderInRequestOverr assertThatThrownBy(() -> syncClient.allTypes(r -> {})).hasMessageContaining("stop"); - assertSelectedAuthSchemeBeforeTransmissionContains(CLIENT_CREDENTIALS); + assertSelectedAuthSchemeBeforeTransmissionContains(OVERRIDE_CREDENTIALS); } @Test From ceaff05eac1631157f3dcd0bc8006b74945b355c Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Thu, 26 Feb 2026 12:04:51 -0800 Subject: [PATCH 2/6] Additional changes --- .../internal/http/AmazonAsyncHttpClient.java | 4 +- .../AuthSchemeResolver.java} | 155 +++++++----------- .../stages/AuthSchemeResolutionStage.java | 149 +---------------- .../internal/signing/DefaultS3Presigner.java | 102 +----------- 4 files changed, 72 insertions(+), 338 deletions(-) rename core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/{pipeline/stages/AsyncAuthSchemeResolutionStage.java => auth/AuthSchemeResolver.java} (59%) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java index cf3722f3587e..f470bab80494 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/AmazonAsyncHttpClient.java @@ -35,11 +35,11 @@ import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage; -import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncAuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncBeforeTransmissionExecutionInterceptorsStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncSigningStage; +import software.amazon.awssdk.core.internal.http.pipeline.stages.AuthSchemeResolutionStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.CompressRequestStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.HttpChecksumStage; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage; @@ -201,7 +201,7 @@ public CompletableFuture execute( .then(() -> new HttpChecksumStage(ClientType.ASYNC)) .then(ApplyUserAgentStage::new) .then(MakeRequestImmutableStage::new) - .then(AsyncAuthSchemeResolutionStage::new) + .then(AuthSchemeResolutionStage::new) .then(RequestPipelineBuilder .first(AsyncSigningStage::new) .then(AsyncBeforeTransmissionExecutionInterceptorsStage::new) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java similarity index 59% rename from core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java rename to core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java index 79e9a6b85e9e..6348a85f33b9 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AsyncAuthSchemeResolutionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolver.java @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.core.internal.http.pipeline.stages; +package software.amazon.awssdk.core.internal.http.auth; import java.time.Duration; import java.util.ArrayList; @@ -26,15 +26,9 @@ import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; -import software.amazon.awssdk.core.internal.http.HttpClientDependencies; -import software.amazon.awssdk.core.internal.http.RequestExecutionContext; -import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; import software.amazon.awssdk.core.internal.util.MetricUtils; import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; -import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; @@ -49,70 +43,32 @@ import software.amazon.awssdk.utils.Logger; /** - * Async pipeline stage that resolves the auth scheme and identity for signing. + * Shared utility for selecting auth schemes from a list of options. */ @SdkInternalApi -public final class AsyncAuthSchemeResolutionStage implements RequestToRequestPipeline { +public final class AuthSchemeResolver { - private static final Logger LOG = Logger.loggerFor(AsyncAuthSchemeResolutionStage.class); + private static final Logger LOG = Logger.loggerFor(AuthSchemeResolver.class); - public AsyncAuthSchemeResolutionStage(HttpClientDependencies dependencies) { + private AuthSchemeResolver() { } - @Override - public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) - throws Exception { - ExecutionAttributes executionAttributes = context.executionAttributes(); - - // Skip if no auth schemes configured (pre-SRA client or authType=None) - Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); - if (authSchemes == null) { - return request; - } - - // Get auth options (set by generated client code via ClientExecutionParams) - List authOptions = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS); - - if (authOptions == null || authOptions.isEmpty()) { - // No auth options means either pre-SRA or no auth required - return request; - } - - // Get base identity providers - IdentityProviders identityProviders = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - - // Apply request-level overrides via callback (aws-core provides this) - IdentityProviderUpdater updater = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); - if (updater != null) { - // Get the modified request (after interceptors) - identityProviders = updater.update( - context.executionContext().interceptorContext().request(), - identityProviders); - } - - SelectedAuthScheme selectedAuthScheme = - selectAuthScheme(authOptions, authSchemes, identityProviders, executionAttributes); - - // Merge pre-existing properties (preserves any externally set properties) - selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); - - // Store the final selected auth scheme - executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); - - return request; - } - - private SelectedAuthScheme selectAuthScheme( + /** + * Select an auth scheme from the given options. + * + * @param authOptions List of auth scheme options to try in order + * @param authSchemes Map of available auth schemes + * @param identityProviders Identity providers to use for resolving identity + * @param metricCollector Optional metric collector for recording identity fetch duration + * @return The selected auth scheme + * @throws SdkException if no auth scheme could be selected + */ + public static SelectedAuthScheme selectAuthScheme( List authOptions, Map> authSchemes, IdentityProviders identityProviders, - ExecutionAttributes executionAttributes) { + MetricCollector metricCollector) { - MetricCollector metricCollector = - executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); List> discardedReasons = new ArrayList<>(); for (AuthSchemeOption authOption : authOptions) { @@ -136,7 +92,32 @@ private SelectedAuthScheme selectAuthScheme( .build(); } - private SelectedAuthScheme trySelectAuthScheme( + /** + * Merge properties from any pre-existing auth scheme into the selected one. + */ + public static SelectedAuthScheme mergePreExistingAuthSchemeProperties( + SelectedAuthScheme selectedAuthScheme, + ExecutionAttributes executionAttributes) { + + SelectedAuthScheme existingAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + + if (existingAuthScheme == null) { + return selectedAuthScheme; + } + + AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + + return new SelectedAuthScheme<>( + selectedAuthScheme.identity(), + selectedAuthScheme.signer(), + mergedOption.build() + ); + } + + private static SelectedAuthScheme trySelectAuthScheme( AuthSchemeOption authOption, AuthScheme authScheme, IdentityProviders identityProviders, @@ -167,21 +148,25 @@ private SelectedAuthScheme trySelectAuthScheme( ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); - CompletableFuture identity; - SdkMetric metric = getIdentityMetric(identityProvider); - if (metric == null) { - identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); - } else { - identity = MetricUtils.reportDuration( - () -> identityProvider.resolveIdentity(identityRequestBuilder.build()), - metricCollector, - metric); - } + CompletableFuture identity = resolveIdentity( + identityProvider, identityRequestBuilder.build(), metricCollector); return new SelectedAuthScheme<>(identity, signer, authOption); } - private SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + private static CompletableFuture resolveIdentity( + IdentityProvider identityProvider, + ResolveIdentityRequest request, + MetricCollector metricCollector) { + + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null || metricCollector == null) { + return identityProvider.resolveIdentity(request); + } + return MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(request), metricCollector, metric); + } + + private static SdkMetric getIdentityMetric(IdentityProvider identityProvider) { Class identityType = identityProvider.identityType(); if (identityType == AwsCredentialsIdentity.class) { return CoreMetric.CREDENTIALS_FETCH_DURATION; @@ -191,28 +176,4 @@ private SdkMetric getIdentityMetric(IdentityProvider identityProvid } return null; } - - private SelectedAuthScheme mergePreExistingAuthSchemeProperties( - SelectedAuthScheme selectedAuthScheme, - ExecutionAttributes executionAttributes) { - - SelectedAuthScheme existingAuthScheme = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); - - // Skip if no existing scheme - if (existingAuthScheme == null) { - return selectedAuthScheme; - } - - // Merge properties from existing scheme - AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); - existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); - existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); - - return new SelectedAuthScheme<>( - selectedAuthScheme.identity(), - selectedAuthScheme.signer(), - mergedOption.build() - ); - } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java index 7e51da93d549..81a1437f1ff4 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -15,38 +15,24 @@ package software.amazon.awssdk.core.internal.http.pipeline.stages; -import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; -import java.util.stream.Collectors; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.SelectedAuthScheme; -import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; import software.amazon.awssdk.core.internal.http.HttpClientDependencies; import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; -import software.amazon.awssdk.core.internal.util.MetricUtils; -import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; -import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; -import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.identity.spi.Identity; -import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; -import software.amazon.awssdk.identity.spi.TokenIdentity; import software.amazon.awssdk.metrics.MetricCollector; -import software.amazon.awssdk.metrics.SdkMetric; -import software.amazon.awssdk.utils.Logger; /** * Pipeline stage that resolves the auth scheme and identity for signing. @@ -54,166 +40,45 @@ @SdkInternalApi public final class AuthSchemeResolutionStage implements RequestToRequestPipeline { - private static final Logger LOG = Logger.loggerFor(AuthSchemeResolutionStage.class); - public AuthSchemeResolutionStage(HttpClientDependencies dependencies) { } @Override - public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) - throws Exception { + public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionContext context) throws Exception { ExecutionAttributes executionAttributes = context.executionAttributes(); - // Skip if no auth schemes configured (pre-SRA client or authType=None) Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); if (authSchemes == null) { return request; } - // Get auth options (set by generated client code via ClientExecutionParams) List authOptions = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS); - if (authOptions == null || authOptions.isEmpty()) { - // No auth options means either pre-SRA or no auth required return request; } - // Get base identity providers IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - // Apply request-level overrides via callback (aws-core provides this) IdentityProviderUpdater updater = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); if (updater != null) { - // Get the modified request (after interceptors) identityProviders = updater.update( context.executionContext().interceptorContext().request(), identityProviders); } - SelectedAuthScheme selectedAuthScheme = - selectAuthScheme(authOptions, authSchemes, identityProviders, executionAttributes); - - - selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); - - // Store the final selected auth scheme - executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); - - return request; - } - - - private SelectedAuthScheme selectAuthScheme( - List authOptions, - Map> authSchemes, - IdentityProviders identityProviders, - ExecutionAttributes executionAttributes) { - MetricCollector metricCollector = executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); - List> discardedReasons = new ArrayList<>(); - - for (AuthSchemeOption authOption : authOptions) { - AuthScheme authScheme = authSchemes.get(authOption.schemeId()); - SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme( - authOption, authScheme, identityProviders, discardedReasons, metricCollector); - - if (selectedAuthScheme != null) { - if (!discardedReasons.isEmpty()) { - LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", - authOption.schemeId(), - discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); - } - return selectedAuthScheme; - } - } - - throw SdkException.builder() - .message("Failed to determine how to authenticate the user: " + - discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))) - .build(); - } - private SelectedAuthScheme trySelectAuthScheme( - AuthSchemeOption authOption, - AuthScheme authScheme, - IdentityProviders identityProviders, - List> discardedReasons, - MetricCollector metricCollector) { - - if (authScheme == null) { - discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); - return null; - } - - IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); - if (identityProvider == null) { - discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", - authOption.schemeId())); - return null; - } - - HttpSigner signer; - try { - signer = authScheme.signer(); - } catch (RuntimeException e) { - discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", - authOption.schemeId(), e.getMessage())); - return null; - } - - ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); - authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); - - CompletableFuture identity; - SdkMetric metric = getIdentityMetric(identityProvider); - if (metric == null) { - identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); - } else { - identity = MetricUtils.reportDuration( - () -> identityProvider.resolveIdentity(identityRequestBuilder.build()), - metricCollector, - metric); - } - - return new SelectedAuthScheme<>(identity, signer, authOption); - } - - private SdkMetric getIdentityMetric(IdentityProvider identityProvider) { - Class identityType = identityProvider.identityType(); - if (identityType == AwsCredentialsIdentity.class) { - return CoreMetric.CREDENTIALS_FETCH_DURATION; - } - if (identityType == TokenIdentity.class) { - return CoreMetric.TOKEN_FETCH_DURATION; - } - return null; - } - - private SelectedAuthScheme mergePreExistingAuthSchemeProperties( - SelectedAuthScheme selectedAuthScheme, - ExecutionAttributes executionAttributes) { - - SelectedAuthScheme existingAuthScheme = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, metricCollector); - // Skip if no existing scheme - if (existingAuthScheme == null) { - return selectedAuthScheme; - } + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); - // Merge properties from existing scheme - AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); - existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); - existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); - return new SelectedAuthScheme<>( - selectedAuthScheme.identity(), - selectedAuthScheme.signer(), - mergedOption.build() - ); + return request; } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index ba61788152f4..b21644c7d82f 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -63,6 +63,7 @@ import software.amazon.awssdk.core.interceptor.InterceptorContext; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.signer.Presigner; import software.amazon.awssdk.core.signer.Signer; import software.amazon.awssdk.core.sync.RequestBody; @@ -82,7 +83,6 @@ import software.amazon.awssdk.identity.spi.Identity; import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.identity.spi.IdentityProviders; -import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; import software.amazon.awssdk.protocols.xml.AwsS3ProtocolFactory; import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.services.s3.S3Client; @@ -607,56 +607,27 @@ private SdkHttpFullRequest getHttpFullRequest(ExecutionContext execCtx) { private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest request, String operationName) { ExecutionAttributes executionAttributes = execCtx.executionAttributes(); - // Get auth schemes and identity providers Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); if (authSchemes == null) { - return; // No auth schemes configured + return; } - // Resolve auth options (same logic as generated client) List authOptions = resolveAuthSchemeOptions(request, operationName, executionAttributes); if (authOptions == null || authOptions.isEmpty()) { return; } - // Get identity providers and apply credential overrides IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); identityProviders = applyCredentialOverrides(request, identityProviders); - // Select auth scheme - SelectedAuthScheme selectedAuthScheme = selectAuthScheme(authOptions, authSchemes, identityProviders); + SelectedAuthScheme selectedAuthScheme = + AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, null); - // Merge pre-existing properties - selectedAuthScheme = mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); + selectedAuthScheme = AuthSchemeResolver.mergePreExistingAuthSchemeProperties(selectedAuthScheme, executionAttributes); executionAttributes.putAttribute(SELECTED_AUTH_SCHEME, selectedAuthScheme); } - /** - * Merge properties from any pre-existing auth scheme into the selected one. - */ - private SelectedAuthScheme mergePreExistingAuthSchemeProperties( - SelectedAuthScheme selectedAuthScheme, - ExecutionAttributes executionAttributes) { - - SelectedAuthScheme existingAuthScheme = executionAttributes.getAttribute(SELECTED_AUTH_SCHEME); - - // Skip if no existing scheme - if (existingAuthScheme == null) { - return selectedAuthScheme; - } - - AuthSchemeOption.Builder mergedOption = selectedAuthScheme.authSchemeOption().toBuilder(); - existingAuthScheme.authSchemeOption().forEachIdentityProperty(mergedOption::putIdentityPropertyIfAbsent); - existingAuthScheme.authSchemeOption().forEachSignerProperty(mergedOption::putSignerPropertyIfAbsent); - - return new SelectedAuthScheme<>( - selectedAuthScheme.identity(), - selectedAuthScheme.signer(), - mergedOption.build() - ); - } - /** * Resolve auth scheme options using full endpoint params. */ @@ -722,69 +693,6 @@ private IdentityProviders applyCredentialOverrides(SdkRequest request, IdentityP .orElse(base); } - /** - * Select auth scheme from options - */ - private SelectedAuthScheme selectAuthScheme(List authOptions, - Map> authSchemes, - IdentityProviders identityProviders) { - List> discardedReasons = new ArrayList<>(); - - for (AuthSchemeOption authOption : authOptions) { - AuthScheme authScheme = authSchemes.get(authOption.schemeId()); - SelectedAuthScheme selectedAuthScheme = - trySelectAuthScheme(authOption, authScheme, identityProviders, discardedReasons); - - if (selectedAuthScheme != null) { - if (!discardedReasons.isEmpty()) { - log.debug(() -> String.format("%s auth will be used, discarded: '%s'", authOption.schemeId(), - discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); - } - return selectedAuthScheme; - } - } - - throw new IllegalStateException("Failed to determine how to authenticate the user: " + - discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))); - } - - /** - * Try to select a specific auth scheme (copied from AuthSchemeResolutionStage). - */ - private SelectedAuthScheme trySelectAuthScheme(AuthSchemeOption authOption, - AuthScheme authScheme, - IdentityProviders identityProviders, - List> discardedReasons) { - if (authScheme == null) { - discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); - return null; - } - - IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); - if (identityProvider == null) { - discardedReasons.add(() -> String.format("'%s' does not have an identity provider configured.", - authOption.schemeId())); - return null; - } - - HttpSigner signer; - try { - signer = authScheme.signer(); - } catch (RuntimeException e) { - discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", - authOption.schemeId(), e.getMessage())); - return null; - } - - ResolveIdentityRequest.Builder identityRequestBuilder = - ResolveIdentityRequest.builder(); - authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); - - CompletableFuture identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); - - return new SelectedAuthScheme<>(identity, signer, authOption); - } - /** * Presign the provided HTTP request using old Signer */ From b9600eaaa33ebc97b140264469eff7e80ca89c9e Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Fri, 27 Feb 2026 09:17:29 -0800 Subject: [PATCH 3/6] Add tests --- .../http/auth/AuthSchemeResolverTest.java | 189 ++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java new file mode 100644 index 000000000000..90c2cdf12008 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/auth/AuthSchemeResolverTest.java @@ -0,0 +1,189 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.auth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolverTest { + + private static final String SCHEME_A = "schemeA"; + private static final String SCHEME_B = "schemeB"; + + @Test + void selectAuthScheme_firstOptionSucceeds_returnsFirstScheme() { + AuthScheme schemeA = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @Test + void selectAuthScheme_firstOptionNoScheme_fallsBackToSecond() { + AuthScheme schemeB = createMockAuthScheme(); + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_firstOptionNoIdentityProvider_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.identityProvider(any())).thenReturn(null); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_signerThrows_fallsBackToSecond() { + AuthScheme schemeA = createMockAuthScheme(); + when(schemeA.signer()).thenThrow(new RuntimeException("Signer not available")); + + AuthScheme schemeB = createMockAuthScheme(); + + Map> authSchemes = new HashMap<>(); + authSchemes.put(SCHEME_A, schemeA); + authSchemes.put(SCHEME_B, schemeB); + + List options = Arrays.asList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build(), + AuthSchemeOption.builder().schemeId(SCHEME_B).build() + ); + + SelectedAuthScheme result = AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null); + + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_B); + } + + @Test + void selectAuthScheme_allOptionsFail_throwsException() { + Map> authSchemes = new HashMap<>(); + + List options = Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_A).build() + ); + + assertThatThrownBy(() -> AuthSchemeResolver.selectAuthScheme( + options, authSchemes, mock(IdentityProviders.class), null)) + .isInstanceOf(SdkException.class) + .hasMessageContaining("Failed to determine how to authenticate"); + } + + @Test + void mergeProperties_noExistingScheme_returnsOriginal() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + ExecutionAttributes attributes = new ExecutionAttributes(); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isSameAs(selected); + } + + @Test + @SuppressWarnings("unchecked") + void mergeProperties_withExistingScheme_returnsNewInstance() { + SelectedAuthScheme selected = createSelectedAuthScheme(SCHEME_A); + SelectedAuthScheme existing = createSelectedAuthScheme(SCHEME_B); + + ExecutionAttributes attributes = new ExecutionAttributes(); + attributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, existing); + + SelectedAuthScheme result = AuthSchemeResolver.mergePreExistingAuthSchemeProperties( + selected, attributes); + + assertThat(result).isNotSameAs(selected); + assertThat(result.authSchemeOption().schemeId()).isEqualTo(SCHEME_A); + } + + @SuppressWarnings("unchecked") + private AuthScheme createMockAuthScheme() { + AuthScheme scheme = mock(AuthScheme.class); + IdentityProvider identityProvider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(identityProvider).resolveIdentity(any(ResolveIdentityRequest.class)); + when(scheme.identityProvider(any())).thenReturn(identityProvider); + when(scheme.signer()).thenReturn(mock(HttpSigner.class)); + return scheme; + } + + @SuppressWarnings("unchecked") + private SelectedAuthScheme createSelectedAuthScheme(String schemeId) { + Identity mockIdentity = mock(Identity.class); + HttpSigner mockSigner = mock(HttpSigner.class); + return new SelectedAuthScheme<>( + CompletableFuture.completedFuture(mockIdentity), + mockSigner, + AuthSchemeOption.builder().schemeId(schemeId).build() + ); + } +} From ac3d8a48aeb4ad27c181ccab6d8fbfc8a845fdd5 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Mon, 2 Mar 2026 14:08:40 -0800 Subject: [PATCH 4/6] Address PR feedback: Added callback pattern for auth scheme options resolution Removed duplicate code Added fixture tests --- .../tasks/AuthSchemeGeneratorTasks.java | 1 - .../poet/builder/BaseClientBuilderClass.java | 1 - .../codegen/poet/client/AsyncClientClass.java | 114 +--------------- .../codegen/poet/client/ClientClassUtils.java | 119 +++++++++++++++++ .../codegen/poet/client/SyncClientClass.java | 110 +-------------- .../poet/client/specs/JsonProtocolSpec.java | 8 +- .../poet/client/specs/QueryProtocolSpec.java | 8 +- .../poet/client/specs/XmlProtocolSpec.java | 8 +- ...dgeForH2-service-client-builder-class.java | 2 - ...tom-context-params-async-client-class.java | 62 ++++++--- .../poet/client/test-query-client-class.java | 125 ++++++++++++++---- .../internal/AwsExecutionContextBuilder.java | 10 +- .../client/handler/ClientExecutionParams.java | 14 +- .../SdkInternalExecutionAttribute.java | 10 +- .../stages/AuthSchemeResolutionStage.java | 20 ++- .../identity/AuthSchemeOptionsResolver.java | 39 ++++++ .../internal/signing/DefaultS3Presigner.java | 23 +--- 17 files changed, 344 insertions(+), 330 deletions(-) create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java index 7bb0eabb14d3..471d9a5ddaab 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/emitters/tasks/AuthSchemeGeneratorTasks.java @@ -47,7 +47,6 @@ protected List createTasks() { tasks.add(generateDefaultParamsImpl()); tasks.add(generateModelBasedProvider()); tasks.add(generatePreferenceProvider()); - // AuthSchemeInterceptor removed - auth scheme resolution now happens in client operation methods if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { tasks.add(generateEndpointBasedProvider()); tasks.add(generateEndpointAwareAuthSchemeParams()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java index 9560eac7fbe8..e342a79b4e25 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java @@ -370,7 +370,6 @@ private MethodSpec finalizeServiceConfigurationMethod() { List builtInInterceptors = new ArrayList<>(); - // AuthSchemeInterceptor removed - auth scheme resolution now happens in client operation methods builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName()); builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName()); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 043297ca8024..b3bb15e9d257 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -174,7 +174,7 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethods(protocolSpec.additionalMethods()) .addMethod(protocolSpec.initProtocolFactory(model)) .addMethod(resolveMetricPublishersMethod()) - .addMethod(resolveAuthSchemeOptionsMethod()); + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(), @@ -591,116 +591,4 @@ private void addScheduledExecutorIfNeeded(Builder classBuilder) { hasScheduledExecutor = true; } } - - private MethodSpec resolveAuthSchemeOptionsMethod() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") - .addModifiers(PRIVATE) - .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) - .addParameter(SdkRequest.class, "request") - .addParameter(String.class, "operationName") - .addParameter(SdkClientConfiguration.class, "clientConfiguration"); - - ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); - ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); - - builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", - providerInterface, providerInterface, SdkClientOption.class); - - if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { - // Simple case: operation, region, and optionally regionSet - builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", - paramsInterface, paramsInterface); - ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); - if (authSchemeSpecUtils.usesSigV4()) { - builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); - } - if (authSchemeSpecUtils.hasSigV4aSupport()) { - ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); - ClassName collectionUtils = ClassName.get("software.amazon.awssdk.utils", "CollectionUtils"); - builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", - ClassName.get(java.util.Set.class), awsClientOption); - builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", collectionUtils); - builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); - builder.nextControlFlow("else"); - builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", - regionSet, awsClientOption); - builder.endControlFlow(); - } - builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); - } else { - // Endpoint-based auth: need to build endpoint params and copy to auth params - ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); - ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); - ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", - "ExecutionAttributes"); - ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", - "AwsExecutionAttribute"); - ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", - "SdkExecutionAttribute"); - ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", - "AwsClientOption"); - - // Build execution attributes with values from client config - builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); - builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " - + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " - + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", - sdkExecutionAttribute); - builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " - + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", - ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), - SdkClientOption.class); - builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " - + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", - ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), - SdkClientOption.class); - - // Use ruleParams to build endpoint params - builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointParamsClass, resolverInterceptor); - - // Build auth scheme params from endpoint params - builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); - - boolean regionIncluded = false; - for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { - if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { - continue; - } - regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); - String methodName = endpointRulesSpecUtils.paramMethodName(paramName); - builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); - } - - builder.addStatement("paramsBuilder.operation(operationName)"); - - if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { - builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); - } - - // Set endpoint provider on params if applicable - ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); - ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); - - builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); - builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", - ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), - SdkClientOption.class); - builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); - builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", - paramsBuilderClass, endpointProviderInterface); - builder.endControlFlow(); - builder.endControlFlow(); - - builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); - } - - return builder.build(); - } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java index 9d4e15b3dd86..f0f27524cf65 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/ClientClassUtils.java @@ -29,6 +29,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.lang.model.element.Modifier; @@ -45,6 +46,7 @@ import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor; import software.amazon.awssdk.codegen.poet.PoetExtension; import software.amazon.awssdk.codegen.poet.PoetUtils; +import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.SdkPlugin; import software.amazon.awssdk.core.SdkRequest; @@ -53,6 +55,7 @@ import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.retry.RetryMode; import software.amazon.awssdk.core.signer.Signer; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.utils.AttributeMap; import software.amazon.awssdk.utils.CollectionUtils; @@ -346,4 +349,120 @@ public static MethodSpec updateRetryStrategyClientConfigurationMethod() { static String transformServiceId(String serviceId) { return serviceId.replace(" ", "_"); } + + static MethodSpec resolveAuthSchemeOptionsMethod(AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") + .addModifiers(PRIVATE) + .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) + .addParameter(SdkRequest.class, "request") + .addParameter(String.class, "operationName") + .addParameter(SdkClientConfiguration.class, "clientConfiguration"); + + ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); + + builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", + providerInterface, providerInterface, SdkClientOption.class); + + if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) { + addEndpointBasedAuthSchemeResolution(builder, authSchemeSpecUtils, endpointRulesSpecUtils); + } else { + addSimpleAuthSchemeResolution(builder, authSchemeSpecUtils); + } + + return builder.build(); + } + + private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", + paramsInterface, paramsInterface); + + if (authSchemeSpecUtils.usesSigV4()) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + if (authSchemeSpecUtils.hasSigV4aSupport()) { + ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); + builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", + ClassName.get(Set.class), awsClientOption); + builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class); + builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); + builder.nextControlFlow("else"); + builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", + regionSet, awsClientOption); + builder.endControlFlow(); + } + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } + + private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder builder, + AuthSchemeSpecUtils authSchemeSpecUtils, + EndpointRulesSpecUtils endpointRulesSpecUtils) { + ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); + ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); + ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); + ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); + ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); + ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); + ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); + ClassName sdkInternalExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", + "SdkInternalExecutionAttribute"); + + builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); + builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " + + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", + awsExecutionAttribute, awsClientOption); + builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", sdkExecutionAttribute); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " + + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", + sdkInternalExecutionAttribute, SdkClientOption.class); + builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " + + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", + sdkInternalExecutionAttribute, SdkClientOption.class); + + builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", + endpointParamsClass, resolverInterceptor); + + builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); + + boolean regionIncluded = false; + for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { + if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { + continue; + } + regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); + String methodName = endpointRulesSpecUtils.paramMethodName(paramName); + builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); + } + + builder.addStatement("paramsBuilder.operation(operationName)"); + + if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { + builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); + } + + ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); + ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); + + builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); + builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", + ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), SdkClientOption.class); + builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); + builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", + paramsBuilderClass, endpointProviderInterface); + builder.endControlFlow(); + builder.endControlFlow(); + + builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); + } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 4a7f52c67b95..605c55f29458 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -142,7 +142,7 @@ protected void addAdditionalMethods(TypeSpec.Builder type) { .addMethod(nameMethod()) .addMethods(protocolSpec.additionalMethods()) .addMethod(resolveMetricPublishersMethod()) - .addMethod(resolveAuthSchemeOptionsMethod()); + .addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils)); protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod); type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod()); @@ -454,112 +454,4 @@ protected MethodSpec.Builder waiterOperationBody(MethodSpec.Builder builder) { .addStatement("return $T.builder().client(this).build()", poetExtensions.getSyncWaiterInterface()); } - - private MethodSpec resolveAuthSchemeOptionsMethod() { - MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions") - .addModifiers(PRIVATE) - .returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class))) - .addParameter(SdkRequest.class, "request") - .addParameter(String.class, "operationName") - .addParameter(SdkClientConfiguration.class, "clientConfiguration"); - - ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName(); - ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName(); - ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption"); - - builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)", - providerInterface, providerInterface, SdkClientOption.class); - - if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { - // Simple case: operation, region, and optionally regionSet - builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)", - paramsInterface, paramsInterface); - if (authSchemeSpecUtils.usesSigV4()) { - builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); - } - if (authSchemeSpecUtils.hasSigV4aSupport()) { - ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet"); - ClassName collectionUtils = ClassName.get("software.amazon.awssdk.utils", "CollectionUtils"); - builder.addStatement("$T sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)", - ClassName.get(java.util.Set.class), awsClientOption); - builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", collectionUtils); - builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet); - builder.nextControlFlow("else"); - builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))", - regionSet, awsClientOption); - builder.endControlFlow(); - } - builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); - } else { - // Endpoint-based auth: use ruleParams to build endpoint params, then copy to auth params - ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName(); - ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName(); - ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes"); - ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute"); - ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute"); - - // Build minimal execution attributes needed by ruleParams - builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass); - builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, " - + "clientConfiguration.option($T.AWS_REGION))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, " - + "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, " - + "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))", - awsExecutionAttribute, awsClientOption); - builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", - sdkExecutionAttribute); - builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, " - + "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))", - ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), - SdkClientOption.class); - builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, " - + "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))", - ClassName.get("software.amazon.awssdk.core.interceptor", "SdkInternalExecutionAttribute"), - SdkClientOption.class); - - // Use ruleParams to build endpoint params (handles context params from request) - builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)", - endpointParamsClass, resolverInterceptor); - - // Build auth scheme params from endpoint params - builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface); - - boolean regionIncluded = false; - for (String paramName : endpointRulesSpecUtils.parameters().keySet()) { - if (!authSchemeSpecUtils.includeParamForProvider(paramName)) { - continue; - } - regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region"); - String methodName = endpointRulesSpecUtils.paramMethodName(paramName); - builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName); - } - - builder.addStatement("paramsBuilder.operation(operationName)"); - - if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) { - builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption); - } - - // Set endpoint provider on params if applicable - ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder"); - ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName(); - - builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass); - builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)", - ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), - SdkClientOption.class); - builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface); - builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)", - paramsBuilderClass, endpointProviderInterface); - builder.endControlFlow(); - builder.endControlFlow(); - - builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())"); - } - - return builder.build(); - } } diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java index 3b1f18672dc8..ad964239ca02 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java @@ -222,8 +222,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)\n", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)\n") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -297,8 +297,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(".withErrorResponseHandler(errorResponseHandler)\n") .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(hostPrefixExpression(opModel)) .add(discoveredEndpoint(opModel)) .add(credentialType(opModel, model)) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java index deb6b7d3b775..624a4813655d 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/QueryProtocolSpec.java @@ -116,8 +116,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -157,8 +157,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper .add(credentialType(opModel, intermediateModel)) .add(".withRequestConfiguration(clientConfiguration)") .add(".withMetricCollector(apiCallMetricCollector)\n") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java index 6a524ff89f0a..92e646312b2a 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/XmlProtocolSpec.java @@ -136,8 +136,8 @@ public CodeBlock executionHandler(OperationModel opModel) { .add(".withRequestConfiguration(clientConfiguration)") .add(".withInput($L)", opModel.getInput().getVariableName()) .add(".withMetricCollector(apiCallMetricCollector)") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); @@ -216,8 +216,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper builder.add(hostPrefixExpression(opModel)) .add(credentialType(opModel, model)) .add(".withMetricCollector(apiCallMetricCollector)\n") - .add(".withAuthSchemeOptions(resolveAuthSchemeOptions($L, $S, clientConfiguration))\n", - opModel.getInput().getVariableName(), opModel.getOperationName()) + .add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n", + opModel.getOperationName()) .add(asyncRequestBody(opModel)) .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) .add(HttpChecksumTrait.create(opModel)); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java index d3fc4996de98..4ff83066fe22 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-h2-usePriorKnowledgeForH2-service-client-builder-class.java @@ -31,7 +31,6 @@ import software.amazon.awssdk.regions.ServiceMetadataAdvancedOption; import software.amazon.awssdk.retries.api.RetryStrategy; import software.amazon.awssdk.services.h2.auth.scheme.H2AuthSchemeProvider; -import software.amazon.awssdk.services.h2.auth.scheme.internal.H2AuthSchemeInterceptor; import software.amazon.awssdk.services.h2.endpoints.H2EndpointProvider; import software.amazon.awssdk.services.h2.endpoints.internal.H2RequestSetEndpointInterceptor; import software.amazon.awssdk.services.h2.endpoints.internal.H2ResolveEndpointInterceptor; @@ -70,7 +69,6 @@ protected final SdkClientConfiguration mergeServiceDefaults(SdkClientConfigurati @Override protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientConfiguration config) { List endpointInterceptors = new ArrayList<>(); - endpointInterceptors.add(new H2AuthSchemeInterceptor()); endpointInterceptors.add(new H2ResolveEndpointInterceptor()); endpointInterceptors.add(new H2RequestSetEndpointInterceptor()); ClasspathInterceptorChainFactory interceptorFactory = new ClasspathInterceptorChainFactory(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java index d59409dcbcdd..67473d68951c 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-custom-context-params-async-client-class.java @@ -13,6 +13,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; @@ -29,6 +30,7 @@ import software.amazon.awssdk.core.http.HttpResponseHandler; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.core.retry.RetryMode; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; @@ -38,6 +40,8 @@ import software.amazon.awssdk.protocols.json.BaseAwsJsonProtocolFactory; import software.amazon.awssdk.protocols.json.JsonOperationMetadata; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeParams; +import software.amazon.awssdk.services.foobar.auth.scheme.FooBarAuthSchemeProvider; import software.amazon.awssdk.services.foobar.endpoints.FooBarClientContextParams; import software.amazon.awssdk.services.foobar.internal.FooBarServiceClientConfigurationBuilder; import software.amazon.awssdk.services.foobar.internal.ServiceVersionInfo; @@ -60,7 +64,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { private static final Logger log = LoggerFactory.getLogger(DefaultFooBarAsyncClient.class); private static final AwsProtocolMetadata protocolMetadata = AwsProtocolMetadata.builder() - .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); + .serviceProtocol(AwsServiceProtocol.REST_JSON).build(); private final AsyncClientHandler clientHandler; @@ -71,7 +75,7 @@ final class DefaultFooBarAsyncClient implements FooBarAsyncClient { protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { this.clientHandler = new AwsAsyncClientHandler(clientConfiguration); this.clientConfiguration = clientConfiguration.toBuilder().option(SdkClientOption.SDK_CLIENT, this) - .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); + .option(SdkClientOption.API_METADATA, "Foo_Bar" + "#" + ServiceVersionInfo.VERSION).build(); this.protocolFactory = init(AwsJsonProtocolFactory.builder()).build(); } @@ -100,38 +104,43 @@ protected DefaultFooBarAsyncClient(SdkClientConfiguration clientConfiguration) { @Override public CompletableFuture getDatabaseVersion(GetDatabaseVersionRequest getDatabaseVersionRequest) { SdkClientConfiguration clientConfiguration = updateSdkClientConfiguration(getDatabaseVersionRequest, - this.clientConfiguration); + this.clientConfiguration); List metricPublishers = resolveMetricPublishers(clientConfiguration, getDatabaseVersionRequest - .overrideConfiguration().orElse(null)); + .overrideConfiguration().orElse(null)); MetricCollector apiCallMetricCollector = metricPublishers.isEmpty() ? NoOpMetricCollector.create() : MetricCollector - .create("ApiCall"); + .create("ApiCall"); try { apiCallMetricCollector.reportMetric(CoreMetric.SERVICE_ID, "Foo Bar"); apiCallMetricCollector.reportMetric(CoreMetric.OPERATION_NAME, "GetDatabaseVersion"); JsonOperationMetadata operationMetadata = JsonOperationMetadata.builder().hasStreamingSuccessResponse(false) - .isPayloadJson(true).build(); + .isPayloadJson(true).build(); HttpResponseHandler responseHandler = protocolFactory.createResponseHandler( - operationMetadata, GetDatabaseVersionResponse::builder); + operationMetadata, GetDatabaseVersionResponse::builder); Function> exceptionMetadataMapper = errorCode -> { if (errorCode == null) { return Optional.empty(); } switch (errorCode) { - default: - return Optional.empty(); + default: + return Optional.empty(); } }; HttpResponseHandler errorResponseHandler = createErrorResponseHandler(protocolFactory, - operationMetadata, exceptionMetadataMapper); + operationMetadata, exceptionMetadataMapper); CompletableFuture executeFuture = clientHandler - .execute(new ClientExecutionParams() - .withOperationName("GetDatabaseVersion").withProtocolMetadata(protocolMetadata) - .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withMetricCollector(apiCallMetricCollector) - .withInput(getDatabaseVersionRequest)); + .execute(new ClientExecutionParams() + .withOperationName("GetDatabaseVersion") + .withProtocolMetadata(protocolMetadata) + .withMarshaller(new GetDatabaseVersionRequestMarshaller(protocolFactory)) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetDatabaseVersion", clientConfiguration)) + .withInput(getDatabaseVersionRequest)); CompletableFuture whenCompleted = executeFuture.whenComplete((r, e) -> { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); }); @@ -155,11 +164,11 @@ public final String serviceName() { private > T init(T builder) { return builder.clientConfiguration(clientConfiguration).defaultServiceExceptionSupplier(FooBarException::builder) - .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); + .protocol(AwsJsonProtocol.REST_JSON).protocolVersion("1.1"); } private static List resolveMetricPublishers(SdkClientConfiguration clientConfiguration, - RequestOverrideConfiguration requestOverrideConfiguration) { + RequestOverrideConfiguration requestOverrideConfiguration) { List publishers = null; if (requestOverrideConfiguration != null) { publishers = requestOverrideConfiguration.metricPublishers(); @@ -173,6 +182,15 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + FooBarAuthSchemeProvider authSchemeProvider = (FooBarAuthSchemeProvider) clientConfiguration + .option(SdkClientOption.AUTH_SCHEME_PROVIDER); + FooBarAuthSchemeParams.Builder paramsBuilder = FooBarAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + return authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); @@ -211,15 +229,15 @@ private SdkClientConfiguration updateSdkClientConfiguration(SdkRequest request, newContextParams = (newContextParams != null) ? newContextParams : AttributeMap.empty(); originalContextParams = originalContextParams != null ? originalContextParams : AttributeMap.empty(); Validate.validState( - Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), - newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), - "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); + Objects.equals(originalContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED), + newContextParams.get(FooBarClientContextParams.CROSS_REGION_ACCESS_ENABLED)), + "CROSS_REGION_ACCESS_ENABLED cannot be modified by request level plugins"); updateRetryStrategyClientConfiguration(configuration); return configuration.build(); } private HttpResponseHandler createErrorResponseHandler(BaseAwsJsonProtocolFactory protocolFactory, - JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { + JsonOperationMetadata operationMetadata, Function> exceptionMetadataMapper) { return protocolFactory.createErrorResponseHandler(operationMetadata, exceptionMetadataMapper); } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java index 3c168823a547..766210e8c6f4 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-query-client-class.java @@ -5,6 +5,7 @@ import java.util.function.Consumer; import software.amazon.awssdk.annotations.Generated; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.client.config.AwsClientOption; import software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.awscore.internal.AwsProtocolMetadata; @@ -32,12 +33,15 @@ import software.amazon.awssdk.core.runtime.transform.StreamingRequestMarshaller; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; import software.amazon.awssdk.metrics.NoOpMetricCollector; import software.amazon.awssdk.protocols.core.ExceptionMetadata; import software.amazon.awssdk.protocols.query.AwsQueryProtocolFactory; import software.amazon.awssdk.retries.api.RetryStrategy; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeParams; +import software.amazon.awssdk.services.query.auth.scheme.QueryAuthSchemeProvider; import software.amazon.awssdk.services.query.internal.QueryServiceClientConfigurationBuilder; import software.amazon.awssdk.services.query.internal.ServiceVersionInfo; import software.amazon.awssdk.services.query.model.APostOperationRequest; @@ -163,6 +167,7 @@ public APostOperationResponse aPostOperation(APostOperationRequest aPostOperatio .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .hostPrefixExpression(resolvedHostExpression).withRequestConfiguration(clientConfiguration) .withInput(aPostOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "APostOperation", clientConfiguration)) .withMarshaller(new APostOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -210,10 +215,15 @@ public APostOperationWithOutputResponse aPostOperationWithOutput( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("APostOperationWithOutput").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(aPostOperationWithOutputRequest) + .withOperationName("APostOperationWithOutput") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(aPostOperationWithOutputRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "APostOperationWithOutput", clientConfiguration)) .withMarshaller(new APostOperationWithOutputRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -259,6 +269,7 @@ public BearerAuthOperationResponse bearerAuthOperation(BearerAuthOperationReques .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) .credentialType(CredentialType.TOKEN).withRequestConfiguration(clientConfiguration) .withInput(bearerAuthOperationRequest).withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, "BearerAuthOperation", clientConfiguration)) .withMarshaller(new BearerAuthOperationRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -309,6 +320,8 @@ public GetOperationWithChecksumResponse getOperationWithChecksum( .withRequestConfiguration(clientConfiguration) .withInput(getOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "GetOperationWithChecksum", clientConfiguration)) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum.builder().requestChecksumRequired(true).isRequestStreaming(false) @@ -364,6 +377,8 @@ public OperationWithChecksumRequiredResponse operationWithChecksumRequired( .withRequestConfiguration(clientConfiguration) .withInput(operationWithChecksumRequiredRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithChecksumRequired", clientConfiguration)) .putExecutionAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED, HttpChecksumRequired.create()) .withMarshaller(new OperationWithChecksumRequiredRequestMarshaller(protocolFactory))); @@ -409,10 +424,15 @@ public OperationWithContextParamResponse operationWithContextParam( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithContextParamRequest) + .withOperationName("OperationWithContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithContextParam", clientConfiguration)) .withMarshaller(new OperationWithContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -457,10 +477,15 @@ public OperationWithCustomMemberResponse operationWithCustomMember( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithCustomMember").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithCustomMemberRequest) + .withOperationName("OperationWithCustomMember") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithCustomMemberRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomMember", clientConfiguration)) .withMarshaller(new OperationWithCustomMemberRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -506,10 +531,15 @@ public OperationWithCustomizedOperationContextParamResponse operationWithCustomi return clientHandler .execute(new ClientExecutionParams() .withOperationName("OperationWithCustomizedOperationContextParam") - .withProtocolMetadata(protocolMetadata).withResponseHandler(responseHandler) - .withErrorResponseHandler(errorResponseHandler).withRequestConfiguration(clientConfiguration) + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) .withInput(operationWithCustomizedOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithCustomizedOperationContextParam", + clientConfiguration)) .withMarshaller(new OperationWithCustomizedOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -554,10 +584,15 @@ public OperationWithMapOperationContextParamResponse operationWithMapOperationCo return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithMapOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) + .withOperationName("OperationWithMapOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) .withRequestConfiguration(clientConfiguration) - .withInput(operationWithMapOperationContextParamRequest).withMetricCollector(apiCallMetricCollector) + .withInput(operationWithMapOperationContextParamRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithMapOperationContextParam", clientConfiguration)) .withMarshaller(new OperationWithMapOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -601,10 +636,15 @@ public OperationWithNoneAuthTypeResponse operationWithNoneAuthType( return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithNoneAuthType").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithNoneAuthTypeRequest) + .withOperationName("OperationWithNoneAuthType") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithNoneAuthTypeRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithNoneAuthType", clientConfiguration)) .withMarshaller(new OperationWithNoneAuthTypeRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -649,10 +689,15 @@ public OperationWithOperationContextParamResponse operationWithOperationContextP return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithOperationContextParam").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithOperationContextParamRequest) + .withOperationName("OperationWithOperationContextParam") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithOperationContextParamRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithOperationContextParam", clientConfiguration)) .withMarshaller(new OperationWithOperationContextParamRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -703,6 +748,8 @@ public OperationWithRequestCompressionResponse operationWithRequestCompression( .withRequestConfiguration(clientConfiguration) .withInput(operationWithRequestCompressionRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithRequestCompression", clientConfiguration)) .putExecutionAttribute(SdkInternalExecutionAttribute.REQUEST_COMPRESSION, RequestCompression.builder().encodings("gzip").isStreaming(false).build()) .withMarshaller(new OperationWithRequestCompressionRequestMarshaller(protocolFactory))); @@ -748,10 +795,15 @@ public OperationWithStaticContextParamsResponse operationWithStaticContextParams return clientHandler .execute(new ClientExecutionParams() - .withOperationName("OperationWithStaticContextParams").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(operationWithStaticContextParamsRequest) + .withOperationName("OperationWithStaticContextParams") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(operationWithStaticContextParamsRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "OperationWithStaticContextParams", clientConfiguration)) .withMarshaller(new OperationWithStaticContextParamsRequestMarshaller(protocolFactory))); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -828,6 +880,8 @@ public ReturnT putOperationWithChecksum(PutOperationWithChecksumReques .withRequestConfiguration(clientConfiguration) .withInput(putOperationWithChecksumRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "PutOperationWithChecksum", clientConfiguration)) .putExecutionAttribute( SdkInternalExecutionAttribute.HTTP_CHECKSUM, HttpChecksum @@ -906,6 +960,8 @@ public StreamingInputOperationResponse streamingInputOperation(StreamingInputOpe .withRequestConfiguration(clientConfiguration) .withInput(streamingInputOperationRequest) .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingInputOperation", clientConfiguration)) .withRequestBody(requestBody) .withMarshaller( StreamingRequestMarshaller.builder() @@ -960,10 +1016,16 @@ public ReturnT streamingOutputOperation(StreamingOutputOperationReques return clientHandler.execute( new ClientExecutionParams() - .withOperationName("StreamingOutputOperation").withProtocolMetadata(protocolMetadata) - .withResponseHandler(responseHandler).withErrorResponseHandler(errorResponseHandler) - .withRequestConfiguration(clientConfiguration).withInput(streamingOutputOperationRequest) - .withMetricCollector(apiCallMetricCollector).withResponseTransformer(responseTransformer) + .withOperationName("StreamingOutputOperation") + .withProtocolMetadata(protocolMetadata) + .withResponseHandler(responseHandler) + .withErrorResponseHandler(errorResponseHandler) + .withRequestConfiguration(clientConfiguration) + .withInput(streamingOutputOperationRequest) + .withMetricCollector(apiCallMetricCollector) + .withAuthSchemeOptionsResolver( + r -> resolveAuthSchemeOptions(r, "StreamingOutputOperation", clientConfiguration)) + .withResponseTransformer(responseTransformer) .withMarshaller(new StreamingOutputOperationRequestMarshaller(protocolFactory)), responseTransformer); } finally { metricPublishers.forEach(p -> p.publish(apiCallMetricCollector.collect())); @@ -1003,6 +1065,15 @@ private static List resolveMetricPublishers(SdkClientConfigurat return publishers; } + private List resolveAuthSchemeOptions(SdkRequest request, String operationName, + SdkClientConfiguration clientConfiguration) { + QueryAuthSchemeProvider authSchemeProvider = (QueryAuthSchemeProvider) clientConfiguration + .option(SdkClientOption.AUTH_SCHEME_PROVIDER); + QueryAuthSchemeParams.Builder paramsBuilder = QueryAuthSchemeParams.builder().operation(operationName); + paramsBuilder.region(clientConfiguration.option(AwsClientOption.AWS_REGION)); + return authSchemeProvider.resolveAuthScheme(paramsBuilder.build()); + } + private void updateRetryStrategyClientConfiguration(SdkClientConfiguration.Builder configuration) { ClientOverrideConfiguration.Builder builder = configuration.asOverrideConfigurationBuilder(); RetryMode retryMode = builder.retryMode(); diff --git a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java index 58c27ad5d3b7..24a9fc09157e 100644 --- a/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java +++ b/core/aws-core/src/main/java/software/amazon/awssdk/awscore/internal/AwsExecutionContextBuilder.java @@ -146,12 +146,12 @@ private AwsExecutionContextBuilder() { // Auth Scheme resolution related attributes putAuthSchemeResolutionAttributes(executionAttributes, clientConfig, originalRequest); - // Set auth scheme options from ClientExecutionParams - if (executionParams.authSchemeOptions() != null) { - executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS, - executionParams.authSchemeOptions()); + if (executionParams.authSchemeOptionsResolver() != null) { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + executionParams.authSchemeOptionsResolver()); - recordAuthSchemeBusinessMetrics(executionParams.authSchemeOptions(), executionAttributes, originalRequest); + List authOptions = executionParams.authSchemeOptionsResolver().resolve(originalRequest); + recordAuthSchemeBusinessMetrics(authOptions, executionAttributes, originalRequest); } // Set the identity provider updater for the pipeline stage to use diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java index ad95335ef5ce..66b4d9e7bf0f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/client/handler/ClientExecutionParams.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.core.client.handler; import java.net.URI; -import java.util.List; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.CredentialType; @@ -31,9 +30,9 @@ import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.runtime.transform.Marshaller; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.metrics.MetricCollector; /** @@ -65,7 +64,7 @@ public final class ClientExecutionParams { private MetricCollector metricCollector; private final ExecutionAttributes attributes = new ExecutionAttributes(); private SdkClientConfiguration requestConfiguration; - private List authSchemeOptions; + private AuthSchemeOptionsResolver authSchemeOptionsResolver; public Marshaller getMarshaller() { return marshaller; @@ -265,12 +264,13 @@ public ClientExecutionParams withRequestConfiguration(SdkCl return this; } - public List authSchemeOptions() { - return authSchemeOptions; + public AuthSchemeOptionsResolver authSchemeOptionsResolver() { + return authSchemeOptionsResolver; } - public ClientExecutionParams withAuthSchemeOptions(List authSchemeOptions) { - this.authSchemeOptions = authSchemeOptions; + public ClientExecutionParams withAuthSchemeOptionsResolver( + AuthSchemeOptionsResolver authSchemeOptionsResolver) { + this.authSchemeOptionsResolver = authSchemeOptionsResolver; return this; } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java index f4858b27def1..8c73ce3b70ef 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/interceptor/SdkInternalExecutionAttribute.java @@ -29,6 +29,7 @@ import software.amazon.awssdk.core.interceptor.trait.HttpChecksum; import software.amazon.awssdk.core.interceptor.trait.HttpChecksumRequired; import software.amazon.awssdk.core.internal.interceptor.trait.RequestCompression; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.core.useragent.AdditionalMetadata; import software.amazon.awssdk.core.useragent.BusinessMetricCollection; @@ -36,7 +37,6 @@ import software.amazon.awssdk.endpoints.EndpointProvider; import software.amazon.awssdk.http.SdkHttpExecutionAttributes; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeProvider; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; @@ -176,11 +176,11 @@ public final class SdkInternalExecutionAttribute extends SdkExecutionAttribute { new ExecutionAttribute<>("IdentityProviderUpdater"); /** - * The resolved auth scheme options for a request. These are resolved by the auth scheme provider - * but identity resolution is deferred to the AuthSchemeResolutionStage. + * Callback to resolve auth scheme options from the (possibly modified) request. + * Called by AuthSchemeResolutionStage after interceptors have run. */ - public static final ExecutionAttribute> AUTH_SCHEME_OPTIONS = - new ExecutionAttribute<>("AuthSchemeOptions"); + public static final ExecutionAttribute AUTH_SCHEME_OPTIONS_RESOLVER = + new ExecutionAttribute<>("AuthSchemeOptionsResolver"); /** * The selected auth scheme for a request. diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java index 81a1437f1ff4..fed726670b99 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStage.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.SelectedAuthScheme; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; @@ -26,6 +27,7 @@ import software.amazon.awssdk.core.internal.http.RequestExecutionContext; import software.amazon.awssdk.core.internal.http.auth.AuthSchemeResolver; import software.amazon.awssdk.core.internal.http.pipeline.RequestToRequestPipeline; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; @@ -52,8 +54,8 @@ public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionCo return request; } - List authOptions = - executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS); + SdkRequest sdkRequest = context.executionContext().interceptorContext().request(); + List authOptions = resolveAuthSchemeOptions(executionAttributes, sdkRequest); if (authOptions == null || authOptions.isEmpty()) { return request; } @@ -64,9 +66,7 @@ public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionCo IdentityProviderUpdater updater = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER); if (updater != null) { - identityProviders = updater.update( - context.executionContext().interceptorContext().request(), - identityProviders); + identityProviders = updater.update(sdkRequest, identityProviders); } MetricCollector metricCollector = @@ -81,4 +81,14 @@ public SdkHttpFullRequest execute(SdkHttpFullRequest request, RequestExecutionCo return request; } + + private List resolveAuthSchemeOptions(ExecutionAttributes executionAttributes, SdkRequest request) { + AuthSchemeOptionsResolver resolver = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER); + + if (resolver == null) { + return null; + } + return resolver.resolve(request); + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java new file mode 100644 index 000000000000..7a37d7c3dd6b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/spi/identity/AuthSchemeOptionsResolver.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.spi.identity; + +import java.util.List; +import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; + +/** + * Callback interface for resolving auth scheme options from the request. + *

+ * This allows auth scheme resolution to happen after interceptors have modified the request, + * ensuring that any request modifications affecting auth scheme selection are respected. + */ +@FunctionalInterface +@SdkProtectedApi +public interface AuthSchemeOptionsResolver { + /** + * Resolves auth scheme options for the given request. + * + * @param request The request (after interceptors have modified it) + * @return List of auth scheme options in priority order + */ + List resolve(SdkRequest request); +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java index b21644c7d82f..cbc9497f93f7 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/signing/DefaultS3Presigner.java @@ -33,8 +33,6 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import java.util.function.Supplier; -import java.util.stream.Collectors; import java.util.stream.Stream; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute; @@ -45,6 +43,7 @@ import software.amazon.awssdk.awscore.endpoint.AwsClientEndpointProvider; import software.amazon.awssdk.awscore.internal.AwsExecutionContextBuilder; import software.amazon.awssdk.awscore.internal.defaultsmode.DefaultsModeConfiguration; +import software.amazon.awssdk.awscore.internal.identity.AwsIdentityProviderUpdater; import software.amazon.awssdk.awscore.presigner.PresignRequest; import software.amazon.awssdk.awscore.presigner.PresignedRequest; import software.amazon.awssdk.core.ClientType; @@ -244,7 +243,6 @@ private List initializeInterceptors() { List s3Interceptors = interceptorFactory.getInterceptors("software/amazon/awssdk/services/s3/execution.interceptors"); List additionalInterceptors = new ArrayList<>(); - // S3AuthSchemeInterceptor removed - auth scheme resolution now done inline additionalInterceptors.add(new S3ResolveEndpointInterceptor()); additionalInterceptors.add(new S3RequestSetEndpointInterceptor()); s3Interceptors = mergeLists(s3Interceptors, additionalInterceptors); @@ -618,7 +616,7 @@ private void resolveAndSelectAuthScheme(ExecutionContext execCtx, SdkRequest req } IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); - identityProviders = applyCredentialOverrides(request, identityProviders); + identityProviders = AwsIdentityProviderUpdater.INSTANCE.update(request, identityProviders); SelectedAuthScheme selectedAuthScheme = AuthSchemeResolver.selectAuthScheme(authOptions, authSchemes, identityProviders, null); @@ -676,23 +674,6 @@ private List resolveAuthSchemeOptions(SdkRequest request, Stri return authSchemeProvider.resolveAuthScheme(authParamsBuilder.build()); } - /** - * Apply credential overrides from request configuration. - */ - private IdentityProviders applyCredentialOverrides(SdkRequest request, IdentityProviders base) { - if (base == null) { - return null; - } - return request.overrideConfiguration() - .filter(c -> c instanceof AwsRequestOverrideConfiguration) - .map(c -> (AwsRequestOverrideConfiguration) c) - .map(c -> base.copy(b -> { - c.credentialsIdentityProvider().ifPresent(b::putIdentityProvider); - c.tokenIdentityProvider().ifPresent(b::putIdentityProvider); - })) - .orElse(base); - } - /** * Presign the provided HTTP request using old Signer */ From eb989a402652dccbdebe7272803e0d237efb4a55 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Mon, 2 Mar 2026 14:34:04 -0800 Subject: [PATCH 5/6] Fix checkstyle --- .../amazon/awssdk/codegen/poet/client/AsyncClientClass.java | 2 -- .../amazon/awssdk/codegen/poet/client/SyncClientClass.java | 2 -- 2 files changed, 4 deletions(-) diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index b3bb15e9d257..fc31f3910476 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -76,7 +76,6 @@ import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; -import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils; import software.amazon.awssdk.core.async.SdkPublisher; @@ -87,7 +86,6 @@ import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java index 605c55f29458..6d9bc4e0ef4f 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/SyncClientClass.java @@ -63,14 +63,12 @@ import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils; import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils; import software.amazon.awssdk.core.RequestOverrideConfiguration; -import software.amazon.awssdk.core.SdkRequest; import software.amazon.awssdk.core.client.config.SdkClientConfiguration; import software.amazon.awssdk.core.client.config.SdkClientOption; import software.amazon.awssdk.core.client.handler.SyncClientHandler; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRefreshCache; import software.amazon.awssdk.core.endpointdiscovery.EndpointDiscoveryRequest; import software.amazon.awssdk.core.metrics.CoreMetric; -import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.metrics.MetricCollector; import software.amazon.awssdk.metrics.MetricPublisher; From c3d031cf0c74c148e6850579b7c4b6697271bb86 Mon Sep 17 00:00:00 2001 From: Saranya Somepalli Date: Mon, 2 Mar 2026 15:40:20 -0800 Subject: [PATCH 6/6] Add tests --- .../stages/AuthSchemeResolutionStageTest.java | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java new file mode 100644 index 000000000000..c8307eb7fda1 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/AuthSchemeResolutionStageTest.java @@ -0,0 +1,231 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.http.pipeline.stages; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.http.ExecutionContext; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.InterceptorContext; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.http.RequestExecutionContext; +import software.amazon.awssdk.core.spi.identity.AuthSchemeOptionsResolver; +import software.amazon.awssdk.core.spi.identity.IdentityProviderUpdater; +import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; + +class AuthSchemeResolutionStageTest { + + private static final String SCHEME_ID = "test.scheme"; + + private AuthSchemeResolutionStage stage; + private SdkHttpFullRequest httpRequest; + private RequestExecutionContext context; + private ExecutionAttributes executionAttributes; + private SdkRequest sdkRequest; + + @BeforeEach + void setup() { + stage = new AuthSchemeResolutionStage(null); + httpRequest = mock(SdkHttpFullRequest.class); + sdkRequest = mock(SdkRequest.class); + executionAttributes = new ExecutionAttributes(); + + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(sdkRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) + .build(); + } + + @Test + void execute_noAuthSchemes_returnsRequestUnchanged() throws Exception { + // AUTH_SCHEMES is null + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_noResolver_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + // AUTH_SCHEME_OPTIONS_RESOLVER is null + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReturnsEmpty_returnsRequestUnchanged() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> Collections.emptyList()); + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNull(); + } + + @Test + void execute_resolverReceivesRequestFromInterceptorContext() throws Exception { + SdkRequest modifiedRequest = mock(SdkRequest.class); + + // Setup interceptor context with a DIFFERENT request than originalRequest + InterceptorContext interceptorContext = InterceptorContext.builder() + .request(modifiedRequest) + .build(); + ExecutionContext executionContext = ExecutionContext.builder() + .interceptorContext(interceptorContext) + .executionAttributes(executionAttributes) + .build(); + context = RequestExecutionContext.builder() + .executionContext(executionContext) + .originalRequest(sdkRequest) // Different from modifiedRequest + .build(); + + AuthSchemeOptionsResolver resolver = mock(AuthSchemeOptionsResolver.class); + doReturn(createAuthOptions()).when(resolver).resolve(modifiedRequest); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, resolver); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + stage.execute(httpRequest, context); + + // Verify resolver was called with the MODIFIED request, not originalRequest + verify(resolver).resolve(modifiedRequest); + } + + @Test + void execute_withIdentityProviderUpdater_callsUpdaterWithRequest() throws Exception { + // Create mocks first before any stubbing + IdentityProvider identityProvider = createMockIdentityProvider(); + Map> authSchemes = createAuthSchemes(); + IdentityProviders baseProviders = mock(IdentityProviders.class); + IdentityProviders updatedProviders = mock(IdentityProviders.class); + + IdentityProviderUpdater updater = mock(IdentityProviderUpdater.class); + doReturn(updatedProviders).when(updater).update(sdkRequest, baseProviders); + + // Setup so that auth scheme uses the updated providers + @SuppressWarnings("unchecked") + AuthScheme scheme = (AuthScheme) authSchemes.get(SCHEME_ID); + doReturn(identityProvider).when(scheme).identityProvider(updatedProviders); + + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, authSchemes); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, baseProviders); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDER_UPDATER, updater); + + stage.execute(httpRequest, context); + + verify(updater).update(sdkRequest, baseProviders); + } + + @Test + void execute_withoutIdentityProviderUpdater_doesNotFail() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + // No IDENTITY_PROVIDER_UPDATER set + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + assertThat(executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME)).isNotNull(); + } + + @Test + void execute_happyPath_setsSelectedAuthScheme() throws Exception { + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, createAuthSchemes()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_OPTIONS_RESOLVER, + (AuthSchemeOptionsResolver) req -> createAuthOptions()); + executionAttributes.putAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS, createIdentityProviders()); + + SdkHttpFullRequest result = stage.execute(httpRequest, context); + + assertThat(result).isSameAs(httpRequest); + SelectedAuthScheme selectedAuthScheme = + executionAttributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + assertThat(selectedAuthScheme).isNotNull(); + assertThat(selectedAuthScheme.authSchemeOption().schemeId()).isEqualTo(SCHEME_ID); + } + + @SuppressWarnings("unchecked") + private Map> createAuthSchemes() { + IdentityProvider identityProvider = createMockIdentityProvider(); + HttpSigner signer = mock(HttpSigner.class); + + AuthScheme scheme = mock(AuthScheme.class); + doReturn(identityProvider).when(scheme).identityProvider(any()); + doReturn(signer).when(scheme).signer(); + + Map> schemes = new HashMap<>(); + schemes.put(SCHEME_ID, scheme); + return schemes; + } + + @SuppressWarnings("unchecked") + private IdentityProvider createMockIdentityProvider() { + IdentityProvider provider = mock(IdentityProvider.class); + Identity mockIdentity = mock(Identity.class); + doReturn(CompletableFuture.completedFuture(mockIdentity)) + .when(provider).resolveIdentity(any(ResolveIdentityRequest.class)); + return provider; + } + + private IdentityProviders createIdentityProviders() { + IdentityProviders providers = mock(IdentityProviders.class); + return providers; + } + + private List createAuthOptions() { + return Collections.singletonList( + AuthSchemeOption.builder().schemeId(SCHEME_ID).build() + ); + } +}