diff --git a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java index 1293e7c2398..323a618dbf0 100644 --- a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java +++ b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/main/java/datadog/trace/instrumentation/java/lang/jdk21/VirtualThreadInstrumentation.java @@ -1,13 +1,11 @@ package datadog.trace.instrumentation.java.lang.jdk21; import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.capture; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.endTaskScope; -import static datadog.trace.bootstrap.instrumentation.java.concurrent.AdviceUtils.startTaskScope; import static datadog.trace.bootstrap.instrumentation.java.lang.VirtualThreadHelper.AGENT_SCOPE_CLASS_NAME; import static datadog.trace.bootstrap.instrumentation.java.lang.VirtualThreadHelper.VIRTUAL_THREAD_CLASS_NAME; import static net.bytebuddy.matcher.ElementMatchers.isConstructor; import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; import com.google.auto.service.AutoService; import datadog.environment.JavaVirtualMachine; @@ -16,7 +14,8 @@ import datadog.trace.bootstrap.ContextStore; import datadog.trace.bootstrap.InstrumentationContext; import datadog.trace.bootstrap.instrumentation.api.AgentScope; -import datadog.trace.bootstrap.instrumentation.java.concurrent.State; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.java.concurrent.ConcurrentState; import java.util.HashMap; import java.util.Map; import net.bytebuddy.asm.Advice; @@ -25,13 +24,13 @@ /** * Instruments {@code VirtualThread} to capture active state at creation, activate it on - * continuation mount, and close the scope from activation on continuation unmount. + * continuation mount, close the scope on continuation unmount, and release the continuation when + * the virtual thread terminates. * *

The instrumentation uses two context stores. The first from {@link Runnable} (as {@code - * VirtualThread} inherits from {@link Runnable}) to store the captured {@link State} to restore - * later. It additionally stores the {@link AgentScope} to be able to close it later as activation / - * close is not done around the same method (so passing the scope from {@link OnMethodEnter} / - * {@link OnMethodExit} using advice return value is not possible). + * VirtualThread} inherits from {@link Runnable}) stores a held {@link ConcurrentState} so the + * parent context can be re-activated on each mount. It additionally stores the {@link AgentScope} + * to be able to close it later as activation / close is not done around the same method. * *

Instrumenting the internal {@code VirtualThread.runContinuation()} method does not work as the * current thread is still the carrier thread and not a virtual thread. Activating the state when on @@ -62,7 +61,7 @@ public boolean isEnabled() { @Override public Map contextStore() { Map contextStore = new HashMap<>(); - contextStore.put(Runnable.class.getName(), State.class.getName()); + contextStore.put(Runnable.class.getName(), ConcurrentState.class.getName()); contextStore.put(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); return contextStore; } @@ -72,36 +71,66 @@ public void methodAdvice(MethodTransformer transformer) { transformer.applyAdvice(isConstructor(), getClass().getName() + "$Construct"); transformer.applyAdvice(isMethod().and(named("mount")), getClass().getName() + "$Activate"); transformer.applyAdvice(isMethod().and(named("unmount")), getClass().getName() + "$Close"); + transformer.applyAdvice( + isMethod() + .and( + // this one for jdk 21 + named("afterTerminate") + .and(takesArguments(2)) + // this one for jdk 25+ + .or(named("afterDone").and(takesArguments(1)))), + getClass().getName() + "$Terminate"); } public static final class Construct { @OnMethodExit(suppress = Throwable.class) public static void captureScope(@Advice.This Object virtualThread) { - capture(InstrumentationContext.get(Runnable.class, State.class), (Runnable) virtualThread); + ContextStore stateStore = + InstrumentationContext.get(Runnable.class, ConcurrentState.class); + ConcurrentState.captureContinuation( + stateStore, (Runnable) virtualThread, AgentTracer.activeSpan()); } } public static final class Activate { @OnMethodExit(suppress = Throwable.class) public static void activate(@Advice.This Object virtualThread) { - ContextStore stateStore = - InstrumentationContext.get(Runnable.class, State.class); + ContextStore stateStore = + InstrumentationContext.get(Runnable.class, ConcurrentState.class); ContextStore scopeStore = InstrumentationContext.get(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); - AgentScope agentScope = startTaskScope(stateStore, (Runnable) virtualThread); - scopeStore.put(virtualThread, agentScope); + AgentScope agentScope = + ConcurrentState.activateAndContinueContinuation(stateStore, (Runnable) virtualThread); + if (agentScope != null) { + scopeStore.put(virtualThread, agentScope); + } } } public static final class Close { @OnMethodEnter(suppress = Throwable.class) public static void close(@Advice.This Object virtualThread) { + ContextStore stateStore = + InstrumentationContext.get(Runnable.class, ConcurrentState.class); ContextStore scopeStore = InstrumentationContext.get(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); - Object agentScope = scopeStore.get(virtualThread); + Object agentScope = scopeStore.remove(virtualThread); if (agentScope instanceof AgentScope) { - endTaskScope((AgentScope) agentScope); + ConcurrentState.closeScope( + stateStore, (Runnable) virtualThread, (AgentScope) agentScope, null); } } } + + public static final class Terminate { + @OnMethodExit(suppress = Throwable.class) + public static void cleanup(@Advice.This Object virtualThread) { + ContextStore stateStore = + InstrumentationContext.get(Runnable.class, ConcurrentState.class); + ConcurrentState.cancelAndClearContinuation(stateStore, (Runnable) virtualThread); + ContextStore scopeStore = + InstrumentationContext.get(VIRTUAL_THREAD_CLASS_NAME, AGENT_SCOPE_CLASS_NAME); + scopeStore.remove(virtualThread); + } + } } diff --git a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java index b6359e826d0..64c378d291f 100644 --- a/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java +++ b/dd-java-agent/instrumentation/java/java-lang/java-lang-21.0/src/test/java/testdog/trace/instrumentation/java/lang/jdk21/VirtualThreadApiInstrumentationTest.java @@ -3,11 +3,14 @@ import static datadog.trace.agent.test.assertions.SpanMatcher.span; import static datadog.trace.agent.test.assertions.TraceMatcher.SORT_BY_START_TIME; import static datadog.trace.agent.test.assertions.TraceMatcher.trace; +import static org.junit.jupiter.api.Assertions.assertEquals; import datadog.trace.agent.test.AbstractInstrumentationTest; +import datadog.trace.api.CorrelationIdentifier; import datadog.trace.api.Trace; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -137,6 +140,68 @@ public void run() { span().childOfPrevious().operationName("great-great-child"))); } + @DisplayName("test CorrelationIdentifier across virtual thread remount") + @Test + void testCorrelationIdentifierAcrossVirtualThreadRemount() throws InterruptedException { + AtomicReference parentTraceId = new AtomicReference<>(); + AtomicReference parentSpanId = new AtomicReference<>(); + AtomicReference traceIdBeforeRemount = new AtomicReference<>(); + AtomicReference spanIdBeforeRemount = new AtomicReference<>(); + AtomicReference traceIdAfterRemount = new AtomicReference<>(); + AtomicReference spanIdAfterRemount = new AtomicReference<>(); + + new Runnable() { + @Override + @Trace(operationName = "parent") + public void run() { + parentTraceId.set(CorrelationIdentifier.getTraceId()); + parentSpanId.set(CorrelationIdentifier.getSpanId()); + + Thread thread = + Thread.startVirtualThread( + () -> { + traceIdBeforeRemount.set(CorrelationIdentifier.getTraceId()); + spanIdBeforeRemount.set(CorrelationIdentifier.getSpanId()); + + try { + // Sleeping should park and later remount the virtual thread. + Thread.sleep(10); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + traceIdAfterRemount.set(CorrelationIdentifier.getTraceId()); + spanIdAfterRemount.set(CorrelationIdentifier.getSpanId()); + }); + + try { + thread.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }.run(); + + assertEquals( + parentTraceId.get(), + traceIdBeforeRemount.get(), + "trace id should be visible before the virtual thread remounts"); + assertEquals( + parentSpanId.get(), + spanIdBeforeRemount.get(), + "span id should be visible before the virtual thread remounts"); + assertEquals( + parentTraceId.get(), + traceIdAfterRemount.get(), + "trace id should survive a virtual thread remount"); + assertEquals( + parentSpanId.get(), + spanIdAfterRemount.get(), + "span id should survive a virtual thread remount"); + + assertTraces(trace(span().root().operationName("parent"))); + } + /** Verifies the parent / child span relation. */ void assertConnectedTrace() { assertTraces(