diff --git a/packages/openai-adapters/src/apis/Gemini.test.ts b/packages/openai-adapters/src/apis/Gemini.test.ts index 3b47620bc71..e1e81d6a743 100644 --- a/packages/openai-adapters/src/apis/Gemini.test.ts +++ b/packages/openai-adapters/src/apis/Gemini.test.ts @@ -1,7 +1,16 @@ -import { describe, expect, it } from "vitest"; +import { afterEach, describe, expect, it, vi } from "vitest"; import { GeminiApi } from "./Gemini.js"; +// Mock the fetch package so the embeddings test can stub the HTTP response. +vi.mock("@continuedev/fetch", async () => { + const actual = await vi.importActual("@continuedev/fetch"); + return { + ...actual, + fetchwithRequestOptions: vi.fn(), + }; +}); + describe("GeminiApi", () => { const api = new GeminiApi({ provider: "gemini", @@ -133,4 +142,43 @@ describe("GeminiApi", () => { expect(result.contents[2].role).toBe("user"); }); }); + + describe("embed", () => { + afterEach(() => { + vi.clearAllMocks(); + vi.unstubAllGlobals(); + }); + + it("parses the 'embeddings' field from the batchEmbedContents response", async () => { + const mockFetch = vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ + embeddings: [ + { values: [0.1, 0.2, 0.3] }, + { values: [0.4, 0.5, 0.6] }, + ], + }), + { headers: { "Content-Type": "application/json" } }, + ), + ); + vi.stubGlobal("fetch", mockFetch); + const fetchPackage = await import("@continuedev/fetch"); + vi.mocked(fetchPackage.fetchwithRequestOptions).mockImplementation( + mockFetch as any, + ); + + const response = await api.embed({ + model: "gemini-embedding-001", + input: ["Hello", "World"], + }); + + expect(mockFetch).toHaveBeenCalledTimes(1); + const [url] = mockFetch.mock.calls[0]; + expect(url.toString()).toContain(":batchEmbedContents"); + expect(response.data.map((d) => d.embedding)).toEqual([ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ]); + }); + }); }); diff --git a/packages/openai-adapters/src/apis/Gemini.ts b/packages/openai-adapters/src/apis/Gemini.ts index f7298267849..fc573b76598 100644 --- a/packages/openai-adapters/src/apis/Gemini.ts +++ b/packages/openai-adapters/src/apis/Gemini.ts @@ -494,11 +494,7 @@ export class GeminiApi implements BaseLlmApi { const data = (await response.json()) as any; return embedding({ model: body.model, - usage: { - total_tokens: data.total_tokens, - prompt_tokens: data.prompt_tokens, - }, - data: data.batchEmbedContents.map((embedding: any) => embedding.values), + data: data.embeddings.map((embedding: any) => embedding.values), }); }