Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions core/src/main/java/com/google/adk/telemetry/Tracing.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,20 @@
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.CompletableObserver;
import io.reactivex.rxjava3.core.CompletableSource;
import io.reactivex.rxjava3.core.CompletableTransformer;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.FlowableTransformer;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.MaybeObserver;
import io.reactivex.rxjava3.core.MaybeSource;
import io.reactivex.rxjava3.core.MaybeTransformer;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleObserver;
import io.reactivex.rxjava3.core.SingleSource;
import io.reactivex.rxjava3.core.SingleTransformer;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -58,6 +62,8 @@
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -550,4 +556,179 @@ public CompletableSource apply(Completable upstream) {
});
}
}

/**
* Returns a transformer that re-activates a given context for the duration of the stream's
* subscription.
*
* @param context The context to re-activate.
* @param <T> The type of the stream.
* @return A transformer that re-activates the context.
*/
public static <T> ContextTransformer<T> withContext(Context context) {
return new ContextTransformer<>(context);
}

/**
* A transformer that re-activates a given context for the duration of the stream's subscription.
*
* @param <T> The type of the stream.
*/
public static final class ContextTransformer<T>
implements FlowableTransformer<T, T>,
SingleTransformer<T, T>,
MaybeTransformer<T, T>,
CompletableTransformer {
private final Context context;

private ContextTransformer(Context context) {
this.context = context;
}

@Override
public Publisher<T> apply(Flowable<T> upstream) {
return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber));
}

@Override
public SingleSource<T> apply(Single<T> upstream) {
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
}

@Override
public MaybeSource<T> apply(Maybe<T> upstream) {
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
}

@Override
public CompletableSource apply(Completable upstream) {
return upstream.lift(observer -> TracingObserver.wrap(context, observer));
}
}

/**
* An observer that wraps another observer and ensures that the OpenTelemetry context is active
* during all callback methods.
*
* @param <T> The type of the items emitted by the stream.
*/
private static final class TracingObserver<T>
implements Subscriber<T>, SingleObserver<T>, MaybeObserver<T>, CompletableObserver {
private final Context context;
private final Subscriber<? super T> subscriber;
private final SingleObserver<? super T> singleObserver;
private final MaybeObserver<? super T> maybeObserver;
private final CompletableObserver completableObserver;

private TracingObserver(
Context context,
Subscriber<? super T> subscriber,
SingleObserver<? super T> singleObserver,
MaybeObserver<? super T> maybeObserver,
CompletableObserver completableObserver) {
this.context = context;
this.subscriber = subscriber;
this.singleObserver = singleObserver;
this.maybeObserver = maybeObserver;
this.completableObserver = completableObserver;
}

static <T> TracingObserver<T> wrap(Context context, Subscriber<? super T> subscriber) {
return new TracingObserver<>(context, subscriber, null, null, null);
}

static <T> TracingObserver<T> wrap(Context context, SingleObserver<? super T> observer) {
return new TracingObserver<>(context, null, observer, null, null);
}

static <T> TracingObserver<T> wrap(Context context, MaybeObserver<? super T> observer) {
return new TracingObserver<>(context, null, null, observer, null);
}

static <T> TracingObserver<T> wrap(Context context, CompletableObserver observer) {
return new TracingObserver<>(context, null, null, null, observer);
}

private void runInContext(Runnable action) {
try (Scope scope = context.makeCurrent()) {
action.run();
}
}

@Override
public void onSubscribe(Subscription s) {
runInContext(
() -> {
if (subscriber != null) {
subscriber.onSubscribe(s);
}
});
}

@Override
public void onSubscribe(Disposable d) {
runInContext(
() -> {
if (singleObserver != null) {
singleObserver.onSubscribe(d);
} else if (maybeObserver != null) {
maybeObserver.onSubscribe(d);
} else if (completableObserver != null) {
completableObserver.onSubscribe(d);
}
});
}

@Override
public void onNext(T t) {
runInContext(
() -> {
if (subscriber != null) {
subscriber.onNext(t);
}
});
}

@Override
public void onSuccess(T t) {
runInContext(
() -> {
if (singleObserver != null) {
singleObserver.onSuccess(t);
} else if (maybeObserver != null) {
maybeObserver.onSuccess(t);
}
});
}

@Override
public void onError(Throwable t) {
runInContext(
() -> {
if (subscriber != null) {
subscriber.onError(t);
} else if (singleObserver != null) {
singleObserver.onError(t);
} else if (maybeObserver != null) {
maybeObserver.onError(t);
} else if (completableObserver != null) {
completableObserver.onError(t);
}
});
}

@Override
public void onComplete() {
runInContext(
() -> {
if (subscriber != null) {
subscriber.onComplete();
} else if (maybeObserver != null) {
maybeObserver.onComplete();
} else if (completableObserver != null) {
completableObserver.onComplete();
}
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.adk.runner.Runner;
import com.google.adk.sessions.InMemorySessionService;
import com.google.adk.sessions.Session;
import com.google.adk.sessions.SessionKey;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.Content;
Expand All @@ -44,12 +45,17 @@
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.ContextKey;
import io.opentelemetry.context.Scope;
import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule;
import io.opentelemetry.sdk.trace.data.SpanData;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -380,6 +386,70 @@ public void testTraceFlowable() throws InterruptedException {
assertTrue(flowableSpanData.hasEnded());
}

@Test
public void testWithContextFlowable() throws InterruptedException {
ContextKey<String> testKey = ContextKey.named("test-key");
Context testContext = Context.root().with(testKey, "test-value");

Flowable<Integer> flowable =
Flowable.just(1, 2, 3)
.compose(Tracing.withContext(testContext))
.subscribeOn(Schedulers.computation())
.doOnNext(
i -> {
assertEquals("test-value", Context.current().get(testKey));
});
flowable.test().await().assertComplete();
}

@Test
public void testWithContextSingle() throws InterruptedException {
ContextKey<String> testKey = ContextKey.named("test-key");
Context testContext = Context.root().with(testKey, "test-value");

Single<Integer> single =
Single.just(1)
.compose(Tracing.withContext(testContext))
.subscribeOn(Schedulers.computation())
.doOnSuccess(
i -> {
assertEquals("test-value", Context.current().get(testKey));
});
single.test().await().assertComplete();
}

@Test
public void testWithContextMaybe() throws InterruptedException {
ContextKey<String> testKey = ContextKey.named("test-key");
Context testContext = Context.root().with(testKey, "test-value");

Maybe<Integer> maybe =
Maybe.just(1)
.compose(Tracing.withContext(testContext))
.subscribeOn(Schedulers.computation())
.doOnSuccess(
i -> {
assertEquals("test-value", Context.current().get(testKey));
});
maybe.test().await().assertComplete();
}

@Test
public void testWithContextCompletable() throws InterruptedException {
ContextKey<String> testKey = ContextKey.named("test-key");
Context testContext = Context.root().with(testKey, "test-value");

Completable completable =
Completable.complete()
.compose(Tracing.withContext(testContext))
.subscribeOn(Schedulers.computation())
.doOnComplete(
() -> {
assertEquals("test-value", Context.current().get(testKey));
});
completable.test().await().assertComplete();
}

@Test
public void testTraceTransformer() throws InterruptedException {
Span parentSpan = tracer.spanBuilder("parent").startSpan();
Expand Down Expand Up @@ -595,7 +665,7 @@ public void runnerRunAsync_propagatesContext() throws InterruptedException {
Session session =
runner
.sessionService()
.createSession("test-app", "test-user", null, "test-session")
.createSession(new SessionKey("test-app", "test-user", "test-session"))
.blockingGet();
Content newMessage = Content.fromParts(Part.fromText("hi"));
RunConfig runConfig = RunConfig.builder().build();
Expand Down Expand Up @@ -623,13 +693,20 @@ public void runnerRunLive_propagatesContext() throws InterruptedException {
Span parentSpan = tracer.spanBuilder("parent").startSpan();
try (Scope s = parentSpan.makeCurrent()) {
Session session =
Session.builder("test-session").userId("test-user").appName("test-app").build();
runner
.sessionService()
.createSession("test-app", "test-user", (Map<String, Object>) null, "test-session")
.blockingGet();
Content newMessage = Content.fromParts(Part.fromText("hi"));
RunConfig runConfig = RunConfig.builder().build();
LiveRequestQueue liveRequestQueue = new LiveRequestQueue();
liveRequestQueue.content(newMessage);
liveRequestQueue.close();
runner.runLive(session, liveRequestQueue, runConfig).test().await().assertComplete();
runner
.runLive(session.userId(), session.id(), liveRequestQueue, runConfig)
.test()
.await()
.assertComplete();
} finally {
parentSpan.end();
}
Expand Down
Loading