diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 35bf3cc96..691a0e51f 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -37,16 +37,20 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.CompletableSource; import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableTransformer; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.MaybeObserver; import io.reactivex.rxjava3.core.MaybeSource; import io.reactivex.rxjava3.core.MaybeTransformer; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; +import io.reactivex.rxjava3.disposables.Disposable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -58,6 +62,8 @@ import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -550,4 +556,179 @@ public CompletableSource apply(Completable upstream) { }); } } + + /** + * Returns a transformer that re-activates a given context for the duration of the stream's + * subscription. + * + * @param context The context to re-activate. + * @param The type of the stream. + * @return A transformer that re-activates the context. + */ + public static ContextTransformer withContext(Context context) { + return new ContextTransformer<>(context); + } + + /** + * A transformer that re-activates a given context for the duration of the stream's subscription. + * + * @param The type of the stream. + */ + public static final class ContextTransformer + implements FlowableTransformer, + SingleTransformer, + MaybeTransformer, + CompletableTransformer { + private final Context context; + + private ContextTransformer(Context context) { + this.context = context; + } + + @Override + public Publisher apply(Flowable upstream) { + return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber)); + } + + @Override + public SingleSource apply(Single upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public MaybeSource apply(Maybe upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + + @Override + public CompletableSource apply(Completable upstream) { + return upstream.lift(observer -> TracingObserver.wrap(context, observer)); + } + } + + /** + * An observer that wraps another observer and ensures that the OpenTelemetry context is active + * during all callback methods. + * + * @param The type of the items emitted by the stream. + */ + private static final class TracingObserver + implements Subscriber, SingleObserver, MaybeObserver, CompletableObserver { + private final Context context; + private final Subscriber subscriber; + private final SingleObserver singleObserver; + private final MaybeObserver maybeObserver; + private final CompletableObserver completableObserver; + + private TracingObserver( + Context context, + Subscriber subscriber, + SingleObserver singleObserver, + MaybeObserver maybeObserver, + CompletableObserver completableObserver) { + this.context = context; + this.subscriber = subscriber; + this.singleObserver = singleObserver; + this.maybeObserver = maybeObserver; + this.completableObserver = completableObserver; + } + + static TracingObserver wrap(Context context, Subscriber subscriber) { + return new TracingObserver<>(context, subscriber, null, null, null); + } + + static TracingObserver wrap(Context context, SingleObserver observer) { + return new TracingObserver<>(context, null, observer, null, null); + } + + static TracingObserver wrap(Context context, MaybeObserver observer) { + return new TracingObserver<>(context, null, null, observer, null); + } + + static TracingObserver wrap(Context context, CompletableObserver observer) { + return new TracingObserver<>(context, null, null, null, observer); + } + + private void runInContext(Runnable action) { + try (Scope scope = context.makeCurrent()) { + action.run(); + } + } + + @Override + public void onSubscribe(Subscription s) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onSubscribe(s); + } + }); + } + + @Override + public void onSubscribe(Disposable d) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSubscribe(d); + } else if (maybeObserver != null) { + maybeObserver.onSubscribe(d); + } else if (completableObserver != null) { + completableObserver.onSubscribe(d); + } + }); + } + + @Override + public void onNext(T t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onNext(t); + } + }); + } + + @Override + public void onSuccess(T t) { + runInContext( + () -> { + if (singleObserver != null) { + singleObserver.onSuccess(t); + } else if (maybeObserver != null) { + maybeObserver.onSuccess(t); + } + }); + } + + @Override + public void onError(Throwable t) { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onError(t); + } else if (singleObserver != null) { + singleObserver.onError(t); + } else if (maybeObserver != null) { + maybeObserver.onError(t); + } else if (completableObserver != null) { + completableObserver.onError(t); + } + }); + } + + @Override + public void onComplete() { + runInContext( + () -> { + if (subscriber != null) { + subscriber.onComplete(); + } else if (maybeObserver != null) { + maybeObserver.onComplete(); + } else if (completableObserver != null) { + completableObserver.onComplete(); + } + }); + } + } } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index f809193cf..e5795d61f 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -31,6 +31,7 @@ import com.google.adk.runner.Runner; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.genai.types.Content; @@ -44,12 +45,17 @@ import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; +import io.opentelemetry.context.ContextKey; import io.opentelemetry.context.Scope; import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; import io.opentelemetry.sdk.trace.data.SpanData; +import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.List; +import java.util.Map; import java.util.Optional; import org.junit.After; import org.junit.Before; @@ -380,6 +386,70 @@ public void testTraceFlowable() throws InterruptedException { assertTrue(flowableSpanData.hasEnded()); } + @Test + public void testWithContextFlowable() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Flowable flowable = + Flowable.just(1, 2, 3) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnNext( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + flowable.test().await().assertComplete(); + } + + @Test + public void testWithContextSingle() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Single single = + Single.just(1) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnSuccess( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + single.test().await().assertComplete(); + } + + @Test + public void testWithContextMaybe() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Maybe maybe = + Maybe.just(1) + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnSuccess( + i -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + maybe.test().await().assertComplete(); + } + + @Test + public void testWithContextCompletable() throws InterruptedException { + ContextKey testKey = ContextKey.named("test-key"); + Context testContext = Context.root().with(testKey, "test-value"); + + Completable completable = + Completable.complete() + .compose(Tracing.withContext(testContext)) + .subscribeOn(Schedulers.computation()) + .doOnComplete( + () -> { + assertEquals("test-value", Context.current().get(testKey)); + }); + completable.test().await().assertComplete(); + } + @Test public void testTraceTransformer() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -595,7 +665,7 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException { Session session = runner .sessionService() - .createSession("test-app", "test-user", null, "test-session") + .createSession(new SessionKey("test-app", "test-user", "test-session")) .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); @@ -623,13 +693,20 @@ public void runnerRunLive_propagatesContext() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); try (Scope s = parentSpan.makeCurrent()) { Session session = - Session.builder("test-session").userId("test-user").appName("test-app").build(); + runner + .sessionService() + .createSession("test-app", "test-user", (Map) null, "test-session") + .blockingGet(); Content newMessage = Content.fromParts(Part.fromText("hi")); RunConfig runConfig = RunConfig.builder().build(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); liveRequestQueue.content(newMessage); liveRequestQueue.close(); - runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete(); + runner + .runLive(session.userId(), session.id(), liveRequestQueue, runConfig) + .test() + .await() + .assertComplete(); } finally { parentSpan.end(); }