From 784dd98f658357c4149cc85a8f53d5a67f645987 Mon Sep 17 00:00:00 2001 From: svetanis Date: Sat, 13 Jun 2026 10:57:38 -0700 Subject: [PATCH] Enforce strict JSON Schema compliance to prevent OpenAI 400 Bad Request errors --- .../models/chat/ChatCompletionsCommon.java | 43 +++++++--- .../models/chat/ChatCompletionsRequest.java | 50 +++++++++-- .../models/chat/ChatCompletionsResponse.java | 4 +- .../chat/ChatCompletionsCommonTest.java | 75 +++++++++++++++++ .../chat/ChatCompletionsRequestTest.java | 83 ++++++++++++++++++- .../chat/ChatCompletionsResponseTest.java | 39 +++++++++ 6 files changed, 272 insertions(+), 22 deletions(-) create mode 100644 core/src/test/java/com/google/adk/models/chat/ChatCompletionsCommonTest.java diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index 1ed997824..530154727 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -19,8 +19,12 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; import java.util.Base64; @@ -37,6 +41,10 @@ private ChatCompletionsCommon() {} private static final ObjectMapper objectMapper = new ObjectMapper(); + static final String EMPTY_JSON_OBJECT = "{}"; + static final ImmutableMap EMPTY_PARAMETERS_SCHEMA = + ImmutableMap.of("type", "object", "properties", ImmutableMap.of()); + public static final String ROLE_ASSISTANT = "assistant"; public static final String ROLE_MODEL = "model"; @@ -157,6 +165,21 @@ public Part applyThoughtSignature(Part part) { } } + static ImmutableMap parseToolCallArguments(String arguments, ObjectMapper mapper) + throws JsonProcessingException { + if (arguments == null || arguments.trim().isEmpty()) { + return ImmutableMap.of(); + } + Map result = + mapper.readValue(arguments, new TypeReference>() {}); + if (result == null) { + throw JsonMappingException.from( + (JsonParser) null, + "JSON literal 'null' is not a valid JSON object for tool call arguments"); + } + return ImmutableMap.copyOf(result); + } + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_function_tool_call%20%3E%20(schema) @@ -181,21 +204,21 @@ public FunctionCall toFunctionCall(@Nullable String toolCallId) { if (name != null) { fcBuilder.name(name); } - if (arguments != null && !arguments.isEmpty()) { - try { - Map args = - objectMapper.readValue(arguments, new TypeReference>() {}); - fcBuilder.args(args); - } catch (Exception e) { - throw new IllegalArgumentException( - "Failed to parse function arguments JSON: " + arguments, e); - } - } + fcBuilder.args(parseArguments(arguments)); if (toolCallId != null) { fcBuilder.id(toolCallId); } return fcBuilder.build(); } + + private ImmutableMap parseArguments(String arguments) { + try { + return parseToolCallArguments(arguments, objectMapper); + } catch (Exception e) { + throw new IllegalArgumentException( + "Failed to parse function arguments JSON: " + arguments, e); + } + } } /** diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java index 2ad0733b9..f9d4adf6e 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -21,8 +21,12 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.adk.JsonBaseModel; import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; @@ -32,6 +36,8 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; +import com.google.genai.types.Type; +import java.io.IOException; import java.util.ArrayList; import java.util.Base64; import java.util.List; @@ -270,7 +276,28 @@ public final class ChatCompletionsRequest { public Map extraBody; private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class); - private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + + /** + * Registers a custom serializer to force JSON Schema types to lowercase (e.g., "STRING" -> + * "string"). The genai SDK uses uppercase Enums for schema types, which strict OpenAI-compatible + * endpoints reject with HTTP 400. + */ + private static SimpleModule schemaNormalizerModule() { + SimpleModule module = new SimpleModule(); + module.addSerializer( + Type.class, + new JsonSerializer() { + @Override + public void serialize(Type value, JsonGenerator gen, SerializerProvider serializers) + throws IOException { + gen.writeString(value.toString().toLowerCase()); + } + }); + return module; + } + + private static final ObjectMapper objectMapper = + JsonBaseModel.getMapper().copy().registerModule(schemaNormalizerModule()); /** * Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for @@ -476,7 +503,10 @@ private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part) function.arguments = objectMapper.writeValueAsString(fc.args().get()); } catch (Exception e) { logger.warn("Failed to serialize function arguments", e); + function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT; } + } else { + function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT; } toolCall.function = function; part.thoughtSignature() @@ -505,7 +535,10 @@ private static Message processFunctionResponsePart(Part part) { toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get())); } catch (Exception e) { logger.warn("Failed to serialize tool response", e); + toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT); } + } else { + toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT); } return toolResp; } @@ -570,12 +603,15 @@ private static void handleTools(GenerateContentConfig config, ChatCompletionsReq FunctionDefinition def = new FunctionDefinition(); def.name = fd.name().orElse(""); def.description = fd.description().orElse(""); - fd.parameters() - .ifPresent( - params -> - def.parameters = - objectMapper.convertValue( - params, new TypeReference>() {})); + if (fd.parameters().isPresent()) { + def.parameters = + objectMapper.convertValue( + fd.parameters().get(), new TypeReference>() {}); + } else { + // OpenAI-compatible APIs (like Groq) strictly require the parameters object + // to exist, even for zero-argument functions. + def.parameters = ChatCompletionsCommon.EMPTY_PARAMETERS_SCHEMA; + } tool.function = def; tools.add(tool); } diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index 6cb25f38f..768af850e 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmResponse; import com.google.common.collect.ImmutableList; @@ -836,8 +835,7 @@ private ImmutableList getFinalToolCallParts() { if (argsSb != null && argsSb.length() > 0) { try { Map args = - objectMapper.readValue( - argsSb.toString(), new TypeReference>() {}); + ChatCompletionsCommon.parseToolCallArguments(argsSb.toString(), objectMapper); fc = fc.toBuilder().args(args).build(); part = part.toBuilder().functionCall(fc).build(); } catch (JsonProcessingException e) { diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsCommonTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsCommonTest.java new file mode 100644 index 000000000..8ddafe0f9 --- /dev/null +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsCommonTest.java @@ -0,0 +1,75 @@ +package com.google.adk.models.chat; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class ChatCompletionsCommonTest { + + private ObjectMapper objectMapper; + + @Before + public void setUp() { + objectMapper = new ObjectMapper(); + } + + @Test + public void parseToolCallArguments_withValidJson() throws Exception { + String json = "{\"pr_number\": 1042, \"reason\": \"review\"}"; + ImmutableMap args = + ChatCompletionsCommon.parseToolCallArguments(json, objectMapper); + assertThat(args).hasSize(2); + assertThat(args.get("pr_number")).isEqualTo(1042); + assertThat(args.get("reason")).isEqualTo("review"); + assertThat(args).isInstanceOf(ImmutableMap.class); + } + + @Test + public void parseToolCallArguments_withEmptyString() throws Exception { + Map args = ChatCompletionsCommon.parseToolCallArguments("", objectMapper); + assertThat(args).isEmpty(); + } + + @Test + public void parseToolCallArguments_withNullString() throws Exception { + Map args = ChatCompletionsCommon.parseToolCallArguments(null, objectMapper); + assertThat(args).isEmpty(); + } + + @Test + public void parseToolCallArguments_withWhitespaceString() throws Exception { + Map args = ChatCompletionsCommon.parseToolCallArguments(" ", objectMapper); + assertThat(args).isEmpty(); + } + + @Test + public void parseToolCallArguments_withInvalidJson_throwsException() { + assertThrows( + JsonProcessingException.class, + () -> ChatCompletionsCommon.parseToolCallArguments("none", objectMapper)); + + assertThrows( + JsonProcessingException.class, + () -> ChatCompletionsCommon.parseToolCallArguments("{bad_json:", objectMapper)); + } + + @Test + public void parseToolCallArguments_withLiteralNullString_throwsException() { + JsonProcessingException exception = + assertThrows( + JsonProcessingException.class, + () -> ChatCompletionsCommon.parseToolCallArguments("null", objectMapper)); + assertThat(exception) + .hasMessageThat() + .contains("JSON literal 'null' is not a valid JSON object for tool call arguments"); + } +} diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index ec4246f90..ac61fcbcc 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -34,6 +34,7 @@ import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Part; +import com.google.genai.types.Schema; import com.google.genai.types.Tool; import com.google.genai.types.ToolConfig; import java.util.AbstractMap; @@ -567,6 +568,84 @@ public void testFromLlmRequest_withFunctionCall() throws Exception { assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{\"location\":\"Paris\"}"); } + @Test + public void testFromLlmRequest_withAbsentFunctionArguments() throws Exception { + FunctionCall functionCall = FunctionCall.builder().id("call_123").name("get_time").build(); + Part part = Part.builder().functionCall(functionCall).build(); + Content content = Content.builder().role("model").parts(ImmutableList.of(part)).build(); + + LlmRequest llmRequest = + LlmRequest.builder().model("gemini-1.5-pro").contents(ImmutableList.of(content)).build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThat(msg.role).isEqualTo("assistant"); + assertThat(msg.toolCalls).hasSize(1); + assertThat(msg.toolCalls.get(0).function.name).isEqualTo("get_time"); + assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{}"); + } + + @Test + public void testFromLlmRequest_withAbsentParameters() throws Exception { + FunctionDeclaration function = + FunctionDeclaration.builder().name("test_func").description("A test function").build(); + + Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build(); + GenerateContentConfig config = + GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build(); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .config(config) + .contents(ImmutableList.of()) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.tools).hasSize(1); + Map params = (Map) request.tools.get(0).function.parameters; + assertThat(params.get("type")).isEqualTo("object"); + @SuppressWarnings("unchecked") + Map props = (Map) params.get("properties"); + assertThat(props).isEmpty(); + } + + @Test + public void testFromLlmRequest_normalizesSchemaTypeToLowerCase() throws Exception { + Schema param1Schema = Schema.builder().type("STRING").build(); + + Schema functionSchema = + Schema.builder().type("OBJECT").properties(ImmutableMap.of("param1", param1Schema)).build(); + + FunctionDeclaration function = + FunctionDeclaration.builder().name("test_func").parameters(functionSchema).build(); + + Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build(); + GenerateContentConfig config = + GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build(); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .config(config) + .contents(ImmutableList.of()) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.tools).hasSize(1); + Map params = (Map) request.tools.get(0).function.parameters; + assertThat(params.get("type")).isEqualTo("object"); + @SuppressWarnings("unchecked") + Map props = (Map) params.get("properties"); + @SuppressWarnings("unchecked") + Map param1 = (Map) props.get("param1"); + assertThat(param1.get("type")).isEqualTo("string"); + } + @Test public void testFromLlmRequest_withStreamOptions() throws Exception { LlmRequest llmRequest = @@ -628,11 +707,11 @@ public void testFromLlmRequest_withFunctionResponse() throws Exception { assertThat(request.messages.get(1).role).isEqualTo("tool"); assertThat(request.messages.get(1).toolCallId).isEmpty(); - assertThat(request.messages.get(1).content).isNull(); + assertThat(request.messages.get(1).content.getValue()).isEqualTo("{}"); assertThat(request.messages.get(2).role).isEqualTo("tool"); assertThat(request.messages.get(2).toolCallId).isEqualTo("call_faulty"); - assertThat(request.messages.get(2).content).isNull(); + assertThat(request.messages.get(2).content.getValue()).isEqualTo("{}"); } @Test diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index 97f35576d..d93486e98 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -1200,6 +1200,45 @@ public void testChunkCollection_streamingToolCall_backfillsMessageLevelSignature assertThat(finalToolPart.thoughtSignature()).hasValue(streamingToolSignature); } + @Test + public void testChunkCollection_streamingToolCall_parsesValidJsonArgs() throws Exception { + String chunk1 = + "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"do_thing\",\"arguments\":\"{\\\"key\\\": \\\"value\\\"}\"}}]}}]}"; + String chunk2 = "{\"choices\":[{\"finish_reason\":\"tool_calls\"}]}"; + + ImmutableList all = runStream(chunk1, chunk2); + + assertThat(all).hasSize(1); + LlmResponse finalResponse = all.get(0); + Part finalToolPart = finalResponse.content().get().parts().get().get(0); + assertThat(finalToolPart.functionCall().get().name()).hasValue("do_thing"); + assertThat(finalToolPart.functionCall().get().args().get().get("key")).isEqualTo("value"); + } + + @Test + public void testChunkCollection_streamingToolCall_handlesEmptyArgs() throws Exception { + String chunk1 = + "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"do_thing\",\"arguments\":\"\"}}]}}]}"; + String chunk2 = "{\"choices\":[{\"finish_reason\":\"tool_calls\"}]}"; + + ImmutableList all = runStream(chunk1, chunk2); + + assertThat(all).hasSize(1); + LlmResponse finalResponse = all.get(0); + Part finalToolPart = finalResponse.content().get().parts().get().get(0); + assertThat(finalToolPart.functionCall().get().name()).hasValue("do_thing"); + assertThat(finalToolPart.functionCall().get().args()).isEmpty(); + } + + @Test + public void testChunkCollection_streamingToolCall_throwsOnInvalidJsonArgs() { + String chunk1 = + "{\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"do_thing\",\"arguments\":\"none\"}}]}}]}"; + String chunk2 = "{\"choices\":[{\"finish_reason\":\"tool_calls\"}]}"; + + org.junit.Assert.assertThrows(IllegalArgumentException.class, () -> runStream(chunk1, chunk2)); + } + // ----- Round-trip: Part(sig) --> request --> response --> Part(sig) bytewise equal ------- @Test