Skip to content

Commit 9ec2692

Browse files
committed
Fix function calling examples
1 parent beadef0 commit 9ec2692

File tree

11 files changed

+202
-212
lines changed

11 files changed

+202
-212
lines changed

README.md

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,9 @@ In an API call, you can describe functions and have the model intelligently choo
232232
use openai_dive::v1::api::Client;
233233
use openai_dive::v1::models::Gpt4Engine;
234234
use openai_dive::v1::resources::chat::{
235-
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolChoice,
236-
ChatCompletionToolChoiceFunction, ChatCompletionToolChoiceFunctionName, ChatCompletionToolType,
237-
ChatMessage, Role,
235+
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolType, ChatMessage,
236+
ChatMessageContent,
238237
};
239-
use openai_dive::v1::resources::shared::FinishReason;
240238
use rand::Rng;
241239
use serde::{Deserialize, Serialize};
242240
use serde_json::{json, Value};
@@ -247,22 +245,14 @@ async fn main() {
247245

248246
let client = Client::new(api_key);
249247

250-
let mut messages = vec![ChatMessage {
251-
content: ChatMessageContent::Text("Give me a random number between 25 and 50?".to_string()),
248+
let messages = vec![ChatMessage {
249+
content: ChatMessageContent::Text("Give me a random number between 100 and no more than 150?".to_string()),
252250
..Default::default()
253251
}];
254252

255253
let parameters = ChatCompletionParameters {
256254
model: Gpt4Engine::Gpt41106Preview.to_string(),
257255
messages: messages.clone(),
258-
tool_choice: Some(ChatCompletionToolChoice::ChatCompletionToolChoiceFunction(
259-
ChatCompletionToolChoiceFunction {
260-
r#type: Some(ChatCompletionToolType::Function),
261-
function: ChatCompletionToolChoiceFunctionName {
262-
name: "get_random_number".to_string(),
263-
},
264-
},
265-
)),
266256
tools: Some(vec![ChatCompletionTool {
267257
r#type: ChatCompletionToolType::Function,
268258
function: ChatCompletionFunction {
@@ -273,7 +263,8 @@ async fn main() {
273263
"properties": {
274264
"min": {"type": "integer", "description": "Minimum value of the random number."},
275265
"max": {"type": "integer", "description": "Maximum value of the random number."},
276-
}
266+
},
267+
"required": ["min", "max"],
277268
}),
278269
},
279270
}]),
@@ -282,36 +273,22 @@ async fn main() {
282273

283274
let result = client.chat().create(parameters).await.unwrap();
284275

285-
for choice in result.choices.iter() {
286-
if choice.finish_reason == FinishReason::StopSequenceReached {
287-
if let Some(tool_calls) = &choice.message.tool_calls {
288-
for tool_call in tool_calls.iter() {
289-
let random_numbers =
290-
serde_json::from_str(&tool_call.function.arguments).unwrap();
291-
292-
if tool_call.function.name == "get_random_number" {
293-
let random_number_result = get_random_number(random_numbers);
294-
295-
messages.push(ChatMessage {
296-
role: Role::Function,
297-
content: ChatMessageContent::Text(
298-
serde_json::to_string(&random_number_result).unwrap(),
299-
),
300-
name: Some("get_random_number".to_string()),
301-
..Default::default()
302-
});
303-
304-
let parameters = ChatCompletionParameters {
305-
model: Gpt4Engine::Gpt41106Preview.to_string(),
306-
messages: messages.clone(),
307-
..Default::default()
308-
};
309-
310-
let result = client.chat().create(parameters).await.unwrap();
311-
312-
println!("{:#?}", result);
313-
}
314-
}
276+
let message = result.choices[0].message.clone();
277+
278+
if let Some(tool_calls) = message.tool_calls {
279+
for tool_call in tool_calls {
280+
let name = tool_call.function.name;
281+
let arguments = tool_call.function.arguments;
282+
283+
if name == "get_random_number" {
284+
let random_numbers: RandomNumber = serde_json::from_str(&arguments).unwrap();
285+
286+
println!("Min: {:?}", &random_numbers.min);
287+
println!("Max: {:?}", &random_numbers.max);
288+
289+
let random_number_result = get_random_number(random_numbers);
290+
291+
println!("Random number between those numbers: {:?}", random_number_result.clone());
315292
}
316293
}
317294
}

examples/chat/create_chat_completion/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async fn main() {
1919
},
2020
ChatMessage {
2121
role: Role::User,
22-
content: ChatMessageContent::Text("Where are you located?".to_string()),
22+
content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()),
2323
..Default::default()
2424
},
2525
],

examples/chat/create_chat_completion_stream/src/main.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use futures::StreamExt;
22
use openai_dive::v1::api::Client;
3-
use openai_dive::v1::resources::chat::{
4-
ChatCompletionParameters, ChatMessage, ChatMessageContent, Role,
5-
};
3+
use openai_dive::v1::models::Gpt4Engine;
4+
use openai_dive::v1::resources::chat::{ChatCompletionParameters, ChatMessage, ChatMessageContent, Role};
65
use std::env;
76

87
#[tokio::main]
@@ -12,7 +11,7 @@ async fn main() {
1211
let client = Client::new(api_key);
1312

1413
let parameters = ChatCompletionParameters {
15-
model: "gpt-3.5-turbo-16k-0613".to_string(),
14+
model: Gpt4Engine::Gpt41106Preview.to_string(),
1615
messages: vec![
1716
ChatMessage {
1817
role: Role::User,
@@ -21,7 +20,7 @@ async fn main() {
2120
},
2221
ChatMessage {
2322
role: Role::User,
24-
content: ChatMessageContent::Text("Where are you located?".to_string()),
23+
content: ChatMessageContent::Text("What is the capital of Vietnam?".to_string()),
2524
..Default::default()
2625
},
2726
],

examples/chat/function_calling/src/main.rs

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use openai_dive::v1::api::Client;
22
use openai_dive::v1::models::Gpt4Engine;
33
use openai_dive::v1::resources::chat::{
4-
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolChoice, ChatCompletionToolChoiceFunction, ChatCompletionToolChoiceFunctionName, ChatCompletionToolType,
5-
ChatMessage, ChatMessageContent, Role,
4+
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolType, ChatMessage,
5+
ChatMessageContent,
66
};
7-
use openai_dive::v1::resources::shared::FinishReason;
87
use rand::Rng;
98
use serde::{Deserialize, Serialize};
109
use serde_json::{json, Value};
@@ -15,20 +14,14 @@ async fn main() {
1514

1615
let client = Client::new(api_key);
1716

18-
let mut messages = vec![ChatMessage {
19-
content: ChatMessageContent::Text("Give me a random number between 25 and 50?".to_string()),
17+
let messages = vec![ChatMessage {
18+
content: ChatMessageContent::Text("Give me a random number between 100 and no more than 150?".to_string()),
2019
..Default::default()
2120
}];
2221

2322
let parameters = ChatCompletionParameters {
2423
model: Gpt4Engine::Gpt41106Preview.to_string(),
2524
messages: messages.clone(),
26-
tool_choice: Some(ChatCompletionToolChoice::ChatCompletionToolChoiceFunction(ChatCompletionToolChoiceFunction {
27-
r#type: Some(ChatCompletionToolType::Function),
28-
function: ChatCompletionToolChoiceFunctionName {
29-
name: "get_random_number".to_string(),
30-
},
31-
})),
3225
tools: Some(vec![ChatCompletionTool {
3326
r#type: ChatCompletionToolType::Function,
3427
function: ChatCompletionFunction {
@@ -39,7 +32,8 @@ async fn main() {
3932
"properties": {
4033
"min": {"type": "integer", "description": "Minimum value of the random number."},
4134
"max": {"type": "integer", "description": "Maximum value of the random number."},
42-
}
35+
},
36+
"required": ["min", "max"],
4337
}),
4438
},
4539
}]),
@@ -48,33 +42,25 @@ async fn main() {
4842

4943
let result = client.chat().create(parameters).await.unwrap();
5044

51-
for choice in result.choices.iter() {
52-
if choice.finish_reason == FinishReason::StopSequenceReached {
53-
if let Some(tool_calls) = &choice.message.tool_calls {
54-
for tool_call in tool_calls.iter() {
55-
let random_numbers = serde_json::from_str(&tool_call.function.arguments).unwrap();
45+
let message = result.choices[0].message.clone();
5646

57-
if tool_call.function.name == "get_random_number" {
58-
let random_number_result = get_random_number(random_numbers);
47+
if let Some(tool_calls) = message.tool_calls {
48+
for tool_call in tool_calls {
49+
let name = tool_call.function.name;
50+
let arguments = tool_call.function.arguments;
5951

60-
messages.push(ChatMessage {
61-
role: Role::Function,
62-
content: ChatMessageContent::Text(serde_json::to_string(&random_number_result).unwrap()),
63-
name: Some("get_random_number".to_string()),
64-
..Default::default()
65-
});
52+
if name == "get_random_number" {
53+
let random_numbers: RandomNumber = serde_json::from_str(&arguments).unwrap();
6654

67-
let parameters = ChatCompletionParameters {
68-
model: "gpt-3.5-turbo-0613".to_string(),
69-
messages: messages.clone(),
70-
..Default::default()
71-
};
55+
println!("Min: {:?}", &random_numbers.min);
56+
println!("Max: {:?}", &random_numbers.max);
7257

73-
let result = client.chat().create(parameters).await.unwrap();
58+
let random_number_result = get_random_number(random_numbers);
7459

75-
println!("{:#?}", result);
76-
}
77-
}
60+
println!(
61+
"Random number between those numbers: {:?}",
62+
random_number_result.clone()
63+
);
7864
}
7965
}
8066
}

examples/chat/function_calling_stream/src/main.rs

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
use futures::executor::block_on;
21
use futures::StreamExt;
32
use openai_dive::v1::api::Client;
3+
use openai_dive::v1::models::Gpt4Engine;
44
use openai_dive::v1::resources::chat::{
5-
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolChoice,
6-
ChatCompletionToolChoiceFunction, ChatCompletionToolChoiceFunctionName, ChatCompletionToolType,
7-
ChatMessage, ChatMessageContent, DeltaFunction, Role,
5+
ChatCompletionFunction, ChatCompletionParameters, ChatCompletionTool, ChatCompletionToolType, ChatMessage,
6+
ChatMessageContent, DeltaFunction,
87
};
98
use openai_dive::v1::resources::shared::FinishReason;
109
use rand::Rng;
@@ -17,22 +16,14 @@ async fn main() {
1716

1817
let client = Client::new(api_key);
1918

20-
let mut messages = vec![ChatMessage {
21-
content: ChatMessageContent::Text("Give me a random number between 25 and 50?".to_string()),
19+
let messages = vec![ChatMessage {
20+
content: ChatMessageContent::Text("Give me a random number higher than 100 but less than 2*150?".to_string()),
2221
..Default::default()
2322
}];
2423

2524
let parameters = ChatCompletionParameters {
26-
model: "gpt-3.5-turbo-0613".to_string(),
25+
model: Gpt4Engine::Gpt41106Preview.to_string(),
2726
messages: messages.clone(),
28-
tool_choice: Some(ChatCompletionToolChoice::ChatCompletionToolChoiceFunction(
29-
ChatCompletionToolChoiceFunction {
30-
r#type: Some(ChatCompletionToolType::Function),
31-
function: ChatCompletionToolChoiceFunctionName {
32-
name: "get_random_number".to_string(),
33-
},
34-
},
35-
)),
3627
tools: Some(vec![ChatCompletionTool {
3728
r#type: ChatCompletionToolType::Function,
3829
function: ChatCompletionFunction {
@@ -43,7 +34,8 @@ async fn main() {
4334
"properties": {
4435
"min": {"type": "integer", "description": "Minimum value of the random number."},
4536
"max": {"type": "integer", "description": "Maximum value of the random number."},
46-
}
37+
},
38+
"required": ["min", "max"],
4739
}),
4840
},
4941
}]),
@@ -66,33 +58,22 @@ async fn main() {
6658
print!("{}", content);
6759
}
6860

69-
if choice.finish_reason == Some(FinishReason::StopSequenceReached) {
70-
let random_numbers =
71-
serde_json::from_str(function.clone().arguments.unwrap().as_ref()).unwrap();
72-
73-
if let Some(name) = function.clone().name {
74-
if name == "get_random_number" {
75-
let random_number_result = get_random_number(random_numbers);
61+
if choice.finish_reason == Some(FinishReason::ToolCalls) {
62+
let name = function.name.clone().unwrap();
63+
let arguments = function.arguments.clone().unwrap();
7664

77-
messages.push(ChatMessage {
78-
role: Role::Function,
79-
content: ChatMessageContent::Text(
80-
serde_json::to_string(&random_number_result).unwrap(),
81-
),
82-
name: Some("get_random_number".to_string()),
83-
..Default::default()
84-
});
65+
if name == "get_random_number" {
66+
let random_numbers: RandomNumber = serde_json::from_str(&arguments).unwrap();
8567

86-
let parameters = ChatCompletionParameters {
87-
model: "gpt-3.5-turbo-0613".to_string(),
88-
messages: messages.clone(),
89-
..Default::default()
90-
};
68+
println!("Min: {:?}", &random_numbers.min);
69+
println!("Max: {:?}", &random_numbers.max);
9170

92-
let result = block_on(client.chat().create(parameters));
71+
let random_number_result = get_random_number(random_numbers);
9372

94-
println!("{:#?}", result);
95-
}
73+
println!(
74+
"Random number between those numbers: {:?}",
75+
random_number_result.clone()
76+
);
9677
}
9778
}
9879
}),

examples/chat/rate_limit_headers/src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use openai_dive::v1::api::Client;
2+
use openai_dive::v1::models::Gpt4Engine;
23
use openai_dive::v1::resources::chat::{
34
ChatCompletionParameters, ChatCompletionResponse, ChatMessage, ChatMessageContent, Role,
45
};
@@ -12,7 +13,7 @@ async fn main() {
1213
let client = Client::new(api_key);
1314

1415
let parameters = ChatCompletionParameters {
15-
model: "gpt-3.5-turbo-16k-0613".to_string(),
16+
model: Gpt4Engine::Gpt41106Preview.to_string(),
1617
messages: vec![
1718
ChatMessage {
1819
role: Role::User,
@@ -21,16 +22,15 @@ async fn main() {
2122
},
2223
ChatMessage {
2324
role: Role::User,
24-
content: ChatMessageContent::Text("Where are you located?".to_string()),
25+
content: ChatMessageContent::Text("Which country has the largest population?".to_string()),
2526
..Default::default()
2627
},
2728
],
2829
max_tokens: Some(12),
2930
..Default::default()
3031
};
3132

32-
let result: ResponseWrapper<ChatCompletionResponse> =
33-
client.chat().create_wrapped(parameters).await.unwrap();
33+
let result: ResponseWrapper<ChatCompletionResponse> = client.chat().create_wrapped(parameters).await.unwrap();
3434

3535
println!("{:#?}", result.headers);
3636

examples/fine_tuning/create_fine_tuning_job/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async fn main() {
1313
let file_id = env::var("FILE_ID").expect("FILE_ID is not set in the .env file.");
1414

1515
let parameters = CreateFineTuningJobParameters {
16-
model: "gpt-3.5-turbo-1106".to_string(),
16+
model: "gpt-4-1106-preview".to_string(),
1717
training_file: file_id,
1818
hyperparameters: None,
1919
suffix: None,

examples/models/retrieve_model/src/main.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@ use std::env;
33

44
#[tokio::main]
55
async fn main() {
6-
let api_key = env::var("OPENAI_API_KEY").expect("$OPENAI_API_KEY is not set");
6+
let api_key = env::var("OPENAI_API_KEY")
7+
.expect("$OPENAI_API_KEY is not set");
78

89
let client = Client::new(api_key);
910

10-
let result = client.models().get("gpt-3.5-turbo-16k-0613").await.unwrap();
11+
let result = client
12+
.models()
13+
.get("gpt-4-1106-preview")
14+
.await
15+
.unwrap();
1116

1217
println!("{:#?}", result);
1318
}

rustfmt.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
max_width = 200
1+
max_width = 120

0 commit comments

Comments
 (0)