Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ protected List<GeneratorTask> createTasks() {
tasks.add(generateDefaultParamsImpl());
tasks.add(generateModelBasedProvider());
tasks.add(generatePreferenceProvider());
tasks.add(generateAuthSchemeInterceptor());
if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) {
tasks.add(generateEndpointBasedProvider());
tasks.add(generateEndpointAwareAuthSchemeParams());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ private MethodSpec finalizeServiceConfigurationMethod() {

List<ClassName> builtInInterceptors = new ArrayList<>();

builtInInterceptors.add(authSchemeSpecUtils.authSchemeInterceptor());
builtInInterceptors.add(endpointRulesSpecUtils.resolverInterceptorName());
builtInInterceptors.add(endpointRulesSpecUtils.requestModifierInterceptorName());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.StaticImport;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.async.AsyncResponseTransformerUtils;
Expand Down Expand Up @@ -100,6 +102,8 @@ public final class AsyncClientClass extends AsyncClientInterface {
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final AuthSchemeSpecUtils authSchemeSpecUtils;
private final EndpointRulesSpecUtils endpointRulesSpecUtils;
private boolean hasScheduledExecutor;

public AsyncClientClass(GeneratorTaskParams dependencies) {
Expand All @@ -110,6 +114,8 @@ public AsyncClientClass(GeneratorTaskParams dependencies) {
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model);
this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model);
}

@Override
Expand Down Expand Up @@ -165,7 +171,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
.addMethod(nameMethod())
.addMethods(protocolSpec.additionalMethods())
.addMethod(protocolSpec.initProtocolFactory(model))
.addMethod(resolveMetricPublishersMethod());
.addMethod(resolveMetricPublishersMethod())
.addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils));

type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod());
type.addMethod(updateSdkClientConfigurationMethod(configurationUtils.serviceClientConfigurationBuilderClassName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.lang.model.element.Modifier;
Expand All @@ -45,6 +46,7 @@
import software.amazon.awssdk.codegen.model.service.HostPrefixProcessor;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
Expand All @@ -53,6 +55,7 @@
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.core.retry.RetryMode;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
import software.amazon.awssdk.retries.api.RetryStrategy;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.CollectionUtils;
Expand Down Expand Up @@ -346,4 +349,120 @@ public static MethodSpec updateRetryStrategyClientConfigurationMethod() {
static String transformServiceId(String serviceId) {
return serviceId.replace(" ", "_");
}

static MethodSpec resolveAuthSchemeOptionsMethod(AuthSchemeSpecUtils authSchemeSpecUtils,
EndpointRulesSpecUtils endpointRulesSpecUtils) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("resolveAuthSchemeOptions")
.addModifiers(PRIVATE)
.returns(ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(AuthSchemeOption.class)))
.addParameter(SdkRequest.class, "request")
.addParameter(String.class, "operationName")
.addParameter(SdkClientConfiguration.class, "clientConfiguration");

ClassName providerInterface = authSchemeSpecUtils.providerInterfaceName();

builder.addStatement("$T authSchemeProvider = ($T) clientConfiguration.option($T.AUTH_SCHEME_PROVIDER)",
providerInterface, providerInterface, SdkClientOption.class);

if (authSchemeSpecUtils.useEndpointBasedAuthProvider()) {
addEndpointBasedAuthSchemeResolution(builder, authSchemeSpecUtils, endpointRulesSpecUtils);
} else {
addSimpleAuthSchemeResolution(builder, authSchemeSpecUtils);
}

return builder.build();
}

private static void addSimpleAuthSchemeResolution(MethodSpec.Builder builder,
AuthSchemeSpecUtils authSchemeSpecUtils) {
ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName();
ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption");

builder.addStatement("$T.Builder paramsBuilder = $T.builder().operation(operationName)",
paramsInterface, paramsInterface);

if (authSchemeSpecUtils.usesSigV4()) {
builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption);
}

if (authSchemeSpecUtils.hasSigV4aSupport()) {
ClassName regionSet = ClassName.get("software.amazon.awssdk.http.auth.aws.signer", "RegionSet");
builder.addStatement("$T<String> sigv4aRegionSet = clientConfiguration.option($T.AWS_SIGV4A_SIGNING_REGION_SET)",
ClassName.get(Set.class), awsClientOption);
builder.beginControlFlow("if (!$T.isNullOrEmpty(sigv4aRegionSet))", CollectionUtils.class);
builder.addStatement("paramsBuilder.regionSet($T.create(sigv4aRegionSet))", regionSet);
builder.nextControlFlow("else");
builder.addStatement("paramsBuilder.regionSet($T.create(clientConfiguration.option($T.AWS_REGION).id()))",
regionSet, awsClientOption);
builder.endControlFlow();
}

builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())");
}

private static void addEndpointBasedAuthSchemeResolution(MethodSpec.Builder builder,
AuthSchemeSpecUtils authSchemeSpecUtils,
EndpointRulesSpecUtils endpointRulesSpecUtils) {
ClassName paramsInterface = authSchemeSpecUtils.parametersInterfaceName();
ClassName awsClientOption = ClassName.get("software.amazon.awssdk.awscore.client.config", "AwsClientOption");
ClassName endpointParamsClass = endpointRulesSpecUtils.parametersClassName();
ClassName resolverInterceptor = endpointRulesSpecUtils.resolverInterceptorName();
ClassName executionAttributesClass = ClassName.get("software.amazon.awssdk.core.interceptor", "ExecutionAttributes");
ClassName awsExecutionAttribute = ClassName.get("software.amazon.awssdk.awscore", "AwsExecutionAttribute");
ClassName sdkExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor", "SdkExecutionAttribute");
ClassName sdkInternalExecutionAttribute = ClassName.get("software.amazon.awssdk.core.interceptor",
"SdkInternalExecutionAttribute");

builder.addStatement("$T executionAttributes = new $T()", executionAttributesClass, executionAttributesClass);
builder.addStatement("executionAttributes.putAttribute($T.AWS_REGION, clientConfiguration.option($T.AWS_REGION))",
awsExecutionAttribute, awsClientOption);
builder.addStatement("executionAttributes.putAttribute($T.DUALSTACK_ENDPOINT_ENABLED, "
+ "clientConfiguration.option($T.DUALSTACK_ENDPOINT_ENABLED))",
awsExecutionAttribute, awsClientOption);
builder.addStatement("executionAttributes.putAttribute($T.FIPS_ENDPOINT_ENABLED, "
+ "clientConfiguration.option($T.FIPS_ENDPOINT_ENABLED))",
awsExecutionAttribute, awsClientOption);
builder.addStatement("executionAttributes.putAttribute($T.OPERATION_NAME, operationName)", sdkExecutionAttribute);
builder.addStatement("executionAttributes.putAttribute($T.CLIENT_ENDPOINT_PROVIDER, "
+ "clientConfiguration.option($T.CLIENT_ENDPOINT_PROVIDER))",
sdkInternalExecutionAttribute, SdkClientOption.class);
builder.addStatement("executionAttributes.putAttribute($T.CLIENT_CONTEXT_PARAMS, "
+ "clientConfiguration.option($T.CLIENT_CONTEXT_PARAMS))",
sdkInternalExecutionAttribute, SdkClientOption.class);

builder.addStatement("$T endpointParams = $T.ruleParams(request, executionAttributes)",
endpointParamsClass, resolverInterceptor);

builder.addStatement("$T.Builder paramsBuilder = $T.builder()", paramsInterface, paramsInterface);

boolean regionIncluded = false;
for (String paramName : endpointRulesSpecUtils.parameters().keySet()) {
if (!authSchemeSpecUtils.includeParamForProvider(paramName)) {
continue;
}
regionIncluded = regionIncluded || paramName.equalsIgnoreCase("region");
String methodName = endpointRulesSpecUtils.paramMethodName(paramName);
builder.addStatement("paramsBuilder.$1N(endpointParams.$1N())", methodName);
}

builder.addStatement("paramsBuilder.operation(operationName)");

if (authSchemeSpecUtils.usesSigV4() && !regionIncluded) {
builder.addStatement("paramsBuilder.region(clientConfiguration.option($T.AWS_REGION))", awsClientOption);
}

ClassName paramsBuilderClass = authSchemeSpecUtils.parametersEndpointAwareDefaultImplName().nestedClass("Builder");
ClassName endpointProviderInterface = endpointRulesSpecUtils.providerInterfaceName();

builder.beginControlFlow("if (paramsBuilder instanceof $T)", paramsBuilderClass);
builder.addStatement("$T endpointProvider = clientConfiguration.option($T.ENDPOINT_PROVIDER)",
ClassName.get("software.amazon.awssdk.endpoints", "EndpointProvider"), SdkClientOption.class);
builder.beginControlFlow("if (endpointProvider instanceof $T)", endpointProviderInterface);
builder.addStatement("(($T) paramsBuilder).endpointProvider(($T) endpointProvider)",
paramsBuilderClass, endpointProviderInterface);
builder.endControlFlow();
builder.endControlFlow();

builder.addStatement("return authSchemeProvider.resolveAuthScheme(paramsBuilder.build())");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@
import software.amazon.awssdk.codegen.model.service.PreClientExecutionRequestCustomizer;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.client.specs.Ec2ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.JsonProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.ProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
Expand All @@ -83,6 +85,8 @@ public class SyncClientClass extends SyncClientInterface {
private final ProtocolSpec protocolSpec;
private final ClassName serviceClientConfigurationClassName;
private final ServiceClientConfigurationUtils configurationUtils;
private final AuthSchemeSpecUtils authSchemeSpecUtils;
private final EndpointRulesSpecUtils endpointRulesSpecUtils;

public SyncClientClass(GeneratorTaskParams taskParams) {
super(taskParams.getModel());
Expand All @@ -92,6 +96,8 @@ public SyncClientClass(GeneratorTaskParams taskParams) {
this.protocolSpec = getProtocolSpecs(poetExtensions, model);
this.serviceClientConfigurationClassName = new PoetExtension(model).getServiceConfigClass();
this.configurationUtils = new ServiceClientConfigurationUtils(model);
this.authSchemeSpecUtils = new AuthSchemeSpecUtils(model);
this.endpointRulesSpecUtils = new EndpointRulesSpecUtils(model);
}

@Override
Expand Down Expand Up @@ -133,7 +139,8 @@ protected void addAdditionalMethods(TypeSpec.Builder type) {
type.addMethod(constructor())
.addMethod(nameMethod())
.addMethods(protocolSpec.additionalMethods())
.addMethod(resolveMetricPublishersMethod());
.addMethod(resolveMetricPublishersMethod())
.addMethod(ClientClassUtils.resolveAuthSchemeOptionsMethod(authSchemeSpecUtils, endpointRulesSpecUtils));

protocolSpec.createErrorResponseHandler().ifPresent(type::addMethod);
type.addMethod(ClientClassUtils.updateRetryStrategyClientConfigurationMethod());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)\n", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));

Expand Down Expand Up @@ -295,6 +297,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(".withErrorResponseHandler(errorResponseHandler)\n")
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(hostPrefixExpression(opModel))
.add(discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public CodeBlock executionHandler(OperationModel opModel) {
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));

Expand Down Expand Up @@ -155,6 +157,8 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
.add(credentialType(opModel, intermediateModel))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ public CodeBlock executionHandler(OperationModel opModel) {
discoveredEndpoint(opModel))
.add(credentialType(opModel, model))
.add(".withRequestConfiguration(clientConfiguration)")
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withInput($L)", opModel.getInput().getVariableName())
.add(".withMetricCollector(apiCallMetricCollector)")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));

Expand Down Expand Up @@ -212,7 +215,9 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper

builder.add(hostPrefixExpression(opModel))
.add(credentialType(opModel, model))
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(".withMetricCollector(apiCallMetricCollector)\n")
.add(".withAuthSchemeOptionsResolver(r -> resolveAuthSchemeOptions(r, $S, clientConfiguration))\n",
opModel.getOperationName())
.add(asyncRequestBody(opModel))
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
.add(HttpChecksumTrait.create(opModel));
Expand Down
Loading
Loading