From 44a39280b8c5bd803e56c2c7674958393605904e Mon Sep 17 00:00:00 2001 From: Santiago Mola Date: Fri, 8 May 2026 10:31:22 +0200 Subject: [PATCH 1/3] test(aiguard): migrate AIGuardSystemTests to JUnit 5 Convert AIGuardSystemTests.groovy to a Java JUnit 5 test extending DDJavaSpecification, using WithConfigExtension for env injection. Add junit5/mockito test deps to the agent-aiguard module and start tracking the module in .github/g2j-migrated-modules.txt. --- .github/g2j-migrated-modules.txt | 1 + dd-java-agent/agent-aiguard/build.gradle | 2 ++ .../datadog/aiguard/AIGuardSystemTests.groovy | 22 -------------- .../datadog/aiguard/AIGuardSystemTests.java | 30 +++++++++++++++++++ 4 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 .github/g2j-migrated-modules.txt delete mode 100644 dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy create mode 100644 dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardSystemTests.java diff --git a/.github/g2j-migrated-modules.txt b/.github/g2j-migrated-modules.txt new file mode 100644 index 00000000000..7b864e53b35 --- /dev/null +++ b/.github/g2j-migrated-modules.txt @@ -0,0 +1 @@ +dd-java-agent/agent-aiguard diff --git a/dd-java-agent/agent-aiguard/build.gradle b/dd-java-agent/agent-aiguard/build.gradle index 5e4841dbf3c..890c42ee444 100644 --- a/dd-java-agent/agent-aiguard/build.gradle +++ b/dd-java-agent/agent-aiguard/build.gradle @@ -23,6 +23,8 @@ dependencies { implementation project(':communication') testImplementation project(':utils:test-utils') + testImplementation libs.bundles.junit5 + testImplementation libs.bundles.mockito testImplementation('org.skyscreamer:jsonassert:1.5.3') testImplementation('com.fasterxml.jackson.core:jackson-databind:2.20.0') } diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy deleted file mode 100644 index 0929422df52..00000000000 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardSystemTests.groovy +++ /dev/null @@ -1,22 +0,0 @@ -package com.datadog.aiguard - -import datadog.trace.api.aiguard.AIGuard -import datadog.trace.test.util.DDSpecification - -class AIGuardSystemTests extends DDSpecification { - - void cleanup() { - AIGuardInternal.uninstall() - } - - void 'test SDK initialization'() { - injectEnvConfig('API_KEY', 'api') - injectEnvConfig('APP_KEY', 'app') - - when: - AIGuardSystem.start() - - then: - AIGuard.EVALUATOR instanceof AIGuardInternal - } -} diff --git a/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardSystemTests.java b/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardSystemTests.java new file mode 100644 index 00000000000..6cb3e6e078b --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardSystemTests.java @@ -0,0 +1,30 @@ +package com.datadog.aiguard; + +import static datadog.trace.junit.utils.config.WithConfigExtension.injectEnvConfig; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.test.util.DDJavaSpecification; +import java.lang.reflect.Field; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +class AIGuardSystemTests extends DDJavaSpecification { + + @AfterEach + void uninstallEvaluator() { + AIGuardInternal.uninstall(); + } + + @Test + void testSdkInitialization() throws ReflectiveOperationException { + injectEnvConfig("API_KEY", "api"); + injectEnvConfig("APP_KEY", "app"); + + AIGuardSystem.start(); + + Field evaluator = AIGuard.class.getDeclaredField("EVALUATOR"); + evaluator.setAccessible(true); + assertTrue(evaluator.get(null) instanceof AIGuardInternal); + } +} From 88f11dd154b1bd4654e161f42c01de0f7a1a2885 Mon Sep 17 00:00:00 2001 From: Santiago Mola Date: Fri, 8 May 2026 11:30:09 +0200 Subject: [PATCH 2/3] test(aiguard): migrate AIGuardInternalTests to JUnit 5 Convert the 58-test Spock spec into a JUnit 5 Java test extending DDJavaSpecification. Use @TableTest for tabular cases (missing keys, endpoint discovery), @MethodSource for parameterized cases that need non-tabular data (action/blocking/suite combinations, options, sds-finding shapes), and Mockito with ArgumentCaptor for the metaStruct capture pattern previously expressed as Spock closures. --- .../aiguard/AIGuardInternalTests.groovy | 896 ------------- .../datadog/aiguard/AIGuardInternalTests.java | 1148 +++++++++++++++++ 2 files changed, 1148 insertions(+), 896 deletions(-) delete mode 100644 dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy create mode 100644 dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardInternalTests.java diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy deleted file mode 100644 index 8a9fdc297de..00000000000 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ /dev/null @@ -1,896 +0,0 @@ -package com.datadog.aiguard - -import com.fasterxml.jackson.annotation.JsonInclude -import com.fasterxml.jackson.databind.ObjectMapper -import com.fasterxml.jackson.databind.PropertyNamingStrategies -import com.squareup.moshi.Moshi -import datadog.common.version.VersionInfo -import datadog.trace.api.Config -import datadog.trace.api.aiguard.AIGuard -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.telemetry.WafMetricCollector -import datadog.trace.bootstrap.instrumentation.api.AgentSpan -import datadog.trace.bootstrap.instrumentation.api.AgentTracer -import datadog.trace.bootstrap.instrumentation.api.ClientIpAddressData -import datadog.trace.bootstrap.instrumentation.api.Tags -import datadog.trace.test.util.DDSpecification -import okhttp3.Call -import okhttp3.HttpUrl -import okhttp3.MediaType -import okhttp3.OkHttpClient -import okhttp3.Protocol -import okhttp3.Request -import okhttp3.RequestBody -import okhttp3.Response -import okhttp3.ResponseBody -import okio.Okio -import spock.lang.Shared - -import org.skyscreamer.jsonassert.JSONAssert -import org.skyscreamer.jsonassert.JSONCompareMode - -import static datadog.trace.api.aiguard.AIGuard.Action.ABORT -import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW -import static datadog.trace.api.aiguard.AIGuard.Action.DENY -import static org.codehaus.groovy.runtime.DefaultGroovyMethods.combinations - -class AIGuardInternalTests extends DDSpecification { - - @Shared - protected static final URL = HttpUrl.parse('https://app.datadoghq.com/api/v2/ai-guard/evaluate') - - @Shared - protected static final HEADERS = ['DD-API-KEY': 'api', - 'DD-APPLICATION-KEY': 'app', - 'DD-AI-GUARD-VERSION': VersionInfo.VERSION, - 'DD-AI-GUARD-SOURCE': 'SDK', - 'DD-AI-GUARD-LANGUAGE': 'jvm'] - - @Shared - protected static final ORIGINAL_TRACER = AgentTracer.get() - - @Shared - protected static final MOSHI = new Moshi.Builder().build() - - @Shared - protected static final MAPPER = new ObjectMapper() - .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) - .setDefaultPropertyInclusion( - JsonInclude.Value.construct(JsonInclude.Include.NON_NULL, JsonInclude.Include.NON_NULL) - ) - - @Shared - protected static final TOOL_CALL = [ - AIGuard.Message.message('system', 'You are a beautiful AI assistant'), - AIGuard.Message.message('user', 'What is 2 + 2'), - AIGuard.Message.assistant( - AIGuard.ToolCall.toolCall('call_1', 'calc', '{ "operator": "+", "args": [2, 2] }') - ) - ] - - @Shared - protected static final TOOL_OUTPUT = TOOL_CALL + [AIGuard.Message.tool('call_1', '5')] - - @Shared - protected static final PROMPT = TOOL_OUTPUT + [AIGuard.Message.message('assistant', '2 + 2 is 5'), AIGuard.Message.message('user', '')] - - protected AgentSpan span - protected AgentSpan localRootSpan - - void setup() { - injectEnvConfig('SERVICE', 'ai_guard_test') - injectEnvConfig('ENV', 'test') - - span = Mock(AgentSpan) - localRootSpan = Mock(AgentSpan) - span.getLocalRootSpan() >> localRootSpan - final builder = Mock(AgentTracer.SpanBuilder) { - start() >> span - } - final tracer = Stub(AgentTracer.TracerAPI) { - buildSpan(_ as String, _ as String) >> builder - } - AgentTracer.forceRegister(tracer) - - WafMetricCollector.get().tap { - prepareMetrics() - drain() - } - } - - void cleanup() { - AgentTracer.forceRegister(ORIGINAL_TRACER) - AIGuardInternal.uninstall() - } - - void 'test missing api/app keys'() { - given: - if (apiKey) { - injectEnvConfig('API_KEY', apiKey) - } - if (appKey) { - injectEnvConfig('APP_KEY', appKey) - } - - when: - AIGuardInternal.install() - - then: - thrown(AIGuardInternal.BadConfigurationException) - - where: - apiKey | appKey - 'apiKey' | null - 'apiKey' | '' - null | 'appKey' - '' | 'appKey' - null | null - '' | '' - } - - void 'test endpoint discovery'() { - given: - injectEnvConfig('API_KEY', 'api') - injectEnvConfig('APP_KEY', 'app') - if (endpoint != null) { - injectEnvConfig("AI_GUARD_ENDPOINT", endpoint) - } else { - removeEnvConfig("AI_GUARD_ENDPOINT") - } - if (site != null) { - injectEnvConfig('SITE', site) - } else { - removeEnvConfig('SITE') - } - - when: - AIGuardInternal.install() - - then: - final internal = (AIGuardInternal) AIGuard.EVALUATOR - internal.url.toString() == expected - - where: - endpoint | site | expected - 'https://test' | null | 'https://test/evaluate' - null | null | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate' - null | 'datadoghq.com' | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate' - null | 'datad0g.com' | 'https://app.datad0g.com/api/v2/ai-guard/evaluate' - } - - void 'test evaluate'() { - given: - Request request = null - Throwable error = null - AIGuard.Evaluation eval = null - Map receivedMeta = null - final throwAbortError = suite.blocking && suite.action != ALLOW - final call = Mock(Call) { - execute() >> { - return mockResponse( - request, - 200, - [data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], tag_probs: suite.tagProbabilities ?: [:], is_blocking_enabled: suite.blocking]]] - ) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) - - when: - try { - eval = aiguard.evaluate(suite.messages, new AIGuard.Options().block(suite.blocking)) - } catch (Throwable e) { - error = e - } - - then: - 1 * span.setTag(AIGuardInternal.TARGET_TAG, suite.target) - 1 * localRootSpan.setTag(Tags.AI_GUARD_KEEP, true) - 1 * localRootSpan.setTag(AIGuardInternal.EVENT_TAG, true) - if (suite.target == 'tool') { - 1 * span.setTag(AIGuardInternal.TOOL_TAG, 'calc') - } - 1 * span.setTag(AIGuardInternal.ACTION_TAG, suite.action) - 1 * span.setTag(AIGuardInternal.REASON_TAG, suite.reason) - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> { - receivedMeta = it[1] as Map - return span - } - if (throwAbortError) { - 1 * span.addThrowable(_ as AIGuard.AIGuardAbortError) - } - - assertMeta(receivedMeta, suite) - assertRequest(request, suite.messages) - if (throwAbortError) { - error instanceof AIGuard.AIGuardAbortError - error.action == suite.action - error.reason == suite.reason - error.tags == suite.tags - error.tagProbabilities == suite.tagProbabilities - error.sds == [] - } else { - error == null - eval.action == suite.action - eval.reason == suite.reason - eval.tags == suite.tags - eval.tagProbabilities == suite.tagProbabilities - eval.sds == [] - } - assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false') - - where: - suite << TestSuite.build() - } - - void 'test evaluate block defaults to remote is_blocking_enabled'() { - given: - def request - final call = Mock(Call) { - execute() >> { - return mockResponse( - request, - 200, - [data: [attributes: [action: 'DENY', reason: 'Nope', tags: ['deny_everything'], is_blocking_enabled: remoteBlocking]]] - ) - } - } - final client = Mock(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - final aiguard = new AIGuardInternal(URL, HEADERS, client) - - when: - Throwable error = null - AIGuard.Evaluation eval = null - try { - eval = aiguard.evaluate(TOOL_CALL, options) - } catch (Throwable e) { - error = e - } - - then: - if (shouldBlock) { - error instanceof AIGuard.AIGuardAbortError - error.action == DENY - } else { - error == null - eval.action == DENY - } - - where: - options | remoteBlocking | shouldBlock - AIGuard.Options.DEFAULT | true | true - AIGuard.Options.DEFAULT | false | false - new AIGuard.Options().block(false) | true | false - } - - void 'test evaluate applies captured client ip tags to local root span'() { - given: - final requestContext = Mock(RequestContext) - localRootSpan.getRequestContext() >> requestContext - requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5') - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - 1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null - 1 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, '4.4.4.4') - 1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> null - 1 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, '2.3.4.5') - } - - void 'test evaluate does not overwrite existing client ip tags'() { - given: - final requestContext = Mock(RequestContext) - localRootSpan.getRequestContext() >> requestContext - requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5') - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - 1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '9.9.9.9' - 0 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, _) - 1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '8.8.8.8' - 0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _) - } - - void 'test evaluate is a noop for client ip tags when no data captured'() { - given: - final requestContext = Mock(RequestContext) - localRootSpan.getRequestContext() >> requestContext - requestContext.getClientIpAddressData() >> null - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - 0 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, _) - 0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _) - } - - void 'test evaluate is a noop for client ip tags when no request context'() { - given: - localRootSpan.getRequestContext() >> null - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - 0 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, _) - 0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _) - } - - void 'test evaluate with API errors'() { - given: - final errors = [[status: 400, title: 'Bad request']] - final aiguard = mockClient(404, [errors: errors]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - final exception = thrown(AIGuard.AIGuardClientError) - exception.errors == errors - 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) - assertTelemetry('ai_guard.requests', 'error:true') - } - - void 'test evaluate with invalid JSON'() { - given: - final aiguard = mockClient(200, [bad: 'This is an invalid response']) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - thrown(AIGuard.AIGuardClientError) - 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) - assertTelemetry('ai_guard.requests', 'error:true') - } - - void 'test evaluate with missing action'() { - given: - final aiguard = mockClient(200, [data: [attributes: [reason: 'I miss something']]]) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - thrown(AIGuard.AIGuardClientError) - 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) - assertTelemetry('ai_guard.requests', 'error:true') - } - - void 'test evaluate with non JSON response'() { - given: - final aiguard = mockClient(200, 'I am no JSON') - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - thrown(AIGuard.AIGuardClientError) - 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) - assertTelemetry('ai_guard.requests', 'error:true') - } - - void 'test evaluate with empty response'() { - given: - final aiguard = mockClient(200, null) - - when: - aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT) - - then: - thrown(AIGuard.AIGuardClientError) - 1 * span.addThrowable(_ as AIGuard.AIGuardClientError) - assertTelemetry('ai_guard.requests', 'error:true') - } - - void 'test message length truncation'() { - given: - final maxMessages = Config.get().getAiGuardMaxMessagesLength() - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - final messages = (0..maxMessages) - .collect { AIGuard.Message.message('user', "This is a prompt: ${it}") } - .toList() - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final received = (List) it[1].messages - assert received.size() == maxMessages - assert received.size() < messages.size() - } - assertTelemetry('ai_guard.truncated', 'type:messages') - } - - void 'test message content truncation'() { - given: - final maxContent = Config.get().getAiGuardMaxContentSize() - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]]) - final message = AIGuard.Message.message("user", (0..maxContent).collect { 'A' }.join()) - - when: - aiguard.evaluate([message], AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final received = (List) it[1].messages - received.last().with { - assert it.content.length() == maxContent - assert it.content.length() < message.content.length() - } - } - assertTelemetry('ai_guard.truncated', 'type:content') - } - - void 'test no messages'() { - given: - final aiguard = new AIGuardInternal(URL, HEADERS, Stub(OkHttpClient)) - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - thrown(IllegalArgumentException) - - - where: - messages << [[], null] - } - - void 'test evaluate with sds findings'() { - given: - final sdsFindings = [ - [ - rule_display_name: 'Credit Card Number', - rule_tag: 'credit_card', - category: 'pii', - matched_text: '4111111111111111', - location: [start_index: 10, end_index_exclusive: 26, path: 'messages[0].content[0].text'] - ], - [ - rule_display_name: 'Social Security Number', - rule_tag: 'ssn', - category: 'pii', - matched_text: '123-45-6789', - location: [start_index: 30, end_index_exclusive: 41, path: 'messages[1].tool_calls[0].function.arguments'] - ] - ] - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine', sds_findings: sdsFindings]]]) - Map receivedMeta - - when: - final result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - receivedMeta = it[1] as Map - return span - } - receivedMeta.sds == sdsFindings - result.sds == sdsFindings - } - - void 'test evaluate with empty sds findings'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine', sds_findings: sdsFindings]]]) - Map receivedMeta - - when: - final result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - receivedMeta = it[1] as Map - return span - } - !receivedMeta.containsKey('sds') - result.sds == (sdsFindings ?: []) - - where: - sdsFindings << [null, []] - } - - void 'test evaluate with sds findings in abort error'() { - given: - final sdsFindings = [ - [ - rule_display_name: 'Credit Card Number', - rule_tag: 'credit_card', - category: 'pii', - matched_text: '4111111111111111', - location: [start_index: 10, end_index_exclusive: 26, path: 'messages[0].content[0].text'] - ] - ] - final aiguard = mockClient(200, [data: [attributes: [action: 'ABORT', reason: 'PII detected', tags: ['pii'], sds_findings: sdsFindings, is_blocking_enabled: true]]]) - - when: - aiguard.evaluate(PROMPT, new AIGuard.Options().block(true)) - - then: - final error = thrown(AIGuard.AIGuardAbortError) - error.sds == sdsFindings - } - - void 'test missing tool name'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]) - - when: - aiguard.evaluate([AIGuard.Message.tool('call_1', 'Content')], AIGuard.Options.DEFAULT) - - then: - 1 * span.setTag(AIGuardInternal.TARGET_TAG, 'tool') - 0 * span.setTag(AIGuardInternal.TOOL_TAG, _) - } - - void 'map requires even number of params'() { - when: - AIGuardInternal.mapOf('1', '2', '3') - - then: - thrown(IllegalArgumentException) - } - - void 'test message immutability'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]]) - final messages = [ - new AIGuard.Message( - "assistant", - (String) null, - [AIGuard.ToolCall.toolCall('call_1', 'execute_shell', '{"cmd": "ls -lah"}')], - null - ) - ] - Map receivedMeta - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.finish() >> { - // modify the messages before serialization - messages.first().toolCalls.add( - AIGuard.ToolCall.toolCall('call_2', 'execute_shell', '{"cmd": "rm -rf"}') - ) - messages.add(AIGuard.Message.tool('call_1', 'dir1, dir2, dir3')) - } - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _ as Map) >> { - receivedMeta = it[1] as Map - return span - } - final metaStructMessages = receivedMeta.messages as List - metaStructMessages.size() != messages.size() - metaStructMessages.size() == 1 - metaStructMessages.first().toolCalls.size() != messages.first().toolCalls.size() - metaStructMessages.first().toolCalls.size() == 1 - } - - private AIGuardInternal mockClient(final int status, final Object response) { - def request - final call = Stub(Call) { - execute() >> { - return mockResponse(request, status, response) - } - } - final client = Stub(OkHttpClient) { - newCall(_ as Request) >> { - request = (Request) it[0] - return call - } - } - return new AIGuardInternal(URL, HEADERS, client) - } - - private static assertTelemetry(final String metric, final String...tags) { - final metrics = WafMetricCollector.get().with { - prepareMetrics() - drain() - } - final filtered = metrics.findAll { - it.namespace == 'appsec' - && it.metricName == metric - && it.tags == tags.toList() - } - assert filtered.size() == 1 : metrics - assert filtered*.value.sum() == 1 - return true - } - - private static assertMeta(final Map meta, final TestSuite suite) { - if (suite.tags) { - assert meta.attack_categories == suite.tags - } - if (suite.tagProbabilities) { - assert meta.tag_probs == suite.tagProbabilities - } - final receivedMessages = snakeCaseJson(meta.messages) - final expectedMessages = snakeCaseJson(suite.messages) - JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE) - return true - } - - private static assertRequest(final Request request, final List messages) { - assert request.url() == URL - assert request.method() == 'POST' - HEADERS.each { entry -> - assert request.header(entry.key) == entry.value - } - assert request.body().contentType().toString().contains('application/json') - final receivedBody = readRequestBody(request.body()) - final expectedBody = snakeCaseJson([data: [attributes: [messages: messages, meta: [service: 'ai_guard_test', env: 'test']]]]) - JSONAssert.assertEquals(expectedBody, receivedBody, JSONCompareMode.NON_EXTENSIBLE) - return true - } - - private static String snakeCaseJson(final Object value) { - MAPPER.writeValueAsString(value) - } - - private static String readRequestBody(final RequestBody body) { - final output = new ByteArrayOutputStream() - final buffer = Okio.buffer(Okio.sink(output)) - body.writeTo(buffer) - buffer.flush() - return new String(output.toByteArray()) - } - - private static Response mockResponse(final Request request, final int status, final Object body) { - return new Response.Builder() - .protocol(Protocol.HTTP_1_1) - .message('ok') - .request(request) - .code(status) - .body(body == null ? null : ResponseBody.create(MediaType.parse('application/json'), MOSHI.adapter(Object).toJson(body))) - .build() - } - - void 'test JSON serialization with text content parts'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final messages = [AIGuard.Message.message('user', [AIGuard.ContentPart.text('Hello world')])] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages.size() == 1 - assert receivedMessages[0].contentParts.size() == 1 - assert receivedMessages[0].contentParts[0].type == AIGuard.ContentPart.Type.TEXT - assert receivedMessages[0].contentParts[0].text == 'Hello world' - return span - } - } - - void 'test JSON serialization with image_url content parts'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final messages = [ - AIGuard.Message.message('user', [AIGuard.ContentPart.imageUrl('https://example.com/image.jpg')]) - ] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages.size() == 1 - assert receivedMessages[0].contentParts.size() == 1 - assert receivedMessages[0].contentParts[0].type == AIGuard.ContentPart.Type.IMAGE_URL - assert receivedMessages[0].contentParts[0].imageUrl.url == 'https://example.com/image.jpg' - return span - } - } - - void 'test JSON serialization with mixed content parts'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final messages = [ - AIGuard.Message.message('user', [ - AIGuard.ContentPart.text('Describe this image:'), - AIGuard.ContentPart.imageUrl('https://example.com/image.jpg'), - AIGuard.ContentPart.text('What do you see?') - ]) - ] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages.size() == 1 - assert receivedMessages[0].contentParts.size() == 3 - assert receivedMessages[0].contentParts[0].type == AIGuard.ContentPart.Type.TEXT - assert receivedMessages[0].contentParts[0].text == 'Describe this image:' - assert receivedMessages[0].contentParts[1].type == AIGuard.ContentPart.Type.IMAGE_URL - assert receivedMessages[0].contentParts[1].imageUrl.url == 'https://example.com/image.jpg' - assert receivedMessages[0].contentParts[2].type == AIGuard.ContentPart.Type.TEXT - assert receivedMessages[0].contentParts[2].text == 'What do you see?' - return span - } - } - - void 'test content parts order is preserved'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final parts = (0..4).collect { - it % 2 == 0 ? AIGuard.ContentPart.text("Text $it") : AIGuard.ContentPart.imageUrl("https://example.com/image${it}.jpg") - } - final messages = [AIGuard.Message.message('user', parts)] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages[0].contentParts.size() == 5 - (0..4).each { i -> - if (i % 2 == 0) { - assert receivedMessages[0].contentParts[i].type == AIGuard.ContentPart.Type.TEXT - assert receivedMessages[0].contentParts[i].text == "Text $i" - } else { - assert receivedMessages[0].contentParts[i].type == AIGuard.ContentPart.Type.IMAGE_URL - assert receivedMessages[0].contentParts[i].imageUrl.url == "https://example.com/image${i}.jpg" - } - } - return span - } - } - - void 'test content part text truncation'() { - given: - final maxContent = Config.get().getAiGuardMaxContentSize() - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final longText = (0..maxContent).collect { 'A' }.join() - final messages = [ - AIGuard.Message.message('user', [AIGuard.ContentPart.text(longText), AIGuard.ContentPart.text('Short text')]) - ] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages[0].contentParts.size() == 2 - assert receivedMessages[0].contentParts[0].text.length() == maxContent - assert receivedMessages[0].contentParts[0].text.length() < longText.length() - assert receivedMessages[0].contentParts[1].text == 'Short text' - return span - } - assertTelemetry('ai_guard.truncated', 'type:content') - } - - void 'test content part image_url not truncated even with long data URI'() { - given: - final maxContent = Config.get().getAiGuardMaxContentSize() - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - // Create a very long data URI (longer than max content size) - final longDataUri = 'data:image/png;base64,' + (0..(maxContent + 1000)).collect { 'A' }.join() - final messages = [ - AIGuard.Message.message('user', [ - AIGuard.ContentPart.text('Image:'), - AIGuard.ContentPart.imageUrl(longDataUri) - ]) - ] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages[0].contentParts.size() == 2 - assert receivedMessages[0].contentParts[1].type == AIGuard.ContentPart.Type.IMAGE_URL - // Image URL should NOT be truncated - assert receivedMessages[0].contentParts[1].imageUrl.url == longDataUri - assert receivedMessages[0].contentParts[1].imageUrl.url.length() > maxContent - return span - } - } - - void 'test backward compatibility with string content'() { - given: - final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Good']]]) - final messages = [AIGuard.Message.message('user', 'Hello world')] - - when: - aiguard.evaluate(messages, AIGuard.Options.DEFAULT) - - then: - 1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> { - final meta = it[1] as Map - final receivedMessages = meta.messages as List - assert receivedMessages.size() == 1 - assert receivedMessages[0].content == 'Hello world' - assert receivedMessages[0].contentParts == null - return span - } - } - - private static class TestSuite { - private final AIGuard.Action action - private final String reason - private final List tags - private final Map tagProbabilities - private final boolean blocking - private final String description - private final String target - private final List messages - - TestSuite(AIGuard.Action action, String reason, Map tagProbabilities, boolean blocking, String description, String target, List messages) { - this.action = action - this.reason = reason - this.tags = new ArrayList<>(tagProbabilities.keySet()) - this.tagProbabilities = tagProbabilities - this.blocking = blocking - this.description = description - this.target = target - this.messages = messages - } - - static List build() { - def actionValues = [ - [ALLOW, 'Go ahead', [:]], - [DENY, 'Nope', ['deny_everything': 0.2D, 'test_deny': 0.8D]], - [ABORT, 'Kill it with fire', ['alarm_tag': 0.1D, 'abort_everything': 0.9D]] - ] - def blockingValues = [true, false] - def suiteValues = [ - ['tool call', 'tool', TOOL_CALL], - ['tool output', 'tool', TOOL_OUTPUT], - ['prompt', 'prompt', PROMPT] - ] - return combinations([actionValues, blockingValues, suiteValues] as Iterable) - .collect { action, blocking, suite -> - new TestSuite(action[0], action[1], action[2], blocking, suite[0], suite[1], suite[2]) - } - } - - - @Override - String toString() { - return "TestSuite{" + - "description='" + description + '\'' + - ", action=" + action + - ", reason='" + reason + '\'' + - ", blocking=" + blocking + - ", target='" + target + '\'' + - ", messages=" + messages.collect {it.content } + '\'' + - ", tags=" + tags + - '}' - } - } -} diff --git a/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardInternalTests.java b/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardInternalTests.java new file mode 100644 index 00000000000..f4494378f11 --- /dev/null +++ b/dd-java-agent/agent-aiguard/src/test/java/com/datadog/aiguard/AIGuardInternalTests.java @@ -0,0 +1,1148 @@ +package com.datadog.aiguard; + +import static datadog.trace.api.aiguard.AIGuard.Action.ABORT; +import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW; +import static datadog.trace.api.aiguard.AIGuard.Action.DENY; +import static datadog.trace.junit.utils.config.WithConfigExtension.injectEnvConfig; +import static datadog.trace.junit.utils.config.WithConfigExtension.removeEnvConfig; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.squareup.moshi.Moshi; +import datadog.common.version.VersionInfo; +import datadog.trace.api.Config; +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.telemetry.WafMetricCollector; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.ClientIpAddressData; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import datadog.trace.test.util.DDJavaSpecification; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import okhttp3.Call; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okio.BufferedSink; +import okio.Okio; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.skyscreamer.jsonassert.JSONAssert; +import org.skyscreamer.jsonassert.JSONCompareMode; +import org.tabletest.junit.TableTest; + +class AIGuardInternalTests extends DDJavaSpecification { + + private static final HttpUrl URL = + HttpUrl.parse("https://app.datadoghq.com/api/v2/ai-guard/evaluate"); + + private static final Map HEADERS = buildHeaders(); + + private static final AgentTracer.TracerAPI ORIGINAL_TRACER = AgentTracer.get(); + + private static final Moshi MOSHI = new Moshi.Builder().build(); + + private static final ObjectMapper MAPPER = + new ObjectMapper() + .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) + .setDefaultPropertyInclusion( + JsonInclude.Value.construct( + JsonInclude.Include.NON_NULL, JsonInclude.Include.NON_NULL)); + + private static final List TOOL_CALL = + Arrays.asList( + AIGuard.Message.message("system", "You are a beautiful AI assistant"), + AIGuard.Message.message("user", "What is 2 + 2"), + AIGuard.Message.assistant( + AIGuard.ToolCall.toolCall( + "call_1", "calc", "{ \"operator\": \"+\", \"args\": [2, 2] }"))); + + private static final List TOOL_OUTPUT = + appendMessage(TOOL_CALL, AIGuard.Message.tool("call_1", "5")); + + private static final List PROMPT = + appendMessages( + TOOL_OUTPUT, + AIGuard.Message.message("assistant", "2 + 2 is 5"), + AIGuard.Message.message("user", "")); + + private AgentSpan span; + private AgentSpan localRootSpan; + + private static Map buildHeaders() { + Map headers = new LinkedHashMap<>(); + headers.put("DD-API-KEY", "api"); + headers.put("DD-APPLICATION-KEY", "app"); + headers.put("DD-AI-GUARD-VERSION", VersionInfo.VERSION); + headers.put("DD-AI-GUARD-SOURCE", "SDK"); + headers.put("DD-AI-GUARD-LANGUAGE", "jvm"); + return headers; + } + + private static List appendMessage( + List base, AIGuard.Message extra) { + List result = new ArrayList<>(base); + result.add(extra); + return result; + } + + private static List appendMessages( + List base, AIGuard.Message... extra) { + List result = new ArrayList<>(base); + result.addAll(Arrays.asList(extra)); + return result; + } + + @BeforeEach + void setup() { + injectEnvConfig("SERVICE", "ai_guard_test"); + injectEnvConfig("ENV", "test"); + + span = mock(AgentSpan.class); + localRootSpan = mock(AgentSpan.class); + when(span.getLocalRootSpan()).thenReturn(localRootSpan); + + AgentTracer.SpanBuilder builder = mock(AgentTracer.SpanBuilder.class); + when(builder.start()).thenReturn(span); + + AgentTracer.TracerAPI tracer = mock(AgentTracer.TracerAPI.class); + when(tracer.buildSpan(anyString(), any(CharSequence.class))).thenReturn(builder); + AgentTracer.forceRegister(tracer); + + WafMetricCollector wafMetricCollector = WafMetricCollector.get(); + wafMetricCollector.prepareMetrics(); + wafMetricCollector.drain(); + } + + @AfterEach + void cleanup() { + AgentTracer.forceRegister(ORIGINAL_TRACER); + AIGuardInternal.uninstall(); + } + + @ParameterizedTest(name = "[{index}] {0}") + @TableTest({ + "scenario | apiKey | appKey", + "missing app | apiKey | ", + "empty app | apiKey | '' ", + "missing api | | appKey", + "empty api | '' | appKey", + "both missing | | ", + "both empty | '' | '' " + }) + void testMissingApiAppKeys(String scenario, String apiKey, String appKey) { + if (apiKey != null) { + injectEnvConfig("API_KEY", apiKey); + } + if (appKey != null) { + injectEnvConfig("APP_KEY", appKey); + } + + assertThrows(AIGuardInternal.BadConfigurationException.class, AIGuardInternal::install); + } + + @ParameterizedTest(name = "[{index}] {0}") + @TableTest({ + "scenario | endpoint | site | expected ", + "explicit endpoint | 'https://test' | | 'https://test/evaluate' ", + "no site | | | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate'", + "default site | | 'datadoghq.com' | 'https://app.datadoghq.com/api/v2/ai-guard/evaluate'", + "staging site | | 'datad0g.com' | 'https://app.datad0g.com/api/v2/ai-guard/evaluate' " + }) + void testEndpointDiscovery(String scenario, String endpoint, String site, String expected) + throws Exception { + injectEnvConfig("API_KEY", "api"); + injectEnvConfig("APP_KEY", "app"); + if (endpoint != null) { + injectEnvConfig("AI_GUARD_ENDPOINT", endpoint); + } else { + removeEnvConfig("AI_GUARD_ENDPOINT"); + } + if (site != null) { + injectEnvConfig("SITE", site); + } else { + removeEnvConfig("SITE"); + } + + AIGuardInternal.install(); + + Field evaluator = AIGuard.class.getDeclaredField("EVALUATOR"); + evaluator.setAccessible(true); + AIGuardInternal internal = (AIGuardInternal) evaluator.get(null); + Field urlField = AIGuardInternal.class.getDeclaredField("url"); + urlField.setAccessible(true); + HttpUrl url = (HttpUrl) urlField.get(internal); + assertEquals(expected, url.toString()); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("testEvaluateArguments") + @SuppressWarnings("unchecked") + void testEvaluate(TestSuite suite) throws Exception { + boolean throwAbortError = suite.blocking && suite.action != ALLOW; + + Map attributes = new LinkedHashMap<>(); + attributes.put("action", suite.action); + attributes.put("reason", suite.reason); + attributes.put("tags", suite.tags != null ? suite.tags : emptyList()); + attributes.put( + "tag_probs", suite.tagProbabilities != null ? suite.tagProbabilities : emptyMap()); + attributes.put("is_blocking_enabled", suite.blocking); + + RequestHolder holder = new RequestHolder(); + AIGuardInternal aiguard = + mockClient(holder, 200, mapOf("data", mapOf("attributes", attributes))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + + AIGuard.Evaluation eval = null; + Throwable error = null; + try { + eval = aiguard.evaluate(suite.messages, new AIGuard.Options().block(suite.blocking)); + } catch (Throwable e) { + error = e; + } + + verify(span).setTag(AIGuardInternal.TARGET_TAG, suite.target); + verify(localRootSpan).setTag(Tags.AI_GUARD_KEEP, true); + verify(localRootSpan).setTag(AIGuardInternal.EVENT_TAG, true); + if ("tool".equals(suite.target)) { + verify(span).setTag(AIGuardInternal.TOOL_TAG, "calc"); + } + verify(span).setTag(AIGuardInternal.ACTION_TAG, suite.action); + verify(span).setTag(AIGuardInternal.REASON_TAG, suite.reason); + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + if (throwAbortError) { + verify(span).addThrowable(any(AIGuard.AIGuardAbortError.class)); + } + + Map receivedMeta = metaCaptor.getValue(); + assertMeta(receivedMeta, suite); + assertRequest(holder.request, suite.messages); + + if (throwAbortError) { + assertTrue(error instanceof AIGuard.AIGuardAbortError); + AIGuard.AIGuardAbortError abort = (AIGuard.AIGuardAbortError) error; + assertEquals(suite.action, abort.getAction()); + assertEquals(suite.reason, abort.getReason()); + assertEquals(suite.tags, abort.getTags()); + assertEquals(suite.tagProbabilities, abort.getTagProbabilities()); + assertEquals(emptyList(), abort.getSds()); + } else { + assertNull(error); + assertEquals(suite.action, eval.getAction()); + assertEquals(suite.reason, eval.getReason()); + assertEquals(suite.tags, eval.getTags()); + assertEquals(suite.tagProbabilities, eval.getTagProbabilities()); + assertEquals(emptyList(), eval.getSds()); + } + assertTelemetry( + "ai_guard.requests", "action:" + suite.action, "block:" + throwAbortError, "error:false"); + } + + static Stream testEvaluateArguments() { + return TestSuite.build().stream().map(s -> arguments(s)); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("testEvaluateBlockDefaultsArguments") + void testEvaluateBlockDefaultsToRemoteIsBlockingEnabled( + String scenario, AIGuard.Options options, boolean remoteBlocking, boolean shouldBlock) { + Map attributes = new LinkedHashMap<>(); + attributes.put("action", "DENY"); + attributes.put("reason", "Nope"); + attributes.put("tags", singletonList("deny_everything")); + attributes.put("is_blocking_enabled", remoteBlocking); + + RequestHolder holder = new RequestHolder(); + AIGuardInternal aiguard = + mockClient(holder, 200, mapOf("data", mapOf("attributes", attributes))); + + AIGuard.Evaluation eval = null; + Throwable error = null; + try { + eval = aiguard.evaluate(TOOL_CALL, options); + } catch (Throwable e) { + error = e; + } + + if (shouldBlock) { + assertTrue(error instanceof AIGuard.AIGuardAbortError); + assertEquals(DENY, ((AIGuard.AIGuardAbortError) error).getAction()); + } else { + assertNull(error); + assertEquals(DENY, eval.getAction()); + } + } + + static Stream testEvaluateBlockDefaultsArguments() { + return Stream.of( + arguments("default options + remote blocking", AIGuard.Options.DEFAULT, true, true), + arguments("default options + no remote blocking", AIGuard.Options.DEFAULT, false, false), + arguments( + "explicit no block + remote blocking", + new AIGuard.Options().block(false), + true, + false)); + } + + @Test + void testEvaluateAppliesCapturedClientIpTagsToLocalRootSpan() { + RequestContext requestContext = mock(RequestContext.class); + when(localRootSpan.getRequestContext()).thenReturn(requestContext); + when(requestContext.getClientIpAddressData()) + .thenReturn(new ClientIpAddressData("4.4.4.4", "2.3.4.5")); + when(localRootSpan.getTag(Tags.NETWORK_CLIENT_IP)).thenReturn(null); + when(localRootSpan.getTag(Tags.HTTP_CLIENT_IP)).thenReturn(null); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT); + + verify(localRootSpan).setTag(Tags.NETWORK_CLIENT_IP, "4.4.4.4"); + verify(localRootSpan).setTag(Tags.HTTP_CLIENT_IP, "2.3.4.5"); + } + + @Test + void testEvaluateDoesNotOverwriteExistingClientIpTags() { + RequestContext requestContext = mock(RequestContext.class); + when(localRootSpan.getRequestContext()).thenReturn(requestContext); + when(requestContext.getClientIpAddressData()) + .thenReturn(new ClientIpAddressData("4.4.4.4", "2.3.4.5")); + when(localRootSpan.getTag(Tags.NETWORK_CLIENT_IP)).thenReturn("9.9.9.9"); + when(localRootSpan.getTag(Tags.HTTP_CLIENT_IP)).thenReturn("8.8.8.8"); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT); + + verify(localRootSpan, never()).setTag(eq(Tags.NETWORK_CLIENT_IP), any(String.class)); + verify(localRootSpan, never()).setTag(eq(Tags.HTTP_CLIENT_IP), any(String.class)); + } + + @Test + void testEvaluateIsANoopForClientIpTagsWhenNoDataCaptured() { + RequestContext requestContext = mock(RequestContext.class); + when(localRootSpan.getRequestContext()).thenReturn(requestContext); + when(requestContext.getClientIpAddressData()).thenReturn(null); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT); + + verify(localRootSpan, never()).setTag(eq(Tags.NETWORK_CLIENT_IP), any(String.class)); + verify(localRootSpan, never()).setTag(eq(Tags.HTTP_CLIENT_IP), any(String.class)); + } + + @Test + void testEvaluateIsANoopForClientIpTagsWhenNoRequestContext() { + when(localRootSpan.getRequestContext()).thenReturn(null); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT); + + verify(localRootSpan, never()).setTag(eq(Tags.NETWORK_CLIENT_IP), any(String.class)); + verify(localRootSpan, never()).setTag(eq(Tags.HTTP_CLIENT_IP), any(String.class)); + } + + @Test + void testEvaluateWithApiErrors() { + List> errors = new ArrayList<>(); + errors.add(mapOf("status", 400.0, "title", "Bad request")); + AIGuardInternal aiguard = mockClient(new RequestHolder(), 404, mapOf("errors", errors)); + + AIGuard.AIGuardClientError exception = + assertThrows( + AIGuard.AIGuardClientError.class, + () -> aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)); + assertEquals(errors, exception.getErrors()); + verify(span).addThrowable(any(AIGuard.AIGuardClientError.class)); + assertTelemetry("ai_guard.requests", "error:true"); + } + + @Test + void testEvaluateWithInvalidJson() { + AIGuardInternal aiguard = + mockClient(new RequestHolder(), 200, mapOf("bad", "This is an invalid response")); + + assertThrows( + AIGuard.AIGuardClientError.class, + () -> aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)); + verify(span).addThrowable(any(AIGuard.AIGuardClientError.class)); + assertTelemetry("ai_guard.requests", "error:true"); + } + + @Test + void testEvaluateWithMissingAction() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("reason", "I miss something")))); + + assertThrows( + AIGuard.AIGuardClientError.class, + () -> aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)); + verify(span).addThrowable(any(AIGuard.AIGuardClientError.class)); + assertTelemetry("ai_guard.requests", "error:true"); + } + + @Test + void testEvaluateWithNonJsonResponse() { + AIGuardInternal aiguard = mockClient(new RequestHolder(), 200, "I am no JSON"); + + assertThrows( + AIGuard.AIGuardClientError.class, + () -> aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)); + verify(span).addThrowable(any(AIGuard.AIGuardClientError.class)); + assertTelemetry("ai_guard.requests", "error:true"); + } + + @Test + void testEvaluateWithEmptyResponse() { + AIGuardInternal aiguard = mockClient(new RequestHolder(), 200, null); + + assertThrows( + AIGuard.AIGuardClientError.class, + () -> aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)); + verify(span).addThrowable(any(AIGuard.AIGuardClientError.class)); + assertTelemetry("ai_guard.requests", "error:true"); + } + + @Test + @SuppressWarnings("unchecked") + void testMessageLengthTruncation() { + int maxMessages = Config.get().getAiGuardMaxMessagesLength(); + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + List messages = new ArrayList<>(); + for (int i = 0; i <= maxMessages; i++) { + messages.add(AIGuard.Message.message("user", "This is a prompt: " + i)); + } + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + assertEquals(maxMessages, received.size()); + assertTrue(received.size() < messages.size()); + assertTelemetry("ai_guard.truncated", "type:messages"); + } + + @Test + @SuppressWarnings("unchecked") + void testMessageContentTruncation() { + int maxContent = Config.get().getAiGuardMaxContentSize(); + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "It is fine")))); + + StringBuilder content = new StringBuilder(); + for (int i = 0; i <= maxContent; i++) { + content.append('A'); + } + AIGuard.Message message = AIGuard.Message.message("user", content.toString()); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(singletonList(message), AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + AIGuard.Message last = received.get(received.size() - 1); + assertEquals(maxContent, last.getContent().length()); + assertTrue(last.getContent().length() < message.getContent().length()); + assertTelemetry("ai_guard.truncated", "type:content"); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("testNoMessagesArguments") + void testNoMessages(String scenario, List messages) { + AIGuardInternal aiguard = new AIGuardInternal(URL, HEADERS, mock(OkHttpClient.class)); + assertThrows( + IllegalArgumentException.class, () -> aiguard.evaluate(messages, AIGuard.Options.DEFAULT)); + } + + static Stream testNoMessagesArguments() { + return Stream.of(arguments("empty list", emptyList()), arguments("null list", null)); + } + + @Test + @SuppressWarnings("unchecked") + void testEvaluateWithSdsFindings() { + List> sdsFindings = new ArrayList<>(); + Map finding1 = new LinkedHashMap<>(); + finding1.put("rule_display_name", "Credit Card Number"); + finding1.put("rule_tag", "credit_card"); + finding1.put("category", "pii"); + finding1.put("matched_text", "4111111111111111"); + finding1.put( + "location", + mapOf( + "start_index", + 10.0, + "end_index_exclusive", + 26.0, + "path", + "messages[0].content[0].text")); + sdsFindings.add(finding1); + Map finding2 = new LinkedHashMap<>(); + finding2.put("rule_display_name", "Social Security Number"); + finding2.put("rule_tag", "ssn"); + finding2.put("category", "pii"); + finding2.put("matched_text", "123-45-6789"); + finding2.put( + "location", + mapOf( + "start_index", + 30.0, + "end_index_exclusive", + 41.0, + "path", + "messages[1].tool_calls[0].function.arguments")); + sdsFindings.add(finding2); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf( + "data", + mapOf( + "attributes", + mapOf( + "action", "ALLOW", "reason", "It is fine", "sds_findings", sdsFindings)))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + AIGuard.Evaluation result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + Map receivedMeta = metaCaptor.getValue(); + assertEquals(sdsFindings, receivedMeta.get("sds")); + assertEquals(sdsFindings, result.getSds()); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("testEvaluateWithEmptySdsFindingsArguments") + @SuppressWarnings("unchecked") + void testEvaluateWithEmptySdsFindings(String scenario, List sdsFindings) { + Map attrs = new LinkedHashMap<>(); + attrs.put("action", "ALLOW"); + attrs.put("reason", "It is fine"); + attrs.put("sds_findings", sdsFindings); + AIGuardInternal aiguard = + mockClient(new RequestHolder(), 200, mapOf("data", mapOf("attributes", attrs))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + AIGuard.Evaluation result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + Map receivedMeta = metaCaptor.getValue(); + assertFalse(receivedMeta.containsKey("sds")); + assertEquals(sdsFindings != null ? sdsFindings : emptyList(), result.getSds()); + } + + static Stream testEvaluateWithEmptySdsFindingsArguments() { + return Stream.of(arguments("null findings", null), arguments("empty findings", emptyList())); + } + + @Test + void testEvaluateWithSdsFindingsInAbortError() { + List> sdsFindings = new ArrayList<>(); + Map finding = new LinkedHashMap<>(); + finding.put("rule_display_name", "Credit Card Number"); + finding.put("rule_tag", "credit_card"); + finding.put("category", "pii"); + finding.put("matched_text", "4111111111111111"); + finding.put( + "location", + mapOf( + "start_index", + 10.0, + "end_index_exclusive", + 26.0, + "path", + "messages[0].content[0].text")); + sdsFindings.add(finding); + + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf( + "data", + mapOf( + "attributes", + mapOf( + "action", + "ABORT", + "reason", + "PII detected", + "tags", + singletonList("pii"), + "sds_findings", + sdsFindings, + "is_blocking_enabled", + true)))); + + AIGuard.AIGuardAbortError error = + assertThrows( + AIGuard.AIGuardAbortError.class, + () -> aiguard.evaluate(PROMPT, new AIGuard.Options().block(true))); + assertEquals(sdsFindings, error.getSds()); + } + + @Test + void testMissingToolName() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Just do it")))); + + aiguard.evaluate( + singletonList(AIGuard.Message.tool("call_1", "Content")), AIGuard.Options.DEFAULT); + + verify(span).setTag(AIGuardInternal.TARGET_TAG, "tool"); + verify(span, never()).setTag(eq(AIGuardInternal.TOOL_TAG), anyString()); + } + + @Test + void mapRequiresEvenNumberOfParams() throws Exception { + Method mapOf = AIGuardInternal.class.getDeclaredMethod("mapOf", String[].class); + mapOf.setAccessible(true); + InvocationTargetException invocation = + assertThrows( + InvocationTargetException.class, + () -> mapOf.invoke(null, (Object) new String[] {"1", "2", "3"})); + assertTrue(invocation.getTargetException() instanceof IllegalArgumentException); + } + + @Test + @SuppressWarnings("unchecked") + void testMessageImmutability() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Just do it")))); + + List toolCalls = new ArrayList<>(); + toolCalls.add(AIGuard.ToolCall.toolCall("call_1", "execute_shell", "{\"cmd\": \"ls -lah\"}")); + List messages = new ArrayList<>(); + messages.add(new AIGuard.Message("assistant", (String) null, toolCalls, null)); + + doAnswer( + invocation -> { + messages + .get(0) + .getToolCalls() + .add( + AIGuard.ToolCall.toolCall( + "call_2", "execute_shell", "{\"cmd\": \"rm -rf\"}")); + messages.add(AIGuard.Message.tool("call_1", "dir1, dir2, dir3")); + return null; + }) + .when(span) + .finish(); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List metaStructMessages = + (List) metaCaptor.getValue().get("messages"); + assertNotEquals(messages.size(), metaStructMessages.size()); + assertEquals(1, metaStructMessages.size()); + assertNotEquals( + messages.get(0).getToolCalls().size(), metaStructMessages.get(0).getToolCalls().size()); + assertEquals(1, metaStructMessages.get(0).getToolCalls().size()); + } + + @Test + @SuppressWarnings("unchecked") + void testJsonSerializationWithTextContentParts() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + List messages = + singletonList( + AIGuard.Message.message( + "user", singletonList(AIGuard.ContentPart.text("Hello world")))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + assertEquals(1, received.size()); + assertEquals(1, received.get(0).getContentParts().size()); + assertEquals(AIGuard.ContentPart.Type.TEXT, received.get(0).getContentParts().get(0).getType()); + assertEquals("Hello world", received.get(0).getContentParts().get(0).getText()); + } + + @Test + @SuppressWarnings("unchecked") + void testJsonSerializationWithImageUrlContentParts() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + List messages = + singletonList( + AIGuard.Message.message( + "user", + singletonList(AIGuard.ContentPart.imageUrl("https://example.com/image.jpg")))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + assertEquals(1, received.size()); + assertEquals(1, received.get(0).getContentParts().size()); + assertEquals( + AIGuard.ContentPart.Type.IMAGE_URL, received.get(0).getContentParts().get(0).getType()); + assertEquals( + "https://example.com/image.jpg", + received.get(0).getContentParts().get(0).getImageUrl().getUrl()); + } + + @Test + @SuppressWarnings("unchecked") + void testJsonSerializationWithMixedContentParts() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + List parts = + Arrays.asList( + AIGuard.ContentPart.text("Describe this image:"), + AIGuard.ContentPart.imageUrl("https://example.com/image.jpg"), + AIGuard.ContentPart.text("What do you see?")); + List messages = singletonList(AIGuard.Message.message("user", parts)); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + assertEquals(1, received.size()); + List rcvParts = received.get(0).getContentParts(); + assertEquals(3, rcvParts.size()); + assertEquals(AIGuard.ContentPart.Type.TEXT, rcvParts.get(0).getType()); + assertEquals("Describe this image:", rcvParts.get(0).getText()); + assertEquals(AIGuard.ContentPart.Type.IMAGE_URL, rcvParts.get(1).getType()); + assertEquals("https://example.com/image.jpg", rcvParts.get(1).getImageUrl().getUrl()); + assertEquals(AIGuard.ContentPart.Type.TEXT, rcvParts.get(2).getType()); + assertEquals("What do you see?", rcvParts.get(2).getText()); + } + + @Test + @SuppressWarnings("unchecked") + void testContentPartsOrderIsPreserved() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + List parts = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + parts.add( + i % 2 == 0 + ? AIGuard.ContentPart.text("Text " + i) + : AIGuard.ContentPart.imageUrl("https://example.com/image" + i + ".jpg")); + } + List messages = singletonList(AIGuard.Message.message("user", parts)); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + List rcvParts = received.get(0).getContentParts(); + assertEquals(5, rcvParts.size()); + for (int i = 0; i < 5; i++) { + if (i % 2 == 0) { + assertEquals(AIGuard.ContentPart.Type.TEXT, rcvParts.get(i).getType()); + assertEquals("Text " + i, rcvParts.get(i).getText()); + } else { + assertEquals(AIGuard.ContentPart.Type.IMAGE_URL, rcvParts.get(i).getType()); + assertEquals( + "https://example.com/image" + i + ".jpg", rcvParts.get(i).getImageUrl().getUrl()); + } + } + } + + @Test + @SuppressWarnings("unchecked") + void testContentPartTextTruncation() { + int maxContent = Config.get().getAiGuardMaxContentSize(); + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + StringBuilder longText = new StringBuilder(); + for (int i = 0; i <= maxContent; i++) { + longText.append('A'); + } + List messages = + singletonList( + AIGuard.Message.message( + "user", + Arrays.asList( + AIGuard.ContentPart.text(longText.toString()), + AIGuard.ContentPart.text("Short text")))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + List rcvParts = received.get(0).getContentParts(); + assertEquals(2, rcvParts.size()); + assertEquals(maxContent, rcvParts.get(0).getText().length()); + assertTrue(rcvParts.get(0).getText().length() < longText.length()); + assertEquals("Short text", rcvParts.get(1).getText()); + assertTelemetry("ai_guard.truncated", "type:content"); + } + + @Test + @SuppressWarnings("unchecked") + void testContentPartImageUrlNotTruncatedEvenWithLongDataUri() { + int maxContent = Config.get().getAiGuardMaxContentSize(); + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + StringBuilder data = new StringBuilder("data:image/png;base64,"); + for (int i = 0; i <= maxContent + 1000; i++) { + data.append('A'); + } + String longDataUri = data.toString(); + List messages = + singletonList( + AIGuard.Message.message( + "user", + Arrays.asList( + AIGuard.ContentPart.text("Image:"), + AIGuard.ContentPart.imageUrl(longDataUri)))); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + List rcvParts = received.get(0).getContentParts(); + assertEquals(2, rcvParts.size()); + assertEquals(AIGuard.ContentPart.Type.IMAGE_URL, rcvParts.get(1).getType()); + assertEquals(longDataUri, rcvParts.get(1).getImageUrl().getUrl()); + assertTrue(rcvParts.get(1).getImageUrl().getUrl().length() > maxContent); + } + + @Test + @SuppressWarnings("unchecked") + void testBackwardCompatibilityWithStringContent() { + AIGuardInternal aiguard = + mockClient( + new RequestHolder(), + 200, + mapOf("data", mapOf("attributes", mapOf("action", "ALLOW", "reason", "Good")))); + + List messages = singletonList(AIGuard.Message.message("user", "Hello world")); + + ArgumentCaptor> metaCaptor = ArgumentCaptor.forClass(Map.class); + aiguard.evaluate(messages, AIGuard.Options.DEFAULT); + + verify(span).setMetaStruct(eq(AIGuardInternal.META_STRUCT_TAG), metaCaptor.capture()); + List received = (List) metaCaptor.getValue().get("messages"); + assertEquals(1, received.size()); + assertEquals("Hello world", received.get(0).getContent()); + assertNull(received.get(0).getContentParts()); + } + + private AIGuardInternal mockClient(RequestHolder holder, int status, Object response) { + Call call = mock(Call.class); + try { + when(call.execute()).thenAnswer(invocation -> mockResponse(holder.request, status, response)); + } catch (IOException e) { + fail(e); + } + OkHttpClient client = mock(OkHttpClient.class); + when(client.newCall(any(Request.class))) + .thenAnswer( + invocation -> { + holder.request = invocation.getArgument(0); + return call; + }); + return new AIGuardInternal(URL, HEADERS, client); + } + + private static void assertTelemetry(String metric, String... tags) { + WafMetricCollector wafMetricCollector = WafMetricCollector.get(); + wafMetricCollector.prepareMetrics(); + Collection metrics = wafMetricCollector.drain(); + List tagList = Arrays.asList(tags); + List filtered = + metrics.stream() + .filter( + m -> + "appsec".equals(m.namespace) + && metric.equals(m.metricName) + && tagList.equals(m.tags)) + .collect(Collectors.toList()); + assertEquals(1, filtered.size(), () -> "metrics: " + metrics); + long sum = filtered.stream().mapToLong(m -> m.value.longValue()).sum(); + assertEquals(1, sum); + } + + @SuppressWarnings("unchecked") + private static void assertMeta(Map meta, TestSuite suite) throws Exception { + if (suite.tags != null && !suite.tags.isEmpty()) { + assertEquals(suite.tags, meta.get("attack_categories")); + } + if (suite.tagProbabilities != null && !suite.tagProbabilities.isEmpty()) { + assertEquals(suite.tagProbabilities, meta.get("tag_probs")); + } + String receivedMessages = snakeCaseJson(meta.get("messages")); + String expectedMessages = snakeCaseJson(suite.messages); + JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE); + } + + private static void assertRequest(Request request, List messages) + throws Exception { + assertEquals(URL, request.url()); + assertEquals("POST", request.method()); + for (Map.Entry entry : HEADERS.entrySet()) { + assertEquals(entry.getValue(), request.header(entry.getKey())); + } + assertTrue(request.body().contentType().toString().contains("application/json")); + String receivedBody = readRequestBody(request.body()); + Map expected = + mapOf( + "data", + mapOf( + "attributes", + mapOf( + "messages", + messages, + "meta", + mapOf("service", "ai_guard_test", "env", "test")))); + String expectedBody = snakeCaseJson(expected); + JSONAssert.assertEquals(expectedBody, receivedBody, JSONCompareMode.NON_EXTENSIBLE); + } + + private static String snakeCaseJson(Object value) throws Exception { + return MAPPER.writeValueAsString(value); + } + + private static String readRequestBody(RequestBody body) throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + BufferedSink sink = Okio.buffer(Okio.sink(output)); + body.writeTo(sink); + sink.flush(); + return new String(output.toByteArray()); + } + + private static Response mockResponse(Request request, int status, Object body) { + Response.Builder builder = + new Response.Builder() + .protocol(Protocol.HTTP_1_1) + .message("ok") + .request(request) + .code(status); + if (body != null) { + String json = MOSHI.adapter(Object.class).toJson(body); + builder.body(ResponseBody.create(MediaType.parse("application/json"), json)); + } + return builder.build(); + } + + @SuppressWarnings("unchecked") + private static Map mapOf(Object... entries) { + if (entries.length % 2 != 0) { + throw new IllegalArgumentException("Expected even number of arguments"); + } + Map map = new LinkedHashMap<>(); + for (int i = 0; i < entries.length; i += 2) { + map.put((K) entries[i], (V) entries[i + 1]); + } + return map; + } + + private static class RequestHolder { + Request request; + } + + private static class TestSuite { + final AIGuard.Action action; + final String reason; + final List tags; + final Map tagProbabilities; + final boolean blocking; + final String description; + final String target; + final List messages; + + TestSuite( + AIGuard.Action action, + String reason, + Map tagProbabilities, + boolean blocking, + String description, + String target, + List messages) { + this.action = action; + this.reason = reason; + this.tags = new ArrayList<>(tagProbabilities.keySet()); + this.tagProbabilities = tagProbabilities; + this.blocking = blocking; + this.description = description; + this.target = target; + this.messages = messages; + } + + static List build() { + List all = new ArrayList<>(); + Object[][] actionValues = { + {ALLOW, "Go ahead", emptyMap()}, + {DENY, "Nope", probs("deny_everything", 0.2, "test_deny", 0.8)}, + {ABORT, "Kill it with fire", probs("alarm_tag", 0.1, "abort_everything", 0.9)} + }; + Object[][] suiteValues = { + {"tool call", "tool", TOOL_CALL}, + {"tool output", "tool", TOOL_OUTPUT}, + {"prompt", "prompt", PROMPT} + }; + for (Object[] action : actionValues) { + for (boolean blocking : new boolean[] {true, false}) { + for (Object[] suite : suiteValues) { + @SuppressWarnings("unchecked") + Map tagProbs = (Map) action[2]; + @SuppressWarnings("unchecked") + List messages = (List) suite[2]; + all.add( + new TestSuite( + (AIGuard.Action) action[0], + (String) action[1], + tagProbs, + blocking, + (String) suite[0], + (String) suite[1], + messages)); + } + } + } + return all; + } + + private static Map probs(Object... entries) { + Map result = new LinkedHashMap<>(); + for (int i = 0; i < entries.length; i += 2) { + result.put((String) entries[i], (Number) entries[i + 1]); + } + return result; + } + + @Override + public String toString() { + List contents = new ArrayList<>(); + for (AIGuard.Message m : messages) { + contents.add(m.getContent()); + } + return "TestSuite{description='" + + description + + "', action=" + + action + + ", reason='" + + reason + + "', blocking=" + + blocking + + ", target='" + + target + + "', messages=" + + contents + + "', tags=" + + tags + + "}"; + } + } +} From b5db0996c5b5e9e13fbc406e198968db3672ce71 Mon Sep 17 00:00:00 2001 From: Santiago Mola Date: Fri, 8 May 2026 11:32:56 +0200 Subject: [PATCH 3/3] test(communication): migrate MessageWriterTest to JUnit 5 Convert the AI Guard MessageWriter Spock spec to a Java JUnit 5 test extending DDJavaSpecification, replacing injectSysConfig with the @WithConfig annotation. The 8 tests cover msgpack serialization of plain content, tool calls/output, and content parts. --- .../aiguard/MessageWriterTest.groovy | 242 ------------------ .../aiguard/MessageWriterTest.java | 240 +++++++++++++++++ 2 files changed, 240 insertions(+), 242 deletions(-) delete mode 100644 communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy create mode 100644 communication/src/test/java/datadog/communication/serialization/aiguard/MessageWriterTest.java diff --git a/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy b/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy deleted file mode 100644 index 45191a2fc8f..00000000000 --- a/communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.groovy +++ /dev/null @@ -1,242 +0,0 @@ -package datadog.communication.serialization.aiguard - -import datadog.communication.serialization.EncodingCache -import datadog.communication.serialization.GrowableBuffer -import datadog.communication.serialization.msgpack.MsgPackWriter -import datadog.trace.api.aiguard.AIGuard -import datadog.trace.test.util.DDSpecification -import org.msgpack.core.MessagePack -import org.msgpack.value.Value - -import java.nio.charset.StandardCharsets -import java.util.function.Function - -class MessageWriterTest extends DDSpecification { - - private EncodingCache encodingCache - private GrowableBuffer buffer - private MsgPackWriter writer - - void setup() { - injectSysConfig('ai_guard.enabled', 'true') - final HashMap cache = new HashMap<>() - encodingCache = new EncodingCache() { - @Override - byte[] encode(CharSequence chars) { - cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8)) - } - } - buffer = new GrowableBuffer(1024) - writer = new MsgPackWriter(buffer) - } - - void 'test write message'() { - given: - final message = AIGuard.Message.message('user', 'What day is today?') - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringValueMap(unpacker.unpackValue()) - value.size() == 2 - value.role == 'user' - value.content == 'What day is today?' - } - } - - void 'test write tool call'() { - given: - final message = - AIGuard.Message.assistant( - AIGuard.ToolCall.toolCall('call_1', 'function_1', 'args_1'), - AIGuard.ToolCall.toolCall('call_2', 'function_2', 'args_2')) - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringKeyMap(unpacker.unpackValue()) - value.size() == 2 - asString(value.role) == 'assistant' - - final toolCalls = value.get('tool_calls').asArrayValue().list() - toolCalls.size() == 2 - - final firstCall = asStringKeyMap(toolCalls[0]) - asString(firstCall.id) == 'call_1' - final firstFunction = asStringValueMap(firstCall.function) - firstFunction.name == 'function_1' - firstFunction.arguments == 'args_1' - - final secondCall = asStringKeyMap(toolCalls[1]) - asString(secondCall.id) == 'call_2' - final secondFunction = asStringValueMap(secondCall.function) - secondFunction.name == 'function_2' - secondFunction.arguments == 'args_2' - } - } - - void 'test write tool output'() throws IOException { - given: - final message = AIGuard.Message.tool('call_1', 'output') - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringValueMap(unpacker.unpackValue()) - value.size() == 3 - value.role == 'tool' - value.tool_call_id == 'call_1' - value.content == 'output' - } - } - - private static Map mapValue( - final Value values, - final Function keyMapper, - final Function valueMapper) { - return values.asMapValue().entrySet().collectEntries { - [(keyMapper.apply(it.key)): valueMapper.apply(it.value)] - } - } - - private static Map asStringKeyMap(final Value values) { - return mapValue(values, MessageWriterTest::asString, Function.identity()) - } - - private static Map asStringValueMap(final Value values) { - return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString) - } - - private static String asString(final Value value) { - return value.asStringValue().asString() - } - - void 'test write message with text content parts'() { - given: - final message = AIGuard.Message.message('user', [ - AIGuard.ContentPart.text('Hello world') - ]) - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringKeyMap(unpacker.unpackValue()) - value.size() == 2 - asString(value.role) == 'user' - - final contentParts = value.content.asArrayValue().list() - contentParts.size() == 1 - - final part = asStringKeyMap(contentParts[0]) - asString(part.type) == 'text' - asString(part.text) == 'Hello world' - } - } - - void 'test write message with image_url content parts'() { - given: - final message = AIGuard.Message.message('user', [ - AIGuard.ContentPart.imageUrl('https://example.com/image.jpg') - ]) - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringKeyMap(unpacker.unpackValue()) - value.size() == 2 - asString(value.role) == 'user' - - final contentParts = value.content.asArrayValue().list() - contentParts.size() == 1 - - final part = asStringKeyMap(contentParts[0]) - asString(part.type) == 'image_url' - - final imageUrl = asStringKeyMap(part.image_url) - asString(imageUrl.url) == 'https://example.com/image.jpg' - } - } - - void 'test write message with mixed content parts'() { - given: - final message = AIGuard.Message.message('user', [ - AIGuard.ContentPart.text('Describe this:'), - AIGuard.ContentPart.imageUrl('https://example.com/image.jpg'), - AIGuard.ContentPart.text('What is it?') - ]) - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringKeyMap(unpacker.unpackValue()) - value.size() == 2 - asString(value.role) == 'user' - - final contentParts = value.content.asArrayValue().list() - contentParts.size() == 3 - - final part1 = asStringKeyMap(contentParts[0]) - asString(part1.type) == 'text' - asString(part1.text) == 'Describe this:' - - final part2 = asStringKeyMap(contentParts[1]) - asString(part2.type) == 'image_url' - final imageUrl = asStringKeyMap(part2.image_url) - asString(imageUrl.url) == 'https://example.com/image.jpg' - - final part3 = asStringKeyMap(contentParts[2]) - asString(part3.type) == 'text' - asString(part3.text) == 'What is it?' - } - } - - void 'test content parts type serializes as string not integer'() { - given: - final message = AIGuard.Message.message('user', [ - AIGuard.ContentPart.text('Test') - ]) - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringKeyMap(unpacker.unpackValue()) - final contentParts = value.content.asArrayValue().list() - final part = asStringKeyMap(contentParts[0]) - - // Verify type is a string value, not an integer - part.type.isStringValue() - !part.type.isIntegerValue() - asString(part.type) == 'text' - } - } - - void 'test backward compatibility with string content'() { - given: - final message = AIGuard.Message.message('user', 'Plain text message') - - when: - writer.writeObject(message, encodingCache) - - then: - try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { - final value = asStringValueMap(unpacker.unpackValue()) - value.size() == 2 - value.role == 'user' - value.content == 'Plain text message' - } - } -} diff --git a/communication/src/test/java/datadog/communication/serialization/aiguard/MessageWriterTest.java b/communication/src/test/java/datadog/communication/serialization/aiguard/MessageWriterTest.java new file mode 100644 index 00000000000..42ddecebc88 --- /dev/null +++ b/communication/src/test/java/datadog/communication/serialization/aiguard/MessageWriterTest.java @@ -0,0 +1,240 @@ +package datadog.communication.serialization.aiguard; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import datadog.communication.serialization.EncodingCache; +import datadog.communication.serialization.GrowableBuffer; +import datadog.communication.serialization.msgpack.MsgPackWriter; +import datadog.trace.api.aiguard.AIGuard; +import datadog.trace.junit.utils.config.WithConfig; +import datadog.trace.test.util.DDJavaSpecification; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessageUnpacker; +import org.msgpack.value.Value; + +@WithConfig(key = "ai_guard.enabled", value = "true") +class MessageWriterTest extends DDJavaSpecification { + + private EncodingCache encodingCache; + private GrowableBuffer buffer; + private MsgPackWriter writer; + + @BeforeEach + void setup() { + HashMap cache = new HashMap<>(); + encodingCache = + chars -> cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8)); + buffer = new GrowableBuffer(1024); + writer = new MsgPackWriter(buffer); + } + + @Test + void testWriteMessage() throws IOException { + AIGuard.Message message = AIGuard.Message.message("user", "What day is today?"); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringValueMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("user", value.get("role")); + assertEquals("What day is today?", value.get("content")); + } + } + + @Test + void testWriteToolCall() throws IOException { + AIGuard.Message message = + AIGuard.Message.assistant( + AIGuard.ToolCall.toolCall("call_1", "function_1", "args_1"), + AIGuard.ToolCall.toolCall("call_2", "function_2", "args_2")); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringKeyMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("assistant", asString(value.get("role"))); + + List toolCalls = value.get("tool_calls").asArrayValue().list(); + assertEquals(2, toolCalls.size()); + + Map firstCall = asStringKeyMap(toolCalls.get(0)); + assertEquals("call_1", asString(firstCall.get("id"))); + Map firstFunction = asStringValueMap(firstCall.get("function")); + assertEquals("function_1", firstFunction.get("name")); + assertEquals("args_1", firstFunction.get("arguments")); + + Map secondCall = asStringKeyMap(toolCalls.get(1)); + assertEquals("call_2", asString(secondCall.get("id"))); + Map secondFunction = asStringValueMap(secondCall.get("function")); + assertEquals("function_2", secondFunction.get("name")); + assertEquals("args_2", secondFunction.get("arguments")); + } + } + + @Test + void testWriteToolOutput() throws IOException { + AIGuard.Message message = AIGuard.Message.tool("call_1", "output"); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringValueMap(unpacker.unpackValue()); + assertEquals(3, value.size()); + assertEquals("tool", value.get("role")); + assertEquals("call_1", value.get("tool_call_id")); + assertEquals("output", value.get("content")); + } + } + + @Test + void testWriteMessageWithTextContentParts() throws IOException { + AIGuard.Message message = + AIGuard.Message.message( + "user", Collections.singletonList(AIGuard.ContentPart.text("Hello world"))); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringKeyMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("user", asString(value.get("role"))); + + List contentParts = value.get("content").asArrayValue().list(); + assertEquals(1, contentParts.size()); + + Map part = asStringKeyMap(contentParts.get(0)); + assertEquals("text", asString(part.get("type"))); + assertEquals("Hello world", asString(part.get("text"))); + } + } + + @Test + void testWriteMessageWithImageUrlContentParts() throws IOException { + AIGuard.Message message = + AIGuard.Message.message( + "user", + Collections.singletonList( + AIGuard.ContentPart.imageUrl("https://example.com/image.jpg"))); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringKeyMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("user", asString(value.get("role"))); + + List contentParts = value.get("content").asArrayValue().list(); + assertEquals(1, contentParts.size()); + + Map part = asStringKeyMap(contentParts.get(0)); + assertEquals("image_url", asString(part.get("type"))); + + Map imageUrl = asStringKeyMap(part.get("image_url")); + assertEquals("https://example.com/image.jpg", asString(imageUrl.get("url"))); + } + } + + @Test + void testWriteMessageWithMixedContentParts() throws IOException { + AIGuard.Message message = + AIGuard.Message.message( + "user", + Arrays.asList( + AIGuard.ContentPart.text("Describe this:"), + AIGuard.ContentPart.imageUrl("https://example.com/image.jpg"), + AIGuard.ContentPart.text("What is it?"))); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringKeyMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("user", asString(value.get("role"))); + + List contentParts = value.get("content").asArrayValue().list(); + assertEquals(3, contentParts.size()); + + Map part1 = asStringKeyMap(contentParts.get(0)); + assertEquals("text", asString(part1.get("type"))); + assertEquals("Describe this:", asString(part1.get("text"))); + + Map part2 = asStringKeyMap(contentParts.get(1)); + assertEquals("image_url", asString(part2.get("type"))); + Map imageUrl = asStringKeyMap(part2.get("image_url")); + assertEquals("https://example.com/image.jpg", asString(imageUrl.get("url"))); + + Map part3 = asStringKeyMap(contentParts.get(2)); + assertEquals("text", asString(part3.get("type"))); + assertEquals("What is it?", asString(part3.get("text"))); + } + } + + @Test + void testContentPartsTypeSerializesAsStringNotInteger() throws IOException { + AIGuard.Message message = + AIGuard.Message.message( + "user", Collections.singletonList(AIGuard.ContentPart.text("Test"))); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringKeyMap(unpacker.unpackValue()); + List contentParts = value.get("content").asArrayValue().list(); + Map part = asStringKeyMap(contentParts.get(0)); + + assertTrue(part.get("type").isStringValue()); + assertFalse(part.get("type").isIntegerValue()); + assertEquals("text", asString(part.get("type"))); + } + } + + @Test + void testBackwardCompatibilityWithStringContent() throws IOException { + AIGuard.Message message = AIGuard.Message.message("user", "Plain text message"); + + writer.writeObject(message, encodingCache); + + try (MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) { + Map value = asStringValueMap(unpacker.unpackValue()); + assertEquals(2, value.size()); + assertEquals("user", value.get("role")); + assertEquals("Plain text message", value.get("content")); + } + } + + private static Map mapValue( + Value values, Function keyMapper, Function valueMapper) { + Map result = new LinkedHashMap<>(); + for (Map.Entry entry : values.asMapValue().entrySet()) { + result.put(keyMapper.apply(entry.getKey()), valueMapper.apply(entry.getValue())); + } + return result; + } + + private static Map asStringKeyMap(Value values) { + return mapValue(values, MessageWriterTest::asString, Function.identity()); + } + + private static Map asStringValueMap(Value values) { + return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString); + } + + private static String asString(Value value) { + return value.asStringValue().asString(); + } +}