Skip to content

Commit a9fba56

Browse files
hjh320ccreutzi
authored andcommitted
Enhance tool handling in generate methods
Add Tools to generate NV pairs and change the default ToolChoice to "auto" (where supported).
1 parent c603ec7 commit a9fba56

File tree

10 files changed

+280
-32
lines changed

10 files changed

+280
-32
lines changed

+llms/+internal/hasTools.m

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,33 @@
1414
end
1515

1616
methods(Hidden)
17-
function mustBeValidFunctionCall(this, functionCall)
17+
function mustBeValidFunctionCall(this, functionCall, functionNames)
18+
if nargin < 3
19+
functionNames = this.FunctionNames;
20+
end
21+
1822
if ~isempty(functionCall)
1923
mustBeTextScalar(functionCall);
20-
if isempty(this.FunctionNames)
24+
if isempty(functionNames) && ~ismember(functionCall, ["auto", "none"])
2125
error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall"));
2226
end
23-
mustBeMember(functionCall, ["none","auto","required", this.FunctionNames]);
27+
mustBeMember(functionCall, ["none","auto","required", functionNames]);
2428
end
2529
end
2630

27-
function toolChoice = convertToolChoice(this, toolChoice)
31+
function toolChoice = convertToolChoice(this, toolChoice, functionNames)
32+
if nargin < 3
33+
functionNames = this.FunctionNames;
34+
end
35+
2836
% if toolChoice is empty
2937
if isempty(toolChoice)
3038
% if Tools is not empty, the default is 'auto'.
3139
if ~isempty(this.Tools)
3240
toolChoice = "auto";
3341
end
42+
elseif ismember(toolChoice, ["auto", "none"]) && isempty(functionNames)
43+
toolChoice = strings(1,0);
3444
elseif ~ismember(toolChoice,["auto","none","required"])
3545
% if toolChoice is not empty, then it must be "auto", "none",
3646
% "required", or in the format {"type": "function", "function":

+llms/+utils/errorMessageCatalog.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
catalog("llms:deploymentMustBeSpecified") = "Unable to find deployment name. Either set environment variable AZURE_OPENAI_DEPLOYMENT or specify name-value argument ""DeploymentID"".";
5959
catalog("llms:keyMustBeSpecified") = "Unable to find API key. Either set environment variable {1} or specify name-value argument ""APIKey"".";
6060
catalog("llms:mustHaveMessages") = "Message history must not be empty.";
61-
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
61+
catalog("llms:mustSetFunctionsForCall") = "When Tools is empty, ToolChoice must be ""none"" or ""auto"".";
6262
catalog("llms:mustBeMessagesOrTxt") = "Message must be nonempty string, character array, cell array of character vectors, or messageHistory object.";
6363
catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for model ""{3}"".";
6464
catalog("llms:invalidOptionForModel") = "Invalid argument name {1} for model ""{2}"".";

azureChat.m

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,13 @@
173173
% MaxNumTokens - Maximum number of tokens in the generated response.
174174
% Default value is inf.
175175
%
176-
% ToolChoice - Function to execute. 'none', 'auto', 'required',
176+
% ToolChoice - Function to execute. "none", "auto", "required",
177177
% or specify the function to call.
178+
% The default value is "auto".
179+
%
180+
% Tools - Array of openAIFunction objects representing
181+
% custom functions to be used during chat completions.
182+
% The default value is CHAT.Tools.
178183
%
179184
% Seed - An integer value to use to obtain
180185
% reproducible responses
@@ -234,10 +239,21 @@
234239
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
235240
nvp.NumCompletions (1,1) {mustBeNumeric,mustBePositive, mustBeInteger} = 1
236241
nvp.MaxNumTokens (1,1) {mustBeNumeric,mustBePositive} = inf
237-
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
242+
nvp.ToolChoice (1,:) {mustBeTextScalar} = "auto"
243+
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")}
238244
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
239245
end
240246

247+
if isfield(nvp, 'Tools')
248+
[functionsStruct, functionNames] = functionAsStruct(nvp.Tools);
249+
else
250+
functionsStruct = this.FunctionsStruct;
251+
functionNames = this.FunctionNames;
252+
end
253+
254+
mustBeValidFunctionCall(this, nvp.ToolChoice, functionNames);
255+
toolChoice = convertToolChoice(this, nvp.ToolChoice, functionNames);
256+
241257
messages = convertCharsToStrings(messages);
242258
if isstring(messages) && isscalar(messages)
243259
messagesStruct = {struct("role", "user", "content", messages)};
@@ -248,9 +264,7 @@
248264
if ~isempty(this.SystemPrompt)
249265
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
250266
end
251-
252-
toolChoice = convertToolChoice(this, nvp.ToolChoice);
253-
267+
!
254268
if isfield(nvp,"StreamFun")
255269
streamFun = nvp.StreamFun;
256270
else
@@ -259,7 +273,7 @@
259273

260274
try
261275
[text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ...
262-
this.DeploymentID, messagesStruct, this.FunctionsStruct, ...
276+
this.DeploymentID, messagesStruct, functionsStruct, ...
263277
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=nvp.Temperature, ...
264278
TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,...
265279
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...

doc/functions/generate.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ The supported name\-value arguments depend on the chat completion API.
9999
| `PresencePenalty` | Supported | Supported | |
100100
| `FrequencyPenalty` | Supported | Supported | |
101101
| `NumCompletions` | Supported | Supported | |
102+
| `Tools` | Supported | Supported | Supported |
102103
| `ToolChoice` | Supported | Supported | |
103104
| `MinP` | | | Supported |
104105
| `TopK` | | | Supported |
@@ -259,11 +260,17 @@ This option is only supported for these chat completion APIs:
259260

260261
- [`openAIChat`](openAIChat.md) objects
261262
- [`azureChat`](azureChat.md) objects
263+
264+
### `Tools` — Functions to call during output generation
265+
266+
`model.Tools` (default) | `openAIFunction` object | array of `openAIFunction` objects
267+
268+
Information about tools available for function calling, specified as [`openAIFunction`](openAIFunction.md) objects.
269+
262270
### `ToolChoice` — Tool choice
263271

264272
`"auto"` (default) | `"none"` | `"required"` | string scalar
265273

266-
267274
Tools that a model is allowed to call during output generation, specified as `"auto"`, `"none"`, `"required"`, or as a tool name. For more information on OpenAI function calling, see [`openAIFunction`](openAIFunction.md).
268275

269276

@@ -274,13 +281,13 @@ If the tool choice is set to `"none"`, then no tools are called during output ge
274281

275282
If the tool choice is set to `"required"`, then one or more tools are called during output generation.
276283

277-
You can also require that the model uses a specific tool by setting `ToolChoice` to the name of that tool. The name must be part of `model.FunctionNames`.
284+
You can also require that the model uses a specific tool by setting `ToolChoice` to the name of that tool. The name must refer to a tool that is available to the model. To give a model access to specific tools, either specify the `Tools` name-value argument during construction of the `model` object, or specify the `Tools` name-value argument of the `generate` function.
278285

279286
This option is only supported for these chat completion APIs:
280287

281288
- [`openAIChat`](openAIChat.md) objects
282289
- [`azureChat`](azureChat.md) objects
283-
290+
284291
### `MinP` — Minimum probability ratio
285292

286293
`model.MinP` (default) | numeric scalar between `0` and `1`

functionSignatures.json

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
4242
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
4343
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
44-
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
44+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
45+
{"name":"Tools","kind":"namevalue","type":"openAIFunction"}
4546
],
4647
"outputs":
4748
[
@@ -92,7 +93,8 @@
9293
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
9394
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
9495
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
95-
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
96+
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
97+
{"name":"Tools","kind":"namevalue","type":"openAIFunction"}
9698
],
9799
"outputs":
98100
[
@@ -142,7 +144,8 @@
142144
{"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]},
143145
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
144146
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
145-
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}
147+
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]},
148+
{"name":"Tools","kind":"namevalue","type":"openAIFunction"}
146149
],
147150
"outputs":
148151
[

ollamaChat.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@
147147
% MaxNumTokens - Maximum number of tokens in the generated response.
148148
% Default value is inf.
149149
%
150+
% Tools - Array of openAIFunction objects representing
151+
% custom functions to be used during chat completions.
152+
% The default value is CHAT.Tools.
153+
%
150154
% Seed - An integer value to use to obtain
151155
% reproducible responses
152156
%
@@ -216,6 +220,13 @@
216220
nvp.Endpoint (1,1) string = this.Endpoint
217221
nvp.MaxNumTokens (1,1) {mustBeNumeric,mustBePositive} = inf
218222
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
223+
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")}
224+
end
225+
226+
if ~isfield(nvp, 'Tools')
227+
functionsStruct = this.FunctionsStruct;
228+
else
229+
functionsStruct = functionAsStruct(nvp.Tools);
219230
end
220231

221232
messages = convertCharsToStrings(messages);
@@ -237,7 +248,7 @@
237248

238249
try
239250
[text, message, response] = llms.internal.callOllamaChatAPI(...
240-
nvp.ModelName, messagesStruct, this.FunctionsStruct, ...
251+
nvp.ModelName, messagesStruct, functionsStruct, ...
241252
Temperature=nvp.Temperature, ...
242253
TopP=nvp.TopP, MinP=nvp.MinP, TopK=nvp.TopK,...
243254
TailFreeSamplingZ=nvp.TailFreeSamplingZ,...

openAIChat.m

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@
157157
% MaxNumTokens - Maximum number of tokens in the generated response.
158158
% Default value is inf.
159159
%
160-
% ToolChoice - Function to execute. 'none', 'auto', 'required',
160+
% ToolChoice - Function to execute. "none", "auto", "required",
161161
% or specify the function to call.
162+
% The default value is "auto".
163+
%
164+
% Tools - Array of openAIFunction objects representing
165+
% custom functions to be used during chat completions.
166+
% The default value is CHAT.Tools.
162167
%
163168
% Seed - An integer value to use to obtain
164169
% reproducible responses
@@ -222,11 +227,20 @@
222227
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
223228
nvp.NumCompletions (1,1) {mustBeNumeric,mustBePositive, mustBeInteger} = 1
224229
nvp.MaxNumTokens (1,1) {mustBeNumeric,mustBePositive} = inf
225-
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
230+
nvp.ToolChoice (1,:) {mustBeTextScalar} = "auto"
231+
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")}
226232
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
227233
end
228234

229-
toolChoice = convertToolChoice(this, nvp.ToolChoice);
235+
if ~isfield(nvp, 'Tools')
236+
functionsStruct = this.FunctionsStruct;
237+
functionNames = this.FunctionNames;
238+
else
239+
[functionsStruct, functionNames] = functionAsStruct(nvp.Tools);
240+
end
241+
242+
mustBeValidFunctionCall(this, nvp.ToolChoice, functionNames);
243+
toolChoice = convertToolChoice(this, nvp.ToolChoice, functionNames);
230244

231245
messages = convertCharsToStrings(messages);
232246
if isstring(messages) && isscalar(messages)
@@ -246,7 +260,7 @@
246260
end
247261

248262
try % just for nicer errors, reducing the stack depth shown
249-
[text, message, response] = llms.internal.callOpenAIChatAPI(messagesStruct, this.FunctionsStruct,...
263+
[text, message, response] = llms.internal.callOpenAIChatAPI(messagesStruct, functionsStruct,...
250264
ModelName=nvp.ModelName, ToolChoice=toolChoice, Temperature=nvp.Temperature, ...
251265
TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,...
252266
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...

tests/hmockSendRequest.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
classdef hmockSendRequest < matlab.mock.TestCase
2+
%Helper method(s) for working with the mock framework.
3+
4+
% Copyright 2025 The MathWorks, Inc.
5+
6+
methods
7+
function [sendRequestMock, sendRequestBehaviour] = setUpSendRequestMock(testCase)
8+
[sendRequestMock,sendRequestBehaviour] = ...
9+
createMock(testCase, AddedMethods="sendRequest");
10+
testCase.assignOutputsWhen( ...
11+
withAnyInputs(sendRequestBehaviour.sendRequest),...
12+
testCase.responseMessage("Hello"),"This output is unused with Stream=false");
13+
end
14+
end
15+
end

tests/htoolCalls.m

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
classdef (Abstract) htoolCalls < matlab.mock.TestCase
1+
classdef (Abstract) htoolCalls < hmockSendRequest
22
% Tests for backends with tool calls
33

44
% Copyright 2023-2025 The MathWorks, Inc.
55
properties(Abstract)
66
constructor
77
defaultModel
88
end
9+
10+
properties(TestParameter)
11+
ToolsForChatConstruction = {openAIFunction.empty, openAIFunction("SomeToolNameForChat")}
12+
end
913

1014
methods (Test) % calling the server, end-to-end tests
1115
function generateWithToolsAndStreamFunc(testCase)
@@ -84,4 +88,78 @@ function generateWithToolsAndStreamFunc(testCase)
8488
testCase.verifyThat(data,HasField("explanation"));
8589
end
8690
end
91+
92+
methods (Test) % generate Tools tests without calling the server
93+
94+
function generateWithEmptyToolsSwitchesOffChatTools(testCase)
95+
import matlab.unittest.constraints.HasField
96+
97+
sendRequestMock = testCase.setUpSendRequestMock;
98+
99+
chat = testCase.constructor("You are a helpful assistant", Tools=openAIFunction("SomeToolNameForChat"));
100+
chat.sendRequestFcn = @(varargin) sendRequestMock.sendRequest(varargin{:});
101+
102+
emptyTools = openAIFunction.empty;
103+
testCase.verifyWarningFree(@() generate(chat,"Hi",Tools=emptyTools));
104+
105+
calls = testCase.getMockHistory(sendRequestMock);
106+
107+
testCase.verifySize(calls,[1,1]);
108+
sentHistory = calls.Inputs{2};
109+
testCase.verifyFalse(isfield(sentHistory,'tools'))
110+
end
111+
112+
function generateSendsOverriddenTools(testCase, ToolsForChatConstruction)
113+
import matlab.unittest.constraints.HasField
114+
115+
sendRequestMock = testCase.setUpSendRequestMock;
116+
117+
chat = testCase.constructor("You are a helpful assistant", Tools=ToolsForChatConstruction);
118+
chat.sendRequestFcn = @(varargin) sendRequestMock.sendRequest(varargin{:});
119+
120+
tools = openAIFunction("ToolForGenerate");
121+
response = testCase.verifyWarningFree(@() generate(chat,"Hi",Tools=tools));
122+
123+
calls = testCase.getMockHistory(sendRequestMock);
124+
125+
testCase.verifySize(calls,[1,1]);
126+
sentHistory = calls.Inputs{2};
127+
testCase.verifyThat(sentHistory,HasField("tools"));
128+
expectedTool = struct('type', 'function', ...
129+
'function', struct('name', "ToolForGenerate", ...
130+
'parameters', struct('type', "object", ...
131+
'properties', struct())));
132+
testCase.verifySize(sentHistory.tools, [1,1]);
133+
testCase.verifyEqual(sentHistory.tools{1}, expectedTool);
134+
testCase.verifyEqual(response,"Hello");
135+
end
136+
137+
function generateWithNoToolsDoesNotOverrideTopLevelTools(testCase)
138+
import matlab.unittest.constraints.HasField
139+
140+
sendRequestMock = testCase.setUpSendRequestMock;
141+
142+
tools = openAIFunction("funNameAtTopLevel");
143+
chat = testCase.constructor("You are a helpful assistant", Tools=tools);
144+
chat.sendRequestFcn = @(varargin) sendRequestMock.sendRequest(varargin{:});
145+
146+
response = testCase.verifyWarningFree(@() generate(chat,"Hi"));
147+
148+
calls = testCase.getMockHistory(sendRequestMock);
149+
150+
testCase.verifySize(calls,[1,1]);
151+
sentHistory = calls.Inputs{2};
152+
testCase.verifyThat(sentHistory,HasField("tools"));
153+
expectedTool = struct( ...
154+
'type', 'function', ...
155+
'function', struct( ...
156+
'name', "funNameAtTopLevel", ...
157+
'parameters', struct( ...
158+
'type', "object", ...
159+
'properties', struct())));
160+
testCase.verifySize(sentHistory.tools, [1,1]);
161+
testCase.verifyEqual(sentHistory.tools{1}, expectedTool);
162+
testCase.verifyEqual(response,"Hello");
163+
end
164+
end
87165
end

0 commit comments

Comments
 (0)