Skip to content

Commit caadca4

Browse files
authored
Merge pull request tjardoo#46 from tjardoo/chat-image-url-support
Add image url support to chat completion endpoint
2 parents 0fd6e41 + 9ec2692 commit caadca4

File tree

15 files changed

+411
-227
lines changed

15 files changed

+411
-227
lines changed

README.md

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ More information: [set API key](#set-api-key), [add proxy](#add-proxy), [rate li
2424
- [Delete fine-tune model](#delete-fine-tune-model)
2525
- Chat
2626
- [Create chat completion](#create-chat-completion)
27+
- [Create chat completion with image](#create-chat-completion-with-image)
2728
- [Function calling](#function-calling)
2829
- Images
2930
- [Create image](#create-image)
@@ -135,6 +136,7 @@ Creates a model response for the given chat conversation.
135136
136137
```rust
137138
use openai_dive::v1::api::Client;
139+
use openai_dive::v1::models::Gpt4Engine;
138140
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, Role};
139141
use std::env;
140142

@@ -145,16 +147,16 @@ async fn main() {
145147
let client = Client::new(api_key);
146148

147149
let parameters = ChatCompletionParameters {
148-
model: "gpt-3.5-turbo-16k-0613".to_string(),
150+
model: Gpt4Engine::Gpt41106Preview.to_string(),
149151
messages: vec![
150152
ChatMessage {
151153
role: Role::User,
152-
content: Some("Hello!".to_string()),
154+
content: ChatMessageContent::Text("Hello!".to_string()),
153155
..Default::default()
154156
},
155157
ChatMessage {
156158
role: Role::User,
157-
content: Some("Where are you located?".to_string()),
159+
content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()),
158160
..Default::default()
159161
},
160162
],
@@ -170,6 +172,55 @@ async fn main() {
170172

171173
More information: [Create chat completion](https://platform.openai.com/docs/api-reference/chat/create)
172174

175+
### Create chat completion with image
176+
177+
Creates a model response for the given chat conversation.
178+
179+
```rust
180+
use openai_dive::v1::api::Client;
181+
use openai_dive::v1::models::Gpt4Engine;
182+
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, Role};
183+
use std::env;
184+
185+
#[tokio::main]
186+
async fn main() {
187+
let api_key = env::var("OPENAI_API_KEY").expect("$OPENAI_API_KEY is not set");
188+
189+
let client = Client::new(api_key);
190+
191+
let parameters = ChatCompletionParameters {
192+
model: Gpt4Engine::Gpt4VisionPreview.to_string(),
193+
messages: vec![
194+
ChatMessage {
195+
role: Role::User,
196+
content: ChatMessageContent::Text("What is in this image?".to_string()),
197+
..Default::default()
198+
},
199+
ChatMessage {
200+
role: Role::User,
201+
content: ChatMessageContent::ImageUrl(vec![ImageUrl {
202+
r#type: "image_url".to_string(),
203+
text: None,
204+
image_url: ImageUrlType {
205+
url: "https://images.unsplash.com/photo-1526682847805-721837c3f83b?w=640".to_string(),
206+
detail: None,
207+
},
208+
}]),
209+
..Default::default()
210+
},
211+
],
212+
max_tokens: Some(50),
213+
..Default::default()
214+
};
215+
216+
let result = client.chat().create(parameters).await.unwrap();
217+
218+
println!("{:#?}", result);
219+
}
220+
```
221+
222+
More information: [Create chat completion](https://platform.openai.com/docs/api-reference/chat/create)
223+
173224
### Function calling
174225

175226
In an API call, you can describe functions and have the model intelligently choose to output a JSON object containing arguments to call one or many functions. The Chat Completions API does not call the function; instead, the model generates JSON that you can use to call the function in your code.
@@ -179,12 +230,11 @@ In an API call, you can describe functions and have the model intelligently choo
179230
180231
```rust
181232
use openai_dive::v1::api::Client;
233+
use openai_dive::v1::models::Gpt4Engine;
182234
use openai_dive::v1::resources::chat::{
183-
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolChoice,
184-
ChatCompletionToolChoiceFunction, ChatCompletionToolChoiceFunctionName, ChatCompletionToolType,
185-
ChatMessage, Role,
235+
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolType, ChatMessage,
236+
ChatMessageContent,
186237
};
187-
use openai_dive::v1::resources::shared::FinishReason;
188238
use rand::Rng;
189239
use serde::{Deserialize, Serialize};
190240
use serde_json::{json, Value};
@@ -195,22 +245,14 @@ async fn main() {
195245

196246
let client = Client::new(api_key);
197247

198-
let mut messages = vec![ChatMessage {
199-
content: Some("Give me a random number between 25 and 50?".to_string()),
248+
let messages = vec![ChatMessage {
249+
content: ChatMessageContent::Text("Give me a random number between 100 and no more than 150?".to_string()),
200250
..Default::default()
201251
}];
202252

203253
let parameters = ChatCompletionParameters {
204-
model: "gpt-3.5-turbo-0613".to_string(),
254+
model: Gpt4Engine::Gpt41106Preview.to_string(),
205255
messages: messages.clone(),
206-
tool_choice: Some(ChatCompletionToolChoice::ChatCompletionToolChoiceFunction(
207-
ChatCompletionToolChoiceFunction {
208-
r#type: Some(ChatCompletionToolType::Function),
209-
function: ChatCompletionToolChoiceFunctionName {
210-
name: "get_random_number".to_string(),
211-
},
212-
},
213-
)),
214256
tools: Some(vec![ChatCompletionTool {
215257
r#type: ChatCompletionToolType::Function,
216258
function: ChatCompletionFunction {
@@ -221,7 +263,8 @@ async fn main() {
221263
"properties": {
222264
"min": {"type": "integer", "description": "Minimum value of the random number."},
223265
"max": {"type": "integer", "description": "Maximum value of the random number."},
224-
}
266+
},
267+
"required": ["min", "max"],
225268
}),
226269
},
227270
}]),
@@ -230,34 +273,22 @@ async fn main() {
230273

231274
let result = client.chat().create(parameters).await.unwrap();
232275

233-
for choice in result.choices.iter() {
234-
if choice.finish_reason == FinishReason::StopSequenceReached {
235-
if let Some(tool_calls) = &choice.message.tool_calls {
236-
for tool_call in tool_calls.iter() {
237-
let random_numbers =
238-
serde_json::from_str(&tool_call.function.arguments).unwrap();
239-
240-
if tool_call.function.name == "get_random_number" {
241-
let random_number_result = get_random_number(random_numbers);
242-
243-
messages.push(ChatMessage {
244-
role: Role::Function,
245-
content: Some(serde_json::to_string(&random_number_result).unwrap()),
246-
name: Some("get_random_number".to_string()),
247-
..Default::default()
248-
});
249-
250-
let parameters = ChatCompletionParameters {
251-
model: "gpt-3.5-turbo-0613".to_string(),
252-
messages: messages.clone(),
253-
..Default::default()
254-
};
255-
256-
let result = client.chat().create(parameters).await.unwrap();
257-
258-
println!("{:#?}", result);
259-
}
260-
}
276+
let message = result.choices[0].message.clone();
277+
278+
if let Some(tool_calls) = message.tool_calls {
279+
for tool_call in tool_calls {
280+
let name = tool_call.function.name;
281+
let arguments = tool_call.function.arguments;
282+
283+
if name == "get_random_number" {
284+
let random_numbers: RandomNumber = serde_json::from_str(&arguments).unwrap();
285+
286+
println!("Min: {:?}", &random_numbers.min);
287+
println!("Max: {:?}", &random_numbers.max);
288+
289+
let random_number_result = get_random_number(random_numbers);
290+
291+
println!("Random number between those numbers: {:?}", random_number_result.clone());
261292
}
262293
}
263294
}

examples/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"audio/create_speech",
77
"chat/create_chat_completion",
88
"chat/create_chat_completion_stream",
9+
"chat/create_image_chat_completion",
910
"chat/function_calling",
1011
"chat/function_calling_stream",
1112
"chat/rate_limit_headers",

examples/chat/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ resolver = "2"
33
members = [
44
"create_chat_completion",
55
"create_chat_completion_stream",
6+
"create_image_chat_completion",
67
"function_calling",
78
"function_calling_stream",
89
"rate_limit_headers",

examples/chat/create_chat_completion/src/main.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use openai_dive::v1::api::Client;
2-
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, Role};
2+
use openai_dive::v1::models::Gpt4Engine;
3+
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent, Role};
34
use std::env;
45

56
#[tokio::main]
@@ -9,16 +10,16 @@ async fn main() {
910
let client = Client::new(api_key);
1011

1112
let parameters = ChatCompletionParameters {
12-
model: "gpt-3.5-turbo-16k-0613".to_string(),
13+
model: Gpt4Engine::Gpt41106Preview.to_string(),
1314
messages: vec![
1415
ChatMessage {
1516
role: Role::User,
16-
content: Some("Hello!".to_string()),
17+
content: ChatMessageContent::Text("Hello!".to_string()),
1718
..Default::default()
1819
},
1920
ChatMessage {
2021
role: Role::User,
21-
content: Some("Where are you located?".to_string()),
22+
content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()),
2223
..Default::default()
2324
},
2425
],

examples/chat/create_chat_completion_stream/src/main.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use futures::StreamExt;
22
use openai_dive::v1::api::Client;
3-
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, Role};
3+
use openai_dive::v1::models::Gpt4Engine;
4+
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent, Role};
45
use std::env;
56

67
#[tokio::main]
@@ -10,16 +11,16 @@ async fn main() {
1011
let client = Client::new(api_key);
1112

1213
let parameters = ChatCompletionParameters {
13-
model: "gpt-3.5-turbo-16k-0613".to_string(),
14+
model: Gpt4Engine::Gpt41106Preview.to_string(),
1415
messages: vec![
1516
ChatMessage {
1617
role: Role::User,
17-
content: Some("Hello!".to_string()),
18+
content: ChatMessageContent::Text("Hello!".to_string()),
1819
..Default::default()
1920
},
2021
ChatMessage {
2122
role: Role::User,
22-
content: Some("Where are you located?".to_string()),
23+
content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()),
2324
..Default::default()
2425
},
2526
],
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "create_image_chat_completion"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
openai_dive = { path = "./../../../../openai-client" }
9+
tokio = { version = "1.0", features = ["full"] }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use openai_dive::v1::api::Client;
2+
use openai_dive::v1::models::Gpt4Engine;
3+
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent, ImageUrl, ImageUrlType, Role};
4+
use std::env;
5+
6+
#[tokio::main]
7+
async fn main() {
8+
let api_key = env::var("OPENAI_API_KEY").expect("$OPENAI_API_KEY is not set");
9+
10+
let client = Client::new(api_key);
11+
12+
let parameters = ChatCompletionParameters {
13+
model: Gpt4Engine::Gpt4VisionPreview.to_string(),
14+
messages: vec![
15+
ChatMessage {
16+
role: Role::User,
17+
content: ChatMessageContent::Text("What is in this image?".to_string()),
18+
..Default::default()
19+
},
20+
ChatMessage {
21+
role: Role::User,
22+
content: ChatMessageContent::ImageUrl(vec![ImageUrl {
23+
r#type: "image_url".to_string(),
24+
text: None,
25+
image_url: ImageUrlType {
26+
url: "https://images.unsplash.com/photo-1526682847805-721837c3f83b?w=640".to_string(),
27+
detail: None,
28+
},
29+
}]),
30+
..Default::default()
31+
},
32+
],
33+
max_tokens: Some(50),
34+
..Default::default()
35+
};
36+
37+
let result = client.chat().create(parameters).await.unwrap();
38+
39+
println!("{:#?}", result);
40+
}

0 commit comments

Comments
 (0)