diff --git a/modules/apm/src/main/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracer.java b/modules/apm/src/main/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracer.java index 4f8f55a473949..01878b3fae86c 100644 --- a/modules/apm/src/main/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracer.java +++ b/modules/apm/src/main/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracer.java @@ -177,7 +177,9 @@ public void startTrace(TraceContext traceContext, Traceable traceable, String sp // A span can have a parent span, which here is modelled though a parent span context. // Setting this is important for seeing a complete trace in the APM UI. - final Context parentContext = getParentContext(traceContext); + // Attempt to fetch a local parent context first, otherwise look for a remote parent + final Context localParentContext = traceContext.getTransient(Task.PARENT_APM_TRACE_CONTEXT); + final Context parentContext = localParentContext != null ? localParentContext : getRemoteParentContext(traceContext); if (parentContext != null) { spanBuilder.setParent(parentContext); } @@ -188,21 +190,21 @@ public void startTrace(TraceContext traceContext, Traceable traceable, String sp if (startTime != null) { spanBuilder.setStartTimestamp(startTime); } + final Span span = spanBuilder.startSpan(); - // If the agent decided not to record this span (e.g., due to transaction_max_spans), isRecording() will be false. - if (span.isRecording() == false) { - logger.trace("Span [{}] [{}] will not be recorded, e.g. due to transaction_max_spans reached", spanId, spanName); - // It's good practice to end the no-op span immediately to release any resources. - span.end(); - // Returning null from computeIfAbsent means no value will be inserted into the map. - return null; + // If not a root span (meaning a local parent exists) and the agent decided not to record the span, discard it immediately. + // Root spans (transactions), however, have to be kept to correctly report their duration. + if (localParentContext != null && span.isRecording() == false) { + logger.trace("Span [{}] [{}] will not be recorded due to transaction_max_spans reached", spanId, spanName); + span.end(); // end span immediately to release any resources. + return null; // return null to discard and not record in map of spans } - // If we are here, the span is real and being recorded. - logger.trace("Successfully started tracing [{}] [{}]", spanId, spanName); final Context contextForNewSpan = Context.current().with(span); - - updateThreadContext(traceContext, services, contextForNewSpan); + if (span.isRecording()) { + logger.trace("Recording trace [{}] [{}]", spanId, spanName); + updateThreadContext(traceContext, services, contextForNewSpan); + } return contextForNewSpan; }); @@ -240,30 +242,26 @@ private static void updateThreadContext(TraceContext traceContext, APMServices s }); } - private Context getParentContext(TraceContext traceContext) { + private Context getRemoteParentContext(TraceContext traceContext) { // https://github.com/open-telemetry/opentelemetry-java/discussions/2884#discussioncomment-381870 // If you just want to propagate across threads within the same process, you don't need context propagators (extract/inject). // You can just pass the Context object directly to another thread (it is immutable and thus thread-safe). - // Attempt to fetch a local parent context first, otherwise look for a remote parent - Context parentContext = traceContext.getTransient(Task.PARENT_APM_TRACE_CONTEXT); - if (parentContext == null) { - final String traceParentHeader = traceContext.getTransient(Task.PARENT_TRACE_PARENT_HEADER); - final String traceStateHeader = traceContext.getTransient(Task.PARENT_TRACE_STATE); - - if (traceParentHeader != null) { - final Map traceContextMap = Maps.newMapWithExpectedSize(2); - // traceparent and tracestate should match the keys used by W3CTraceContextPropagator - traceContextMap.put(Task.TRACE_PARENT_HTTP_HEADER, traceParentHeader); - if (traceStateHeader != null) { - traceContextMap.put(Task.TRACE_STATE, traceStateHeader); - } - parentContext = services.openTelemetry.getPropagators() - .getTextMapPropagator() - .extract(Context.current(), traceContextMap, new MapKeyGetter()); + final String traceParentHeader = traceContext.getTransient(Task.PARENT_TRACE_PARENT_HEADER); + final String traceStateHeader = traceContext.getTransient(Task.PARENT_TRACE_STATE); + + if (traceParentHeader != null) { + final Map traceContextMap = Maps.newMapWithExpectedSize(2); + // traceparent and tracestate should match the keys used by W3CTraceContextPropagator + traceContextMap.put(Task.TRACE_PARENT_HTTP_HEADER, traceParentHeader); + if (traceStateHeader != null) { + traceContextMap.put(Task.TRACE_STATE, traceStateHeader); } + return services.openTelemetry.getPropagators() + .getTextMapPropagator() + .extract(Context.current(), traceContextMap, new MapKeyGetter()); } - return parentContext; + return null; } /** @@ -288,7 +286,7 @@ private Context getParentContext(TraceContext traceContext) { @Override public Releasable withScope(Traceable traceable) { final Context context = spans.get(traceable.getSpanId()); - if (context != null) { + if (context != null && Span.fromContextOrNull(context).isRecording()) { return context.makeCurrent()::close; } return () -> {}; @@ -385,9 +383,10 @@ public void setAttribute(Traceable traceable, String key, String value) { @Override public void stopTrace(Traceable traceable) { - final var span = Span.fromContextOrNull(spans.remove(traceable.getSpanId())); + final String spanId = traceable.getSpanId(); + final var span = Span.fromContextOrNull(spans.remove(spanId)); if (span != null) { - logger.trace("Finishing trace [{}]", traceable); + logger.trace("Finishing trace [{}]", spanId); span.end(); } } diff --git a/modules/apm/src/test/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracerTests.java b/modules/apm/src/test/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracerTests.java index 60f4c80b1d7fd..c4af44597fd3e 100644 --- a/modules/apm/src/test/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracerTests.java +++ b/modules/apm/src/test/java/org/elasticsearch/telemetry/apm/internal/tracing/APMTracerTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.tasks.Task; import org.elasticsearch.telemetry.apm.internal.APMAgentSettings; +import org.elasticsearch.telemetry.tracing.TraceContext; import org.elasticsearch.telemetry.tracing.Traceable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -42,6 +43,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -87,22 +89,29 @@ public void test_onTraceStarted_startsTrace() { Settings settings = Settings.builder().put(APMAgentSettings.TELEMETRY_TRACING_ENABLED_SETTING.getKey(), true).build(); APMTracer apmTracer = buildTracer(settings); - apmTracer.startTrace(new ThreadContext(settings), TRACEABLE1, "name1", null); + ThreadContext traceContext = new ThreadContext(settings); + apmTracer.startTrace(traceContext, TRACEABLE1, "name1", null); + assertThat(traceContext.getTransient(Task.APM_TRACE_CONTEXT), notNullValue()); assertThat(apmTracer.getSpans(), aMapWithSize(1)); assertThat(apmTracer.getSpans(), hasKey(TRACEABLE1.getSpanId())); } /** - * Check that when a trace is started, but it is not recorded, e.g. due to sampling, the tracer does not record it either. + * Check that when a root trace is started, but it is not recorded, e.g. due to sampling, + * the tracer tracks it but doesn't start tracing. */ - public void test_onTraceStarted_ifNotRecorded_doesNotStartTrace() { + public void test_onTraceStarted_ifNotRecorded_doesNotStartTracing() { Settings settings = Settings.builder().put(APMAgentSettings.TELEMETRY_TRACING_ENABLED_SETTING.getKey(), true).build(); APMTracer apmTracer = buildTracer(settings); - apmTracer.startTrace(new ThreadContext(settings), TRACEABLE1, "name1_discard", null); + ThreadContext traceContext = new ThreadContext(settings); + apmTracer.startTrace(traceContext, TRACEABLE1, "name1_discard", null); - assertThat(apmTracer.getSpans(), anEmptyMap()); + assertThat(traceContext.getTransient(Task.APM_TRACE_CONTEXT), nullValue()); + // the root span (transaction) is tracked + assertThat(apmTracer.getSpans(), aMapWithSize(1)); + assertThat(apmTracer.getSpans(), hasKey(TRACEABLE1.getSpanId())); } /** @@ -116,8 +125,11 @@ public void test_onNestedTraceStarted_ifNotRecorded_doesNotStartTrace() { apmTracer.startTrace(traceContext, TRACEABLE1, "name1", null); try (var ignore1 = traceContext.newTraceContext()) { apmTracer.startTrace(traceContext, TRACEABLE2, "name2_discard", null); + assertThat(traceContext.getTransient(Task.APM_TRACE_CONTEXT), nullValue()); + try (var ignore2 = traceContext.newTraceContext()) { apmTracer.startTrace(traceContext, TRACEABLE3, "name3_discard", null); + assertThat(traceContext.getTransient(Task.APM_TRACE_CONTEXT), nullValue()); } } assertThat(apmTracer.getSpans(), aMapWithSize(1)); @@ -131,12 +143,13 @@ public void test_onTraceStartedWithStartTime_startsTrace() { Settings settings = Settings.builder().put(APMAgentSettings.TELEMETRY_TRACING_ENABLED_SETTING.getKey(), true).build(); APMTracer apmTracer = buildTracer(settings); - ThreadContext threadContext = new ThreadContext(settings); + TraceContext traceContext = new ThreadContext(settings); // 1_000_000L because of "toNanos" conversions that overflow for large long millis Instant spanStartTime = Instant.ofEpochMilli(randomLongBetween(0, Long.MAX_VALUE / 1_000_000L)); - threadContext.putTransient(Task.TRACE_START_TIME, spanStartTime); - apmTracer.startTrace(threadContext, TRACEABLE1, "name1", null); + traceContext.putTransient(Task.TRACE_START_TIME, spanStartTime); + apmTracer.startTrace(traceContext, TRACEABLE1, "name1", null); + assertThat(traceContext.getTransient(Task.APM_TRACE_CONTEXT), notNullValue()); assertThat(apmTracer.getSpans(), aMapWithSize(1)); assertThat(apmTracer.getSpans(), hasKey(TRACEABLE1.getSpanId())); assertThat(((SpyAPMTracer) apmTracer).getSpanStartTime("name1"), is(spanStartTime)); @@ -151,6 +164,7 @@ public void test_onTraceStopped_stopsTrace() { apmTracer.startTrace(new ThreadContext(settings), TRACEABLE1, "name1", null); apmTracer.stopTrace(TRACEABLE1); + apmTracer.stopTrace(TRACEABLE2); // stopping a non-existent trace is a noop assertThat(apmTracer.getSpans(), anEmptyMap()); } diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index 86f5c59284460..9eae6636f79a6 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -204,7 +204,13 @@ public StoredContext newTraceContext() { // this is the context when this method returns final ThreadContextStruct newContext; - if (originalContext.hasTraceContext() == false) { + + final boolean hasTraceHeaders = originalContext.requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER) + || originalContext.requestHeaders.containsKey(Task.TRACE_STATE) + || originalContext.transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + + if (hasTraceHeaders == false) { + // no need to copy if no trace headers are present newContext = originalContext; } else { final Map newRequestHeaders = new HashMap<>(originalContext.requestHeaders); @@ -223,6 +229,9 @@ public StoredContext newTraceContext() { final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT); if (previousTraceContext != null) { newTransientHeaders.put(Task.PARENT_APM_TRACE_CONTEXT, previousTraceContext); + // Remove the trace start time override for a previous context if such a context already exists. + // If kept, all spans would contain the same start time. + newTransientHeaders.remove(Task.TRACE_START_TIME); } newContext = new ThreadContextStruct( @@ -246,12 +255,12 @@ public StoredContext newTraceContext() { }; } - public boolean hasTraceContext() { - return threadLocal.get().hasTraceContext(); + public boolean hasApmTraceContext() { + return threadLocal.get().hasApmTraceContext(); } - public boolean hasParentTraceContext() { - return threadLocal.get().hasParentTraceContext(); + public boolean hasParentApmTraceContext() { + return threadLocal.get().hasParentApmTraceContext(); } /** @@ -644,6 +653,10 @@ public T getTransient(String key) { return (T) threadLocal.get().transientHeaders.get(key); } + public boolean hasTransient(Collection keys) { + return threadLocal.get().transientHeaders.keySet().containsAll(keys); + } + /** * Returns unmodifiable copy of all transient headers. */ @@ -873,18 +886,6 @@ private ThreadContextStruct putResponseHeaders(Map> headers) return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext); } - private boolean hasTraceContext() { - return requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER) - || requestHeaders.containsKey(Task.TRACE_STATE) - || transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); - } - - private boolean hasParentTraceContext() { - return transientHeaders.containsKey(Task.PARENT_TRACE_PARENT_HEADER) - || transientHeaders.containsKey(Task.PARENT_TRACE_STATE) - || transientHeaders.containsKey(Task.PARENT_APM_TRACE_CONTEXT); - } - private void logWarningHeaderThresholdExceeded(long threshold, Setting thresholdSetting) { // If available, log some selected headers to help identifying the source of the request. // Note: Only Task.HEADERS_TO_COPY are guaranteed to be preserved at this point. @@ -963,7 +964,18 @@ private ThreadContextStruct putResponse( return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext, newWarningHeaderSize); } + private boolean hasApmTraceContext() { + return transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + } + + private boolean hasParentApmTraceContext() { + return transientHeaders.containsKey(Task.PARENT_APM_TRACE_CONTEXT); + } + private ThreadContextStruct putTransient(String key, Object value) { + assert key != Task.TRACE_START_TIME || (hasApmTraceContext() || hasParentApmTraceContext()) == false + : "trace.starttime cannot be set after a trace context is present"; + Map newTransient = new HashMap<>(this.transientHeaders); putSingleHeader(key, value, newTransient); return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext); diff --git a/server/src/main/java/org/elasticsearch/tasks/Task.java b/server/src/main/java/org/elasticsearch/tasks/Task.java index 23ac4de8b618a..09d2de99b0214 100644 --- a/server/src/main/java/org/elasticsearch/tasks/Task.java +++ b/server/src/main/java/org/elasticsearch/tasks/Task.java @@ -60,6 +60,10 @@ public class Task implements Traceable { */ public static final String TRACE_STATE = "tracestate"; + /** + * Optional transient header allowing to override the start time of the root trace. + * This is discarded when creating a new trace context once an APM trace context exists. + */ public static final String TRACE_START_TIME = "trace.starttime"; /** diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index b6216d4b7db12..521f91f830d16 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -139,7 +139,7 @@ public Task register(String type, String action, TaskAwareRequest request, boole long maxSize = maxHeaderSize.getBytes(); ThreadContext threadContext = threadPool.getThreadContext(); - assert threadContext.hasTraceContext() == false : "Expected threadContext to have no traceContext fields"; + assert threadContext.hasApmTraceContext() == false : "Expected threadContext to have no APM trace context"; for (String key : taskHeaders) { String httpHeader = threadContext.getHeader(key); @@ -181,7 +181,7 @@ public Task register(String type, String action, TaskAwareRequest request, boole * For REST actions this will be the case, otherwise {@link Tracer#startTrace} can be used. */ void maybeStartTrace(ThreadContext threadContext, Task task) { - if (threadContext.hasParentTraceContext() == false) { + if (threadContext.hasParentApmTraceContext() == false) { return; } TaskId parentTask = task.getParentTaskId(); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java index b963cac114c4b..cee975ab32bcd 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java @@ -21,6 +21,7 @@ import org.hamcrest.Matcher; import java.io.IOException; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -1189,20 +1190,23 @@ public void testNewTraceContext() { var rootTraceContext = Map.of(Task.TRACE_PARENT_HTTP_HEADER, randomIdentifier(), Task.TRACE_STATE, randomIdentifier()); var apmTraceContext = new Object(); + var traceStartTime = Instant.now(); var responseKey = randomIdentifier(); var responseValue = randomAlphaOfLength(10); threadContext.putHeader(rootTraceContext); + threadContext.putTransient(Task.TRACE_START_TIME, traceStartTime); threadContext.putTransient(Task.APM_TRACE_CONTEXT, apmTraceContext); - assertThat(threadContext.hasTraceContext(), equalTo(true)); - assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + assertThat(threadContext.hasApmTraceContext(), equalTo(true)); + assertThat(threadContext.hasParentApmTraceContext(), equalTo(false)); try (var ignored = threadContext.newTraceContext()) { - assertThat(threadContext.hasTraceContext(), equalTo(false)); // no trace started yet - assertThat(threadContext.hasParentTraceContext(), equalTo(true)); + assertThat(threadContext.hasApmTraceContext(), equalTo(false)); // no trace started yet + assertThat(threadContext.hasParentApmTraceContext(), equalTo(true)); assertThat(threadContext.getHeaders(), is(anEmptyMap())); + // trace start time is not propagated assertThat( threadContext.getTransientHeaders(), equalTo( @@ -1220,11 +1224,14 @@ public void testNewTraceContext() { threadContext.addResponseHeader(responseKey, responseValue); } - assertThat(threadContext.hasTraceContext(), equalTo(true)); - assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + assertThat(threadContext.hasApmTraceContext(), equalTo(true)); + assertThat(threadContext.hasParentApmTraceContext(), equalTo(false)); assertThat(threadContext.getHeaders(), equalTo(rootTraceContext)); - assertThat(threadContext.getTransientHeaders(), equalTo(Map.of(Task.APM_TRACE_CONTEXT, apmTraceContext))); + assertThat( + threadContext.getTransientHeaders(), + equalTo(Map.of(Task.APM_TRACE_CONTEXT, apmTraceContext, Task.TRACE_START_TIME, traceStartTime)) + ); assertThat(threadContext.getResponseHeaders(), equalTo(Map.of(responseKey, List.of(responseValue)))); } @@ -1234,13 +1241,13 @@ public void testNewTraceContextWithoutParentTrace() { var responseKey = randomIdentifier(); var responseValue = randomAlphaOfLength(10); - assertThat(threadContext.hasTraceContext(), equalTo(false)); - assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + assertThat(threadContext.hasApmTraceContext(), equalTo(false)); + assertThat(threadContext.hasParentApmTraceContext(), equalTo(false)); try (var ignored = threadContext.newTraceContext()) { assertTrue(threadContext.isDefaultContext()); - assertThat(threadContext.hasTraceContext(), equalTo(false)); - assertThat(threadContext.hasParentTraceContext(), equalTo(false)); + assertThat(threadContext.hasApmTraceContext(), equalTo(false)); + assertThat(threadContext.hasParentApmTraceContext(), equalTo(false)); // discared, just making sure the context is isolated threadContext.putTransient(randomIdentifier(), randomAlphaOfLength(10)); diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index f38af77be6150..3a05ba07877bc 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -287,8 +287,8 @@ public void testRegisterTaskStartsTracingIfTraceParentExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); - // fake a trace parent - threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + // fake a parent APM trace context + threadPool.getThreadContext().putTransient(Task.PARENT_APM_TRACE_CONTEXT, null); final boolean hasParentTask = randomBoolean(); final TaskId parentTask = hasParentTask ? new TaskId("parentNode", 1) : TaskId.EMPTY_TASK_ID; @@ -366,8 +366,8 @@ public TaskId getParentTask() { } }); - // fake a trace context (trace parent) - threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + // fake an APM trace context + threadPool.getThreadContext().putTransient(Task.APM_TRACE_CONTEXT, null); taskManager.unregister(task); verify(mockTracer).stopTrace(task); @@ -408,8 +408,8 @@ public void testRegisterAndExecuteStartsTracingIfTraceParentExists() { final Tracer mockTracer = mock(Tracer.class); final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of(), mockTracer); - // fake a trace parent - threadPool.getThreadContext().putHeader(Task.TRACE_PARENT_HTTP_HEADER, "traceparent"); + // fake a parent APM trace context + threadPool.getThreadContext().putTransient(Task.PARENT_APM_TRACE_CONTEXT, null); final Task task = taskManager.registerAndExecute( "testType",