diff --git a/packages/appkit/src/connectors/genie/tests/client.test.ts b/packages/appkit/src/connectors/genie/tests/client.test.ts new file mode 100644 index 000000000..62fc3578d --- /dev/null +++ b/packages/appkit/src/connectors/genie/tests/client.test.ts @@ -0,0 +1,786 @@ +import type { GenieMessage } from "@databricks/sdk-experimental/dist/apis/dashboards"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { GenieConnector } from "../client"; +import type { GenieStreamEvent } from "../types"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +async function collect( + gen: AsyncGenerator, +): Promise { + const events: GenieStreamEvent[] = []; + for await (const event of gen) { + events.push(event); + } + return events; +} + +function makeGenieMessage(overrides: Partial = {}): GenieMessage { + return { + message_id: "msg-1", + conversation_id: "conv-1", + space_id: "space-1", + status: "COMPLETED", + content: "Hello from Genie", + attachments: [], + ...overrides, + } as GenieMessage; +} + +function makeGenieMessageWithQuery( + overrides: Partial = {}, +): GenieMessage { + return makeGenieMessage({ + attachments: [ + { + attachment_id: "att-1", + query: { + title: "Sales Query", + description: "Total sales", + query: "SELECT sum(amount) FROM sales", + statement_id: "stmt-1", + }, + }, + ], + ...overrides, + }); +} + +/** Creates a mock WorkspaceClient with genie methods stubbed. */ +function createMockWorkspaceClient() { + return { + genie: { + startConversation: vi.fn(), + createMessage: vi.fn(), + getMessage: vi.fn(), + listConversationMessages: vi.fn(), + getMessageAttachmentQueryResult: vi.fn(), + }, + } as any; +} + +/** + * Builds a mock waiter whose `.wait()` invokes `onProgress` for each + * progress value, then resolves with the final result. + */ +function createMockWaiter(opts: { + progressValues?: Partial[]; + result: GenieMessage; +}) { + return { + wait: vi.fn().mockImplementation(async (options: any = {}) => { + if (opts.progressValues) { + for (const value of opts.progressValues) { + if (options.onProgress) { + await options.onProgress(value); + } + } + } + return opts.result; + }), + message_id: opts.result.message_id, + conversation_id: opts.result.conversation_id, + }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("GenieConnector", () => { + let connector: GenieConnector; + let ws: ReturnType; + + beforeEach(() => { + connector = new GenieConnector({ timeout: 0 }); + ws = createMockWorkspaceClient(); + }); + + // ----------------------------------------------------------------------- + // streamSendMessage + // ----------------------------------------------------------------------- + + describe("streamSendMessage", () => { + test("yields message_start, status updates, then message_result", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ + progressValues: [ + { status: "EXECUTING_QUERY" }, + { status: "COMPLETED" }, + ], + result: completedMsg, + }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage( + ws, + "space-1", + "What are sales?", + undefined, + ), + ); + + expect(events[0]).toEqual({ + type: "message_start", + conversationId: "conv-1", + messageId: "msg-1", + spaceId: "space-1", + }); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "COMPLETED" }, + ]); + + const msgResult = events.find((e) => e.type === "message_result"); + expect(msgResult).toBeDefined(); + expect((msgResult as any).message.messageId).toBe("msg-1"); + }); + + test("new conversation calls startConversation", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(ws.genie.startConversation).toHaveBeenCalledWith({ + space_id: "space-1", + content: "hello", + }); + expect(ws.genie.createMessage).not.toHaveBeenCalled(); + }); + + test("existing conversation calls createMessage", async () => { + const completedMsg = makeGenieMessage(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.createMessage.mockResolvedValue(waiter); + + await collect( + connector.streamSendMessage(ws, "space-1", "hello", "conv-existing"), + ); + + expect(ws.genie.createMessage).toHaveBeenCalledWith({ + space_id: "space-1", + conversation_id: "conv-existing", + content: "hello", + }); + expect(ws.genie.startConversation).not.toHaveBeenCalled(); + }); + + test("emits query_result for attachments with statementIds", async () => { + const completedMsg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const statementResponse = { + manifest: { + schema: { columns: [{ name: "total", type_name: "DOUBLE" }] }, + }, + result: { data_array: [["1234.56"]] }, + }; + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: statementResponse, + }); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "query", undefined), + ); + + const queryResult = events.find((e) => e.type === "query_result"); + expect(queryResult).toEqual({ + type: "query_result", + attachmentId: "att-1", + statementId: "stmt-1", + data: statementResponse, + }); + }); + + test("yields error event on SDK failure", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("Network timeout"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(events).toEqual([{ type: "error", error: "Network timeout" }]); + }); + + test("classifies RESOURCE_DOES_NOT_EXIST as access denied", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: space not found"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hello", undefined), + ); + + expect(events).toEqual([ + { + type: "error", + error: "You don't have access to this Genie Space.", + }, + ]); + }); + + test("emits error event when query result fetch fails", async () => { + const completedMsg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + ws.genie.getMessageAttachmentQueryResult.mockRejectedValue( + new Error("statement expired"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "query", undefined), + ); + + const errorEvent = events.find((e) => e.type === "error"); + expect(errorEvent).toEqual({ + type: "error", + error: "Failed to fetch query result for attachment att-1", + }); + }); + }); + + // ----------------------------------------------------------------------- + // streamConversation + // ----------------------------------------------------------------------- + + describe("streamConversation", () => { + test("yields message_result for each message, then history_info", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessage({ message_id: "m1", content: "first" }), + makeGenieMessage({ message_id: "m2", content: "second" }), + ], + next_page_token: null, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: false, + }), + ); + + const messageResults = events.filter((e) => e.type === "message_result"); + expect(messageResults).toHaveLength(2); + + const historyInfo = events.find((e) => e.type === "history_info"); + expect(historyInfo).toEqual({ + type: "history_info", + conversationId: "conv-1", + spaceId: "space-1", + nextPageToken: null, + loadedCount: 2, + }); + }); + + test("fetches query results in parallel when includeQueryResults=true", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessageWithQuery({ + message_id: "m1", + attachments: [ + { + attachment_id: "att-a", + query: { + title: "Q1", + query: "SELECT 1", + statement_id: "stmt-a", + }, + }, + { + attachment_id: "att-b", + query: { + title: "Q2", + query: "SELECT 2", + statement_id: "stmt-b", + }, + }, + ], + }), + ], + next_page_token: null, + }); + + const stmtResponse = { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }; + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: stmtResponse, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: true, + }), + ); + + const queryResults = events.filter((e) => e.type === "query_result"); + expect(queryResults).toHaveLength(2); + expect(ws.genie.getMessageAttachmentQueryResult).toHaveBeenCalledTimes(2); + }); + + test("skips query results when includeQueryResults=false", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [makeGenieMessageWithQuery()], + next_page_token: null, + }); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: false, + }), + ); + + expect(events.filter((e) => e.type === "query_result")).toHaveLength(0); + expect(ws.genie.getMessageAttachmentQueryResult).not.toHaveBeenCalled(); + }); + + test("handles partial query result failures via Promise.allSettled", async () => { + ws.genie.listConversationMessages.mockResolvedValue({ + messages: [ + makeGenieMessage({ + message_id: "m1", + attachments: [ + { + attachment_id: "att-ok", + query: { + title: "OK", + query: "SELECT 1", + statement_id: "stmt-ok", + }, + }, + { + attachment_id: "att-fail", + query: { + title: "Fail", + query: "SELECT 2", + statement_id: "stmt-fail", + }, + }, + ], + }), + ], + next_page_token: null, + }); + + const stmtResponse = { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }; + + ws.genie.getMessageAttachmentQueryResult + .mockResolvedValueOnce({ statement_response: stmtResponse }) + .mockRejectedValueOnce(new Error("statement expired")); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1", { + includeQueryResults: true, + }), + ); + + const queryResults = events.filter((e) => e.type === "query_result"); + expect(queryResults).toHaveLength(1); + + const errors = events.filter((e) => e.type === "error"); + expect(errors).toHaveLength(1); + expect((errors[0] as any).error).toBe("statement expired"); + }); + + test("yields error when listConversationMessages fails", async () => { + ws.genie.listConversationMessages.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: conv not found"), + ); + + const events = await collect( + connector.streamConversation(ws, "space-1", "conv-1"), + ); + + expect(events).toEqual([ + { + type: "error", + error: "You don't have access to this Genie Space.", + }, + ]); + }); + }); + + // ----------------------------------------------------------------------- + // streamGetMessage + // ----------------------------------------------------------------------- + + describe("streamGetMessage", () => { + test("polls until COMPLETED, yields status + message_result", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "COMPLETED" })); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + expect(events[0]).toEqual({ + type: "status", + status: "EXECUTING_QUERY", + }); + expect(events[1]).toEqual({ type: "status", status: "COMPLETED" }); + expect(events[2]).toMatchObject({ type: "message_result" }); + expect(ws.genie.getMessage).toHaveBeenCalledTimes(2); + }); + + test("polls until FAILED, yields status + message_result", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce( + makeGenieMessage({ + status: "FAILED", + error: { error: "query timed out" }, + }), + ); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "FAILED" }, + ]); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.status).toBe("FAILED"); + expect(msgResult.message.error).toBe("query timed out"); + }); + + test("respects abort signal", async () => { + const controller = new AbortController(); + + ws.genie.getMessage.mockResolvedValue( + makeGenieMessage({ status: "EXECUTING_QUERY" }), + ); + + const gen = connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 50, + signal: controller.signal, + }); + + const events: GenieStreamEvent[] = []; + // Collect the first status event, then abort + for await (const event of gen) { + events.push(event); + if (events.length === 1) { + controller.abort(); + } + } + + // Should have stopped after abort - at most 2 events + // (the status from poll 1, and possibly status from poll 2 that was already in-flight) + expect(events.length).toBeLessThanOrEqual(2); + expect(events[0]).toEqual({ + type: "status", + status: "EXECUTING_QUERY", + }); + }); + + test("yields error when getMessage throws", async () => { + ws.genie.getMessage.mockRejectedValue(new Error("service unavailable")); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + expect(events).toEqual([{ type: "error", error: "service unavailable" }]); + }); + + test("does not duplicate status events for same status", async () => { + ws.genie.getMessage + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "EXECUTING_QUERY" })) + .mockResolvedValueOnce(makeGenieMessage({ status: "COMPLETED" })); + + const events = await collect( + connector.streamGetMessage(ws, "space-1", "conv-1", "msg-1", { + pollInterval: 0, + }), + ); + + const statusEvents = events.filter((e) => e.type === "status"); + expect(statusEvents).toEqual([ + { type: "status", status: "EXECUTING_QUERY" }, + { type: "status", status: "COMPLETED" }, + ]); + }); + }); + + // ----------------------------------------------------------------------- + // sendMessage + // ----------------------------------------------------------------------- + + describe("sendMessage", () => { + test("returns completed message response", async () => { + const completedMsg = makeGenieMessage({ + message_id: "msg-42", + conversation_id: "conv-new", + }); + const waiter = createMockWaiter({ result: completedMsg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const result = await connector.sendMessage( + ws, + "space-1", + "What are sales?", + undefined, + ); + + expect(result.messageId).toBe("msg-42"); + expect(result.conversationId).toBe("conv-new"); + expect(result.status).toBe("COMPLETED"); + }); + }); + + // ----------------------------------------------------------------------- + // getConversation + // ----------------------------------------------------------------------- + + describe("getConversation", () => { + test("paginates through all pages", async () => { + // listConversationMessages reverses the SDK response, so mock data + // is ordered newest-first (as the SDK returns) and results are + // oldest-first after reversal. + ws.genie.listConversationMessages + .mockResolvedValueOnce({ + messages: [ + makeGenieMessage({ message_id: "m2" }), + makeGenieMessage({ message_id: "m1" }), + ], + next_page_token: "page2", + }) + .mockResolvedValueOnce({ + messages: [makeGenieMessage({ message_id: "m3" })], + next_page_token: null, + }); + + const result = await connector.getConversation(ws, "space-1", "conv-1"); + + expect(result.messages).toHaveLength(3); + expect(result.messages.map((m) => m.messageId)).toEqual([ + "m1", + "m2", + "m3", + ]); + expect(ws.genie.listConversationMessages).toHaveBeenCalledTimes(2); + }); + + test("respects maxMessages limit", async () => { + const smallConnector = new GenieConnector({ + timeout: 0, + maxMessages: 2, + }); + + ws.genie.listConversationMessages.mockResolvedValueOnce({ + messages: [ + makeGenieMessage({ message_id: "m1" }), + makeGenieMessage({ message_id: "m2" }), + makeGenieMessage({ message_id: "m3" }), + ], + next_page_token: "page2", + }); + + const result = await smallConnector.getConversation( + ws, + "space-1", + "conv-1", + ); + + // Should be sliced to maxMessages + expect(result.messages).toHaveLength(2); + // Should NOT fetch a second page since length already >= maxMessages + expect(ws.genie.listConversationMessages).toHaveBeenCalledTimes(1); + }); + }); + + // ----------------------------------------------------------------------- + // mapAttachments (tested indirectly via toMessageResponse) + // ----------------------------------------------------------------------- + + describe("mapAttachments", () => { + test("handles query attachments", async () => { + const msg = makeGenieMessageWithQuery(); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + // We drive through streamSendMessage to exercise mapAttachments + ws.genie.getMessageAttachmentQueryResult.mockResolvedValue({ + statement_response: { + manifest: { schema: { columns: [] } }, + result: { data_array: [] }, + }, + }); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-1", + query: { + title: "Sales Query", + description: "Total sales", + query: "SELECT sum(amount) FROM sales", + statementId: "stmt-1", + }, + text: undefined, + suggestedQuestions: undefined, + }); + }); + + test("handles text attachments", async () => { + const msg = makeGenieMessage({ + attachments: [ + { + attachment_id: "att-text", + text: { content: "Here is the explanation" }, + }, + ], + }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-text", + query: undefined, + text: { content: "Here is the explanation" }, + suggestedQuestions: undefined, + }); + }); + + test("handles suggestedQuestions attachments", async () => { + const msg = makeGenieMessage({ + attachments: [ + { + attachment_id: "att-sq", + suggested_questions: { + questions: ["What is X?", "Show me Y"], + }, + }, + ], + }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments[0]).toEqual({ + attachmentId: "att-sq", + query: undefined, + text: undefined, + suggestedQuestions: ["What is X?", "Show me Y"], + }); + }); + + test("returns empty array when message has no attachments", async () => { + const msg = makeGenieMessage({ attachments: undefined }); + const waiter = createMockWaiter({ result: msg }); + ws.genie.startConversation.mockResolvedValue(waiter); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "q", undefined), + ); + + const msgResult = events.find((e) => e.type === "message_result") as any; + expect(msgResult.message.attachments).toEqual([]); + }); + }); + + // ----------------------------------------------------------------------- + // classifyGenieError (tested indirectly via error events) + // ----------------------------------------------------------------------- + + describe("classifyGenieError", () => { + test("maps RESOURCE_DOES_NOT_EXIST to space access denied", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("RESOURCE_DOES_NOT_EXIST: space xyz"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "You don't have access to this Genie Space.", + }); + }); + + test("maps failed-to-reach-COMPLETED + FAILED to table permissions", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("failed to reach COMPLETED state, got FAILED"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: + "You may not have access to the data tables. Please verify your table permissions.", + }); + }); + + test("passes through unknown error messages", async () => { + ws.genie.startConversation.mockRejectedValue( + new Error("something unexpected"), + ); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "something unexpected", + }); + }); + + test("handles non-Error throwable", async () => { + ws.genie.startConversation.mockRejectedValue("string error"); + + const events = await collect( + connector.streamSendMessage(ws, "space-1", "hi", undefined), + ); + + expect(events[0]).toEqual({ + type: "error", + error: "string error", + }); + }); + }); +}); diff --git a/packages/appkit/src/context/tests/service-context.test.ts b/packages/appkit/src/context/tests/service-context.test.ts new file mode 100644 index 000000000..e8610da14 --- /dev/null +++ b/packages/appkit/src/context/tests/service-context.test.ts @@ -0,0 +1,457 @@ +import { setupDatabricksEnv } from "@tools/test-helpers"; +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; +import { + AuthenticationError, + ConfigurationError, + InitializationError, +} from "../../errors"; +import { ServiceContext } from "../service-context"; + +// ── Mock @databricks/sdk-experimental ────────────────────────────── + +const { mockMe, mockApiRequest, MockWorkspaceClient } = vi.hoisted(() => { + const mockMe = vi.fn(); + const mockApiRequest = vi.fn(); + + const MockWorkspaceClient = vi.fn().mockImplementation(() => ({ + currentUser: { me: mockMe }, + apiClient: { request: mockApiRequest }, + })); + + return { mockMe, mockApiRequest, MockWorkspaceClient }; +}); + +vi.mock("@databricks/sdk-experimental", () => ({ + WorkspaceClient: MockWorkspaceClient, +})); + +// ── Helpers ──────────────────────────────────────────────────────── + +function setupDefaultMocks() { + mockMe.mockResolvedValue({ id: "service-user-123" }); + mockApiRequest.mockResolvedValue({ "x-databricks-org-id": "ws-456" }); +} + +// ── Tests ────────────────────────────────────────────────────────── + +describe("ServiceContext", () => { + const originalEnv = { ...process.env }; + + beforeEach(() => { + vi.clearAllMocks(); + ServiceContext.reset(); + setupDatabricksEnv(); + setupDefaultMocks(); + }); + + afterEach(() => { + process.env = { ...originalEnv }; + ServiceContext.reset(); + }); + + // ── initialize() ─────────────────────────────────────────────── + + describe("initialize()", () => { + test("should initialize with a pre-configured client", async () => { + const client = new MockWorkspaceClient() as any; + + const state = await ServiceContext.initialize({}, client); + + expect(state.client).toBe(client); + expect(state.serviceUserId).toBe("service-user-123"); + expect(await state.workspaceId).toBe("ws-456"); + }); + + test("should create a WorkspaceClient when none is provided", async () => { + await ServiceContext.initialize(); + + // The mock constructor is called once internally + expect(MockWorkspaceClient).toHaveBeenCalled(); + }); + + test("should resolve warehouseId when options.warehouseId is true", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "wh-789"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + expect(state.warehouseId).toBeDefined(); + expect(await state.warehouseId).toBe("wh-789"); + }); + + test("should not set warehouseId when options.warehouseId is false", async () => { + const state = await ServiceContext.initialize({ warehouseId: false }); + + expect(state.warehouseId).toBeUndefined(); + }); + + test("should not set warehouseId when options are omitted", async () => { + const state = await ServiceContext.initialize(); + + expect(state.warehouseId).toBeUndefined(); + }); + + test("should throw when currentUser.me() returns no id", async () => { + mockMe.mockResolvedValue({}); + + await expect(ServiceContext.initialize()).rejects.toThrow( + ConfigurationError, + ); + }); + + test("should be idempotent - calling twice returns same instance", async () => { + const state1 = await ServiceContext.initialize(); + const state2 = await ServiceContext.initialize(); + + expect(state1).toBe(state2); + }); + + test("concurrent calls return the same promise", async () => { + const p1 = ServiceContext.initialize(); + const p2 = ServiceContext.initialize(); + + const [state1, state2] = await Promise.all([p1, p2]); + + expect(state1).toBe(state2); + // currentUser.me should only be called once regardless of concurrent calls + expect(mockMe).toHaveBeenCalledTimes(1); + }); + }); + + // ── get() ────────────────────────────────────────────────────── + + describe("get()", () => { + test("should throw InitializationError when not initialized", () => { + expect(() => ServiceContext.get()).toThrow(InitializationError); + expect(() => ServiceContext.get()).toThrow( + /ServiceContext not initialized/, + ); + }); + + test("should return state after initialization", async () => { + const state = await ServiceContext.initialize(); + const retrieved = ServiceContext.get(); + + expect(retrieved).toBe(state); + }); + }); + + // ── isInitialized() ──────────────────────────────────────────── + + describe("isInitialized()", () => { + test("should return false before initialization", () => { + expect(ServiceContext.isInitialized()).toBe(false); + }); + + test("should return true after initialization", async () => { + await ServiceContext.initialize(); + + expect(ServiceContext.isInitialized()).toBe(true); + }); + + test("should return false after reset()", async () => { + await ServiceContext.initialize(); + ServiceContext.reset(); + + expect(ServiceContext.isInitialized()).toBe(false); + }); + }); + + // ── createUserContext() ──────────────────────────────────────── + + describe("createUserContext()", () => { + beforeEach(async () => { + await ServiceContext.initialize({ warehouseId: true }); + }); + + test("should create a user context with correct properties", () => { + const userCtx = ServiceContext.createUserContext( + "user-token-abc", + "user-42", + "Alice", + ); + + expect(userCtx.userId).toBe("user-42"); + expect(userCtx.userName).toBe("Alice"); + expect(userCtx.isUserContext).toBe(true); + expect(userCtx.client).toBeDefined(); + }); + + test("should share warehouseId and workspaceId from service context", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "wh-shared"; + + // Re-initialize with the new env + ServiceContext.reset(); + mockApiRequest.mockResolvedValue({ "x-databricks-org-id": "ws-shared" }); + await ServiceContext.initialize({ warehouseId: true }); + + const userCtx = ServiceContext.createUserContext("user-token", "user-1"); + + const serviceCtx = ServiceContext.get(); + expect(userCtx.warehouseId).toBe(serviceCtx.warehouseId); + expect(userCtx.workspaceId).toBe(serviceCtx.workspaceId); + }); + + test("should create user client with PAT authType", () => { + ServiceContext.createUserContext("user-token", "user-1"); + + // The last call to MockWorkspaceClient should be for the user client + const lastCall = + MockWorkspaceClient.mock.calls[ + MockWorkspaceClient.mock.calls.length - 1 + ]; + expect(lastCall[0]).toMatchObject({ + token: "user-token", + host: process.env.DATABRICKS_HOST, + authType: "pat", + }); + }); + + test("should handle missing userName gracefully", () => { + const userCtx = ServiceContext.createUserContext("user-token", "user-1"); + + expect(userCtx.userName).toBeUndefined(); + }); + + test("should throw AuthenticationError on missing token", () => { + expect(() => ServiceContext.createUserContext("", "user-1")).toThrow( + AuthenticationError, + ); + }); + + test("should throw ConfigurationError when DATABRICKS_HOST is not set", () => { + delete process.env.DATABRICKS_HOST; + + expect(() => ServiceContext.createUserContext("token", "user-1")).toThrow( + ConfigurationError, + ); + }); + + test("should throw InitializationError when service context is not initialized", () => { + ServiceContext.reset(); + + expect(() => ServiceContext.createUserContext("token", "user-1")).toThrow( + InitializationError, + ); + }); + }); + + // ── reset() ──────────────────────────────────────────────────── + + describe("reset()", () => { + test("should clear the singleton state", async () => { + await ServiceContext.initialize(); + expect(ServiceContext.isInitialized()).toBe(true); + + ServiceContext.reset(); + + expect(ServiceContext.isInitialized()).toBe(false); + expect(() => ServiceContext.get()).toThrow(InitializationError); + }); + + test("should allow re-initialization after reset", async () => { + await ServiceContext.initialize(); + ServiceContext.reset(); + + mockMe.mockResolvedValue({ id: "new-service-user" }); + const state = await ServiceContext.initialize(); + + expect(state.serviceUserId).toBe("new-service-user"); + }); + }); + + // ── getWorkspaceId() (private, tested via initialize) ───────── + + describe("getWorkspaceId()", () => { + test("should use DATABRICKS_WORKSPACE_ID env var when set", async () => { + process.env.DATABRICKS_WORKSPACE_ID = "env-ws-123"; + + const state = await ServiceContext.initialize(); + + expect(await state.workspaceId).toBe("env-ws-123"); + // Should not call the SCIM API when env var is set + expect(mockApiRequest).not.toHaveBeenCalledWith( + expect.objectContaining({ path: "/api/2.0/preview/scim/v2/Me" }), + ); + }); + + test("should call SCIM API when env var is not set", async () => { + delete process.env.DATABRICKS_WORKSPACE_ID; + mockApiRequest.mockResolvedValue({ + "x-databricks-org-id": "scim-ws-789", + }); + + const state = await ServiceContext.initialize(); + + expect(await state.workspaceId).toBe("scim-ws-789"); + expect(mockApiRequest).toHaveBeenCalledWith( + expect.objectContaining({ + path: "/api/2.0/preview/scim/v2/Me", + method: "GET", + responseHeaders: ["x-databricks-org-id"], + }), + ); + }); + + test("should throw when SCIM API returns no workspace ID", async () => { + delete process.env.DATABRICKS_WORKSPACE_ID; + mockApiRequest.mockResolvedValue({}); + + const state = await ServiceContext.initialize(); + + await expect(state.workspaceId).rejects.toThrow(ConfigurationError); + }); + }); + + // ── getWarehouseId() (private, tested via initialize) ───────── + + describe("getWarehouseId()", () => { + test("should use DATABRICKS_WAREHOUSE_ID env var when set", async () => { + process.env.DATABRICKS_WAREHOUSE_ID = "env-wh-abc"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + expect(await state.warehouseId).toBe("env-wh-abc"); + }); + + test("should auto-discover warehouse in development mode", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-stopped", state: "STOPPED" }, + { id: "wh-running", state: "RUNNING" }, + { id: "wh-starting", state: "STARTING" }, + ], + }); + } + // SCIM response for workspaceId + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + // Should pick RUNNING warehouse (highest priority) + expect(await state.warehouseId).toBe("wh-running"); + }); + + test("should sort warehouses by state priority in dev mode", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-stopping", state: "STOPPING" }, + { id: "wh-starting", state: "STARTING" }, + { id: "wh-stopped", state: "STOPPED" }, + ], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + // STOPPED (priority 1) < STARTING (priority 2) < STOPPING (priority 3) + expect(await state.warehouseId).toBe("wh-stopped"); + }); + + test("should throw in dev mode when no warehouses are available", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ warehouses: [] }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in dev mode when all warehouses are deleted", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [ + { id: "wh-deleted", state: "DELETED" }, + { id: "wh-deleting", state: "DELETING" }, + ], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in dev mode when best warehouse has no id", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "development"; + + mockApiRequest.mockImplementation(({ path }: { path: string }) => { + if (path === "/api/2.0/sql/warehouses") { + return Promise.resolve({ + warehouses: [{ state: "RUNNING" }], + }); + } + return Promise.resolve({ "x-databricks-org-id": "ws-dev" }); + }); + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + }); + + test("should throw in production when DATABRICKS_WAREHOUSE_ID is not set", async () => { + delete process.env.DATABRICKS_WAREHOUSE_ID; + process.env.NODE_ENV = "production"; + + const state = await ServiceContext.initialize({ warehouseId: true }); + + await expect(state.warehouseId).rejects.toThrow(ConfigurationError); + await expect(state.warehouseId).rejects.toThrow( + /DATABRICKS_WAREHOUSE_ID/, + ); + }); + }); + + // ── getClientOptions() ───────────────────────────────────────── + + describe("getClientOptions()", () => { + test("should return product name and version", () => { + const options = ServiceContext.getClientOptions(); + + expect(options.product).toBe("@databricks/appkit"); + expect(options.productVersion).toBeDefined(); + }); + + test("should include dev mode user agent extra in development", () => { + process.env.NODE_ENV = "development"; + + const options = ServiceContext.getClientOptions(); + + expect(options.userAgentExtra).toEqual({ mode: "dev" }); + }); + + test("should not include dev mode user agent extra in production", () => { + process.env.NODE_ENV = "production"; + + const options = ServiceContext.getClientOptions(); + + expect(options.userAgentExtra).toBeUndefined(); + }); + }); +}); diff --git a/packages/appkit/src/stream/tests/stream-registry.test.ts b/packages/appkit/src/stream/tests/stream-registry.test.ts new file mode 100644 index 000000000..d3f70e95a --- /dev/null +++ b/packages/appkit/src/stream/tests/stream-registry.test.ts @@ -0,0 +1,582 @@ +import type { Context } from "@opentelemetry/api"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { EventRingBuffer } from "../buffers"; +import { StreamRegistry } from "../stream-registry"; +import type { StreamEntry } from "../types"; +import { SSEErrorCode } from "../types"; + +/** Create a minimal mock StreamEntry for testing. */ +function createMockStreamEntry( + streamId: string, + overrides: Partial = {}, +): StreamEntry { + return { + streamId, + generator: (async function* () {})(), + eventBuffer: new EventRingBuffer(10), + clients: new Set(), + isCompleted: false, + lastAccess: Date.now(), + abortController: new AbortController(), + traceContext: {} as Context, + ...overrides, + }; +} + +/** Create a mock response object that mimics express.Response for SSE writes. */ +function createMockClient(writableEnded = false) { + return { + write: vi.fn().mockReturnValue(true), + writableEnded, + } as unknown as import("express").Response; +} + +describe("StreamRegistry", () => { + let registry: StreamRegistry; + + beforeEach(() => { + registry = new StreamRegistry(3); + }); + + describe("add and get", () => { + test("should add a stream and retrieve it by id", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + const result = registry.get("stream-1"); + + expect(result).toBe(entry); + }); + + test("should return null for a non-existent stream", () => { + const result = registry.get("non-existent"); + + expect(result).toBeNull(); + }); + + test("should add multiple streams and retrieve each", () => { + const entry1 = createMockStreamEntry("stream-1"); + const entry2 = createMockStreamEntry("stream-2"); + const entry3 = createMockStreamEntry("stream-3"); + + registry.add(entry1); + registry.add(entry2); + registry.add(entry3); + + expect(registry.get("stream-1")).toBe(entry1); + expect(registry.get("stream-2")).toBe(entry2); + expect(registry.get("stream-3")).toBe(entry3); + }); + }); + + describe("has", () => { + test("should return true for an existing stream", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + expect(registry.has("stream-1")).toBe(true); + }); + + test("should return false for a non-existent stream", () => { + expect(registry.has("non-existent")).toBe(false); + }); + + test("should return false after a stream is removed", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + registry.remove("stream-1"); + + expect(registry.has("stream-1")).toBe(false); + }); + }); + + describe("remove", () => { + test("should remove an existing stream", () => { + const entry = createMockStreamEntry("stream-1"); + registry.add(entry); + + registry.remove("stream-1"); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.size()).toBe(0); + }); + + test("should not throw when removing a non-existent stream", () => { + expect(() => registry.remove("non-existent")).not.toThrow(); + }); + + test("should only remove the specified stream", () => { + const entry1 = createMockStreamEntry("stream-1"); + const entry2 = createMockStreamEntry("stream-2"); + registry.add(entry1); + registry.add(entry2); + + registry.remove("stream-1"); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.get("stream-2")).toBe(entry2); + expect(registry.size()).toBe(1); + }); + }); + + describe("size", () => { + test("should return 0 for an empty registry", () => { + expect(registry.size()).toBe(0); + }); + + test("should track size as streams are added", () => { + registry.add(createMockStreamEntry("stream-1")); + expect(registry.size()).toBe(1); + + registry.add(createMockStreamEntry("stream-2")); + expect(registry.size()).toBe(2); + + registry.add(createMockStreamEntry("stream-3")); + expect(registry.size()).toBe(3); + }); + + test("should decrease when streams are removed", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.add(createMockStreamEntry("stream-2")); + + registry.remove("stream-1"); + + expect(registry.size()).toBe(1); + }); + + test("should not exceed capacity after eviction", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Adding a fourth stream to a capacity-3 registry triggers eviction + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(registry.size()).toBe(3); + }); + }); + + describe("capacity enforcement and eviction", () => { + test("should evict the oldest stream when at capacity", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Adding a fourth should evict stream-1 (oldest lastAccess=100) + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-2")).toBe(true); + expect(registry.has("stream-3")).toBe(true); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should evict the stream with the smallest lastAccess and abort it", () => { + // When lastAccess order matches insertion order, the eviction logic + // cleanly targets the LRU stream. The stream with the smallest + // lastAccess is found and aborted. + const ac1 = new AbortController(); + const ac2 = new AbortController(); + const ac3 = new AbortController(); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: ac1, + }), + ); + registry.add( + createMockStreamEntry("stream-2", { + lastAccess: 300, + abortController: ac2, + }), + ); + registry.add( + createMockStreamEntry("stream-3", { + lastAccess: 200, + abortController: ac3, + }), + ); + + // Adding stream-4 triggers eviction. stream-1 has the smallest + // lastAccess (100) so it should be targeted. + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(ac1.signal.aborted).toBe(true); + expect(ac2.signal.aborted).toBe(false); + expect(ac3.signal.aborted).toBe(false); + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should exclude the stream being added from eviction", () => { + // This tests the excludeStreamId parameter: if a stream with the same + // ID as the one being added already exists and is the oldest, it should + // still be excluded from eviction. In practice, the new stream won't be + // in the registry yet when eviction runs, so excludeStreamId prevents + // misidentification. + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Add stream with id "stream-1" again; eviction should skip "stream-1" + // even though stream-1 has the oldest lastAccess, because it's the + // excludeStreamId. stream-2 should be evicted instead. + registry.add(createMockStreamEntry("stream-1", { lastAccess: 400 })); + + // stream-1 is updated (RingBuffer updates existing keys in place) + expect(registry.has("stream-1")).toBe(true); + // stream-2 should have been evicted as it was the oldest non-excluded + expect(registry.has("stream-2")).toBe(false); + expect(registry.has("stream-3")).toBe(true); + }); + + test("should abort the evicted stream's AbortController", () => { + const abortController1 = new AbortController(); + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: abortController1, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(abortController1.signal.aborted).toBe(true); + }); + + test("should abort with 'Stream evicted' reason", () => { + const abortController1 = new AbortController(); + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: abortController1, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + expect(abortController1.signal.reason).toBe("Stream evicted"); + }); + }); + + describe("eviction SSE broadcast", () => { + test("should send STREAM_EVICTED error to all clients of evicted stream", () => { + const client1 = createMockClient(); + const client2 = createMockClient(); + + const clients = new Set([client1, client2]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Trigger eviction of stream-1 + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Each client should have received the SSE error event + for (const client of [client1, client2]) { + expect(client.write).toHaveBeenCalledWith("event: error\n"); + expect(client.write).toHaveBeenCalledWith( + `data: ${JSON.stringify({ error: "Stream evicted", code: SSEErrorCode.STREAM_EVICTED })}\n\n`, + ); + } + }); + + test("should skip clients with writableEnded=true during eviction broadcast", () => { + const activeClient = createMockClient(false); + const endedClient = createMockClient(true); + + const clients = new Set([activeClient, endedClient]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Active client should receive the error + expect(activeClient.write).toHaveBeenCalledWith("event: error\n"); + + // Ended client should NOT receive any writes + expect(endedClient.write).not.toHaveBeenCalled(); + }); + + test("should handle client.write throwing an error gracefully", () => { + const throwingClient = createMockClient(false); + (throwingClient.write as ReturnType).mockImplementation( + () => { + throw new Error("Connection reset"); + }, + ); + + const normalClient = createMockClient(false); + + const clients = new Set([throwingClient, normalClient]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Should not throw despite the throwing client + expect(() => { + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + }).not.toThrow(); + + // The normal client should still receive the error despite the other + // client throwing. Note: both clients are in a Set, iteration order is + // insertion order. The throwing client's error is caught per-client. + // We verify the abort still happened (the overall eviction completed). + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-4")).toBe(true); + }); + + test("should send correct SSE error format with STREAM_EVICTED code", () => { + const client = createMockClient(); + const clients = new Set([client]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Verify the exact data payload + const dataCall = ( + client.write as ReturnType + ).mock.calls.find((call: unknown[]) => + (call[0] as string).startsWith("data:"), + ); + expect(dataCall).toBeDefined(); + + const payload = JSON.parse( + (dataCall![0] as string).replace("data: ", "").trim(), + ); + expect(payload).toEqual({ + error: "Stream evicted", + code: "STREAM_EVICTED", + }); + }); + + test("should broadcast to multiple clients on the same evicted stream", () => { + const client1 = createMockClient(); + const client2 = createMockClient(); + const client3 = createMockClient(); + + const clients = new Set([client1, client2, client3]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // All three clients should have received exactly 2 write calls each + // (one for "event: error\n" and one for the data line) + for (const client of [client1, client2, client3]) { + expect(client.write).toHaveBeenCalledTimes(2); + } + }); + + test("should not broadcast if evicted stream has no clients", () => { + const abortController = new AbortController(); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients: new Set(), + abortController, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // Should not throw even with no clients + expect(() => { + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + }).not.toThrow(); + + // Stream should still be evicted and aborted + expect(registry.has("stream-1")).toBe(false); + expect(abortController.signal.aborted).toBe(true); + }); + }); + + describe("clear", () => { + test("should abort all streams and clear the registry", () => { + const ac1 = new AbortController(); + const ac2 = new AbortController(); + const ac3 = new AbortController(); + + registry.add(createMockStreamEntry("stream-1", { abortController: ac1 })); + registry.add(createMockStreamEntry("stream-2", { abortController: ac2 })); + registry.add(createMockStreamEntry("stream-3", { abortController: ac3 })); + + registry.clear(); + + expect(registry.size()).toBe(0); + expect(ac1.signal.aborted).toBe(true); + expect(ac2.signal.aborted).toBe(true); + expect(ac3.signal.aborted).toBe(true); + }); + + test("should abort with 'Server shutdown' reason", () => { + const ac = new AbortController(); + registry.add(createMockStreamEntry("stream-1", { abortController: ac })); + + registry.clear(); + + expect(ac.signal.reason).toBe("Server shutdown"); + }); + + test("should handle clearing an empty registry", () => { + expect(() => registry.clear()).not.toThrow(); + expect(registry.size()).toBe(0); + }); + + test("should make all streams inaccessible after clear", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.add(createMockStreamEntry("stream-2")); + + registry.clear(); + + expect(registry.get("stream-1")).toBeNull(); + expect(registry.get("stream-2")).toBeNull(); + expect(registry.has("stream-1")).toBe(false); + expect(registry.has("stream-2")).toBe(false); + }); + + test("should allow adding new streams after clear", () => { + registry.add(createMockStreamEntry("stream-1")); + registry.clear(); + + const newEntry = createMockStreamEntry("stream-new"); + registry.add(newEntry); + + expect(registry.get("stream-new")).toBe(newEntry); + expect(registry.size()).toBe(1); + }); + }); + + describe("edge cases", () => { + test("should work with capacity of 1", () => { + const smallRegistry = new StreamRegistry(1); + const ac1 = new AbortController(); + + smallRegistry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + abortController: ac1, + }), + ); + expect(smallRegistry.size()).toBe(1); + + smallRegistry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + + expect(smallRegistry.size()).toBe(1); + expect(smallRegistry.has("stream-1")).toBe(false); + expect(smallRegistry.has("stream-2")).toBe(true); + expect(ac1.signal.aborted).toBe(true); + }); + + test("should handle adding a stream with the same id (update)", () => { + const entry1 = createMockStreamEntry("stream-1", { + lastAccess: 100, + }); + const entry2 = createMockStreamEntry("stream-1", { + lastAccess: 200, + }); + + registry.add(entry1); + registry.add(entry2); + + // The RingBuffer updates in place for same key + expect(registry.size()).toBe(1); + const retrieved = registry.get("stream-1"); + expect(retrieved?.lastAccess).toBe(200); + }); + + test("should handle sequential evictions correctly", () => { + registry.add(createMockStreamEntry("stream-1", { lastAccess: 100 })); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + // First eviction: stream-1 evicted + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + expect(registry.has("stream-1")).toBe(false); + + // Second eviction: stream-2 evicted + registry.add(createMockStreamEntry("stream-5", { lastAccess: 500 })); + expect(registry.has("stream-2")).toBe(false); + + // stream-3, stream-4, stream-5 remain + expect(registry.has("stream-3")).toBe(true); + expect(registry.has("stream-4")).toBe(true); + expect(registry.has("stream-5")).toBe(true); + expect(registry.size()).toBe(3); + }); + + test("should not evict when under capacity", () => { + const ac1 = new AbortController(); + registry.add(createMockStreamEntry("stream-1", { abortController: ac1 })); + registry.add(createMockStreamEntry("stream-2")); + + // Only 2 streams in a capacity-3 registry, no eviction + expect(registry.size()).toBe(2); + expect(ac1.signal.aborted).toBe(false); + }); + + test("should handle mixed writable states during eviction", () => { + const activeClient = createMockClient(false); + const endedClient1 = createMockClient(true); + const endedClient2 = createMockClient(true); + + const clients = new Set([endedClient1, activeClient, endedClient2]); + + registry.add( + createMockStreamEntry("stream-1", { + lastAccess: 100, + clients, + }), + ); + registry.add(createMockStreamEntry("stream-2", { lastAccess: 200 })); + registry.add(createMockStreamEntry("stream-3", { lastAccess: 300 })); + + registry.add(createMockStreamEntry("stream-4", { lastAccess: 400 })); + + // Only the active client should receive writes + expect(activeClient.write).toHaveBeenCalledTimes(2); + expect(endedClient1.write).not.toHaveBeenCalled(); + expect(endedClient2.write).not.toHaveBeenCalled(); + }); + }); +});