diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java index e68e2ac85ad..1aabdff4598 100644 --- a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -78,6 +78,15 @@ public BadConfigurationException(final String message) { static final String META_STRUCT_SDS = "sds"; static final String META_STRUCT_TAG_PROBS = "tag_probs"; + /** + * Anomaly detection tags copied from the local root span onto every {@code ai_guard} span with + * the {@code ai_guard.} prefix, so the AI Guard backend can correlate AI Guard requests with the + * request context (client IP, user, session) without depending on the local root span. + */ + static final String[] ANOMALY_DETECTION_TAGS = { + Tags.HTTP_CLIENT_IP, Tags.NETWORK_CLIENT_IP, Tags.HTTP_USER_AGENT, "usr.id", "usr.session_id" + }; + public static void install() { final Config config = Config.get(); final String apiKey = config.getApiKey(); @@ -241,6 +250,16 @@ private static void applyClientIpTags(final AgentSpan localRootSpan) { } } + private static void copyAnomalyDetectionTags( + final AgentSpan span, final AgentSpan localRootSpan) { + for (final String tag : ANOMALY_DETECTION_TAGS) { + final Object value = localRootSpan.getTag(tag); + if (value != null) { + span.setTag("ai_guard." + tag, value.toString()); + } + } + } + @Override public Evaluation evaluate(final List messages, final Options options) { if (messages == null || messages.isEmpty()) { @@ -258,6 +277,9 @@ public Evaluation evaluate(final List messages, final Options options) localRootSpan.setTag(Tags.AI_GUARD_KEEP, true); localRootSpan.setTag(EVENT_TAG, true); applyClientIpTags(localRootSpan); + // copyAnomalyDetectionTags MUST run after applyClientIpTags, to make + // sure client IP tags were populated. + copyAnomalyDetectionTags(span, localRootSpan); } try (final AgentScope scope = tracer.activateSpan(span)) { final Message last = messages.get(messages.size() - 1); diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy index 8a9fdc297de..406b16c7025 100644 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -279,15 +279,15 @@ class AIGuardInternalTests extends DDSpecification { final requestContext = Mock(RequestContext) localRootSpan.getRequestContext() >> requestContext requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5') + localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null + localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> null final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) then: - 1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null 1 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, '4.4.4.4') - 1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> null 1 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, '2.3.4.5') } @@ -296,15 +296,15 @@ class AIGuardInternalTests extends DDSpecification { final requestContext = Mock(RequestContext) localRootSpan.getRequestContext() >> requestContext requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5') + localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '9.9.9.9' + localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '8.8.8.8' final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) when: aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) then: - 1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '9.9.9.9' 0 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, _) - 1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '8.8.8.8' 0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _) } @@ -336,6 +336,46 @@ class AIGuardInternalTests extends DDSpecification { 0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _) } + void 'test evaluate copies anomaly detection tags from local root span to ai_guard span'() { + given: + localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '1.2.3.4' + localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '5.6.7.8' + localRootSpan.getTag(Tags.HTTP_USER_AGENT) >> 'curl/8.0' + localRootSpan.getTag('usr.id') >> 'u-123' + localRootSpan.getTag('usr.session_id') >> 's-456' + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + 1 * span.setTag('ai_guard.http.client_ip', '1.2.3.4') + 1 * span.setTag('ai_guard.network.client.ip', '5.6.7.8') + 1 * span.setTag('ai_guard.http.useragent', 'curl/8.0') + 1 * span.setTag('ai_guard.usr.id', 'u-123') + 1 * span.setTag('ai_guard.usr.session_id', 's-456') + } + + void 'test evaluate skips missing anomaly detection tags'() { + given: + localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '1.2.3.4' + localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null + localRootSpan.getTag(Tags.HTTP_USER_AGENT) >> null + localRootSpan.getTag('usr.id') >> 'u-123' + localRootSpan.getTag('usr.session_id') >> null + final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) + + when: + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) + + then: + 1 * span.setTag('ai_guard.http.client_ip', '1.2.3.4') + 1 * span.setTag('ai_guard.usr.id', 'u-123') + 0 * span.setTag('ai_guard.network.client.ip', _) + 0 * span.setTag('ai_guard.http.useragent', _) + 0 * span.setTag('ai_guard.usr.session_id', _) + } + void 'test evaluate with API errors'() { given: final errors = [[status: 400, title: 'Bad request']] diff --git a/dd-smoke-tests/appsec/springboot/build.gradle b/dd-smoke-tests/appsec/springboot/build.gradle index 2b2285a9ea4..94c44259a45 100644 --- a/dd-smoke-tests/appsec/springboot/build.gradle +++ b/dd-smoke-tests/appsec/springboot/build.gradle @@ -21,6 +21,10 @@ dependencies { implementation(group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.6.0') implementation group: 'com.h2database', name: 'h2', version: '2.1.212' + // Used by AIGuardController to set user/session tags on the local root span via the active OT span + implementation group: 'io.opentracing', name: 'opentracing-api', version: '0.32.0' + implementation group: 'io.opentracing', name: 'opentracing-util', version: '0.32.0' + // file upload implementation group: 'commons-fileupload', name: 'commons-fileupload', version: '1.5' diff --git a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java index 34e7e035c15..4f972aecb88 100644 --- a/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java +++ b/dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java @@ -8,6 +8,9 @@ import datadog.trace.api.aiguard.AIGuard.Evaluation; import datadog.trace.api.aiguard.AIGuard.Message; import datadog.trace.api.aiguard.AIGuard.Options; +import datadog.trace.api.interceptor.MutableSpan; +import io.opentracing.Span; +import io.opentracing.util.GlobalTracer; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -26,7 +29,19 @@ public class AIGuardController { @GetMapping(value = "/allow") - public ResponseEntity allow() { + public ResponseEntity allow( + @RequestHeader(name = "X-User-Id", required = false) final String userId, + @RequestHeader(name = "X-Session-Id", required = false) final String sessionId) { + final Span activeSpan = GlobalTracer.get().activeSpan(); + if (activeSpan instanceof MutableSpan) { + final MutableSpan rootSpan = ((MutableSpan) activeSpan).getLocalRootSpan(); + if (userId != null && !userId.isEmpty()) { + rootSpan.setTag("usr.id", userId); + } + if (sessionId != null && !sessionId.isEmpty()) { + rootSpan.setTag("usr.session_id", sessionId); + } + } final Evaluation result = AIGuard.evaluate( asList( diff --git a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy index d1a3f84ec18..bfb3e62ca69 100644 --- a/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy +++ b/dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy @@ -167,6 +167,52 @@ class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest { rootSpan.meta.get('network.client.ip') != publicIp } + void 'anomaly detection tags are copied from the local root span to the ai_guard span'() { + given: + final publicIp = '5.6.7.9' + final userId = 'u12345' + final sessionId = 's12345' + final userAgent = 'AIGuardSmokeTest/1.0' + final request = new Request.Builder() + .url("http://localhost:${httpPort}/aiguard/allow") + .header('X-Forwarded-For', publicIp) + .header('X-User-Id', userId) + .header('X-Session-Id', sessionId) + .header('User-Agent', userAgent) + .get() + .build() + + when: + final response = client.newCall(request).execute() + + then: + response.code() == 200 + + and: + waitForTraceCount(2) // /aiguard/allow + internal /aiguard/evaluate mock + final aiGuardSpan = traces*.spans + ?.flatten() + ?.find { it.resource == 'ai_guard' } as DecodedSpan + aiGuardSpan != null + final rootSpan = traces*.spans + ?.flatten() + ?.find { it.traceId == aiGuardSpan.traceId && it.parentId == 0 } as DecodedSpan + rootSpan != null + + // Tags must match what is on the root span + aiGuardSpan.meta.get('ai_guard.http.client_ip') == rootSpan.meta.get('http.client_ip') + aiGuardSpan.meta.get('ai_guard.network.client.ip') == rootSpan.meta.get('network.client.ip') + aiGuardSpan.meta.get('ai_guard.http.useragent') == rootSpan.meta.get('http.useragent') + aiGuardSpan.meta.get('ai_guard.usr.id') == rootSpan.meta.get('usr.id') + aiGuardSpan.meta.get('ai_guard.usr.session_id') == rootSpan.meta.get('usr.session_id') + + // And carry the expected values + aiGuardSpan.meta.get('ai_guard.http.client_ip') == publicIp + aiGuardSpan.meta.get('ai_guard.http.useragent') == userAgent + aiGuardSpan.meta.get('ai_guard.usr.id') == userId + aiGuardSpan.meta.get('ai_guard.usr.session_id') == sessionId + } + void 'test multimodal content parts evaluation'() { given: def request = new Request.Builder()