Skip to content

Commit 3fc939a

Browse files
committed
Add @JsonAnySetter support for extraBody deserialization in ChatCompletionRequest
1 parent b6ccb03 commit 3fc939a

File tree

2 files changed

+141
-18
lines changed

2 files changed

+141
-18
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.stream.Collectors;
2727

2828
import com.fasterxml.jackson.annotation.JsonAnyGetter;
29+
import com.fasterxml.jackson.annotation.JsonAnySetter;
2930
import com.fasterxml.jackson.annotation.JsonFormat;
3031
import com.fasterxml.jackson.annotation.JsonIgnore;
3132
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
@@ -1135,6 +1136,16 @@ public record ChatCompletionRequest(// @formatter:off
11351136
@JsonProperty("safety_identifier") String safetyIdentifier,
11361137
Map<String, Object> extraBody) {
11371138

1139+
/**
1140+
* Compact constructor that ensures extraBody is initialized as a mutable HashMap
1141+
* when null, enabling @JsonAnySetter to populate it during deserialization.
1142+
*/
1143+
public ChatCompletionRequest {
1144+
if (extraBody == null) {
1145+
extraBody = new java.util.HashMap<>();
1146+
}
1147+
}
1148+
11381149
/**
11391150
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
11401151
*
@@ -1231,6 +1242,20 @@ public Map<String, Object> extraBody() {
12311242
return this.extraBody;
12321243
}
12331244

1245+
/**
1246+
* Handles deserialization of unknown properties into the extraBody map.
1247+
* This enables JSON with extra fields to be deserialized into ChatCompletionRequest,
1248+
* which is useful for implementing OpenAI API proxy servers with @RestController.
1249+
* @param key The property name
1250+
* @param value The property value
1251+
*/
1252+
@JsonAnySetter
1253+
private void setExtraBodyProperty(String key, Object value) {
1254+
if (this.extraBody != null) {
1255+
this.extraBody.put(key, value);
1256+
}
1257+
}
1258+
12341259
/**
12351260
* Helper factory that creates a tool_choice of type 'none', 'auto' or selected function by name.
12361261
*/

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/ExtraBodySerializationTest.java

Lines changed: 116 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,10 @@ void testExtraBodySerializationFlattensToTopLevel() throws Exception {
4949
// Act: Serialize to JSON
5050
String json = this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(request);
5151

52-
// Debug: Print the actual JSON
53-
System.out.println("=== JSON Output (with @JsonAnyGetter) ===");
54-
System.out.println(json);
55-
5652
// Assert: Verify @JsonAnyGetter flattens fields to top level
5753
assertThat(json).contains("\"top_k\" : 50");
5854
assertThat(json).contains("\"repetition_penalty\" : 1.1");
5955
assertThat(json).doesNotContain("\"extra_body\"");
60-
61-
System.out.println("\n=== Analysis ===");
62-
System.out.println("✓ Fields are FLATTENED to top level (correct!)");
63-
System.out.println(" Format: { \"model\": \"gpt-4\", \"top_k\": 50, \"repetition_penalty\": 1.1 }");
64-
System.out.println(" This matches official OpenAI SDK and LangChain4j behavior");
65-
System.out.println(" This is CORRECT for vLLM, Ollama, and other OpenAI-compatible servers");
6656
}
6757

6858
@Test
@@ -78,10 +68,6 @@ void testExtraBodyWithEmptyMap() throws Exception {
7868
// Act
7969
String json = this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(request);
8070

81-
// Debug
82-
System.out.println("\n=== JSON Output (empty extraBody map) ===");
83-
System.out.println(json);
84-
8571
// Assert: No extra fields should appear
8672
assertThat(json).doesNotContain("extra_body");
8773
assertThat(json).doesNotContain("top_k");
@@ -101,13 +87,125 @@ void testExtraBodyNullSerialization() throws Exception {
10187
// Act
10288
String json = this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(request);
10389

104-
// Debug
105-
System.out.println("\n=== JSON Output (null extraBody) ===");
106-
System.out.println(json);
107-
10890
// Assert: extra_body should not appear in JSON when null
10991
assertThat(json).doesNotContain("extra_body");
11092
assertThat(json).doesNotContain("top_k");
11193
}
11294

95+
@Test
96+
void testExtraBodyDeserialization() throws Exception {
97+
// Arrange: JSON with extra fields (simulating proxy server receiving request)
98+
String json = """
99+
{
100+
"model": "gpt-4",
101+
"messages": [],
102+
"stream": false,
103+
"top_k": 50,
104+
"repetition_penalty": 1.1,
105+
"custom_param": "test_value"
106+
}
107+
""";
108+
109+
// Act: Deserialize JSON to ChatCompletionRequest
110+
ChatCompletionRequest request = this.objectMapper.readValue(json, ChatCompletionRequest.class);
111+
112+
// Assert: Extra fields should be captured in extraBody map
113+
assertThat(request.extraBody()).isNotNull();
114+
assertThat(request.extraBody()).containsEntry("top_k", 50);
115+
assertThat(request.extraBody()).containsEntry("repetition_penalty", 1.1);
116+
assertThat(request.extraBody()).containsEntry("custom_param", "test_value");
117+
118+
// Assert: Standard fields should be set correctly
119+
assertThat(request.model()).isEqualTo("gpt-4");
120+
assertThat(request.messages()).isEmpty();
121+
assertThat(request.stream()).isFalse();
122+
}
123+
124+
@Test
125+
void testRoundTripSerializationDeserialization() throws Exception {
126+
// Arrange: Create request with extraBody
127+
ChatCompletionRequest originalRequest = new ChatCompletionRequest(List.of(), // messages
128+
"gpt-4", // model
129+
null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false,
130+
null, null, null, null, null, null, null, null, null, null, null, null,
131+
Map.of("top_k", 50, "min_p", 0.05, "stop_token_ids", List.of(128001, 128009)) // extraBody
132+
);
133+
134+
// Act: Serialize to JSON
135+
String json = this.objectMapper.writeValueAsString(originalRequest);
136+
137+
// Act: Deserialize back to object
138+
ChatCompletionRequest deserializedRequest = this.objectMapper.readValue(json, ChatCompletionRequest.class);
139+
140+
// Assert: All extraBody fields should survive round trip
141+
assertThat(deserializedRequest.extraBody()).isNotNull();
142+
assertThat(deserializedRequest.extraBody()).containsEntry("top_k", 50);
143+
assertThat(deserializedRequest.extraBody()).containsEntry("min_p", 0.05);
144+
assertThat(deserializedRequest.extraBody()).containsKey("stop_token_ids");
145+
146+
// Assert: Standard fields should match
147+
assertThat(deserializedRequest.model()).isEqualTo(originalRequest.model());
148+
assertThat(deserializedRequest.stream()).isEqualTo(originalRequest.stream());
149+
}
150+
151+
@Test
152+
void testDeserializationWithNullExtraBody() throws Exception {
153+
// Arrange: JSON without any extra fields (standard OpenAI request)
154+
String json = """
155+
{
156+
"model": "gpt-4",
157+
"messages": [],
158+
"stream": false,
159+
"temperature": 0.7
160+
}
161+
""";
162+
163+
// Act: Deserialize
164+
ChatCompletionRequest request = this.objectMapper.readValue(json, ChatCompletionRequest.class);
165+
166+
// Assert: extraBody should be null or empty when no extra fields present
167+
// (depending on Jackson configuration and constructor behavior)
168+
if (request.extraBody() != null) {
169+
assertThat(request.extraBody()).isEmpty();
170+
}
171+
172+
// Assert: Standard fields should work
173+
assertThat(request.model()).isEqualTo("gpt-4");
174+
assertThat(request.temperature()).isEqualTo(0.7);
175+
}
176+
177+
@Test
178+
void testDeserializationWithComplexExtraFields() throws Exception {
179+
// Arrange: JSON with real vLLM extra fields (complex types)
180+
String json = """
181+
{
182+
"model": "deepseek-r1",
183+
"messages": [],
184+
"stream": false,
185+
"top_k": 50,
186+
"min_p": 0.05,
187+
"best_of": 3,
188+
"guided_json": "{\\"type\\": \\"object\\", \\"properties\\": {\\"name\\": {\\"type\\": \\"string\\"}}}",
189+
"stop_token_ids": [128001, 128009],
190+
"skip_special_tokens": true
191+
}
192+
""";
193+
194+
// Act: Deserialize
195+
ChatCompletionRequest request = this.objectMapper.readValue(json, ChatCompletionRequest.class);
196+
197+
// Assert: Real vLLM extra fields should be captured
198+
assertThat(request.extraBody()).isNotNull();
199+
assertThat(request.extraBody()).containsEntry("top_k", 50);
200+
assertThat(request.extraBody()).containsEntry("min_p", 0.05);
201+
assertThat(request.extraBody()).containsEntry("best_of", 3);
202+
assertThat(request.extraBody()).containsKey("guided_json");
203+
assertThat(request.extraBody()).containsKey("stop_token_ids");
204+
assertThat(request.extraBody()).containsEntry("skip_special_tokens", true);
205+
206+
// Assert: Complex types should be preserved as String/List
207+
assertThat(request.extraBody().get("guided_json")).isInstanceOf(String.class);
208+
assertThat(request.extraBody().get("stop_token_ids")).isInstanceOf(List.class);
209+
}
210+
113211
}

0 commit comments

Comments
 (0)