Skip to content

Commit 870cc20

Browse files
authored
fix: preserve Gemini thought_signature in multi-turn tool calls (#718)
1 parent 4ea9550 commit 870cc20

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed

.changeset/fine-symbols-jam.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'@openai/agents-extensions': patch
3+
---
4+
5+
fix: preserve Gemini thought_signature in multi-turn tool calls

packages/agents-extensions/src/aiSdk.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,9 @@ export class AiSdkModel implements Model {
711711
name: toolCall.toolName,
712712
arguments: toolCallArguments,
713713
status: 'completed',
714-
providerData: hasToolCalls ? result.providerMetadata : undefined,
714+
providerData:
715+
toolCall.providerMetadata ??
716+
(hasToolCalls ? result.providerMetadata : undefined),
715717
});
716718
}
717719

@@ -916,6 +918,9 @@ export class AiSdkModel implements Model {
916918
name: (part as any).toolName,
917919
arguments: (part as any).input ?? '',
918920
status: 'completed',
921+
...((part as any).providerMetadata
922+
? { providerData: (part as any).providerMetadata }
923+
: {}),
919924
};
920925
}
921926
break;

packages/agents-extensions/test/aiSdk.test.ts

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,106 @@ describe('AiSdkModel.getResponse', () => {
745745
]);
746746
});
747747

748+
test('preserves per-tool-call providerMetadata (e.g., Gemini thoughtSignature)', async () => {
749+
const toolCallProviderMetadata = {
750+
google: { thoughtSignature: 'sig123' },
751+
};
752+
const resultProviderMetadata = {
753+
google: { usageMetadata: { totalTokenCount: 100 } },
754+
};
755+
756+
const model = new AiSdkModel(
757+
stubModel({
758+
async doGenerate() {
759+
return {
760+
content: [
761+
{
762+
type: 'tool-call',
763+
toolCallId: 'c1',
764+
toolName: 'get_weather',
765+
input: { location: 'Tokyo' },
766+
providerMetadata: toolCallProviderMetadata,
767+
},
768+
],
769+
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 },
770+
providerMetadata: resultProviderMetadata,
771+
response: { id: 'resp-1' },
772+
finishReason: 'tool-calls',
773+
warnings: [],
774+
} as any;
775+
},
776+
}),
777+
);
778+
779+
const res = await withTrace('t', () =>
780+
model.getResponse({
781+
input: 'What is the weather in Tokyo?',
782+
tools: [
783+
{
784+
type: 'function',
785+
name: 'get_weather',
786+
description: 'Get weather',
787+
parameters: { type: 'object', properties: {} },
788+
},
789+
],
790+
handoffs: [],
791+
modelSettings: {},
792+
outputType: 'text',
793+
tracing: false,
794+
} as any),
795+
);
796+
797+
expect(res.output).toHaveLength(1);
798+
expect(res.output[0]).toMatchObject({
799+
type: 'function_call',
800+
callId: 'c1',
801+
name: 'get_weather',
802+
providerData: toolCallProviderMetadata,
803+
});
804+
// Ensure we get per-tool-call metadata, not result-level metadata
805+
expect(res.output[0].providerData).not.toEqual(resultProviderMetadata);
806+
});
807+
808+
test('falls back to result.providerMetadata when toolCall.providerMetadata is undefined', async () => {
809+
const resultProviderMetadata = { fallback: true };
810+
811+
const model = new AiSdkModel(
812+
stubModel({
813+
async doGenerate() {
814+
return {
815+
content: [
816+
{
817+
type: 'tool-call',
818+
toolCallId: 'c1',
819+
toolName: 'foo',
820+
input: {},
821+
// No providerMetadata on tool call
822+
},
823+
],
824+
usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 },
825+
providerMetadata: resultProviderMetadata,
826+
response: { id: 'id' },
827+
finishReason: 'tool-calls',
828+
warnings: [],
829+
} as any;
830+
},
831+
}),
832+
);
833+
834+
const res = await withTrace('t', () =>
835+
model.getResponse({
836+
input: 'hi',
837+
tools: [],
838+
handoffs: [],
839+
modelSettings: {},
840+
outputType: 'text',
841+
tracing: false,
842+
} as any),
843+
);
844+
845+
expect(res.output[0].providerData).toEqual(resultProviderMetadata);
846+
});
847+
748848
test('propagates errors', async () => {
749849
const model = new AiSdkModel(
750850
stubModel({
@@ -905,6 +1005,116 @@ describe('AiSdkModel.getStreamedResponse', () => {
9051005
]);
9061006
});
9071007

1008+
test('preserves per-tool-call providerMetadata in streaming mode (e.g., Gemini thoughtSignature)', async () => {
1009+
const toolCallProviderMetadata = {
1010+
google: { thoughtSignature: 'stream-sig-456' },
1011+
};
1012+
1013+
const parts = [
1014+
{
1015+
type: 'tool-call',
1016+
toolCallId: 'c1',
1017+
toolName: 'get_weather',
1018+
input: '{"location":"Tokyo"}',
1019+
providerMetadata: toolCallProviderMetadata,
1020+
},
1021+
{ type: 'response-metadata', id: 'resp-stream-1' },
1022+
{
1023+
type: 'finish',
1024+
finishReason: 'tool-calls',
1025+
usage: { inputTokens: 10, outputTokens: 20 },
1026+
},
1027+
];
1028+
1029+
const model = new AiSdkModel(
1030+
stubModel({
1031+
async doStream() {
1032+
return {
1033+
stream: partsStream(parts),
1034+
} as any;
1035+
},
1036+
}),
1037+
);
1038+
1039+
const events: any[] = [];
1040+
for await (const ev of model.getStreamedResponse({
1041+
input: 'What is the weather?',
1042+
tools: [
1043+
{
1044+
type: 'function',
1045+
name: 'get_weather',
1046+
description: 'Get weather',
1047+
parameters: { type: 'object', properties: {} },
1048+
},
1049+
],
1050+
handoffs: [],
1051+
modelSettings: {},
1052+
outputType: 'text',
1053+
tracing: false,
1054+
} as any)) {
1055+
events.push(ev);
1056+
}
1057+
1058+
const final = events.at(-1);
1059+
expect(final.type).toBe('response_done');
1060+
expect(final.response.output).toHaveLength(1);
1061+
expect(final.response.output[0]).toMatchObject({
1062+
type: 'function_call',
1063+
callId: 'c1',
1064+
name: 'get_weather',
1065+
providerData: toolCallProviderMetadata,
1066+
});
1067+
});
1068+
1069+
test('omits providerData in streaming mode when providerMetadata is not present', async () => {
1070+
const parts = [
1071+
{
1072+
type: 'tool-call',
1073+
toolCallId: 'c1',
1074+
toolName: 'foo',
1075+
input: '{}',
1076+
// No providerMetadata
1077+
},
1078+
{
1079+
type: 'finish',
1080+
finishReason: 'tool-calls',
1081+
usage: { inputTokens: 1, outputTokens: 2 },
1082+
},
1083+
];
1084+
1085+
const model = new AiSdkModel(
1086+
stubModel({
1087+
async doStream() {
1088+
return {
1089+
stream: partsStream(parts),
1090+
} as any;
1091+
},
1092+
}),
1093+
);
1094+
1095+
const events: any[] = [];
1096+
for await (const ev of model.getStreamedResponse({
1097+
input: 'hi',
1098+
tools: [],
1099+
handoffs: [],
1100+
modelSettings: {},
1101+
outputType: 'text',
1102+
tracing: false,
1103+
} as any)) {
1104+
events.push(ev);
1105+
}
1106+
1107+
const final = events.at(-1);
1108+
expect(final.type).toBe('response_done');
1109+
expect(final.response.output[0]).toMatchObject({
1110+
type: 'function_call',
1111+
callId: 'c1',
1112+
name: 'foo',
1113+
});
1114+
// providerData should not be present when providerMetadata was not provided
1115+
expect(final.response.output[0].providerData).toBeUndefined();
1116+
});
1117+
9081118
test('propagates stream errors', async () => {
9091119
const err = new Error('bad');
9101120
const parts = [{ type: 'error', error: err }];

0 commit comments

Comments
 (0)