Skip to content

Commit 376693b

Browse files
committed
feat: make LlmGenerationClient::generate return json
fix: handle json value
1 parent 1d5bf05 commit 376693b

File tree

7 files changed

+64
-24
lines changed

7 files changed

+64
-24
lines changed

rust/cocoindex/src/llm/anthropic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl LlmGenerationClient for Client {
123123
}
124124
let text = if let Some(json) = extracted_json {
125125
// Try strict JSON serialization first
126-
serde_json::to_string(&json)?
126+
return Ok(LlmGenerateResponse::Json(json));
127127
} else {
128128
// Fallback: try text if no tool output found
129129
match &mut resp_json["content"][0]["text"] {
@@ -155,7 +155,7 @@ impl LlmGenerationClient for Client {
155155
}
156156
};
157157

158-
Ok(LlmGenerateResponse { text })
158+
Ok(LlmGenerateResponse::Text(text))
159159
}
160160

161161
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/bedrock.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ impl LlmGenerationClient for Client {
148148

149149
if let Some(json) = extracted_json {
150150
// Return the structured output as JSON
151-
serde_json::to_string(&json)?
151+
return Ok(LlmGenerateResponse::Json(json));
152152
} else {
153153
// Fall back to text content
154154
let mut text_parts = Vec::new();
@@ -165,7 +165,7 @@ impl LlmGenerationClient for Client {
165165
return Err(anyhow::anyhow!("No content found in Bedrock response"));
166166
};
167167

168-
Ok(LlmGenerateResponse { text })
168+
Ok(LlmGenerateResponse::Text(text))
169169
}
170170

171171
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/gemini.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ impl LlmGenerationClient for AiStudioClient {
147147
});
148148
}
149149

150+
let mut need_json = false;
151+
150152
// If structured output is requested, add schema and responseMimeType
151153
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
154+
need_json = true;
152155
let mut schema_json = serde_json::to_value(schema)?;
153156
remove_additional_properties(&mut schema_json);
154157
payload["generationConfig"] = serde_json::json!({
@@ -161,18 +164,24 @@ impl LlmGenerationClient for AiStudioClient {
161164
let resp = http::request(|| self.client.post(&url).json(&payload))
162165
.await
163166
.context("Gemini API error")?;
164-
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
167+
let mut resp_json: Value = resp.json().await.context("Invalid JSON")?;
165168

166169
if let Some(error) = resp_json.get("error") {
167170
bail!("Gemini API error: {:?}", error);
168171
}
169-
let mut resp_json = resp_json;
172+
173+
if need_json {
174+
return Ok(super::LlmGenerateResponse::Json(serde_json::json!(
175+
resp_json["candidates"][0]
176+
)));
177+
}
178+
170179
let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] {
171180
Value::String(s) => std::mem::take(s),
172181
_ => bail!("No text in response"),
173182
};
174183

175-
Ok(LlmGenerateResponse { text })
184+
Ok(LlmGenerateResponse::Text(text))
176185
}
177186

178187
fn json_schema_options(&self) -> ToJsonSchemaOptions {
@@ -333,9 +342,12 @@ impl LlmGenerationClient for VertexAiClient {
333342
.set_parts(vec![Part::new().set_text(sys.to_string())])
334343
});
335344

345+
let mut need_json = false;
346+
336347
// Compose generation config
337348
let mut generation_config = None;
338349
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
350+
need_json = true;
339351
let schema_json = serde_json::to_value(schema)?;
340352
generation_config = Some(
341353
GenerationConfig::new()
@@ -359,6 +371,18 @@ impl LlmGenerationClient for VertexAiClient {
359371

360372
// Call the API
361373
let resp = req.send().await?;
374+
375+
if need_json {
376+
match resp.candidates.into_iter().next() {
377+
Some(resp_json) => {
378+
return Ok(super::LlmGenerateResponse::Json(serde_json::json!(
379+
resp_json
380+
)));
381+
}
382+
None => bail!("No response"),
383+
}
384+
}
385+
362386
// Extract text from response
363387
let Some(Data::Text(text)) = resp
364388
.candidates
@@ -370,7 +394,7 @@ impl LlmGenerationClient for VertexAiClient {
370394
else {
371395
bail!("No text in response");
372396
};
373-
Ok(super::LlmGenerateResponse { text })
397+
Ok(super::LlmGenerateResponse::Text(text))
374398
}
375399

376400
fn json_schema_options(&self) -> ToJsonSchemaOptions {

rust/cocoindex/src/llm/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ pub struct LlmGenerateRequest<'a> {
6666
}
6767

6868
#[derive(Debug)]
69-
pub struct LlmGenerateResponse {
70-
pub text: String,
69+
pub enum LlmGenerateResponse {
70+
Text(String),
71+
Json(serde_json::Value),
7172
}
7273

7374
#[async_trait]

rust/cocoindex/src/llm/ollama.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,8 @@ impl LlmGenerationClient for Client {
108108
})
109109
.await
110110
.context("Ollama API error")?;
111-
let json: OllamaResponse = res.json().await?;
112-
Ok(super::LlmGenerateResponse {
113-
text: json.response,
114-
})
111+
112+
Ok(super::LlmGenerateResponse::Json(res.json().await?))
115113
}
116114

117115
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

rust/cocoindex/src/llm/openai.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::prelude::*;
1+
use crate::{llm::OutputFormat, prelude::*};
22
use base64::prelude::*;
33

44
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
@@ -145,15 +145,29 @@ impl LlmGenerationClient for Client {
145145
)
146146
.await?;
147147

148-
// Extract the response text from the first choice
149-
let text = response
150-
.choices
151-
.into_iter()
152-
.next()
153-
.and_then(|choice| choice.message.content)
154-
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
148+
let mut response_iter = response.choices.into_iter();
155149

156-
Ok(super::LlmGenerateResponse { text })
150+
match request.output_format {
151+
Some(OutputFormat::JsonSchema { .. }) => {
152+
// Extract the response json from the first choice
153+
let response_json = serde_json::json!(
154+
response_iter
155+
.next()
156+
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?
157+
);
158+
159+
Ok(super::LlmGenerateResponse::Json(response_json))
160+
}
161+
None => {
162+
// Extract the response text from the first choice
163+
let text = response_iter
164+
.next()
165+
.and_then(|choice| choice.message.content)
166+
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
167+
168+
Ok(super::LlmGenerateResponse::Text(text))
169+
}
170+
}
157171
}
158172

159173
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

rust/cocoindex/src/ops/functions/extract_by_llm.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ impl SimpleFunctionExecutor for Executor {
113113
}),
114114
};
115115
let res = self.client.generate(req).await?;
116-
let json_value: serde_json::Value = utils::deser::from_json_str(res.text.as_str())?;
116+
let json_value = match res {
117+
crate::llm::LlmGenerateResponse::Text(text) => utils::deser::from_json_str(&text)?,
118+
crate::llm::LlmGenerateResponse::Json(value) => value,
119+
};
117120
let value = self.value_extractor.extract_value(json_value)?;
118121
Ok(value)
119122
}

0 commit comments

Comments
 (0)