diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 79066b213..ab5f6567a 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -190,20 +190,23 @@ private Flowable callLlm( context, llmRequestBuilder, eventForCallbackUsage, exception) .switchIfEmpty(Single.error(exception)) .toFlowable()) - .doOnNext( - llmResp -> - Tracing.traceCallLlm( - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp)) .doOnError( error -> { Span span = Span.current(); span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); }) - .compose(Tracing.trace("call_llm").setParent(spanContext)) + .compose( + Tracing.trace("call_llm") + .setParent(spanContext) + .onSuccess( + (span, llmResp) -> + Tracing.traceCallLlm( + span, + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp))) .concatMap( llmResp -> handleAfterModelCallback(context, llmResp, eventForCallbackUsage) diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 7f338fdcf..35bf3cc96 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -54,6 +54,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; @@ -292,58 +293,49 @@ private static Map buildLlmRequestForTrace(LlmRequest llmRequest * @param llmResponse The LLM response object. */ public static void traceCallLlm( + Span span, InvocationContext invocationContext, String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - traceWithSpan( - "traceCallLlm", - span -> { - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest - .model() - .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .ifPresent( - config -> { - config - .topP() - .ifPresent( - topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - config - .maxOutputTokens() - .ifPresent( - maxTokens -> - span.setAttribute( - GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); - }); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() - .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); - }); - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> - span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); - }); + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); + + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) + .ifPresent( + reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -455,6 +447,7 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); + private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -474,6 +467,16 @@ public TracerProvider setParent(Context parentContext) { return this; } + /** + * Registers a callback to be executed with the span and the result item when the stream emits a + * success value. + */ + @CanIgnoreReturnValue + public TracerProvider onSuccess(BiConsumer consumer) { + this.onSuccessConsumer = consumer; + return this; + } + private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } @@ -504,7 +507,11 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Flowable pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -513,7 +520,11 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Single pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } @@ -522,7 +533,11 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); + Maybe pipeline = upstream.doOnSubscribe(s -> lifecycle.start()); + if (onSuccessConsumer != null) { + pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); + } + return pipeline.doFinally(lifecycle::end); }); } diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 9439fe718..f809193cf 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -503,7 +503,7 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); } finally { span.end(); }