diff --git a/.changeset/long-keys-watch.md b/.changeset/long-keys-watch.md new file mode 100644 index 00000000000..fb13ed74987 --- /dev/null +++ b/.changeset/long-keys-watch.md @@ -0,0 +1,6 @@ +--- +'firebase': minor +'@firebase/ai': minor +--- + +Add support for `AbortSignal`, allowing requests to be aborted. diff --git a/.vscode/launch.json b/.vscode/launch.json index 8f132cbe5c6..55badac87e6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,7 +37,7 @@ "src/index.node.ts", "--timeout", "5000", - "integration/**/*.test.ts" + "integration/**/prompt-templates.test.ts" ], "env": { "TS_NODE_COMPILER_OPTIONS": "{\"module\":\"commonjs\"}" diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index 2bf194fbaf2..c5a180e0824 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -150,8 +150,8 @@ export class ChatSession { params?: StartChatParams | undefined; // (undocumented) requestOptions?: RequestOptions | undefined; - sendMessage(request: string | Array): Promise; - sendMessageStream(request: string | Array): Promise; + sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; } // @beta @@ -539,9 +539,9 @@ export interface GenerativeContentBlob { export class GenerativeModel extends AIModel { // Warning: (ae-incompatible-release-tags) The symbol "__constructor" is marked as @public, but its signature references "ChromeAdapter" which is marked as @beta constructor(ai: AI, modelParams: ModelParams, requestOptions?: RequestOptions, chromeAdapter?: ChromeAdapter | undefined); - countTokens(request: CountTokensRequest | string | Array): Promise; - generateContent(request: GenerateContentRequest | string | Array): Promise; - generateContentStream(request: GenerateContentRequest | string | Array): Promise; + countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; // (undocumented) generationConfig: GenerationConfig; // (undocumented) @@ -784,9 +784,9 @@ export interface ImagenInlineImage { // @public export class ImagenModel extends AIModel { constructor(ai: AI, modelParams: ImagenModelParams, requestOptions?: RequestOptions | undefined); - generateImages(prompt: string): Promise>; + generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; // @internal - generateImagesGCS(prompt: string, gcsURI: string): Promise>; + generateImagesGCS(prompt: string, gcsURI: string, singleRequestOptions?: SingleRequestOptions): Promise>; generationConfig?: ImagenGenerationConfig; // (undocumented) requestOptions?: RequestOptions | undefined; @@ -1294,6 +1294,11 @@ export interface Segment { text: string; } +// @public +export interface SingleRequestOptions extends RequestOptions { + signal?: AbortSignal; +} + // @beta export interface SpeechConfig { voiceConfig?: VoiceConfig; @@ -1333,8 +1338,9 @@ export class TemplateGenerativeModel { constructor(ai: AI, requestOptions?: RequestOptions); // @internal (undocumented) _apiSettings: ApiSettings; - generateContent(templateId: string, templateVariables: object): Promise; - generateContentStream(templateId: string, templateVariables: object): Promise; + generateContent(templateId: string, templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions): Promise; + generateContentStream(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise; requestOptions?: RequestOptions; } @@ -1343,7 +1349,7 @@ export class TemplateImagenModel { constructor(ai: AI, requestOptions?: RequestOptions); // @internal (undocumented) _apiSettings: ApiSettings; - generateImages(templateId: string, templateVariables: object): Promise>; + generateImages(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise>; requestOptions?: RequestOptions; } diff --git a/docs-devsite/_toc.yaml b/docs-devsite/_toc.yaml index 92633c553a3..06a976686f9 100644 --- a/docs-devsite/_toc.yaml +++ b/docs-devsite/_toc.yaml @@ -190,6 +190,8 @@ toc: path: /docs/reference/js/ai.searchentrypoint.md - title: Segment path: /docs/reference/js/ai.segment.md + - title: SingleRequestOptions + path: /docs/reference/js/ai.singlerequestoptions.md - title: SpeechConfig path: /docs/reference/js/ai.speechconfig.md - title: StartAudioConversationOptions diff --git a/docs-devsite/ai.chatsession.md b/docs-devsite/ai.chatsession.md index 4e4358898a5..2062f9868f1 100644 --- a/docs-devsite/ai.chatsession.md +++ b/docs-devsite/ai.chatsession.md @@ -37,8 +37,8 @@ export declare class ChatSession | Method | Modifiers | Description | | --- | --- | --- | | [getHistory()](./ai.chatsession.md#chatsessiongethistory) | | Gets the chat history so far. Blocked prompts are not added to history. Neither blocked candidates nor the prompts that generated them are added to history. | -| [sendMessage(request)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | -| [sendMessageStream(request)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | +| [sendMessage(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessage) | | Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.generatecontentresult.md#generatecontentresult_interface) | +| [sendMessageStream(request, singleRequestOptions)](./ai.chatsession.md#chatsessionsendmessagestream) | | Sends a chat message and receives the response as a [GenerateContentStreamResult](./ai.generatecontentstreamresult.md#generatecontentstreamresult_interface) containing an iterable stream and a response promise. | ## ChatSession.(constructor) @@ -104,7 +104,7 @@ Sends a chat message and receives a non-streaming [GenerateContentResult](./ai.g Signature: ```typescript -sendMessage(request: string | Array): Promise; +sendMessage(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -112,6 +112,7 @@ sendMessage(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -124,7 +125,7 @@ Sends a chat message and receives the response as a [GenerateContentStreamResult Signature: ```typescript -sendMessageStream(request: string | Array): Promise; +sendMessageStream(request: string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -132,6 +133,7 @@ sendMessageStream(request: string | Array): Promise> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.generativemodel.md b/docs-devsite/ai.generativemodel.md index 323fcfe9d76..4b1c71b8d2c 100644 --- a/docs-devsite/ai.generativemodel.md +++ b/docs-devsite/ai.generativemodel.md @@ -40,9 +40,9 @@ export declare class GenerativeModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [countTokens(request)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | -| [generateContent(request)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | -| [generateContentStream(request)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [countTokens(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelcounttokens) | | Counts the tokens in the provided request. | +| [generateContent(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontent) | | Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(request, singleRequestOptions)](./ai.generativemodel.md#generativemodelgeneratecontentstream) | | Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | | [startChat(startChatParams)](./ai.generativemodel.md#generativemodelstartchat) | | Gets a new [ChatSession](./ai.chatsession.md#chatsession_class) instance which can be used for multi-turn chats. | ## GenerativeModel.(constructor) @@ -119,7 +119,7 @@ Counts the tokens in the provided request. Signature: ```typescript -countTokens(request: CountTokensRequest | string | Array): Promise; +countTokens(request: CountTokensRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -127,6 +127,7 @@ countTokens(request: CountTokensRequest | string | Array): Promis | Parameter | Type | Description | | --- | --- | --- | | request | [CountTokensRequest](./ai.counttokensrequest.md#counttokensrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -139,7 +140,7 @@ Makes a single non-streaming call to the model and returns an object containing Signature: ```typescript -generateContent(request: GenerateContentRequest | string | Array): Promise; +generateContent(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -147,6 +148,7 @@ generateContent(request: GenerateContentRequest | string | Array) | Parameter | Type | Description | | --- | --- | --- | | request | [GenerateContentRequest](./ai.generatecontentrequest.md#generatecontentrequest_interface) \| string \| Array<string \| [Part](./ai.md#part)> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: @@ -159,7 +161,7 @@ Makes a single streaming call to the model and returns an object containing an i Signature: ```typescript -generateContentStream(request: GenerateContentRequest | string | Array): Promise; +generateContentStream(request: GenerateContentRequest | string | Array, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -167,6 +169,7 @@ generateContentStream(request: GenerateContentRequest | string | Array> | | +| singleRequestOptions | [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | | Returns: diff --git a/docs-devsite/ai.imagenmodel.md b/docs-devsite/ai.imagenmodel.md index 68375972cbb..6559723878a 100644 --- a/docs-devsite/ai.imagenmodel.md +++ b/docs-devsite/ai.imagenmodel.md @@ -39,7 +39,7 @@ export declare class ImagenModel extends AIModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateImages(prompt)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | Generates images using the Imagen model and returns them as base64-encoded strings. | +| [generateImages(prompt, singleRequestOptions)](./ai.imagenmodel.md#imagenmodelgenerateimages) | | Generates images using the Imagen model and returns them as base64-encoded strings. | ## ImagenModel.(constructor) @@ -100,7 +100,7 @@ If the prompt was not blocked, but one or more of the generated images were filt Signature: ```typescript -generateImages(prompt: string): Promise>; +generateImages(prompt: string, singleRequestOptions?: SingleRequestOptions): Promise>; ``` #### Parameters @@ -108,6 +108,7 @@ generateImages(prompt: string): PromiseReturns: diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index 53e4057cade..482c49c3cdd 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -133,6 +133,7 @@ The Firebase AI Web SDK. | [SchemaShared](./ai.schemashared.md#schemashared_interface) | Basic [Schema](./ai.schema.md#schema_class) properties shared across several Schema-related types. | | [SearchEntrypoint](./ai.searchentrypoint.md#searchentrypoint_interface) | Google search entry point. | | [Segment](./ai.segment.md#segment_interface) | Represents a specific segment within a [Content](./ai.content.md#content_interface) object, often used to pinpoint the exact location of text or data that grounding information refers to. | +| [SingleRequestOptions](./ai.singlerequestoptions.md#singlerequestoptions_interface) | Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like timeout and baseUrl) with request-specific controls like cancellation via AbortSignal.Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). | | [SpeechConfig](./ai.speechconfig.md#speechconfig_interface) | (Public Preview) Configures speech synthesis. | | [StartAudioConversationOptions](./ai.startaudioconversationoptions.md#startaudioconversationoptions_interface) | (Public Preview) Options for [startAudioConversation()](./ai.md#startaudioconversation_01c8e7f). | | [StartChatParams](./ai.startchatparams.md#startchatparams_interface) | Params for [GenerativeModel.startChat()](./ai.generativemodel.md#generativemodelstartchat). | diff --git a/docs-devsite/ai.singlerequestoptions.md b/docs-devsite/ai.singlerequestoptions.md new file mode 100644 index 00000000000..a55bd3c2f3c --- /dev/null +++ b/docs-devsite/ai.singlerequestoptions.md @@ -0,0 +1,61 @@ +Project: /docs/reference/js/_project.yaml +Book: /docs/reference/_book.yaml +page_type: reference + +{% comment %} +DO NOT EDIT THIS FILE! +This is generated by the JS SDK team, and any local changes will be +overwritten. Changes should be made in the source code at +https://github.com/firebase/firebase-js-sdk +{% endcomment %} + +# SingleRequestOptions interface +Options that can be provided per-request. Extends the base [RequestOptions](./ai.requestoptions.md#requestoptions_interface) (like `timeout` and `baseUrl`) with request-specific controls like cancellation via `AbortSignal`. + +Options specified here will override any default [RequestOptions](./ai.requestoptions.md#requestoptions_interface) configured on a model (for example, [GenerativeModel](./ai.generativemodel.md#generativemodel_class)). + +Signature: + +```typescript +export interface SingleRequestOptions extends RequestOptions +``` +Extends: [RequestOptions](./ai.requestoptions.md#requestoptions_interface) + +## Properties + +| Property | Type | Description | +| --- | --- | --- | +| [signal](./ai.singlerequestoptions.md#singlerequestoptionssignal) | AbortSignal | An AbortSignal instance that allows cancelling ongoing requests (like generateContent or generateImages).If provided, calling abort() on the corresponding AbortController will attempt to cancel the underlying HTTP request. An AbortError will be thrown if cancellation is successful.Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. | + +## SingleRequestOptions.signal + +An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or `generateImages`). + +If provided, calling `abort()` on the corresponding `AbortController` will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown if cancellation is successful. + +Note that this will not cancel the request in the backend, so any applicable billing charges will still be applied despite cancellation. + +Signature: + +```typescript +signal?: AbortSignal; +``` + +### Example + + +```javascript +const controller = new AbortController(); +const model = getGenerativeModel({ + // ... +}); +model.generateContent( + "Write a story about a magic backpack.", + { signal: controller.signal } +); + +// To cancel request: +controller.abort(); + +``` + diff --git a/docs-devsite/ai.templategenerativemodel.md b/docs-devsite/ai.templategenerativemodel.md index c115af62b1e..a9ed568fa19 100644 --- a/docs-devsite/ai.templategenerativemodel.md +++ b/docs-devsite/ai.templategenerativemodel.md @@ -39,8 +39,8 @@ export declare class TemplateGenerativeModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateContent(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | -| [generateContentStream(templateId, templateVariables)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | +| [generateContent(templateId, templateVariables, singleRequestOptions)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontent) | | (Public Preview) Makes a single non-streaming call to the model and returns an object containing a single [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [generateContentStream(templateId, templateVariables, singleRequestOptions)](./ai.templategenerativemodel.md#templategenerativemodelgeneratecontentstream) | | (Public Preview) Makes a single streaming call to the model and returns an object containing an iterable stream that iterates over all chunks in the streaming response as well as a promise that returns the final aggregated response. | ## TemplateGenerativeModel.(constructor) @@ -85,7 +85,8 @@ Makes a single non-streaming call to the model and returns an object containing Signature: ```typescript -generateContent(templateId: string, templateVariables: object): Promise; +generateContent(templateId: string, templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -94,6 +95,7 @@ generateContent(templateId: string, templateVariables: object): PromiseReturns: @@ -109,7 +111,7 @@ Makes a single streaming call to the model and returns an object containing an i Signature: ```typescript -generateContentStream(templateId: string, templateVariables: object): Promise; +generateContentStream(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise; ``` #### Parameters @@ -118,6 +120,7 @@ generateContentStream(templateId: string, templateVariables: object): PromiseReturns: diff --git a/docs-devsite/ai.templateimagenmodel.md b/docs-devsite/ai.templateimagenmodel.md index 2d86071993f..3b33d94f71f 100644 --- a/docs-devsite/ai.templateimagenmodel.md +++ b/docs-devsite/ai.templateimagenmodel.md @@ -39,7 +39,7 @@ export declare class TemplateImagenModel | Method | Modifiers | Description | | --- | --- | --- | -| [generateImages(templateId, templateVariables)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | +| [generateImages(templateId, templateVariables, singleRequestOptions)](./ai.templateimagenmodel.md#templateimagenmodelgenerateimages) | | (Public Preview) Makes a single call to the model and returns an object containing a single [ImagenGenerationResponse](./ai.imagengenerationresponse.md#imagengenerationresponse_interface). | ## TemplateImagenModel.(constructor) @@ -84,7 +84,7 @@ Makes a single call to the model and returns an object containing a single [Imag Signature: ```typescript -generateImages(templateId: string, templateVariables: object): Promise>; +generateImages(templateId: string, templateVariables: object, singleRequestOptions?: SingleRequestOptions): Promise>; ``` #### Parameters @@ -93,6 +93,7 @@ generateImages(templateId: string, templateVariables: object): PromiseReturns: diff --git a/packages/ai/integration/constants.ts b/packages/ai/integration/constants.ts index 99a65f31c54..04061ba2e47 100644 --- a/packages/ai/integration/constants.ts +++ b/packages/ai/integration/constants.ts @@ -22,7 +22,8 @@ import { BackendType, GoogleAIBackend, VertexAIBackend, - getAI + getAI, + getGenerativeModel } from '../src'; import { FIREBASE_CONFIG } from './firebase-config'; @@ -54,6 +55,16 @@ const backendNames: Map = new Map([ const modelNames: readonly string[] = ['gemini-2.0-flash', 'gemini-2.5-flash']; +// Used for testing non-AI behavior (e.g. Network requests). Configured to minimize cost. +export const cheapestModel = 'gemini-2.0-flash'; +export const defaultAIInstance = getAI(app, { backend: new VertexAIBackend() }); +export const defaultGenerativeModel = getGenerativeModel(defaultAIInstance, { + model: cheapestModel, + generationConfig: { + maxOutputTokens: 10 // Just enough to confirm we actually get something back. + } +}); + // The Live API requires a different set of models, and they're different for each backend. const liveModelNames: Map = new Map([ [BackendType.GOOGLE_AI, ['gemini-live-2.5-flash-preview']], diff --git a/packages/ai/integration/generate-content.test.ts b/packages/ai/integration/generate-content.test.ts index ffb1ecca698..083a6c18000 100644 --- a/packages/ai/integration/generate-content.test.ts +++ b/packages/ai/integration/generate-content.test.ts @@ -15,7 +15,11 @@ * limitations under the License. */ -import { expect } from 'chai'; +import chai, { AssertionError } from 'chai'; +import chaiAsPromised from 'chai-as-promised'; +chai.use(chaiAsPromised); +const expect = chai.expect; + import { BackendType, Content, @@ -29,7 +33,15 @@ import { URLRetrievalStatus, getGenerativeModel } from '../src'; -import { testConfigs, TOKEN_COUNT_DELTA } from './constants'; +import { + cheapestModel, + defaultAIInstance, + defaultGenerativeModel, + testConfigs, + TOKEN_COUNT_DELTA +} from './constants'; +import { TIMEOUT_EXPIRED_MESSAGE } from '../src/requests/request'; +import { isNode } from '@firebase/util'; describe('Generate Content', function () { this.timeout(20_000); @@ -370,4 +382,154 @@ describe('Generate Content', function () { }); }); }); + + describe('Request Options', async () => { + const defaultAbortReason = isNode() + ? 'This operation was aborted' + : 'signal is aborted without reason'; + describe('unary', async () => { + it('timeout cancels request', async () => { + await expect( + defaultGenerativeModel.generateContent('hello', { timeout: 100 }) + ).to.be.rejectedWith(DOMException, TIMEOUT_EXPIRED_MESSAGE); + }); + + it('long timeout does not cancel request', async () => { + const result = await defaultGenerativeModel.generateContent('hello', { + timeout: 50_000 + }); + expect(result.response.text().length).to.be.greaterThan(0); + }); + + it('abort signal with no reason causes request to throw AbortError', async () => { + const abortController = new AbortController(); + const responsePromise = defaultGenerativeModel.generateContent( + 'hello', + { signal: abortController.signal } + ); + abortController.abort(); + await expect(responsePromise) + .to.be.rejectedWith(DOMException, defaultAbortReason) + .and.eventually.have.property('name', 'AbortError'); + }); + + it('abort signal with string reason causes request to throw reason string', async () => { + const abortController = new AbortController(); + const responsePromise = defaultGenerativeModel.generateContent( + 'hello', + { signal: abortController.signal } + ); + const reason = 'Cancelled'; + abortController.abort(reason); + await expect(responsePromise).to.be.rejectedWith(reason); + }); + + it('abort signal with error reason causes request to throw reason error', async () => { + const abortController = new AbortController(); + const responsePromise = defaultGenerativeModel.generateContent( + 'hello', + { signal: abortController.signal } + ); + abortController.abort(new Error('Cancelled')); + // `fetch()` will reject with the exact object we passed to `abort()`. Since we throw a generic + // Error, we cannot differentiate between this error and other generic fetch errors, which + // we wrap in an AIError. + await expect(responsePromise) + .to.be.rejectedWith(Error, 'Cancelled') + .and.eventually.have.property('name', 'FirebaseError'); + }); + }); + + describe('streaming', async () => { + it('timeout cancels initial request', async () => { + await expect( + defaultGenerativeModel.generateContent('hello', { timeout: 50 }) + ).to.be.rejectedWith(DOMException, TIMEOUT_EXPIRED_MESSAGE); + }); + + it('timeout does not cancel request once streaming has begun', async () => { + const generativeModel = getGenerativeModel(defaultAIInstance, { + model: cheapestModel + }); + // Setting a timeout that will be in the interval between the stream starting and ending. + // Since the timeout will expire once the stream has begun, it should have already been + // cleared, and so it shouldn't abort the stream. + const { stream, response } = + await generativeModel.generateContentStream( + 'tell me a short story with 200 words.', + { timeout: 1_000 } + ); + + // We should be able to get through the entire stream without an error being thrown + // from the async generator. + for await (const chunk of stream) { + expect(chunk.text().length).to.be.greaterThan(0); + } + + expect((await response).text().length).to.be.greaterThan(0); + }); + + it('abort signal without reason should cancel stream with default abort reason', async () => { + const abortController = new AbortController(); + const generativeModel = getGenerativeModel(defaultAIInstance, { + model: cheapestModel + }); + const { stream, response } = + await generativeModel.generateContentStream( + 'tell me a short story with 200 words.', + { signal: abortController.signal } + ); + + // As soon as the initial request resolves and the stream starts, abort the stream. + abortController.abort(); + + try { + for await (const _ of stream) { + expect.fail('Expected stream to throw an error'); + } + expect.fail('Expected stream to throw an error'); + } catch (err) { + if ((err as Error) instanceof AssertionError) { + throw err; + } + expect(err).to.be.instanceof(DOMException); + expect((err as Error).name).to.equal('AbortError'); + expect((err as Error).message).to.equal(defaultAbortReason); + } + + await expect(response) + .to.be.rejectedWith(DOMException, defaultAbortReason) + .and.to.eventually.have.property('name', 'AbortError'); + }); + + it('abort signal with reason string should cancel stream with string abort reason', async () => { + const abortController = new AbortController(); + const generativeModel = getGenerativeModel(defaultAIInstance, { + model: cheapestModel + }); + const { stream, response } = + await generativeModel.generateContentStream( + 'tell me a short story with 200 words.', + { signal: abortController.signal } + ); + + // As soon as the initial request resolves and the stream starts, abort the stream. + abortController.abort('Cancelled'); + + try { + for await (const _ of stream) { + expect.fail('Expected stream to throw an error'); + } + expect.fail('Expected stream to throw an error'); + } catch (err) { + if ((err as Error) instanceof AssertionError) { + throw err; + } + expect(err).to.equal('Cancelled'); + } + + await expect(response).to.be.rejectedWith('Cancelled'); + }); + }); + }); }); diff --git a/packages/ai/integration/prompt-templates.test.ts b/packages/ai/integration/prompt-templates.test.ts index 3a7f9038561..34424427b8e 100644 --- a/packages/ai/integration/prompt-templates.test.ts +++ b/packages/ai/integration/prompt-templates.test.ts @@ -35,16 +35,25 @@ describe('Prompt templates', function () { describe(`${testConfig.toString()}`, () => { describe('Generative Model', () => { it('successfully generates content', async () => { + const a = new AbortController(); const model = getTemplateGenerativeModel(testConfig.ai, { baseUrl: STAGING_URL }); - const { response } = await model.generateContent( - `sassy-greeting-${templateBackendSuffix( - testConfig.ai.backend.backendType - )}`, - { name: 'John' } - ); - expect(response.text()).to.contain('John'); // Template asks to address directly by name + // a.abort(); + try { + await model.generateContent( + `sassy-greeting-${templateBackendSuffix( + testConfig.ai.backend.backendType + )}`, + { name: 'John' }, + { signal: a.signal, timeout: 100 } + ); + } catch (e) { + console.error(e); + if ((e as DOMException).name === 'AbortError') { + console.log(1); + } + } }); }); describe('Imagen model', async () => { @@ -56,7 +65,8 @@ describe('Prompt templates', function () { `portrait-${templateBackendSuffix( testConfig.ai.backend.backendType )}`, - { animal: 'Rhino' } + { animal: 'Rhino' }, + { timeout: 100 } ); expect(images.length).to.equal(2); // We ask for two images in the prompt template }); diff --git a/packages/ai/package.json b/packages/ai/package.json index dcb6f11fdbf..d988d25e734 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -41,7 +41,7 @@ "test:browser": "yarn testsetup && karma start", "test:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha --require ts-node/register --require src/index.node.ts 'src/**/!(*-browser)*.test.ts' --config ../../config/mocharc.node.js", "test:integration": "karma start --integration", - "test:integration:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha integration/**/*.test.ts --config ../../config/mocharc.node.js", + "test:integration:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' mocha integration/**/prompt-templates.test.ts --config ../../config/mocharc.node.js", "api-report": "api-extractor run --local --verbose", "typings:public": "node ../../scripts/build/use_typings.js ./dist/ai-public.d.ts", "type-check": "yarn tsc --noEmit", diff --git a/packages/ai/src/constants.ts b/packages/ai/src/constants.ts index 0a6f7e91436..0282edb2e13 100644 --- a/packages/ai/src/constants.ts +++ b/packages/ai/src/constants.ts @@ -32,7 +32,7 @@ export const PACKAGE_VERSION = version; export const LANGUAGE_TAG = 'gl-js'; -export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; +export const DEFAULT_FETCH_TIMEOUT_MS = 180 * 1000; // TODO: Extend default timeout to accommodate for longer generation requests with pro models. /** * Defines the name of the default in-cloud model to use for hybrid inference. diff --git a/packages/ai/src/methods/chat-session.test.ts b/packages/ai/src/methods/chat-session.test.ts index 1273d02876c..1f084d0b1e5 100644 --- a/packages/ai/src/methods/chat-session.test.ts +++ b/packages/ai/src/methods/chat-session.test.ts @@ -25,6 +25,7 @@ import { ChatSession } from './chat-session'; import { ApiSettings } from '../types/internal'; import { VertexAIBackend } from '../backend'; import { fakeChromeAdapter } from '../../test-utils/get-fake-firebase-services'; +import { logger } from '../logger'; use(sinonChai); use(chaiAsPromised); @@ -59,6 +60,68 @@ describe('ChatSession', () => { match.any ); }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessage('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); it('adds message and response to history', async () => { const fakeContent: Content = { role: 'model', @@ -124,6 +187,7 @@ describe('ChatSession', () => { expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, 'a-model', + match.any, match.any ); await clock.runAllAsync(); @@ -147,6 +211,7 @@ describe('ChatSession', () => { expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, 'a-model', + match.any, match.any ); await clock.runAllAsync(); @@ -156,5 +221,161 @@ describe('ChatSession', () => { ); clock.restore(); }); + it('logs error and rejects user promise when response aggregation fails', async () => { + const loggerStub = stub(logger, 'error'); + const error = new Error('Aggregation failed'); + + // Simulate stream returning, but the response promise failing (e.g. parsing error) + stub(generateContentMethods, 'generateContentStream').resolves({ + stream: (async function* () {})(), + response: Promise.reject(error) + } as unknown as GenerateContentStreamResult); + + const chatSession = new ChatSession(fakeApiSettings, 'a-model'); + const initialHistoryLength = (await chatSession.getHistory()).length; + + // Immediate call resolves with the stream object + const result = await chatSession.sendMessageStream('hello'); + + // User's response promise should reject + await expect(result.response).to.be.rejectedWith(error); + + // Wait for the internal _sendPromise chain to settle + await new Promise(resolve => setTimeout(resolve, 0)); + + expect(loggerStub).to.have.been.calledWith(error); + + // History should NOT have been updated (no response appended) + const finalHistory = await chatSession.getHistory(); + expect(finalHistory.length).to.equal(initialHistoryLength); + }); + it('logs error but resolves user promise when history appending logic fails', async () => { + const loggerStub = stub(logger, 'error'); + + // Simulate a response that is technically valid enough to resolve aggregation, + // but malformed in a way that causes the history update logic to throw. + // Passing `null` as a candidate causes `{ ...response.candidates[0].content }` to throw. + const malformedResponse = { + candidates: [null] + }; + + stub(generateContentMethods, 'generateContentStream').resolves({ + stream: (async function* () {})(), + response: Promise.resolve(malformedResponse) + } as unknown as GenerateContentStreamResult); + + const chatSession = new ChatSession(fakeApiSettings, 'a-model'); + const initialHistoryLength = (await chatSession.getHistory()).length; + + const result = await chatSession.sendMessageStream('hello'); + + // The user's response promise SHOULD resolve, because aggregation succeeded. + // The error is purely internal side-effect (history update). + await expect(result.response).to.eventually.equal(malformedResponse); + + // Wait for internal chain + await new Promise(resolve => setTimeout(resolve, 0)); + + expect(loggerStub).to.have.been.called; + const errorArg = loggerStub.firstCall.args[0]; + expect(errorArg).to.be.instanceOf(TypeError); + + // The user message WAS added before the crash, but the response wasn't. + const finalHistory = await chatSession.getHistory(); + expect(finalHistory.length).to.equal(initialHistoryLength + 1); + expect(finalHistory[finalHistory.length - 1].role).to.equal('user'); + }); + it('error from stream promise should not be logged', async () => { + const consoleStub = stub(console, 'error'); + stub(generateContentMethods, 'generateContentStream').rejects('foo'); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + fakeChromeAdapter + ); + try { + // This will throw since generateContentStream will reject immediately. + await chatSession.sendMessageStream('hello'); + } catch (e) { + expect((e as unknown as any).name).to.equal('foo'); + } + + expect(consoleStub).to.not.have.been.called; + }); + it('error from final response promise should not be logged', async () => { + const consoleStub = stub(console, 'error'); + stub(generateContentMethods, 'generateContentStream').resolves({ + response: new Promise((_, reject) => reject(new Error())) + } as unknown as GenerateContentStreamResult); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + fakeChromeAdapter + ); + await chatSession.sendMessageStream('hello'); + expect(consoleStub).to.not.have.been.called; + }); + it('singleRequestOptions overrides requestOptions', async () => { + const generateContentStreamStub = stub( + generateContentMethods, + 'generateContentStream' + ).rejects('generateContentStream failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessageStream('hello', singleRequestOptions)) + .to.be.rejected; + expect(generateContentStreamStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('singleRequestOptions is merged with requestOptions', async () => { + const generateContentStreamStub = stub( + generateContentMethods, + 'generateContentStream' + ).rejects('generateContentStream failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + undefined, + undefined, + requestOptions + ); + await expect(chatSession.sendMessageStream('hello', singleRequestOptions)) + .to.be.rejected; + expect(generateContentStreamStub).to.be.calledWith( + fakeApiSettings, + 'a-model', + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); }); diff --git a/packages/ai/src/methods/chat-session.ts b/packages/ai/src/methods/chat-session.ts index dac16430b7a..d70d755a32b 100644 --- a/packages/ai/src/methods/chat-session.ts +++ b/packages/ai/src/methods/chat-session.ts @@ -22,6 +22,7 @@ import { GenerateContentStreamResult, Part, RequestOptions, + SingleRequestOptions, StartChatParams } from '../types'; import { formatNewContent } from '../requests/request-helpers'; @@ -33,7 +34,11 @@ import { logger } from '../logger'; import { ChromeAdapter } from '../types/chrome-adapter'; /** - * Do not log a message for this error. + * Used to break the internal promise chain when an error is already handled + * by the user, preventing duplicate console logs. + * + * TODO: Refactor to use `Promise.allSettled` to decouple the internal + * sequencing chain from user error handling. */ const SILENT_ERROR = 'SILENT_ERROR'; @@ -46,6 +51,11 @@ const SILENT_ERROR = 'SILENT_ERROR'; export class ChatSession { private _apiSettings: ApiSettings; private _history: Content[] = []; + + /** + * Ensures sequential execution of chat messages to maintain history order. + * Each call waits for the previous one to settle before proceeding. + */ private _sendPromise: Promise = Promise.resolve(); constructor( @@ -77,7 +87,8 @@ export class ChatSession { * {@link GenerateContentResult} */ async sendMessage( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -90,7 +101,7 @@ export class ChatSession { contents: [...this._history, newContent] }; let finalResult = {} as GenerateContentResult; - // Add onto the chain. + this._sendPromise = this._sendPromise .then(() => generateContent( @@ -98,10 +109,16 @@ export class ChatSession { this.model, generateContentRequest, this.chromeAdapter, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ) ) .then(result => { + // TODO: Make this update atomic. If creating `responseContent` throws, + // history will contain the user message but not the response, causing + // validation errors on the next request. if ( result.response.candidates && result.response.candidates.length > 0 @@ -109,7 +126,6 @@ export class ChatSession { this._history.push(newContent); const responseContent: Content = { parts: result.response.candidates?.[0].content.parts || [], - // Response seems to come back without a role set. role: result.response.candidates?.[0].content.role || 'model' }; this._history.push(responseContent); @@ -133,7 +149,8 @@ export class ChatSession { * and a response promise. */ async sendMessageStream( - request: string | Array + request: string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { await this._sendPromise; const newContent = formatNewContent(request); @@ -150,23 +167,30 @@ export class ChatSession { this.model, generateContentRequest, this.chromeAdapter, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); - // Add onto the chain. + // We hook into the chain to update history, but we don't block the + // return of `streamPromise` to the user. this._sendPromise = this._sendPromise .then(() => streamPromise) - // This must be handled to avoid unhandled rejection, but jump - // to the final catch block with a label to not log this error. .catch(_ignored => { + // If the initial fetch fails, the user's `streamPromise` rejects. + // We swallow the error here to prevent double logging in the final catch. throw new Error(SILENT_ERROR); }) .then(streamResult => streamResult.response) .then(response => { + // This runs after the stream completes. Runtime errors here cannot be + // caught by the user because their promise has likely already resolved. + // TODO: Move response validation logic upstream to `stream-reader` so + // errors propagate to the user's `result.response` promise. if (response.candidates && response.candidates.length > 0) { this._history.push(newContent); const responseContent = { ...response.candidates[0].content }; - // Response seems to come back without a role set. if (!responseContent.role) { responseContent.role = 'model'; } @@ -181,12 +205,8 @@ export class ChatSession { } }) .catch(e => { - // Errors in streamPromise are already catchable by the user as - // streamPromise is returned. - // Avoid duplicating the error message in logs. - if (e.message !== SILENT_ERROR) { - // Users do not have access to _sendPromise to catch errors - // downstream from streamPromise, so they should not throw. + // Filter out errors already handled by the user or initiated by them. + if (e.message !== SILENT_ERROR && e.name !== 'AbortError') { logger.error(e); } }); diff --git a/packages/ai/src/methods/count-tokens.test.ts b/packages/ai/src/methods/count-tokens.test.ts index b3ed7f7fa4d..67eed84ea13 100644 --- a/packages/ai/src/methods/count-tokens.test.ts +++ b/packages/ai/src/methods/count-tokens.test.ts @@ -77,7 +77,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -108,7 +108,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -137,7 +137,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, match((value: string) => { return value.includes('contents'); @@ -191,7 +191,7 @@ describe('countTokens()', () => { task: Task.COUNT_TOKENS, apiSettings: fakeGoogleAIApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')) ); diff --git a/packages/ai/src/methods/count-tokens.ts b/packages/ai/src/methods/count-tokens.ts index 20c633ee703..1731592a0e2 100644 --- a/packages/ai/src/methods/count-tokens.ts +++ b/packages/ai/src/methods/count-tokens.ts @@ -19,6 +19,7 @@ import { AIError } from '../errors'; import { CountTokensRequest, CountTokensResponse, + SingleRequestOptions, InferenceMode, RequestOptions, AIErrorCode @@ -33,7 +34,7 @@ export async function countTokensOnCloud( apiSettings: ApiSettings, model: string, params: CountTokensRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { let body: string = ''; if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { @@ -48,7 +49,7 @@ export async function countTokensOnCloud( task: Task.COUNT_TOKENS, apiSettings, stream: false, - requestOptions + singleRequestOptions }, body ); diff --git a/packages/ai/src/methods/generate-content.test.ts b/packages/ai/src/methods/generate-content.test.ts index 8a274c24417..82858844266 100644 --- a/packages/ai/src/methods/generate-content.test.ts +++ b/packages/ai/src/methods/generate-content.test.ts @@ -115,7 +115,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -141,7 +141,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -179,7 +179,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -209,7 +209,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -259,7 +259,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -356,7 +356,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -381,7 +381,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -406,7 +406,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -447,7 +447,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: undefined + singleRequestOptions: undefined }, JSON.stringify(fakeRequestParams) ); @@ -542,7 +542,7 @@ describe('generateContent()', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeGoogleAIApiSettings, stream: false, - requestOptions: match.any + singleRequestOptions: match.any }, JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)) ); @@ -585,13 +585,13 @@ describe('templateGenerateContent', () => { ); const templateId = 'my-template'; const templateParams = { name: 'world' }; - const requestOptions = { timeout: 5000 }; + const singleRequestOptions = { timeout: 5000 }; const result = await templateGenerateContent( fakeApiSettings, templateId, templateParams, - requestOptions + singleRequestOptions ); expect(makeRequestStub).to.have.been.calledOnceWith( @@ -600,7 +600,7 @@ describe('templateGenerateContent', () => { templateId, apiSettings: fakeApiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -622,13 +622,13 @@ describe('templateGenerateContentStream', () => { ); const templateId = 'my-stream-template'; const templateParams = { name: 'streaming world' }; - const requestOptions = { timeout: 10000 }; + const singleRequestOptions = { timeout: 10000 }; const result = await templateGenerateContentStream( fakeApiSettings, templateId, templateParams, - requestOptions + singleRequestOptions ); expect(makeRequestStub).to.have.been.calledOnceWith( @@ -637,7 +637,7 @@ describe('templateGenerateContentStream', () => { templateId, apiSettings: fakeApiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index fc6eac15c74..ce15e7c7f7c 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -20,7 +20,7 @@ import { GenerateContentResponse, GenerateContentResult, GenerateContentStreamResult, - RequestOptions + SingleRequestOptions } from '../types'; import { makeRequest, @@ -39,7 +39,7 @@ async function generateContentStreamOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -50,7 +50,7 @@ async function generateContentStreamOnCloud( model, apiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(params) ); @@ -61,14 +61,19 @@ export async function generateContentStream( model: string, params: GenerateContentRequest, chromeAdapter?: ChromeAdapter, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContentStream(params), () => - generateContentStreamOnCloud(apiSettings, model, params, requestOptions) + generateContentStreamOnCloud( + apiSettings, + model, + params, + singleRequestOptions + ) ); return processStream(callResult.response, apiSettings); // TODO: Map streaming responses } @@ -77,7 +82,7 @@ async function generateContentOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { if (apiSettings.backend.backendType === BackendType.GOOGLE_AI) { params = GoogleAIMapper.mapGenerateContentRequest(params); @@ -88,7 +93,7 @@ async function generateContentOnCloud( task: Task.GENERATE_CONTENT, apiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(params) ); @@ -98,7 +103,7 @@ export async function templateGenerateContent( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const response = await makeRequest( { @@ -106,7 +111,7 @@ export async function templateGenerateContent( templateId, apiSettings, stream: false, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -126,7 +131,7 @@ export async function templateGenerateContentStream( apiSettings: ApiSettings, templateId: string, templateParams: object, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const response = await makeRequest( { @@ -134,7 +139,7 @@ export async function templateGenerateContentStream( templateId, apiSettings, stream: true, - requestOptions + singleRequestOptions }, JSON.stringify(templateParams) ); @@ -146,13 +151,14 @@ export async function generateContent( model: string, params: GenerateContentRequest, chromeAdapter?: ChromeAdapter, - requestOptions?: RequestOptions + singleRequestOptions?: SingleRequestOptions ): Promise { const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContent(params), - () => generateContentOnCloud(apiSettings, model, params, requestOptions) + () => + generateContentOnCloud(apiSettings, model, params, singleRequestOptions) ); const generateContentResponse = await processGenerateContentResponse( callResult.response, diff --git a/packages/ai/src/models/generative-model.test.ts b/packages/ai/src/models/generative-model.test.ts index 45430cb5f59..8d8bfc7c544 100644 --- a/packages/ai/src/models/generative-model.test.ts +++ b/packages/ai/src/models/generative-model.test.ts @@ -30,6 +30,8 @@ import { getMockResponseStreaming } from '../../test-utils/mock-response'; import sinonChai from 'sinon-chai'; +import * as generateContentMethods from '../methods/generate-content'; +import * as countTokens from '../methods/count-tokens'; import { VertexAIBackend } from '../backend'; import { AIError } from '../errors'; import chaiAsPromised from 'chai-as-promised'; @@ -53,6 +55,9 @@ const fakeAI: AI = { }; describe('GenerativeModel', () => { + afterEach(() => { + restore(); + }); it('passes params through to generateContent', async () => { const genModel = new GenerativeModel( fakeAI, @@ -97,7 +102,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -136,7 +141,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return value.includes('be friendly'); @@ -199,7 +204,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -213,6 +218,34 @@ describe('GenerativeModel', () => { ); restore(); }); + it('generateContent singleRequestOptions overrides requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); it('passes base model params through to ChatSession when there are no startChatParams', async () => { const genModel = new GenerativeModel( fakeAI, @@ -231,18 +264,56 @@ describe('GenerativeModel', () => { }); restore(); }); - it('overrides base model params with startChatParams', () => { + it('generateContent singleRequestOptions is merged with requestOptions', async () => { + const generateContentStub = stub( + generateContentMethods, + 'generateContent' + ).rejects('generateContent failed'); // not important + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; const genModel = new GenerativeModel( fakeAI, - { - model: 'my-model', - generationConfig: { - topK: 1 - } - }, - {}, - fakeChromeAdapter + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.generateContent('hello', singleRequestOptions)).to.be + .rejected; + expect(generateContentStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) ); + }); + it('passes base model params through to ChatSession when there are no startChatParams', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); + const chatSession = genModel.startChat(); + expect(chatSession.params?.generationConfig).to.deep.equal({ + topK: 1 + }); + restore(); + }); + it('overrides base model params with startChatParams', () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + generationConfig: { + topK: 1 + } + }); const chatSession = genModel.startChat({ generationConfig: { topK: 2 @@ -292,7 +363,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( @@ -332,7 +403,7 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return value.includes('be friendly'); @@ -346,7 +417,9 @@ describe('GenerativeModel', () => { { model: 'my-model', tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }, + { googleSearch: {} }, + { urlContext: {} } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } @@ -359,6 +432,80 @@ describe('GenerativeModel', () => { {}, fakeChromeAdapter ); + expect(genModel.tools?.length).to.equal(3); + expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( + FunctionCallingMode.NONE + ); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + singleRequestOptions: {} + }, + match((value: string) => { + return ( + value.includes('myfunc') && + value.includes(FunctionCallingMode.NONE) && + value.includes('be friendly') + // value.includes('topK') + ); + }) + ); + restore(); + }); + it('passes text-only systemInstruction through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + systemInstruction: 'be friendly' + }); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + { + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: match.any, + stream: false, + singleRequestOptions: {} + }, + match((value: string) => { + return value.includes('be friendly'); + }) + ); + restore(); + }); + it('startChat overrides model values', async () => { + const genModel = new GenerativeModel(fakeAI, { + model: 'my-model', + tools: [ + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + generationConfig: { + responseMimeType: 'image/jpeg' + } + }); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -378,9 +525,7 @@ describe('GenerativeModel', () => { functionDeclarations: [ { name: 'otherfunc', description: 'otherdesc' } ] - }, - { googleSearch: {} }, - { codeExecution: {} } + } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } @@ -397,13 +542,11 @@ describe('GenerativeModel', () => { task: request.Task.GENERATE_CONTENT, apiSettings: match.any, stream: false, - requestOptions: {} + singleRequestOptions: {} }, match((value: string) => { return ( value.includes('otherfunc') && - value.includes('googleSearch') && - value.includes('codeExecution') && value.includes(FunctionCallingMode.AUTO) && value.includes('be formal') && value.includes('image/png') && @@ -434,7 +577,7 @@ describe('GenerativeModel', () => { task: request.Task.COUNT_TOKENS, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return value.includes('hello'); @@ -442,6 +585,62 @@ describe('GenerativeModel', () => { ); restore(); }); + it('countTokens singleRequestOptions overrides requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: singleRequestOptions.timeout + }) + ); + }); + it('countTokens singleRequestOptions is merged with requestOptions', async () => { + const countTokensStub = stub(countTokens, 'countTokens').rejects( + 'countTokens failed' + ); + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const genModel = new GenerativeModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + await expect(genModel.countTokens('hello', singleRequestOptions)).to.be + .rejected; + expect(countTokensStub).to.be.calledWith( + match.any, + match.any, + match.any, + match.any, + match({ + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + }) + ); + }); }); describe('GenerativeModel dispatch logic', () => { diff --git a/packages/ai/src/models/generative-model.ts b/packages/ai/src/models/generative-model.ts index ffce645eeb1..8defedd33bd 100644 --- a/packages/ai/src/models/generative-model.ts +++ b/packages/ai/src/models/generative-model.ts @@ -29,11 +29,12 @@ import { GenerationConfig, ModelParams, Part, - RequestOptions, SafetySetting, + RequestOptions, StartChatParams, Tool, - ToolConfig + ToolConfig, + SingleRequestOptions } from '../types'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; @@ -79,7 +80,8 @@ export class GenerativeModel extends AIModel { * and returns an object containing a single {@link GenerateContentResponse}. */ async generateContent( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContent( @@ -94,7 +96,11 @@ export class GenerativeModel extends AIModel { ...formattedParams }, this.chromeAdapter, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -105,7 +111,8 @@ export class GenerativeModel extends AIModel { * a promise that returns the final aggregated response. */ async generateContentStream( - request: GenerateContentRequest | string | Array + request: GenerateContentRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContentStream( @@ -120,7 +127,11 @@ export class GenerativeModel extends AIModel { ...formattedParams }, this.chromeAdapter, - this.requestOptions + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -154,14 +165,20 @@ export class GenerativeModel extends AIModel { * Counts the tokens in the provided request. */ async countTokens( - request: CountTokensRequest | string | Array + request: CountTokensRequest | string | Array, + singleRequestOptions?: SingleRequestOptions ): Promise { const formattedParams = formatGenerateContentInput(request); return countTokens( this._apiSettings, this.model, formattedParams, - this.chromeAdapter + this.chromeAdapter, + // Merge request options + { + ...this.requestOptions, + ...singleRequestOptions + } ); } } diff --git a/packages/ai/src/models/imagen-model.test.ts b/packages/ai/src/models/imagen-model.test.ts index 68b6caca098..34470cc1d14 100644 --- a/packages/ai/src/models/imagen-model.test.ts +++ b/packages/ai/src/models/imagen-model.test.ts @@ -47,6 +47,9 @@ const fakeAI: AI = { }; describe('ImagenModel', () => { + afterEach(() => { + restore(); + }); it('generateImages makes a request to predict with default parameters', async () => { const mockResponse = getMockResponse( 'vertexAI', @@ -67,7 +70,7 @@ describe('ImagenModel', () => { task: request.Task.PREDICT, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return ( @@ -76,7 +79,6 @@ describe('ImagenModel', () => { ); }) ); - restore(); }); it('generateImages makes a request to predict with generation config and safety settings', async () => { const imagenModel = new ImagenModel(fakeAI, { @@ -109,7 +111,7 @@ describe('ImagenModel', () => { task: request.Task.PREDICT, apiSettings: match.any, stream: false, - requestOptions: undefined + singleRequestOptions: {} }, match((value: string) => { return ( @@ -137,7 +139,76 @@ describe('ImagenModel', () => { ); }) ); - restore(); + }); + it('generateImages singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: singleRequestOptions.timeout + } + }, + match.any + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-base64.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImages(prompt, singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + } + }, + match.any + ); }); it('throws if prompt blocked', async () => { const mockResponse = getMockResponse( @@ -163,8 +234,76 @@ describe('ImagenModel', () => { expect((e as AIError).message).to.include( "Image generation failed with the following error: The prompt could not be submitted. This prompt contains sensitive words that violate Google's Responsible AI practices. Try rephrasing the prompt. If you think this was an error, send feedback." ); - } finally { - restore(); } }); + it('generateImagesGCS singleRequestOptions overrides requestOptions', async () => { + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + timeout: 2000 + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: singleRequestOptions.timeout + } + }, + match.any + ); + }); + it('generateImages singleRequestOptions is merged with requestOptions', async () => { + const abortController = new AbortController(); + const requestOptions = { + timeout: 1000 + }; + const singleRequestOptions = { + signal: abortController.signal + }; + const imagenModel = new ImagenModel( + fakeAI, + { model: 'my-model' }, + requestOptions + ); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-generate-images-gcs.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const prompt = 'A photorealistic image of a toy boat at sea.'; + await imagenModel.generateImagesGCS(prompt, '', singleRequestOptions); + expect(makeRequestStub).to.be.calledWith( + { + model: match.any, + task: request.Task.PREDICT, + apiSettings: match.any, + stream: false, + singleRequestOptions: { + timeout: requestOptions.timeout, + signal: singleRequestOptions.signal + } + }, + match.any + ); + }); }); diff --git a/packages/ai/src/models/imagen-model.ts b/packages/ai/src/models/imagen-model.ts index 567333ee64f..beeb01ac12c 100644 --- a/packages/ai/src/models/imagen-model.ts +++ b/packages/ai/src/models/imagen-model.ts @@ -26,7 +26,8 @@ import { RequestOptions, ImagenModelParams, ImagenGenerationResponse, - ImagenSafetySettings + ImagenSafetySettings, + SingleRequestOptions } from '../types'; import { AIModel } from './ai-model'; @@ -102,7 +103,8 @@ export class ImagenModel extends AIModel { * @public */ async generateImages( - prompt: string + prompt: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { ...this.generationConfig, @@ -114,7 +116,11 @@ export class ImagenModel extends AIModel { model: this.model, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + // Merge request options. Single request options overwrite the model's request options. + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify(body) ); @@ -142,7 +148,8 @@ export class ImagenModel extends AIModel { */ async generateImagesGCS( prompt: string, - gcsURI: string + gcsURI: string, + singleRequestOptions?: SingleRequestOptions ): Promise> { const body = createPredictRequestBody(prompt, { gcsURI, @@ -155,7 +162,11 @@ export class ImagenModel extends AIModel { model: this.model, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + // Merge request options. Single request options overwrite the model's request options. + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify(body) ); diff --git a/packages/ai/src/models/template-generative-model.test.ts b/packages/ai/src/models/template-generative-model.test.ts index c3eb43af491..d3f7ec28ffa 100644 --- a/packages/ai/src/models/template-generative-model.test.ts +++ b/packages/ai/src/models/template-generative-model.test.ts @@ -73,6 +73,51 @@ describe('TemplateGenerativeModel', () => { { timeout: 5000 } ); }); + + it('singleRequestOptions overrides requestOptions', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateContent( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 2000 } + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const templateGenerateContentStub = stub( + generateContentMethods, + 'templateGenerateContent' + ).resolves({} as any); + const abortController = new AbortController(); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateContent( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 1000, signal: abortController.signal } + ); + }); }); describe('generateContentStream', () => { @@ -92,5 +137,50 @@ describe('TemplateGenerativeModel', () => { { timeout: 5000 } ); }); + + it('singleRequestOptions overrides requestOptions', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateContentStream( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 2000 } + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const templateGenerateContentStreamStub = stub( + generateContentMethods, + 'templateGenerateContentStream' + ).resolves({} as any); + const abortController = new AbortController(); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateContentStream( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(templateGenerateContentStreamStub).to.have.been.calledOnceWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 1000, signal: abortController.signal } + ); + }); }); }); diff --git a/packages/ai/src/models/template-generative-model.ts b/packages/ai/src/models/template-generative-model.ts index ec9e653618d..ccc61253ed9 100644 --- a/packages/ai/src/models/template-generative-model.ts +++ b/packages/ai/src/models/template-generative-model.ts @@ -20,7 +20,11 @@ import { templateGenerateContentStream } from '../methods/generate-content'; import { GenerateContentResult, RequestOptions } from '../types'; -import { AI, GenerateContentStreamResult } from '../public-types'; +import { + AI, + GenerateContentStreamResult, + SingleRequestOptions +} from '../public-types'; import { ApiSettings } from '../types/internal'; import { initApiSettings } from './utils'; @@ -62,13 +66,17 @@ export class TemplateGenerativeModel { */ async generateContent( templateId: string, - templateVariables: object // anything! + templateVariables: object, // anything! + singleRequestOptions?: SingleRequestOptions ): Promise { return templateGenerateContent( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); } @@ -86,13 +94,17 @@ export class TemplateGenerativeModel { */ async generateContentStream( templateId: string, - templateVariables: object + templateVariables: object, + singleRequestOptions?: SingleRequestOptions ): Promise { return templateGenerateContentStream( this._apiSettings, templateId, { inputs: templateVariables }, - this.requestOptions + { + ...this.requestOptions, + ...singleRequestOptions + } ); } } diff --git a/packages/ai/src/models/template-imagen-model.test.ts b/packages/ai/src/models/template-imagen-model.test.ts index c053753ea0f..9451981f83d 100644 --- a/packages/ai/src/models/template-imagen-model.test.ts +++ b/packages/ai/src/models/template-imagen-model.test.ts @@ -18,7 +18,7 @@ import { use, expect } from 'chai'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; -import { restore, stub } from 'sinon'; +import { restore, stub, match } from 'sinon'; import { AI } from '../public-types'; import { VertexAIBackend } from '../backend'; import { TemplateImagenModel } from './template-imagen-model'; @@ -83,12 +83,68 @@ describe('TemplateImagenModel', () => { templateId: TEMPLATE_ID, apiSettings: model._apiSettings, stream: false, - requestOptions: { timeout: 5000 } + singleRequestOptions: { timeout: 5000 } }, JSON.stringify({ inputs: TEMPLATE_VARS }) ); }); + it('singleRequestOptions overrides requestOptions', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + const model = new TemplateImagenModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { timeout: 2000 }; + + await model.generateImages( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + match({ + singleRequestOptions: { timeout: 2000 } + }), + match.any + ); + }); + + it('singleRequestOptions is merged with requestOptions', async () => { + const mockPrediction = { + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + 'mimeType': 'image/png' + }; + const makeRequestStub = stub(request, 'makeRequest').resolves({ + json: () => Promise.resolve({ predictions: [mockPrediction] }) + } as Response); + const abortController = new AbortController(); + const model = new TemplateImagenModel(fakeAI, { timeout: 1000 }); + const singleRequestOptions = { signal: abortController.signal }; + + await model.generateImages( + TEMPLATE_ID, + TEMPLATE_VARS, + singleRequestOptions + ); + + expect(makeRequestStub).to.have.been.calledOnceWith( + match({ + singleRequestOptions: { + timeout: 1000, + signal: abortController.signal + } + }), + match.any + ); + }); + it('should return the result of handlePredictResponse', async () => { const mockPrediction = { 'bytesBase64Encoded': diff --git a/packages/ai/src/models/template-imagen-model.ts b/packages/ai/src/models/template-imagen-model.ts index 34325c711b3..be4d10f72d0 100644 --- a/packages/ai/src/models/template-imagen-model.ts +++ b/packages/ai/src/models/template-imagen-model.ts @@ -19,7 +19,8 @@ import { RequestOptions } from '../types'; import { AI, ImagenGenerationResponse, - ImagenInlineImage + ImagenInlineImage, + SingleRequestOptions } from '../public-types'; import { ApiSettings } from '../types/internal'; import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; @@ -64,7 +65,8 @@ export class TemplateImagenModel { */ async generateImages( templateId: string, - templateVariables: object + templateVariables: object, + singleRequestOptions?: SingleRequestOptions ): Promise> { const response = await makeRequest( { @@ -72,7 +74,10 @@ export class TemplateImagenModel { templateId, apiSettings: this._apiSettings, stream: false, - requestOptions: this.requestOptions + singleRequestOptions: { + ...this.requestOptions, + ...singleRequestOptions + } }, JSON.stringify({ inputs: templateVariables }) ); diff --git a/packages/ai/src/requests/request.test.ts b/packages/ai/src/requests/request.test.ts index a54ff521bea..29b5118390f 100644 --- a/packages/ai/src/requests/request.test.ts +++ b/packages/ai/src/requests/request.test.ts @@ -16,12 +16,14 @@ */ import { expect, use } from 'chai'; -import { match, restore, stub } from 'sinon'; +import Sinon, { match, restore, stub, useFakeTimers } from 'sinon'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; import { + ABORT_ERROR_NAME, RequestURL, ServerPromptTemplateTask, + TIMEOUT_EXPIRED_MESSAGE, Task, getHeaders, makeRequest @@ -55,7 +57,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.include('alt=sse'); @@ -66,7 +68,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include('models/model-name:generateContent'); expect(url.toString()).to.not.include(fakeApiSettings); @@ -78,7 +80,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include(DEFAULT_API_VERSION); }); @@ -88,7 +90,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: { baseUrl: 'https://my.special.endpoint' } + singleRequestOptions: { baseUrl: 'https://my.special.endpoint' } }); expect(url.toString()).to.include('https://my.special.endpoint'); }); @@ -98,7 +100,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include( 'tunedModels/model-name:generateContent' @@ -112,7 +114,7 @@ describe('request methods', () => { task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: {} + singleRequestOptions: undefined }); expect(url.toString()).to.include( 'templates/my-template:templateGenerateContent' @@ -135,7 +137,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); it('adds client headers', async () => { const headers = await getHeaders(fakeUrl); @@ -163,7 +165,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.equal('my-appid'); @@ -188,7 +190,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.get('X-Firebase-Appid')).to.be.null; @@ -209,7 +211,7 @@ describe('request methods', () => { backend: new VertexAIBackend() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; @@ -226,7 +228,7 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('X-Firebase-AppCheck')).to.be.false; @@ -245,7 +247,7 @@ describe('request methods', () => { Promise.resolve({ token: 'dummytoken', error: Error('oops') }) }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const warnStub = stub(console, 'warn'); const headers = await getHeaders(fakeUrl); @@ -271,7 +273,7 @@ describe('request methods', () => { backend: new VertexAIBackend() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; @@ -288,15 +290,45 @@ describe('request methods', () => { getAppCheckToken: () => Promise.resolve() }, stream: true, - requestOptions: {} + singleRequestOptions: undefined }); const headers = await getHeaders(fakeUrl); expect(headers.has('Authorization')).to.be.false; }); }); describe('makeRequest', () => { + let fetchStub: Sinon.SinonStub; + let clock: Sinon.SinonFakeTimers; + const fetchAborter = ( + _url: string, + options?: RequestInit + ): Promise => { + expect(options).to.not.be.undefined; + expect(options!.signal).to.not.be.undefined; + const signal = options!.signal; + return new Promise((_resolve, reject): void => { + const abortListener = (): void => { + reject( + new DOMException(signal?.reason || 'Aborted', ABORT_ERROR_NAME) + ); + }; + + signal?.addEventListener('abort', abortListener, { once: true }); + }); + }; + + beforeEach(() => { + fetchStub = stub(globalThis, 'fetch'); + clock = useFakeTimers(); + }); + + afterEach(() => { + restore(); + clock.restore(); + }); + it('no error', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: true } as Response); const response = await makeRequest( @@ -312,10 +344,10 @@ describe('request methods', () => { expect(response.ok).to.be.true; }); it('error with timeout', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, - statusText: 'AbortError' + statusText: ABORT_ERROR_NAME } as Response); try { @@ -325,7 +357,7 @@ describe('request methods', () => { task: Task.GENERATE_CONTENT, apiSettings: fakeApiSettings, stream: false, - requestOptions: { + singleRequestOptions: { timeout: 180000 } }, @@ -335,7 +367,7 @@ describe('request methods', () => { expect((e as AIError).code).to.equal(AIErrorCode.FETCH_ERROR); expect((e as AIError).customErrorData?.status).to.equal(500); expect((e as AIError).customErrorData?.statusText).to.equal( - 'AbortError' + ABORT_ERROR_NAME ); expect((e as AIError).message).to.include('500 AbortError'); } @@ -343,7 +375,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, no response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error' @@ -369,7 +401,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json()', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -397,7 +429,7 @@ describe('request methods', () => { expect(fetchStub).to.be.calledOnce; }); it('Network error, includes response.json() and details', async () => { - const fetchStub = stub(globalThis, 'fetch').resolves({ + fetchStub.resolves({ ok: false, status: 500, statusText: 'Server Error', @@ -437,16 +469,209 @@ describe('request methods', () => { } expect(fetchStub).to.be.calledOnce; }); - }); - it('Network error, API not enabled', async () => { - const mockResponse = getMockResponse( - 'vertexAI', - 'unary-failure-firebasevertexai-api-not-enabled.json' - ); - const fetchStub = stub(globalThis, 'fetch').resolves( - mockResponse as Response - ); - try { + it('Network error, API not enabled', async () => { + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-failure-firebasevertexai-api-not-enabled.json' + ); + fetchStub.resolves(mockResponse as Response); + try { + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false + }, + '' + ); + } catch (e) { + expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); + expect((e as AIError).message).to.include('my-project'); + expect((e as AIError).message).to.include('googleapis.com'); + } + expect(fetchStub).to.be.calledOnce; + }); + + it('should throw DOMException if external signal is already aborted', async () => { + const controller = new AbortController(); + const abortReason = 'Aborted before request'; + controller.abort(abortReason); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + abortReason + ); + + expect(fetchStub).not.to.have.been.called; + }); + it('should throw DOMException if external signal aborts during request', async () => { + fetchStub.callsFake(fetchAborter); + const controller = new AbortController(); + const abortReason = 'Aborted during request'; + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + await clock.tickAsync(0); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + abortReason + ); + }); + + it('should abort fetch if timeout expires during request', async () => { + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: timeoutDuration } + }, + '{}' + ); + + await clock.tickAsync(timeoutDuration + 100); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + TIMEOUT_EXPIRED_MESSAGE + ); + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + const internalSignal = fetchOptions.signal; + + expect(internalSignal?.aborted).to.be.true; + expect((internalSignal?.reason as Error).name).to.equal(ABORT_ERROR_NAME); + expect((internalSignal?.reason as Error).message).to.equal( + 'Timeout has expired.' + ); + }); + + it('should succeed and clear timeout if fetch completes before timeout', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + const fetchPromise = Promise.resolve(mockResponse); + fetchStub.resolves(fetchPromise); + const clearTimeoutStub = stub(globalThis, 'clearTimeout'); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: 5000 } // Generous timeout + }, + '{}' + ); + + // Advance time slightly, well within timeout + await clock.tickAsync(10); + + const response = await requestPromise; + expect(response.ok).to.be.true; + expect(clearTimeoutStub).to.have.been.calledOnce; + expect(fetchStub).to.have.been.calledOnce; + }); + + it('should use external signal abort reason if it occurs before timeout', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Wins'; + const timeoutDuration = 500; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { + signal: controller.signal, + timeout: timeoutDuration + } + }, + '{}' + ); + + // Advance time, but less than the timeout + await clock.tickAsync(timeoutDuration / 2); + controller.abort(abortReason); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + abortReason + ); + }); + + it('should use timeout reason if it occurs before external signal abort', async () => { + const controller = new AbortController(); + const abortReason = 'External Abort Loses'; + const timeoutDuration = 100; + fetchStub.callsFake(fetchAborter); + + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { + signal: controller.signal, + timeout: timeoutDuration + } + }, + '{}' + ); + + // Schedule external abort after timeout + setTimeout(() => controller.abort(abortReason), timeoutDuration * 2); + + // Advance time past the timeout + await clock.tickAsync(timeoutDuration + 1); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + TIMEOUT_EXPIRED_MESSAGE + ); + }); + + it('should pass internal signal to fetch options', async () => { + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + await makeRequest( { model: 'models/model-name', @@ -456,11 +681,59 @@ describe('request methods', () => { }, '' ); - } catch (e) { - expect((e as AIError).code).to.equal(AIErrorCode.API_NOT_ENABLED); - expect((e as AIError).message).to.include('my-project'); - expect((e as AIError).message).to.include('googleapis.com'); - } - expect(fetchStub).to.be.calledOnce; + + expect(fetchStub).to.have.been.calledOnce; + const fetchOptions = fetchStub.firstCall.args[1] as RequestInit; + expect(fetchOptions.signal).to.exist; + expect(fetchOptions.signal).to.be.instanceOf(AbortSignal); + expect(fetchOptions.signal?.aborted).to.be.false; + }); + + it('should abort immediately if timeout is 0', async () => { + fetchStub.callsFake(fetchAborter); + const requestPromise = makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { timeout: 0 } + }, + '{}' + ); + + // Tick the clock just enough to trigger a timeout(0) + await clock.tickAsync(1); + + await expect(requestPromise).to.be.rejectedWith( + DOMException, + TIMEOUT_EXPIRED_MESSAGE + ); + }); + + it('should not error if signal is aborted after completion', async () => { + const controller = new AbortController(); + const mockResponse = new Response('{}', { + status: 200, + statusText: 'OK' + }); + fetchStub.resolves(mockResponse); + + const response = await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + singleRequestOptions: { signal: controller.signal } + }, + '{}' + ); + + // Listener should be removed, so this abort should do nothing. + controller.abort('Too late'); + + expect(response.ok).to.be.true; + }); }); }); diff --git a/packages/ai/src/requests/request.ts b/packages/ai/src/requests/request.ts index 7664765ab03..06e1620000c 100644 --- a/packages/ai/src/requests/request.ts +++ b/packages/ai/src/requests/request.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { ErrorDetails, RequestOptions, AIErrorCode } from '../types'; +import { SingleRequestOptions, AIErrorCode, ErrorDetails } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; import { @@ -27,6 +27,9 @@ import { import { logger } from '../logger'; import { BackendType } from '../public-types'; +export const TIMEOUT_EXPIRED_MESSAGE = 'Timeout has expired.'; +export const ABORT_ERROR_NAME = 'AbortError'; + export const enum Task { GENERATE_CONTENT = 'generateContent', STREAM_GENERATE_CONTENT = 'streamGenerateContent', @@ -43,7 +46,7 @@ export const enum ServerPromptTemplateTask { interface BaseRequestURLParams { apiSettings: ApiSettings; stream: boolean; - requestOptions?: RequestOptions; + singleRequestOptions?: SingleRequestOptions; } /** @@ -94,7 +97,9 @@ export class RequestURL { } private get baseUrl(): string { - return this.params.requestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}`; + return ( + this.params.singleRequestOptions?.baseUrl ?? `https://${DEFAULT_DOMAIN}` + ); } private get queryParams(): URLSearchParams { @@ -175,24 +180,48 @@ export async function makeRequest( ): Promise { const url = new RequestURL(requestUrlParams); let response; - let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; + + const externalSignal = requestUrlParams.singleRequestOptions?.signal; + const timeoutMillis = + requestUrlParams.singleRequestOptions?.timeout != null && + requestUrlParams.singleRequestOptions.timeout >= 0 + ? requestUrlParams.singleRequestOptions.timeout + : DEFAULT_FETCH_TIMEOUT_MS; + + const internalAbortController = new AbortController(); + const fetchTimeoutId = setTimeout(() => { + internalAbortController.abort( + new DOMException(TIMEOUT_EXPIRED_MESSAGE, ABORT_ERROR_NAME) + ); + logger.debug( + `Aborting request to ${url} due to timeout (${timeoutMillis}ms)` + ); + }, timeoutMillis); + + // Used to abort the fetch if either the user-defined `externalSignal` is aborted, or if the + // internal signal (triggered by timeouts) is aborted. + const combinedSignal = AbortSignal.any( + externalSignal + ? [externalSignal, internalAbortController.signal] + : [internalAbortController.signal] + ); + + if (externalSignal && externalSignal.aborted) { + clearTimeout(fetchTimeoutId); + throw new DOMException( + externalSignal.reason ?? 'Aborted externally before fetch', + ABORT_ERROR_NAME + ); + } + try { const fetchOptions: RequestInit = { method: 'POST', headers: await getHeaders(url), + signal: combinedSignal, body }; - // Timeout is 180s by default. - const timeoutMillis = - requestUrlParams.requestOptions?.timeout != null && - requestUrlParams.requestOptions.timeout >= 0 - ? requestUrlParams.requestOptions.timeout - : DEFAULT_FETCH_TIMEOUT_MS; - const abortController = new AbortController(); - fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); - fetchOptions.signal = abortController.signal; - response = await fetch(url.toString(), fetchOptions); if (!response.ok) { let message = ''; @@ -252,7 +281,8 @@ export async function makeRequest( if ( (e as AIError).code !== AIErrorCode.FETCH_ERROR && (e as AIError).code !== AIErrorCode.API_NOT_ENABLED && - e instanceof Error + e instanceof Error && + (e as DOMException).name !== ABORT_ERROR_NAME ) { err = new AIError( AIErrorCode.ERROR, @@ -263,9 +293,10 @@ export async function makeRequest( throw err; } finally { - if (fetchTimeoutId) { - clearTimeout(fetchTimeoutId); - } + // When doing streaming requests, this will clear the timeout once the stream begins. + // If a timeout it 3000ms, and the stream starts after 300ms and ends after 5000ms, the + // timeout will be cleared after 300ms, so it won't abort the request. + clearTimeout(fetchTimeoutId); } return response; } diff --git a/packages/ai/src/requests/stream-reader.ts b/packages/ai/src/requests/stream-reader.ts index b4968969be7..7af4d857bca 100644 --- a/packages/ai/src/requests/stream-reader.ts +++ b/packages/ai/src/requests/stream-reader.ts @@ -52,9 +52,14 @@ export function processStream( const inputStream = response.body!.pipeThrough( new TextDecoderStream('utf8', { fatal: true }) ); + const responseStream = getResponseStream(inputStream); + + // We split the stream so the user can iterate over partial results (stream1) + // while we aggregate the full result for history/final response (stream2). const [stream1, stream2] = responseStream.tee(); + return { stream: generateResponseSequence(stream1, apiSettings, inferenceSource), response: getResponsePromise(stream2, apiSettings, inferenceSource) @@ -82,7 +87,6 @@ async function getResponsePromise( inferenceSource ); } - allResponses.push(value); } } @@ -112,7 +116,6 @@ async function* generateResponseSequence( } const firstCandidate = enhancedResponse.candidates?.[0]; - // Don't yield a response with no useful data for the developer. if ( !firstCandidate?.content?.parts && !firstCandidate?.finishReason && @@ -127,9 +130,7 @@ async function* generateResponseSequence( } /** - * Reads a raw stream from the fetch response and join incomplete - * chunks, returning a new stream that provides a single complete - * GenerateContentResponse in each iteration. + * Reads a raw string stream, buffers incomplete chunks, and yields parsed JSON objects. */ export function getResponseStream( inputStream: ReadableStream @@ -153,6 +154,8 @@ export function getResponseStream( } currentText += value; + // SSE events may span chunk boundaries, so we buffer until we match + // the full "data: {json}\n\n" pattern. let match = currentText.match(responseLineRE); let parsedResponse: T; while (match) { @@ -193,8 +196,7 @@ export function aggregateResponses( for (const response of responses) { if (response.candidates) { for (const candidate of response.candidates) { - // Index will be undefined if it's the first index (0), so we should use 0 if it's undefined. - // See: https://github.com/firebase/firebase-js-sdk/issues/8566 + // Use 0 if index is undefined (protobuf default value omission). const i = candidate.index || 0; if (!aggregatedResponse.candidates) { aggregatedResponse.candidates = []; @@ -204,7 +206,8 @@ export function aggregateResponses( index: candidate.index } as GenerateContentCandidate; } - // Keep overwriting, the last one will be final + + // Overwrite with the latest metadata aggregatedResponse.candidates[i].citationMetadata = candidate.citationMetadata; aggregatedResponse.candidates[i].finishReason = candidate.finishReason; @@ -229,12 +232,7 @@ export function aggregateResponses( urlContextMetadata as URLContextMetadata; } - /** - * Candidates should always have content and parts, but this handles - * possible malformed responses. - */ if (candidate.content) { - // Skip a candidate without parts. if (!candidate.content.parts) { continue; } diff --git a/packages/ai/src/types/requests.ts b/packages/ai/src/types/requests.ts index 6e5d2147686..991453c53f3 100644 --- a/packages/ai/src/types/requests.ts +++ b/packages/ai/src/types/requests.ts @@ -253,6 +253,47 @@ export interface RequestOptions { baseUrl?: string; } +/** + * Options that can be provided per-request. + * Extends the base {@link RequestOptions} (like `timeout` and `baseUrl`) + * with request-specific controls like cancellation via `AbortSignal`. + * + * Options specified here will override any default {@link RequestOptions} + * configured on a model (for example, {@link GenerativeModel}). + * + * @public + */ +export interface SingleRequestOptions extends RequestOptions { + /** + * An `AbortSignal` instance that allows cancelling ongoing requests (like `generateContent` or + * `generateImages`). + * + * If provided, calling `abort()` on the corresponding `AbortController` + * will attempt to cancel the underlying HTTP request. An `AbortError` will be thrown + * if cancellation is successful. + * + * Note that this will not cancel the request in the backend, so any applicable billing charges + * will still be applied despite cancellation. + * + * @example + * ```javascript + * const controller = new AbortController(); + * const model = getGenerativeModel({ + * // ... + * }); + * model.generateContent( + * "Write a story about a magic backpack.", + * { signal: controller.signal } + * ); + * + * // To cancel request: + * controller.abort(); + * ``` + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal + */ + signal?: AbortSignal; +} + /** * Defines a tool that model can call to access external knowledge. * @public