diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java index b05884305dc..e5985f744e7 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryMetricsModule.java @@ -159,14 +159,6 @@ static String recordMethodName(String fullMethodName, boolean isGeneratedMethod) return isGeneratedMethod ? fullMethodName : "other"; } - private static Context otelContextWithBaggage() { - Baggage baggage = BAGGAGE_KEY.get(); - if (baggage == null) { - return Context.current(); - } - return Context.current().with(baggage); - } - private static final class ClientTracer extends ClientStreamTracer { @Nullable private static final AtomicLongFieldUpdater outboundWireSizeUpdater; @Nullable private static final AtomicLongFieldUpdater inboundWireSizeUpdater; @@ -282,7 +274,6 @@ public void streamClosed(Status status) { } void recordFinishedAttempt() { - Context otelContext = otelContextWithBaggage(); AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder() .put(METHOD_KEY, fullMethodName) .put(TARGET_KEY, target) @@ -308,15 +299,15 @@ void recordFinishedAttempt() { if (module.resource.clientAttemptDurationCounter() != null ) { module.resource.clientAttemptDurationCounter() - .record(attemptNanos * SECONDS_PER_NANO, attribute, otelContext); + .record(attemptNanos * SECONDS_PER_NANO, attribute, attemptsState.otelContext); } if (module.resource.clientTotalSentCompressedMessageSizeCounter() != null) { module.resource.clientTotalSentCompressedMessageSizeCounter() - .record(outboundWireSize, attribute, otelContext); + .record(outboundWireSize, attribute, attemptsState.otelContext); } if (module.resource.clientTotalReceivedCompressedMessageSizeCounter() != null) { module.resource.clientTotalReceivedCompressedMessageSizeCounter() - .record(inboundWireSize, attribute, otelContext); + .record(inboundWireSize, attribute, attemptsState.otelContext); } } } @@ -331,6 +322,7 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory private boolean callEnded; private final String fullMethodName; private final List callPlugins; + private final Context otelContext; private Status status; private long retryDelayNanos; private long callLatencyNanos; @@ -343,15 +335,26 @@ static final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory @GuardedBy("lock") private boolean finishedCallToBeRecorded; + // TODO: Let tests continue compiling. Probably a hack that we want to remove. CallAttemptsTracerFactory( OpenTelemetryMetricsModule module, String target, String fullMethodName, List callPlugins) { + this(module, target, fullMethodName, callPlugins, Context.current()); + } + + CallAttemptsTracerFactory( + OpenTelemetryMetricsModule module, + String target, + String fullMethodName, + List callPlugins, + Context otelContext) { this.module = checkNotNull(module, "module"); this.target = checkNotNull(target, "target"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); this.callPlugins = checkNotNull(callPlugins, "callPlugins"); + this.otelContext = checkNotNull(otelContext, "otelContext"); this.attemptDelayStopwatch = module.stopwatchSupplier.get(); this.callStopWatch = module.stopwatchSupplier.get().start(); @@ -448,7 +451,6 @@ void callEnded(Status status) { } void recordFinishedCall() { - Context otelContext = otelContextWithBaggage(); if (attemptsPerCall.get() == 0) { ClientTracer tracer = newClientTracer(null); tracer.attemptNanos = attemptDelayStopwatch.elapsed(TimeUnit.NANOSECONDS); @@ -548,6 +550,7 @@ private static final class ServerTracer extends ServerStreamTracer { private final OpenTelemetryMetricsModule module; private final String fullMethodName; private final List streamPlugins; + private Context otelContext = Context.root(); private volatile boolean isGeneratedMethod; private volatile int streamClosed; private final Stopwatch stopwatch; @@ -562,6 +565,16 @@ private static final class ServerTracer extends ServerStreamTracer { this.stopwatch = module.stopwatchSupplier.get().start(); } + @Override + public io.grpc.Context filterContext(io.grpc.Context context) { + Baggage baggage = BAGGAGE_KEY.get(context); + if (baggage == null) { + throw new IllegalStateException("Baggage from OpenTelemetryTracingModule is missing"); + } + otelContext = Context.current().with(baggage); + return context; + } + @Override public void serverCallStarted(ServerCallInfo callInfo) { // Only record method name as an attribute if isSampledToLocalTracing is set to true, @@ -606,7 +619,6 @@ public void inboundWireSize(long bytes) { */ @Override public void streamClosed(Status status) { - Context otelContext = otelContextWithBaggage(); if (streamClosedUpdater != null) { if (streamClosedUpdater.getAndSet(this, 1) != 0) { return; @@ -694,7 +706,8 @@ public ClientCall interceptCall( final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( OpenTelemetryMetricsModule.this, target, recordMethodName(method.getFullMethodName(), method.isSampledToLocalTracing()), - callPlugins); + callPlugins, + Context.current()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java index 391f94cefea..91df8ffb33f 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryMetricsModuleTest.java @@ -1647,14 +1647,16 @@ public void serverBaggagePropagationToMetrics() { io.grpc.Context grpcContext = io.grpc.Context.current() .withValue(OpenTelemetryConstants.BAGGAGE_KEY, testBaggage); - // 3. Attach the gRPC context, trigger metric recording, and detach io.grpc.Context previousContext = grpcContext.attach(); try { - tracer.streamClosed(Status.OK); + tracer.filterContext(grpcContext); } finally { grpcContext.detach(previousContext); } + // 3. Trigger metric recording + tracer.streamClosed(Status.OK); + // 4. Verify the record call and capture the OTel Context verify(mockServerCallDurationHistogram).record( anyDouble(),