diff --git a/apps/web/src/app/api/openrouter/[...path]/route.test.ts b/apps/web/src/app/api/openrouter/[...path]/route.test.ts new file mode 100644 index 0000000000..054e9c46d4 --- /dev/null +++ b/apps/web/src/app/api/openrouter/[...path]/route.test.ts @@ -0,0 +1,175 @@ +import { getUserFromAuth } from '@/lib/user/server'; +import { getBalanceAndOrgSettings } from '@/lib/organizations/organization-usage'; +import { getProvider } from '@/lib/ai-gateway/providers/get-provider'; +import { upstreamRequest } from '@/lib/ai-gateway/providers/upstream-request'; +import { isPublicIdExperimented } from '@/lib/ai-gateway/experiments/membership'; +import { checkFreeModelRateLimit, logFreeModelRequest } from '@/lib/free-model-rate-limiter'; +import { classifyAbuse } from '@/lib/ai-gateway/abuse-service'; +import { emitApiMetricsForResponse } from '@/lib/ai-gateway/o11y/api-metrics.server'; +import { accountForMicrodollarUsage } from '@/lib/ai-gateway/llm-proxy-helpers'; +import { handleRequestLogging } from '@/lib/ai-gateway/handleRequestLogging'; +import type { User } from '@kilocode/db/schema'; + +jest.mock('@/lib/user/server'); +jest.mock('@/lib/organizations/organization-usage'); +jest.mock('@/lib/ai-gateway/providers/get-provider'); +jest.mock('@/lib/ai-gateway/providers/upstream-request'); +jest.mock('@/lib/ai-gateway/experiments/membership'); +jest.mock('@/lib/free-model-rate-limiter', () => ({ + checkFreeModelRateLimit: jest.fn(), + checkFreeModelRateLimitByUser: jest.fn(), + checkPromotionLimit: jest.fn(), + logFreeModelRequest: jest.fn(), +})); +jest.mock('@/lib/ai-gateway/abuse-service'); +jest.mock('@/lib/ai-gateway/o11y/api-metrics.server', () => ({ + emitApiMetricsForResponse: jest.fn(), + getToolsAvailable: jest.fn(() => []), + getToolsUsed: jest.fn(() => []), +})); +jest.mock('@/lib/ai-gateway/llm-proxy-helpers', () => { + const actual = jest.requireActual('@/lib/ai-gateway/llm-proxy-helpers'); + return { + ...actual, + accountForMicrodollarUsage: jest.fn(), + }; +}); +jest.mock('@/lib/ai-gateway/handleRequestLogging'); +jest.mock('@/lib/debugUtils', () => ({ + debugSaveProxyRequest: jest.fn(), + debugSaveProxyResponseStream: jest.fn(), +})); + +const publicModelId = 'kilo/preview-model'; +const upstreamInternalId = 'partner/secret-checkpoint-rc1'; + +const mockedGetUserFromAuth = jest.mocked(getUserFromAuth); +const mockedGetBalanceAndOrgSettings = jest.mocked(getBalanceAndOrgSettings); +const mockedGetProvider = jest.mocked(getProvider); +const mockedUpstreamRequest = jest.mocked(upstreamRequest); +const mockedIsPublicIdExperimented = jest.mocked(isPublicIdExperimented); +const mockedCheckFreeModelRateLimit = jest.mocked(checkFreeModelRateLimit); +const mockedLogFreeModelRequest = jest.mocked(logFreeModelRequest); +const mockedClassifyAbuse = jest.mocked(classifyAbuse); +const mockedEmitApiMetricsForResponse = jest.mocked(emitApiMetricsForResponse); +const mockedAccountForMicrodollarUsage = jest.mocked(accountForMicrodollarUsage); +const mockedHandleRequestLogging = jest.mocked(handleRequestLogging); + +function makeRequest(body: unknown) { + return new Request('http://localhost:3000/api/gateway/v1/chat/completions', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'x-forwarded-for': '127.0.0.1', + 'x-kilocode-machineid': 'machine-123', + }, + body: JSON.stringify(body), + }); +} + +function setExperimentRouting(upstreamResponse: Response) { + mockedGetUserFromAuth.mockResolvedValue({ + user: { + id: 'user-123', + google_user_email: 'test@example.com', + microdollars_used: 0, + } as User, + authFailedResponse: null, + organizationId: undefined, + }); + mockedGetBalanceAndOrgSettings.mockResolvedValue({ + balance: 1000, + settings: undefined, + plan: undefined, + }); + mockedIsPublicIdExperimented.mockResolvedValue(true); + mockedCheckFreeModelRateLimit.mockResolvedValue({ allowed: true, requestCount: 1 }); + mockedLogFreeModelRequest.mockResolvedValue(undefined); + mockedClassifyAbuse.mockResolvedValue(null); + mockedEmitApiMetricsForResponse.mockReturnValue(undefined); + mockedAccountForMicrodollarUsage.mockReturnValue(undefined); + mockedHandleRequestLogging.mockResolvedValue(undefined); + mockedUpstreamRequest.mockResolvedValue(upstreamResponse); + mockedGetProvider.mockResolvedValue({ + kind: 'provider', + provider: { + id: 'custom', + apiUrl: 'https://partner.example.test/v1', + apiKey: 'test-key-not-real', + supportedChatApis: ['chat_completions'], + transformRequest(context) { + context.request.body.model = upstreamInternalId; + }, + }, + userByok: null, + bypassAccessCheck: true, + skipProviderPin: true, + skipKiloExclusiveModelSettings: true, + experiment: { + experimentId: 'experiment-123', + variantId: 'variant-123', + variantVersionId: 'variant-version-123', + allocationSubject: 'machine', + }, + }); +} + +describe('POST /api/gateway/v1/chat/completions experiment response blinding', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('returns the requested public model id for experimented JSON responses', async () => { + setExperimentRouting( + new Response( + JSON.stringify({ + id: 'chatcmpl-test', + object: 'chat.completion', + model: upstreamInternalId, + choices: [], + }), + { headers: { 'content-type': 'application/json' } } + ) + ); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ model: publicModelId, messages: [{ role: 'user', content: 'hi' }] }) as never + ); + + const text = await response.text(); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(JSON.parse(text).model).toBe(publicModelId); + expect(mockedUpstreamRequest.mock.calls[0]?.[0].body.model).toBe(upstreamInternalId); + }); + + it('returns the requested public model id for experimented SSE responses', async () => { + setExperimentRouting( + new Response( + `data: ${JSON.stringify({ + id: 'chatcmpl-test', + object: 'chat.completion.chunk', + model: upstreamInternalId, + choices: [{ index: 0, delta: { content: 'hi' } }], + })}\n\ndata: [DONE]\n\n`, + { headers: { 'content-type': 'text/event-stream' } } + ) + ); + + const { POST } = await import('./route'); + const response = await POST( + makeRequest({ + model: publicModelId, + stream: true, + messages: [{ role: 'user', content: 'hi' }], + }) as never + ); + + const text = await response.text(); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(text).toContain('data: [DONE]'); + expect(mockedUpstreamRequest.mock.calls[0]?.[0].body.model).toBe(upstreamInternalId); + }); +}); diff --git a/apps/web/src/app/api/openrouter/[...path]/route.ts b/apps/web/src/app/api/openrouter/[...path]/route.ts index ee49a96162..6fcf063c05 100644 --- a/apps/web/src/app/api/openrouter/[...path]/route.ts +++ b/apps/web/src/app/api/openrouter/[...path]/route.ts @@ -718,7 +718,7 @@ export async function POST(request: NextRequest): Promise { + it('rewrites chat completion JSON model ids to the requested public id', async () => { + const upstreamResponse = new Response( + JSON.stringify({ + id: 'chatcmpl-test', + object: 'chat.completion', + model: upstreamInternalId, + choices: [], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + }), + { headers: { 'content-type': 'application/json' } } + ); + + const rewritten = await rewriteFreeModelResponse_ChatCompletions( + upstreamResponse, + publicModelId + ); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(JSON.parse(text).model).toBe(publicModelId); + }); + + it('rewrites chat completion SSE model ids to the requested public id', async () => { + const upstreamResponse = sseResponse([ + `data: ${JSON.stringify({ + id: 'chatcmpl-test', + object: 'chat.completion.chunk', + model: upstreamInternalId, + choices: [{ index: 0, delta: { content: 'hello' } }], + })}\n\n`, + 'data: [DONE]\n\n', + ]); + + const rewritten = await rewriteFreeModelResponse_ChatCompletions( + upstreamResponse, + publicModelId + ); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(text).toContain('data: [DONE]'); + }); + + it('rewrites messages JSON model ids to the requested public id', async () => { + const upstreamResponse = new Response( + JSON.stringify({ + id: 'msg-test', + type: 'message', + role: 'assistant', + model: upstreamInternalId, + content: [], + stop_reason: 'end_turn', + stop_sequence: null, + usage: { input_tokens: 1, output_tokens: 1 }, + }), + { headers: { 'content-type': 'application/json' } } + ); + + const rewritten = await rewriteFreeModelResponse_Messages(upstreamResponse, publicModelId); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(JSON.parse(text).model).toBe(publicModelId); + }); + + it('rewrites messages SSE model ids to the requested public id', async () => { + const upstreamResponse = sseResponse([ + `event: message_start\ndata: ${JSON.stringify({ + type: 'message_start', + message: { + id: 'msg-test', + type: 'message', + role: 'assistant', + model: upstreamInternalId, + content: [], + stop_reason: null, + stop_sequence: null, + usage: { input_tokens: 1, output_tokens: 0 }, + }, + })}\n\n`, + `event: content_block_delta\ndata: ${JSON.stringify({ + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'hello' }, + })}\n\n`, + ]); + + const rewritten = await rewriteFreeModelResponse_Messages(upstreamResponse, publicModelId); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(text).toContain('event: message_start'); + }); + + it('rewrites responses JSON model ids to the requested public id', async () => { + const upstreamResponse = new Response( + JSON.stringify({ + id: 'resp-test', + object: 'response', + model: upstreamInternalId, + status: 'completed', + }), + { headers: { 'content-type': 'application/json' } } + ); + + const rewritten = await rewriteFreeModelResponse_Responses(upstreamResponse, publicModelId); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(JSON.parse(text).model).toBe(publicModelId); + }); + + it('rewrites responses SSE model ids to the requested public id', async () => { + const upstreamResponse = sseResponse([ + `event: response.created\ndata: ${JSON.stringify({ + type: 'response.created', + response: { + id: 'resp-test', + object: 'response', + model: upstreamInternalId, + status: 'in_progress', + }, + })}\n\n`, + 'data: [DONE]\n\n', + ]); + + const rewritten = await rewriteFreeModelResponse_Responses(upstreamResponse, publicModelId); + + const text = await responseText(rewritten); + expect(text).toContain(publicModelId); + expect(text).not.toContain(upstreamInternalId); + expect(text).toContain('data: [DONE]'); + }); +}); diff --git a/apps/web/src/lib/rewriteModelResponse.ts b/apps/web/src/lib/rewriteModelResponse.ts index 1c15045bd0..0a723485df 100644 --- a/apps/web/src/lib/rewriteModelResponse.ts +++ b/apps/web/src/lib/rewriteModelResponse.ts @@ -59,6 +59,7 @@ export async function rewriteFreeModelResponse_ChatCompletions(response: Respons controller.close(); return; } + const encoder = new TextEncoder(); let doneReceived = false; const parser = createParser({ @@ -89,10 +90,10 @@ export async function rewriteFreeModelResponse_ChatCompletions(response: Respons rewriteUsage(json.usage); } - controller.enqueue('data: ' + JSON.stringify(json) + '\n\n'); + controller.enqueue(encoder.encode('data: ' + JSON.stringify(json) + '\n\n')); }, onComment() { - controller.enqueue(': KILO PROCESSING\n\n'); + controller.enqueue(encoder.encode(': KILO PROCESSING\n\n')); }, }); @@ -101,7 +102,7 @@ export async function rewriteFreeModelResponse_ChatCompletions(response: Respons const { done, value } = await reader.read(); if (done) { if (doneReceived) { - controller.enqueue('data: [DONE]\n\n'); + controller.enqueue(encoder.encode('data: [DONE]\n\n')); } controller.close(); break; @@ -179,6 +180,7 @@ export async function rewriteFreeModelResponse_Messages(response: Response, mode controller.close(); return; } + const encoder = new TextEncoder(); const parser = createParser({ onEvent(event: EventSourceMessage) { @@ -209,10 +211,10 @@ export async function rewriteFreeModelResponse_Messages(response: Response, mode } const eventLine = event.event ? 'event: ' + event.event + '\n' : ''; - controller.enqueue(eventLine + 'data: ' + JSON.stringify(json) + '\n\n'); + controller.enqueue(encoder.encode(eventLine + 'data: ' + JSON.stringify(json) + '\n\n')); }, onComment() { - controller.enqueue(': KILO PROCESSING\n\n'); + controller.enqueue(encoder.encode(': KILO PROCESSING\n\n')); }, }); @@ -278,6 +280,7 @@ export async function rewriteFreeModelResponse_Responses(response: Response, mod controller.close(); return; } + const encoder = new TextEncoder(); let doneReceived = false; const parser = createParser({ @@ -295,10 +298,10 @@ export async function rewriteFreeModelResponse_Responses(response: Response, mod rewriteUsage(json.response.usage); } } - controller.enqueue('data: ' + JSON.stringify(json) + '\n\n'); + controller.enqueue(encoder.encode('data: ' + JSON.stringify(json) + '\n\n')); }, onComment() { - controller.enqueue(': KILO PROCESSING\n\n'); + controller.enqueue(encoder.encode(': KILO PROCESSING\n\n')); }, }); @@ -307,7 +310,7 @@ export async function rewriteFreeModelResponse_Responses(response: Response, mod const { done, value } = await reader.read(); if (done) { if (doneReceived) { - controller.enqueue('data: [DONE]\n\n'); + controller.enqueue(encoder.encode('data: [DONE]\n\n')); } controller.close(); break;