diff --git a/packages/typescript/ai-client/src/chat-client.ts b/packages/typescript/ai-client/src/chat-client.ts index 65554a447..69b6e743c 100644 --- a/packages/typescript/ai-client/src/chat-client.ts +++ b/packages/typescript/ai-client/src/chat-client.ts @@ -4,13 +4,17 @@ import { normalizeToUIMessage, } from '@tanstack/ai' import { DefaultChatClientEventEmitter } from './events' +import { normalizeConnectionAdapter } from './connection-adapters' import type { AnyClientTool, ContentPart, ModelMessage, StreamChunk, } from '@tanstack/ai' -import type { ConnectionAdapter } from './connection-adapters' +import type { + ConnectionAdapter, + SubscribeConnectionAdapter, +} from './connection-adapters' import type { ChatClientEventEmitter } from './events' import type { ChatClientOptions, @@ -23,7 +27,7 @@ import type { export class ChatClient { private processor: StreamProcessor - private connection: ConnectionAdapter + private connection: SubscribeConnectionAdapter private uniqueId: string private body: Record = {} private pendingMessageBody: Record | undefined = undefined @@ -40,6 +44,10 @@ export class ChatClient { private pendingToolExecutions: Map> = new Map() // Flag to deduplicate continuation checks during action draining private continuationPending = false + private subscriptionAbortController: AbortController | null = null + private processingResolve: (() => void) | null = null + private errorReportedGeneration: number | null = null + private streamGeneration = 0 private callbacksRef: { current: { @@ -57,7 +65,7 @@ export class ChatClient { constructor(options: ChatClientOptions) { this.uniqueId = options.id || this.generateUniqueId('chat') this.body = options.body || {} - this.connection = options.connection + this.connection = normalizeConnectionAdapter(options.connection) this.events = new DefaultChatClientEventEmitter(this.uniqueId) // Build client tools map @@ -91,15 +99,30 @@ export class ChatClient { }, onStreamStart: () => { this.setStatus('streaming') + const assistantMessageId = this.processor.getCurrentAssistantMessageId() + if (!assistantMessageId) { + return + } + const messages = this.processor.getMessages() + const assistantMessage = messages.find( + (m: UIMessage) => m.id === assistantMessageId, + ) + if (assistantMessage) { + this.currentMessageId = assistantMessage.id + this.events.messageAppended( + assistantMessage, + this.currentStreamId || undefined, + ) + } }, onStreamEnd: (message: UIMessage) => { this.callbacksRef.current.onFinish(message) this.setStatus('ready') + // Resolve the processing-complete promise so streamResponse can continue + this.resolveProcessing() }, onError: (error: Error) => { - this.setError(error) - this.setStatus('error') - this.callbacksRef.current.onError(error) + this.reportStreamError(error) }, onTextUpdate: (messageId: string, content: string) => { // Emit text update to devtools @@ -225,69 +248,94 @@ export class ChatClient { this.events.errorChanged(error?.message || null) } + private abortSubscriptionLoop(): void { + this.subscriptionAbortController?.abort() + this.subscriptionAbortController = null + } + + private resolveProcessing(): void { + this.processingResolve?.() + this.processingResolve = null + } + + private cancelInFlightStream(options?: { setReadyStatus?: boolean }): void { + this.abortController?.abort() + this.abortController = null + this.abortSubscriptionLoop() + this.resolveProcessing() + this.setIsLoading(false) + if (options?.setReadyStatus) { + this.setStatus('ready') + } + } + + private reportStreamError(error: Error): void { + const alreadyReported = this.errorReportedGeneration === this.streamGeneration + this.setError(error) + this.setStatus('error') + if (!alreadyReported) { + this.errorReportedGeneration = this.streamGeneration + this.callbacksRef.current.onError(error) + } + } + /** - * Process a stream through the StreamProcessor + * Start the background subscription loop. */ - private async processStream( - source: AsyncIterable, - ): Promise { - // Generate a stream ID for this streaming operation - this.currentStreamId = this.generateUniqueId('stream') + private startSubscription(): void { + this.subscriptionAbortController = new AbortController() + const signal = this.subscriptionAbortController.signal - // Prepare for a new assistant message (created lazily on first content) - this.processor.prepareAssistantMessage() + this.consumeSubscription(signal).catch((err) => { + if (err instanceof Error && err.name !== 'AbortError') { + this.reportStreamError(err) + } + // Resolve pending processing so streamResponse doesn't hang + this.resolveProcessing() + }) + } - // Process each chunk - for await (const chunk of source) { + /** + * Consume chunks from the connection subscription. + */ + private async consumeSubscription(signal: AbortSignal): Promise { + const stream = this.connection.subscribe(signal) + for await (const chunk of stream) { + if (signal.aborted) break this.callbacksRef.current.onChunk(chunk) this.processor.processChunk(chunk) - - // Track the message ID once the processor lazily creates it - if (!this.currentMessageId) { - const newMessageId = - this.processor.getCurrentAssistantMessageId() ?? null - if (newMessageId) { - this.currentMessageId = newMessageId - // Emit message appended event now that the assistant message exists - const assistantMessage = this.processor - .getMessages() - .find((m: UIMessage) => m.id === newMessageId) - if (assistantMessage) { - this.events.messageAppended( - assistantMessage, - this.currentStreamId || undefined, - ) - } - } + // RUN_FINISHED / RUN_ERROR signal run completion — resolve processing + // (redundant if onStreamEnd already resolved it, harmless) + if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { + this.resolveProcessing() } - - // Yield control back to event loop to allow UI updates + // Yield control back to event loop for UI updates await new Promise((resolve) => setTimeout(resolve, 0)) } + } - // Wait for all pending tool executions to complete before finalizing - // This ensures client tools finish before we check for continuation - if (this.pendingToolExecutions.size > 0) { - await Promise.all(this.pendingToolExecutions.values()) - } - - // Finalize the stream - this.processor.finalizeStream() - - // Get the message ID (may be null if no content arrived) - const messageId = this.processor.getCurrentAssistantMessageId() - - // Clear the current stream and message IDs - this.currentStreamId = null - this.currentMessageId = null - - // Return the assistant message if one was created - if (messageId) { - const messages = this.processor.getMessages() - return messages.find((m: UIMessage) => m.id === messageId) || null + /** + * Ensure subscription loop is running, starting it if needed. + */ + private ensureSubscription(): void { + if ( + !this.subscriptionAbortController || + this.subscriptionAbortController.signal.aborted + ) { + this.startSubscription() } + } - return null + /** + * Create a promise that resolves when onStreamEnd fires. + * Used by streamResponse to await processing completion. + */ + private waitForProcessing(): Promise { + // Resolve any stale promise (e.g., from a previous aborted request) + this.resolveProcessing() + return new Promise((resolve) => { + this.processingResolve = resolve + }) } /** @@ -407,9 +455,13 @@ export class ChatClient { return } + // Track generation so a superseded stream's cleanup doesn't clobber the new one + const generation = ++this.streamGeneration + this.setIsLoading(true) this.setStatus('submitted') this.setError(undefined) + this.errorReportedGeneration = null this.abortController = new AbortController() // Reset pending tool executions for the new stream this.pendingToolExecutions.clear() @@ -433,42 +485,81 @@ export class ChatClient { // Clear the pending message body after use this.pendingMessageBody = undefined - // Connect and stream - const stream = this.connection.connect( - messages, - mergedBody, - this.abortController.signal, - ) + // Generate stream ID — assistant message will be created by stream events + this.currentStreamId = this.generateUniqueId('stream') + this.currentMessageId = null + + // Reset processor stream state for new response — prevents stale + // messageStates entries (from a previous stream) from blocking + // creation of a new assistant message (e.g. after reload). + this.processor.prepareAssistantMessage() + + // Ensure subscription loop is running + this.ensureSubscription() - await this.processStream(stream) + // Set up promise that resolves when onStreamEnd fires + const processingComplete = this.waitForProcessing() + + // Send through normalized connection (pushes chunks to subscription queue) + await this.connection.send(messages, mergedBody, this.abortController.signal) + + // Wait for subscription loop to finish processing all chunks + await processingComplete + + // If this stream was superseded (e.g. by reload()), bail out — + // the new stream owns the processor and processingResolve now. + if (generation !== this.streamGeneration) { + return + } + + // A RUN_ERROR from the stream transitions status to error. + // Do not treat this stream as a successful completion. + if (this.status === 'error') { + return + } + + // Wait for pending client tool executions + if (this.pendingToolExecutions.size > 0) { + await Promise.all(this.pendingToolExecutions.values()) + } + + // Finalize (idempotent — may already be done by RUN_FINISHED handler) + this.processor.finalizeStream() streamCompletedSuccessfully = true } catch (err) { if (err instanceof Error) { if (err.name === 'AbortError') { return } - this.setError(err) - this.setStatus('error') - this.callbacksRef.current.onError(err) + if (generation === this.streamGeneration) { + this.reportStreamError(err) + } } } finally { - this.abortController = null - this.setIsLoading(false) - this.pendingMessageBody = undefined // Ensure it's cleared even on error - - // Drain any actions that were queued while the stream was in progress - await this.drainPostStreamActions() - - // Continue conversation if the stream ended with a tool result (server tool completed) - if (streamCompletedSuccessfully) { - const messages = this.processor.getMessages() - const lastPart = messages.at(-1)?.parts.at(-1) - - if (lastPart?.type === 'tool-result' && this.shouldAutoSend()) { - try { - await this.checkForContinuation() - } catch (error) { - console.error('Failed to continue flow after tool result:', error) + // Only clean up if this is still the active stream. + // A superseded stream (e.g. reload() started a new one) must not + // clobber the new stream's abortController or isLoading state. + if (generation === this.streamGeneration) { + this.currentStreamId = null + this.currentMessageId = null + this.abortController = null + this.setIsLoading(false) + this.pendingMessageBody = undefined // Ensure it's cleared even on error + + // Drain any actions that were queued while the stream was in progress + await this.drainPostStreamActions() + + // Continue conversation if the stream ended with a tool result (server tool completed) + if (streamCompletedSuccessfully) { + const messages = this.processor.getMessages() + const lastPart = messages.at(-1)?.parts.at(-1) + + if (lastPart?.type === 'tool-result' && this.shouldAutoSend()) { + try { + await this.checkForContinuation() + } catch (error) { + console.error('Failed to continue flow after tool result:', error) + } } } } @@ -489,6 +580,11 @@ export class ChatClient { if (lastUserMessageIndex === -1) return + // Cancel any active stream before reloading + if (this.isLoading) { + this.cancelInFlightStream() + } + this.events.reloaded(lastUserMessageIndex) // Remove all messages after the last user message @@ -502,12 +598,7 @@ export class ChatClient { * Stop the current stream */ stop(): void { - if (this.abortController) { - this.abortController.abort() - this.abortController = null - } - this.setIsLoading(false) - this.setStatus('ready') + this.cancelInFlightStream({ setReadyStatus: true }) this.events.stopped() } @@ -686,7 +777,13 @@ export class ChatClient { onError?: (error: Error) => void }): void { if (options.connection !== undefined) { - this.connection = options.connection + // Cancel any in-flight stream to avoid hanging on stale processing promises + if (this.isLoading) { + this.cancelInFlightStream({ setReadyStatus: true }) + } else { + this.abortSubscriptionLoop() + } + this.connection = normalizeConnectionAdapter(options.connection) } if (options.body !== undefined) { this.body = options.body diff --git a/packages/typescript/ai-client/src/connection-adapters.ts b/packages/typescript/ai-client/src/connection-adapters.ts index b29e90551..0f1c2e02f 100644 --- a/packages/typescript/ai-client/src/connection-adapters.ts +++ b/packages/typescript/ai-client/src/connection-adapters.ts @@ -62,15 +62,9 @@ async function* readStreamLines( } } -/** - * Connection adapter interface - converts a connection into a stream of chunks - */ -export interface ConnectionAdapter { +export interface ConnectConnectionAdapter { /** - * Connect and return an async iterable of StreamChunks - * @param messages - The messages to send (UIMessages or ModelMessages) - * @param data - Additional data to send - * @param abortSignal - Optional abort signal for request cancellation + * Connect and return an async iterable of StreamChunks. */ connect: ( messages: Array | Array, @@ -79,6 +73,145 @@ export interface ConnectionAdapter { ) => AsyncIterable } +export interface SubscribeConnectionAdapter { + /** + * Subscribe to stream chunks. + */ + subscribe: (abortSignal?: AbortSignal) => AsyncIterable + /** + * Send a request; chunks arrive through subscribe(). + */ + send: ( + messages: Array | Array, + data?: Record, + abortSignal?: AbortSignal, + ) => Promise +} + +/** + * Connection adapter union. + * Provide either `connect`, or `subscribe` + `send`. + */ +export type ConnectionAdapter = + | ConnectConnectionAdapter + | SubscribeConnectionAdapter + +/** + * Normalize a ConnectionAdapter to subscribe/send operations. + * + * If a connection provides native subscribe/send, that mode is used. + * Otherwise, connect() is wrapped using an async queue. + */ +export function normalizeConnectionAdapter( + connection: ConnectionAdapter | undefined, +): SubscribeConnectionAdapter { + if (!connection) { + throw new Error('Connection adapter is required') + } + + const hasConnect = 'connect' in connection + const hasSubscribe = 'subscribe' in connection + const hasSend = 'send' in connection + + if (hasConnect && (hasSubscribe || hasSend)) { + throw new Error( + 'Connection adapter must provide either connect or both subscribe and send, not both modes', + ) + } + + if (hasSubscribe && hasSend) { + return { + subscribe: connection.subscribe.bind(connection), + send: connection.send.bind(connection), + } + } + + if (!hasConnect) { + throw new Error( + 'Connection adapter must provide either connect or both subscribe and send', + ) + } + + // Legacy connect() wrapper + let activeBuffer: Array = [] + let activeWaiters: Array<(chunk: StreamChunk | null) => void> = [] + + function push(chunk: StreamChunk): void { + const waiter = activeWaiters.shift() + if (waiter) { + waiter(chunk) + } else { + activeBuffer.push(chunk) + } + } + + return { + subscribe(abortSignal?: AbortSignal): AsyncIterable { + // Transfer ownership to the latest subscriber so only one active + // subscribe() call receives chunks from the shared connect-wrapper queue. + const myBuffer: Array = activeBuffer.splice(0) + const myWaiters: Array<(chunk: StreamChunk | null) => void> = [] + activeBuffer = myBuffer + activeWaiters = myWaiters + + return (async function* () { + while (!abortSignal?.aborted) { + let chunk: StreamChunk | null + if (myBuffer.length > 0) { + chunk = myBuffer.shift()! + } else { + chunk = await new Promise((resolve) => { + const onAbort = () => resolve(null) + myWaiters.push((c) => { + abortSignal?.removeEventListener('abort', onAbort) + resolve(c) + }) + abortSignal?.addEventListener('abort', onAbort, { once: true }) + }) + } + if (chunk !== null) yield chunk + } + })() + }, + async send(messages, data, abortSignal) { + let hasTerminalEvent = false + try { + const stream = connection.connect(messages, data, abortSignal) + for await (const chunk of stream) { + if (chunk.type === 'RUN_FINISHED' || chunk.type === 'RUN_ERROR') { + hasTerminalEvent = true + } + push(chunk) + } + + // If the connect stream ended cleanly without a terminal event, + // synthesize RUN_FINISHED so request-scoped consumers can complete. + if (!abortSignal?.aborted && !hasTerminalEvent) { + push({ + type: 'RUN_FINISHED', + runId: `run-${Date.now()}`, + model: 'connect-wrapper', + timestamp: Date.now(), + finishReason: 'stop', + }) + } + } catch (err) { + if (!abortSignal?.aborted && !hasTerminalEvent) { + push({ + type: 'RUN_ERROR', + timestamp: Date.now(), + error: { + message: + err instanceof Error ? err.message : 'Unknown error in connect()', + }, + }) + } + throw err + } + }, + } +} + /** * Options for fetch-based connection adapters */ @@ -129,7 +262,7 @@ export function fetchServerSentEvents( options: | FetchConnectionOptions | (() => FetchConnectionOptions | Promise) = {}, -): ConnectionAdapter { +): ConnectConnectionAdapter { return { async *connect(messages, data, abortSignal) { // Resolve URL and options if they are functions @@ -228,7 +361,7 @@ export function fetchHttpStream( options: | FetchConnectionOptions | (() => FetchConnectionOptions | Promise) = {}, -): ConnectionAdapter { +): ConnectConnectionAdapter { return { async *connect(messages, data, abortSignal) { // Resolve URL and options if they are functions @@ -301,7 +434,7 @@ export function stream( messages: Array | Array, data?: Record, ) => AsyncIterable, -): ConnectionAdapter { +): ConnectConnectionAdapter { return { async *connect(messages, data) { // Pass messages as-is (UIMessages with parts preserved) @@ -332,7 +465,7 @@ export function rpcStream( messages: Array | Array, data?: Record, ) => AsyncIterable, -): ConnectionAdapter { +): ConnectConnectionAdapter { return { async *connect(messages, data) { // Pass messages as-is (UIMessages with parts preserved) diff --git a/packages/typescript/ai-client/src/index.ts b/packages/typescript/ai-client/src/index.ts index b279605d1..2c6d21a81 100644 --- a/packages/typescript/ai-client/src/index.ts +++ b/packages/typescript/ai-client/src/index.ts @@ -27,8 +27,10 @@ export { fetchHttpStream, stream, rpcStream, + type ConnectConnectionAdapter, type ConnectionAdapter, type FetchConnectionOptions, + type SubscribeConnectionAdapter, } from './connection-adapters' // Re-export message converters from @tanstack/ai diff --git a/packages/typescript/ai-client/src/types.ts b/packages/typescript/ai-client/src/types.ts index 985725481..388e1d73b 100644 --- a/packages/typescript/ai-client/src/types.ts +++ b/packages/typescript/ai-client/src/types.ts @@ -178,8 +178,9 @@ export interface ChatClientOptions< TTools extends ReadonlyArray = any, > { /** - * Connection adapter for streaming - * Use fetchServerSentEvents(), fetchHttpStream(), or stream() to create adapters + * Connection adapter for streaming. + * Supports mutually exclusive modes: request-response via `connect()`, or + * subscribe/send mode via `subscribe()` + `send()`. */ connection: ConnectionAdapter diff --git a/packages/typescript/ai-client/tests/chat-client.test.ts b/packages/typescript/ai-client/tests/chat-client.test.ts index 279603783..2e0df40b8 100644 --- a/packages/typescript/ai-client/tests/chat-client.test.ts +++ b/packages/typescript/ai-client/tests/chat-client.test.ts @@ -6,6 +6,7 @@ import { createThinkingChunks, createToolCallChunks, } from './test-utils' +import type { StreamChunk } from '@tanstack/ai' import type { UIMessage } from '../src/types' describe('ChatClient', () => { @@ -74,6 +75,95 @@ describe('ChatClient', () => { // Message IDs should be unique between clients expect(client1MessageId).not.toBe(client2MessageId) }) + + it('should throw if connection is not provided', () => { + expect(() => new ChatClient({} as any)).toThrow( + 'Connection adapter is required', + ) + }) + }) + + describe('subscribe/send connection mode', () => { + function createSubscribeAdapter(chunksToSend: Array) { + let hasPendingSend = false + let wakeSubscriber: (() => void) | null = null + let removeAbortListener: (() => void) | null = null + + const subscribe = vi.fn((signal?: AbortSignal) => { + return (async function* () { + while (!signal?.aborted) { + if (!hasPendingSend) { + await new Promise((resolve) => { + removeAbortListener?.() + removeAbortListener = null + wakeSubscriber = resolve + const onAbort = () => resolve() + signal?.addEventListener('abort', onAbort, { once: true }) + removeAbortListener = () => { + signal?.removeEventListener('abort', onAbort) + } + }) + continue + } + + hasPendingSend = false + for (const chunk of chunksToSend) { + yield chunk + } + } + removeAbortListener?.() + removeAbortListener = null + })() + }) + + const send = vi.fn(async () => { + removeAbortListener?.() + removeAbortListener = null + hasPendingSend = true + wakeSubscriber?.() + wakeSubscriber = null + }) + + return { subscribe, send } + } + + it('should use subscribe/send adapter mode', async () => { + const adapter = createSubscribeAdapter( + createTextChunks('From subscribe/send mode'), + ) + const client = new ChatClient({ connection: adapter }) + + await client.sendMessage('Hello') + + expect(adapter.subscribe).toHaveBeenCalled() + expect(adapter.send).toHaveBeenCalled() + }) + + it('should remain pending without terminal run events', async () => { + const adapter = createSubscribeAdapter([ + { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'H', + content: 'H', + }, + ]) + const client = new ChatClient({ connection: adapter }) + + const sendPromise = client.sendMessage('Hello') + const completed = await Promise.race([ + sendPromise.then(() => true), + new Promise((resolve) => setTimeout(() => resolve(false), 100)), + ]) + + expect(completed).toBe(false) + + // Explicitly stop to unblock the in-flight request. + client.stop() + await sendPromise + }) }) describe('sendMessage', () => { @@ -387,8 +477,12 @@ describe('ChatClient', () => { await client.sendMessage('Hello') - expect(onError).toHaveBeenCalledWith(error) - expect(client.getError()).toBe(error) + expect(onError).toHaveBeenCalled() + expect(onError).toHaveBeenCalledTimes(1) + expect(onError.mock.calls[0]![0]).toBeInstanceOf(Error) + expect(onError.mock.calls[0]![0].message).toBe('Connection failed') + expect(client.getError()).toBeInstanceOf(Error) + expect(client.getError()?.message).toBe('Connection failed') }) }) @@ -500,7 +594,8 @@ describe('ChatClient', () => { await client.sendMessage('Hello') - expect(client.getError()).toBe(error) + expect(client.getError()).toBeInstanceOf(Error) + expect(client.getError()?.message).toBe('Network error') expect(client.getStatus()).toBe('error') }) @@ -526,6 +621,56 @@ describe('ChatClient', () => { expect(client.getError()).toBeUndefined() expect(client.getStatus()).not.toBe('error') }) + + it('should not hang when connection is updated during an active stream', async () => { + const noTerminalAdapter = createMockConnectionAdapter({ + chunks: [ + { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'H', + content: 'H', + }, + ], + chunkDelay: 50, + }) + const replacementAdapter = createMockConnectionAdapter({ + chunks: createTextChunks('replacement'), + }) + const client = new ChatClient({ connection: noTerminalAdapter }) + + const sendPromise = client.sendMessage('Hello') + await new Promise((resolve) => setTimeout(resolve, 10)) + client.updateOptions({ connection: replacementAdapter }) + + const completed = await Promise.race([ + sendPromise.then(() => true), + new Promise((resolve) => setTimeout(() => resolve(false), 500)), + ]) + + expect(completed).toBe(true) + expect(client.getIsLoading()).toBe(false) + }) + + it('should surface subscription loop failures without hanging', async () => { + const connection = { + subscribe: async function* () { + throw new Error('subscription exploded') + }, + send: async () => {}, + } + const onError = vi.fn() + const client = new ChatClient({ connection, onError }) + + await client.sendMessage('Hello') + + expect(onError).toHaveBeenCalledTimes(1) + expect(onError.mock.calls[0]?.[0]).toBeInstanceOf(Error) + expect(onError.mock.calls[0]?.[0].message).toBe('subscription exploded') + expect(client.getStatus()).toBe('error') + }) }) describe('devtools events', () => { diff --git a/packages/typescript/ai-client/tests/connection-adapters.test.ts b/packages/typescript/ai-client/tests/connection-adapters.test.ts index b25b76b32..893c3f3b5 100644 --- a/packages/typescript/ai-client/tests/connection-adapters.test.ts +++ b/packages/typescript/ai-client/tests/connection-adapters.test.ts @@ -2,6 +2,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { fetchHttpStream, fetchServerSentEvents, + normalizeConnectionAdapter, stream, } from '../src/connection-adapters' import type { StreamChunk } from '@tanstack/ai' @@ -523,4 +524,133 @@ describe('connection-adapters', () => { ) }) }) + + describe('normalizeConnectionAdapter', () => { + it('should throw when connection is not provided', () => { + expect(() => normalizeConnectionAdapter(undefined)).toThrow( + 'Connection adapter is required', + ) + }) + + it('should throw when subscribe/send are partially implemented', () => { + const invalidAdapters = [ + { subscribe: async function* () {} }, + { send: async () => {} }, + ] as const + + for (const adapter of invalidAdapters) { + expect(() => normalizeConnectionAdapter(adapter as any)).toThrow( + 'Connection adapter must provide either connect or both subscribe and send', + ) + } + }) + + it('should throw when both connection modes are provided', () => { + const invalidAdapter = { + connect: async function* () {}, + subscribe: async function* () {}, + send: async () => {}, + } + + expect(() => normalizeConnectionAdapter(invalidAdapter as any)).toThrow( + 'Connection adapter must provide either connect or both subscribe and send, not both modes', + ) + }) + + it('should synthesize RUN_FINISHED when wrapped connect stream has no terminal event', async () => { + const base = stream(async function* () { + yield { + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + model: 'test', + timestamp: Date.now(), + delta: 'Hi', + content: 'Hi', + } + }) + + const adapter = normalizeConnectionAdapter(base) + const abortController = new AbortController() + const receivedPromise = (async () => { + const received: Array = [] + for await (const chunk of adapter.subscribe(abortController.signal)) { + received.push(chunk) + if (received.length === 2) { + abortController.abort() + } + } + return received + })() + + await adapter.send([{ role: 'user', content: 'Hello' }]) + const received = await receivedPromise + + expect(received).toHaveLength(2) + expect(received[1]?.type).toBe('RUN_FINISHED') + }) + + it('should synthesize RUN_ERROR when wrapped connect stream throws', async () => { + const base = stream(async function* () { + throw new Error('connect exploded') + }) + + const adapter = normalizeConnectionAdapter(base) + const abortController = new AbortController() + const receivedPromise = (async () => { + const received: Array = [] + for await (const chunk of adapter.subscribe(abortController.signal)) { + received.push(chunk) + if (received.length === 1) { + abortController.abort() + } + } + return received + })() + + await expect( + adapter.send([{ role: 'user', content: 'Hello' }]), + ).rejects.toThrow('connect exploded') + const received = await receivedPromise + + expect(received).toHaveLength(1) + expect(received[0]?.type).toBe('RUN_ERROR') + }) + + it('should not synthesize duplicate RUN_ERROR when stream already emitted one before throwing', async () => { + const base = stream(async function* () { + yield { + type: 'RUN_ERROR', + timestamp: Date.now(), + error: { + message: 'already failed', + }, + } + throw new Error('connect exploded') + }) + + const adapter = normalizeConnectionAdapter(base) + const abortController = new AbortController() + const receivedPromise = (async () => { + const received: Array = [] + for await (const chunk of adapter.subscribe(abortController.signal)) { + received.push(chunk) + if (received.length === 1) { + abortController.abort() + } + } + return received + })() + + await expect( + adapter.send([{ role: 'user', content: 'Hello' }]), + ).rejects.toThrow('connect exploded') + const received = await receivedPromise + + expect(received).toHaveLength(1) + expect(received[0]?.type).toBe('RUN_ERROR') + if (received[0]?.type === 'RUN_ERROR') { + expect(received[0].error.message).toBe('already failed') + } + }) + }) }) diff --git a/packages/typescript/ai-client/tests/test-utils.ts b/packages/typescript/ai-client/tests/test-utils.ts index 9d0b3d364..3a8a4f805 100644 --- a/packages/typescript/ai-client/tests/test-utils.ts +++ b/packages/typescript/ai-client/tests/test-utils.ts @@ -1,4 +1,4 @@ -import type { ConnectionAdapter } from '../src/connection-adapters' +import type { ConnectConnectionAdapter } from '../src/connection-adapters' import type { ModelMessage, StreamChunk } from '@tanstack/ai' import type { UIMessage } from '../src/types' /** @@ -55,7 +55,7 @@ interface MockConnectionAdapterOptions { */ export function createMockConnectionAdapter( options: MockConnectionAdapterOptions = {}, -): ConnectionAdapter { +): ConnectConnectionAdapter { const { chunks = [], chunkDelay = 0, diff --git a/packages/typescript/ai/src/activities/chat/stream/processor.ts b/packages/typescript/ai/src/activities/chat/stream/processor.ts index 96d95865d..555c46d41 100644 --- a/packages/typescript/ai/src/activities/chat/stream/processor.ts +++ b/packages/typescript/ai/src/activities/chat/stream/processor.ts @@ -12,6 +12,7 @@ * - Thinking/reasoning content * - Recording/replay for testing * - Event-driven architecture for UI updates + * - Per-message stream state tracking for multi-message sessions * * @see docs/chat-architecture.md — Canonical reference for AG-UI chunk ordering, * adapter contract, single-shot flows, and expected UIMessage output. @@ -32,6 +33,7 @@ import type { ChunkRecording, ChunkStrategy, InternalToolCallState, + MessageStreamState, ProcessorResult, ProcessorState, ToolCallState, @@ -109,9 +111,8 @@ export interface StreamProcessorOptions { * * State tracking: * - Full message array - * - Current assistant message being streamed - * - Text content accumulation (reset on TEXT_MESSAGE_START) - * - Multiple parallel tool calls + * - Per-message stream state (text, tool calls, thinking) + * - Multiple concurrent message streams * - Tool call completion via TOOL_CALL_END events * * @see docs/chat-architecture.md#streamprocessor-internal-state — State field reference @@ -125,17 +126,14 @@ export class StreamProcessor { // Message state private messages: Array = [] - private currentAssistantMessageId: string | null = null - - // Stream state for current assistant message - // Total accumulated text across all segments (for the final result) - private totalTextContent = '' - // Current segment's text content (for onTextUpdate callbacks) - private currentSegmentText = '' - private lastEmittedText = '' - private thinkingContent = '' - private toolCalls: Map = new Map() - private toolCallOrder: Array = [] + + // Per-message stream state + private messageStates: Map = new Map() + private activeMessageIds: Set = new Set() + private toolCallToMessage: Map = new Map() + private pendingManualMessageId: string | null = null + + // Shared stream state private finishReason: string | null = null private hasError = false private isDone = false @@ -224,18 +222,17 @@ export class StreamProcessor { prepareAssistantMessage(): void { // Reset stream state for new message this.resetStreamState() - // Clear the current assistant message ID so ensureAssistantMessage() - // will create a fresh message on the first content chunk - this.currentAssistantMessageId = null } /** * @deprecated Use prepareAssistantMessage() instead. This eagerly creates * an assistant message which can cause empty message flicker. */ - startAssistantMessage(): string { + startAssistantMessage(messageId?: string): string { this.prepareAssistantMessage() - return this.ensureAssistantMessage() + const { messageId: id } = this.ensureAssistantMessage(messageId) + this.pendingManualMessageId = id + return id } /** @@ -244,39 +241,16 @@ export class StreamProcessor { * has arrived yet. */ getCurrentAssistantMessageId(): string | null { - return this.currentAssistantMessageId - } - - /** - * Lazily create the assistant message if it hasn't been created yet. - * Called by content handlers on the first content-bearing chunk. - * Returns the message ID. - * - * Content-bearing chunks that trigger this: - * TEXT_MESSAGE_CONTENT, TOOL_CALL_START, STEP_FINISHED, RUN_ERROR. - * - * @see docs/chat-architecture.md#streamprocessor-internal-state — Lazy creation pattern - */ - private ensureAssistantMessage(): string { - if (this.currentAssistantMessageId) { - return this.currentAssistantMessageId - } - - const assistantMessage: UIMessage = { - id: generateMessageId(), - role: 'assistant', - parts: [], - createdAt: new Date(), + // Scan all message states (not just active) for the last assistant. + // After finalizeStream() clears activeMessageIds, messageStates retains entries. + // After reset() / resetStreamState(), messageStates is cleared → returns null. + let lastId: string | null = null + for (const [id, state] of this.messageStates) { + if (state.role === 'assistant') { + lastId = id + } } - - this.currentAssistantMessageId = assistantMessage.id - this.messages = [...this.messages, assistantMessage] - - // Emit events - this.events.onStreamStart?.() - this.emitMessagesChange() - - return assistantMessage.id + return lastId } /** @@ -403,7 +377,10 @@ export class StreamProcessor { */ clearMessages(): void { this.messages = [] - this.currentAssistantMessageId = null + this.messageStates.clear() + this.activeMessageIds.clear() + this.toolCallToMessage.clear() + this.pendingManualMessageId = null this.emitMessagesChange() } @@ -444,7 +421,7 @@ export class StreamProcessor { * * Central dispatch for all AG-UI events. Each event type maps to a specific * handler. Events not listed in the switch are intentionally ignored - * (RUN_STARTED, TEXT_MESSAGE_END, STEP_STARTED, STATE_SNAPSHOT, STATE_DELTA). + * (RUN_STARTED, STEP_STARTED, STATE_DELTA). * * @see docs/chat-architecture.md#adapter-contract — Expected event types and ordering */ @@ -461,13 +438,17 @@ export class StreamProcessor { switch (chunk.type) { // AG-UI Events case 'TEXT_MESSAGE_START': - this.handleTextMessageStartEvent() + this.handleTextMessageStartEvent(chunk) break case 'TEXT_MESSAGE_CONTENT': this.handleTextMessageContentEvent(chunk) break + case 'TEXT_MESSAGE_END': + this.handleTextMessageEndEvent(chunk) + break + case 'TOOL_CALL_START': this.handleToolCallStartEvent(chunk) break @@ -492,35 +473,230 @@ export class StreamProcessor { this.handleStepFinishedEvent(chunk) break + case 'MESSAGES_SNAPSHOT': + this.handleMessagesSnapshotEvent(chunk) + break + case 'CUSTOM': this.handleCustomEvent(chunk) break default: - // RUN_STARTED, TEXT_MESSAGE_END, STEP_STARTED, - // STATE_SNAPSHOT, STATE_DELTA - no special handling needed + // RUN_STARTED, STEP_STARTED, STATE_SNAPSHOT, STATE_DELTA - no special handling needed break } } + // ============================================ + // Per-Message State Helpers + // ============================================ + /** - * Handle TEXT_MESSAGE_START event — marks the beginning of a new text segment. - * Resets segment accumulation so text after tool calls starts fresh. - * - * This is the key mechanism for multi-segment text (text before and after tool - * calls becoming separate TextParts). Without this reset, all text would merge - * into a single TextPart and tool-call interleaving would be lost. - * - * @see docs/chat-architecture.md#single-shot-text-response — Step-by-step text processing - * @see docs/chat-architecture.md#text-then-tool-interleaving-single-shot — Multi-segment text + * Create a new MessageStreamState for a message + */ + private createMessageState( + messageId: string, + role: 'user' | 'assistant' | 'system', + ): MessageStreamState { + const state: MessageStreamState = { + id: messageId, + role, + totalTextContent: '', + currentSegmentText: '', + lastEmittedText: '', + thinkingContent: '', + toolCalls: new Map(), + toolCallOrder: [], + hasToolCallsSinceTextStart: false, + isComplete: false, + } + this.messageStates.set(messageId, state) + return state + } + + /** + * Get the MessageStreamState for a message */ - private handleTextMessageStartEvent(): void { - // Emit any pending text from a previous segment before resetting - if (this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + private getMessageState(messageId: string): MessageStreamState | undefined { + return this.messageStates.get(messageId) + } + + /** + * Get the most recent active assistant message ID. + * Used as fallback for events that don't include a messageId. + */ + private getActiveAssistantMessageId(): string | null { + // Set iteration is insertion-order; convert to array and search from the end + const ids = Array.from(this.activeMessageIds) + for (let i = ids.length - 1; i >= 0; i--) { + const id = ids[i]! + const state = this.messageStates.get(id) + if (state && state.role === 'assistant') { + return id + } } - this.currentSegmentText = '' - this.lastEmittedText = '' + return null + } + + /** + * Ensure an active assistant message exists, creating one if needed. + * Used for backward compat when events arrive without prior TEXT_MESSAGE_START. + */ + private ensureAssistantMessage(preferredId?: string): { + messageId: string + state: MessageStreamState + } { + // Try to find state by preferred ID + if (preferredId) { + const state = this.getMessageState(preferredId) + if (state) return { messageId: preferredId, state } + } + + // Try active assistant message + const activeId = this.getActiveAssistantMessageId() + if (activeId) { + const state = this.getMessageState(activeId)! + return { messageId: activeId, state } + } + + // Auto-create an assistant message (backward compat for process() without TEXT_MESSAGE_START) + const id = preferredId || generateMessageId() + const assistantMessage: UIMessage = { + id, + role: 'assistant', + parts: [], + createdAt: new Date(), + } + this.messages = [...this.messages, assistantMessage] + const state = this.createMessageState(id, 'assistant') + this.activeMessageIds.add(id) + this.pendingManualMessageId = id + this.events.onStreamStart?.() + this.emitMessagesChange() + return { messageId: id, state } + } + + // ============================================ + // Event Handlers + // ============================================ + + /** + * Handle TEXT_MESSAGE_START event + */ + private handleTextMessageStartEvent( + chunk: Extract, + ): void { + const { messageId, role } = chunk + + // Map 'tool' role to 'assistant' for both UIMessage and MessageStreamState + // (UIMessage doesn't support 'tool' role, and lookups like + // getActiveAssistantMessageId() check state.role === 'assistant') + const uiRole: 'system' | 'user' | 'assistant' = + role === 'tool' ? 'assistant' : role + + // Case 1: A manual message was created via startAssistantMessage() + if (this.pendingManualMessageId) { + const pendingId = this.pendingManualMessageId + this.pendingManualMessageId = null + + if (pendingId !== messageId) { + // Update the message's ID in the messages array + this.messages = this.messages.map((msg) => + msg.id === pendingId ? { ...msg, id: messageId } : msg, + ) + + // Move state to the new key + const existingState = this.messageStates.get(pendingId) + if (existingState) { + existingState.id = messageId + this.messageStates.delete(pendingId) + this.messageStates.set(messageId, existingState) + } + + // Update activeMessageIds + this.activeMessageIds.delete(pendingId) + this.activeMessageIds.add(messageId) + } + + // Ensure state exists + if (!this.messageStates.has(messageId)) { + this.createMessageState(messageId, uiRole) + this.activeMessageIds.add(messageId) + } + + this.emitMessagesChange() + return + } + + // Case 2: Message already exists (dedup) + const existingMsg = this.messages.find((m) => m.id === messageId) + if (existingMsg) { + this.activeMessageIds.add(messageId) + if (!this.messageStates.has(messageId)) { + this.createMessageState(messageId, uiRole) + } else { + const existingState = this.messageStates.get(messageId)! + // If tool calls happened since last text, this TEXT_MESSAGE_START + // signals a new text segment — reset segment accumulation + if (existingState.hasToolCallsSinceTextStart) { + if ( + existingState.currentSegmentText !== existingState.lastEmittedText + ) { + this.emitTextUpdateForMessage(messageId) + } + existingState.currentSegmentText = '' + existingState.lastEmittedText = '' + existingState.hasToolCallsSinceTextStart = false + } + } + return + } + + // Case 3: New message from the stream + const newMessage: UIMessage = { + id: messageId, + role: uiRole, + parts: [], + createdAt: new Date(), + } + + this.messages = [...this.messages, newMessage] + this.createMessageState(messageId, uiRole) + this.activeMessageIds.add(messageId) + + this.events.onStreamStart?.() + this.emitMessagesChange() + } + + /** + * Handle TEXT_MESSAGE_END event + */ + private handleTextMessageEndEvent( + chunk: Extract, + ): void { + const { messageId } = chunk + const state = this.getMessageState(messageId) + if (!state) return + if (state.isComplete) return + + // Emit any pending text for this message + if (state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + + // Complete all tool calls for this message + this.completeAllToolCallsForMessage(messageId) + } + + /** + * Handle MESSAGES_SNAPSHOT event + */ + private handleMessagesSnapshotEvent( + chunk: Extract, + ): void { + this.resetStreamState() + this.messages = [...chunk.messages] + this.emitMessagesChange() } /** @@ -537,17 +713,62 @@ export class StreamProcessor { private handleTextMessageContentEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + const { messageId, state } = this.ensureAssistantMessage(chunk.messageId) + + // Content arriving means all current tool calls for this message are complete + this.completeAllToolCallsForMessage(messageId) + + const previousSegment = state.currentSegmentText + + // Detect if this is a NEW text segment (after tool calls) vs continuation + const isNewSegment = + state.hasToolCallsSinceTextStart && + previousSegment.length > 0 && + this.isNewTextSegment(chunk, previousSegment) + + if (isNewSegment) { + // Emit any accumulated text before starting new segment + if (previousSegment !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + // Reset SEGMENT text accumulation for the new text segment after tool calls + state.currentSegmentText = '' + state.lastEmittedText = '' + state.hasToolCallsSinceTextStart = false + } + + const currentText = state.currentSegmentText + let nextText = currentText + + // Prefer delta over content - delta is the incremental change + // Normalize to empty string to avoid "undefined" string concatenation + const delta = chunk.delta || '' + if (delta !== '') { + nextText = currentText + delta + } else if (chunk.content !== undefined && chunk.content !== '') { + // Fallback: use content if delta is not provided + if (chunk.content.startsWith(currentText)) { + nextText = chunk.content + } else if (currentText.startsWith(chunk.content)) { + nextText = currentText + } else { + nextText = currentText + chunk.content + } + } - this.currentSegmentText += chunk.delta - this.totalTextContent += chunk.delta + // Calculate the delta for totalTextContent + const textDelta = nextText.slice(currentText.length) + state.currentSegmentText = nextText + state.totalTextContent += textDelta + // Use delta for chunk strategy if available + const chunkPortion = chunk.delta || chunk.content || '' const shouldEmit = this.chunkStrategy.shouldEmit( - chunk.delta, - this.currentSegmentText, + chunkPortion, + state.currentSegmentText, ) - if (shouldEmit && this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + if (shouldEmit && state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) } } @@ -567,10 +788,18 @@ export class StreamProcessor { private handleToolCallStartEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + // Determine the message this tool call belongs to + const targetMessageId = + chunk.parentMessageId ?? this.getActiveAssistantMessageId() + const { messageId, state } = this.ensureAssistantMessage( + targetMessageId ?? undefined, + ) + + // Mark that we've seen tool calls since the last text segment + state.hasToolCallsSinceTextStart = true const toolCallId = chunk.toolCallId - const existingToolCall = this.toolCalls.get(toolCallId) + const existingToolCall = state.toolCalls.get(toolCallId) if (!existingToolCall) { // New tool call starting @@ -582,34 +811,31 @@ export class StreamProcessor { arguments: '', state: initialState, parsedArguments: undefined, - index: chunk.index ?? this.toolCalls.size, + index: chunk.index ?? state.toolCalls.size, } - this.toolCalls.set(toolCallId, newToolCall) - this.toolCallOrder.push(toolCallId) + state.toolCalls.set(toolCallId, newToolCall) + state.toolCallOrder.push(toolCallId) + + // Store mapping for TOOL_CALL_ARGS/END routing + this.toolCallToMessage.set(toolCallId, messageId) // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: chunk.toolCallId, - name: chunk.toolName, - arguments: '', - state: initialState, - }, - ) - this.emitMessagesChange() + this.messages = updateToolCallPart(this.messages, messageId, { + id: chunk.toolCallId, + name: chunk.toolName, + arguments: '', + state: initialState, + }) + this.emitMessagesChange() - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - chunk.toolCallId, - initialState, - '', - ) - } + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + chunk.toolCallId, + initialState, + '', + ) } } @@ -629,47 +855,46 @@ export class StreamProcessor { chunk: Extract, ): void { const toolCallId = chunk.toolCallId - const existingToolCall = this.toolCalls.get(toolCallId) + const messageId = this.toolCallToMessage.get(toolCallId) + if (!messageId) return - if (existingToolCall) { - const wasAwaitingInput = existingToolCall.state === 'awaiting-input' + const state = this.getMessageState(messageId) + if (!state) return - // Accumulate arguments from delta - existingToolCall.arguments += chunk.delta || '' + const existingToolCall = state.toolCalls.get(toolCallId) + if (!existingToolCall) return - // Update state - if (wasAwaitingInput && chunk.delta) { - existingToolCall.state = 'input-streaming' - } + const wasAwaitingInput = existingToolCall.state === 'awaiting-input' - // Try to parse the updated arguments - existingToolCall.parsedArguments = this.jsonParser.parse( - existingToolCall.arguments, - ) - - // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: existingToolCall.id, - name: existingToolCall.name, - arguments: existingToolCall.arguments, - state: existingToolCall.state, - }, - ) - this.emitMessagesChange() + // Accumulate arguments from delta + existingToolCall.arguments += chunk.delta || '' - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - existingToolCall.id, - existingToolCall.state, - existingToolCall.arguments, - ) - } + // Update state + if (wasAwaitingInput && chunk.delta) { + existingToolCall.state = 'input-streaming' } + + // Try to parse the updated arguments + existingToolCall.parsedArguments = this.jsonParser.parse( + existingToolCall.arguments, + ) + + // Update UIMessage + this.messages = updateToolCallPart(this.messages, messageId, { + id: existingToolCall.id, + name: existingToolCall.name, + arguments: existingToolCall.arguments, + state: existingToolCall.state, + }) + this.emitMessagesChange() + + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + existingToolCall.id, + existingToolCall.state, + existingToolCall.arguments, + ) } /** @@ -689,11 +914,17 @@ export class StreamProcessor { private handleToolCallEndEvent( chunk: Extract, ): void { + const messageId = this.toolCallToMessage.get(chunk.toolCallId) + if (!messageId) return + + const msgState = this.getMessageState(messageId) + if (!msgState) return + // Transition the tool call to input-complete (the authoritative completion signal) - const existingToolCall = this.toolCalls.get(chunk.toolCallId) + const existingToolCall = msgState.toolCalls.get(chunk.toolCallId) if (existingToolCall && existingToolCall.state !== 'input-complete') { - const index = this.toolCallOrder.indexOf(chunk.toolCallId) - this.completeToolCall(index, existingToolCall) + const index = msgState.toolCallOrder.indexOf(chunk.toolCallId) + this.completeToolCall(messageId, index, existingToolCall) // If TOOL_CALL_END provides parsed input, use it as the canonical parsed // arguments (overrides the accumulated string parse from completeToolCall) if (chunk.input !== undefined) { @@ -701,10 +932,8 @@ export class StreamProcessor { } } - // Update UIMessage if we have a current assistant message and a result - if (this.currentAssistantMessageId && chunk.result) { - const state: ToolResultState = 'complete' - + // Update UIMessage if there's a result + if (chunk.result) { // Step 1: Update the tool-call part's output field (for UI consistency // with client tools — see GitHub issue #176) let output: unknown @@ -720,12 +949,13 @@ export class StreamProcessor { ) // Step 2: Create/update the tool-result part (for LLM conversation history) + const resultState: ToolResultState = 'complete' this.messages = updateToolResultPart( this.messages, - this.currentAssistantMessageId, + messageId, chunk.toolCallId, chunk.result, - state, + resultState, ) this.emitMessagesChange() } @@ -747,6 +977,7 @@ export class StreamProcessor { this.finishReason = chunk.finishReason this.isDone = true this.completeAllToolCalls() + this.finalizeStream() } /** @@ -772,25 +1003,38 @@ export class StreamProcessor { private handleStepFinishedEvent( chunk: Extract, ): void { - this.ensureAssistantMessage() + const { messageId, state } = this.ensureAssistantMessage( + this.getActiveAssistantMessageId() ?? undefined, + ) + + const previous = state.thinkingContent + let nextThinking = previous + + // Prefer delta over content + if (chunk.delta && chunk.delta !== '') { + nextThinking = previous + chunk.delta + } else if (chunk.content && chunk.content !== '') { + if (chunk.content.startsWith(previous)) { + nextThinking = chunk.content + } else if (previous.startsWith(chunk.content)) { + nextThinking = previous + } else { + nextThinking = previous + chunk.content + } + } - this.thinkingContent += chunk.delta + state.thinkingContent = nextThinking // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateThinkingPart( - this.messages, - this.currentAssistantMessageId, - this.thinkingContent, - ) - this.emitMessagesChange() + this.messages = updateThinkingPart( + this.messages, + messageId, + state.thinkingContent, + ) + this.emitMessagesChange() - // Emit granular event - this.events.onThinkingUpdate?.( - this.currentAssistantMessageId, - this.thinkingContent, - ) - } + // Emit granular event + this.events.onThinkingUpdate?.(messageId, state.thinkingContent) } /** @@ -806,6 +1050,8 @@ export class StreamProcessor { private handleCustomEvent( chunk: Extract, ): void { + const messageId = this.getActiveAssistantMessageId() + // Handle client tool input availability - trigger client-side execution if (chunk.name === 'tool-input-available' && chunk.data) { const { toolCallId, toolName, input } = chunk.data as { @@ -832,10 +1078,10 @@ export class StreamProcessor { } // Update the tool call part with approval state - if (this.currentAssistantMessageId) { + if (messageId) { this.messages = updateToolCallApproval( this.messages, - this.currentAssistantMessageId, + messageId, toolCallId, approval.id, ) @@ -852,8 +1098,34 @@ export class StreamProcessor { } } + // ============================================ + // Internal Helpers + // ============================================ + + /** + * Detect if an incoming content chunk represents a NEW text segment + */ + private isNewTextSegment( + chunk: Extract, + previous: string, + ): boolean { + // Check if content is present (delta is always defined but may be empty string) + if (chunk.content !== undefined) { + if (chunk.content.length < previous.length) { + return true + } + if ( + !chunk.content.startsWith(previous) && + !previous.startsWith(chunk.content) + ) { + return true + } + } + return false + } + /** - * Complete all tool calls — safety net for stream termination. + * Complete all tool calls across all active messages — safety net for stream termination. * * Called by RUN_FINISHED and finalizeStream(). Force-transitions any tool call * not yet in input-complete state. Handles cases where TOOL_CALL_END was @@ -862,10 +1134,22 @@ export class StreamProcessor { * @see docs/chat-architecture.md#single-shot-tool-call-response — Safety net behavior */ private completeAllToolCalls(): void { - this.toolCalls.forEach((toolCall, id) => { + for (const messageId of this.activeMessageIds) { + this.completeAllToolCallsForMessage(messageId) + } + } + + /** + * Complete all tool calls for a specific message + */ + private completeAllToolCallsForMessage(messageId: string): void { + const state = this.getMessageState(messageId) + if (!state) return + + state.toolCalls.forEach((toolCall, id) => { if (toolCall.state !== 'input-complete') { - const index = this.toolCallOrder.indexOf(id) - this.completeToolCall(index, toolCall) + const index = state.toolCallOrder.indexOf(id) + this.completeToolCall(messageId, index, toolCall) } }) } @@ -874,6 +1158,7 @@ export class StreamProcessor { * Mark a tool call as complete and emit event */ private completeToolCall( + messageId: string, _index: number, toolCall: InternalToolCallState, ): void { @@ -883,31 +1168,25 @@ export class StreamProcessor { toolCall.parsedArguments = this.jsonParser.parse(toolCall.arguments) // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateToolCallPart( - this.messages, - this.currentAssistantMessageId, - { - id: toolCall.id, - name: toolCall.name, - arguments: toolCall.arguments, - state: 'input-complete', - }, - ) - this.emitMessagesChange() + this.messages = updateToolCallPart(this.messages, messageId, { + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + state: 'input-complete', + }) + this.emitMessagesChange() - // Emit granular event - this.events.onToolCallStateChange?.( - this.currentAssistantMessageId, - toolCall.id, - 'input-complete', - toolCall.arguments, - ) - } + // Emit granular event + this.events.onToolCallStateChange?.( + messageId, + toolCall.id, + 'input-complete', + toolCall.arguments, + ) } /** - * Emit pending text update. + * Emit pending text update for a specific message. * * Calls updateTextPart() which has critical append-vs-replace logic: * - If last UIMessage part is TextPart → replaces its content (same segment). @@ -915,24 +1194,22 @@ export class StreamProcessor { * * @see docs/chat-architecture.md#uimessage-part-ordering-invariants — Replace vs. push logic */ - private emitTextUpdate(): void { - this.lastEmittedText = this.currentSegmentText + private emitTextUpdateForMessage(messageId: string): void { + const state = this.getMessageState(messageId) + if (!state) return + + state.lastEmittedText = state.currentSegmentText // Update UIMessage - if (this.currentAssistantMessageId) { - this.messages = updateTextPart( - this.messages, - this.currentAssistantMessageId, - this.currentSegmentText, - ) - this.emitMessagesChange() + this.messages = updateTextPart( + this.messages, + messageId, + state.currentSegmentText, + ) + this.emitMessagesChange() - // Emit granular event - this.events.onTextUpdate?.( - this.currentAssistantMessageId, - this.currentSegmentText, - ) - } + // Emit granular event + this.events.onTextUpdate?.(messageId, state.currentSegmentText) } /** @@ -952,81 +1229,116 @@ export class StreamProcessor { * @see docs/chat-architecture.md#single-shot-text-response — Finalization step */ finalizeStream(): void { - // Safety net: complete any remaining tool calls (e.g. on network errors / aborted streams) - this.completeAllToolCalls() + let lastAssistantMessage: UIMessage | undefined - // Emit any pending text if not already emitted - if (this.currentSegmentText !== this.lastEmittedText) { - this.emitTextUpdate() + // Finalize ALL active messages + for (const messageId of this.activeMessageIds) { + const state = this.getMessageState(messageId) + if (!state) continue + + // Complete any remaining tool calls + this.completeAllToolCallsForMessage(messageId) + + // Emit any pending text if not already emitted + if (state.currentSegmentText !== state.lastEmittedText) { + this.emitTextUpdateForMessage(messageId) + } + + state.isComplete = true + + const msg = this.messages.find((m) => m.id === messageId) + if (msg && msg.role === 'assistant') { + lastAssistantMessage = msg + } } - // Remove the assistant message if it only contains whitespace text - // (no tool calls, no meaningful content). This handles models like Gemini - // that sometimes return just "\n" during auto-continuation. + this.activeMessageIds.clear() + + // Remove whitespace-only assistant messages (handles models like Gemini + // that sometimes return just "\n" during auto-continuation). // Preserve the message on errors so the UI can show error state. - if (this.currentAssistantMessageId && !this.hasError) { - const assistantMessage = this.messages.find( - (m) => m.id === this.currentAssistantMessageId, - ) - if (assistantMessage && this.isWhitespaceOnlyMessage(assistantMessage)) { + if (lastAssistantMessage && !this.hasError) { + if (this.isWhitespaceOnlyMessage(lastAssistantMessage)) { this.messages = this.messages.filter( - (m) => m.id !== this.currentAssistantMessageId, + (m) => m.id !== lastAssistantMessage.id, ) this.emitMessagesChange() - this.currentAssistantMessageId = null return } } - // Emit stream end event (only if a message was actually created) - if (this.currentAssistantMessageId) { - const assistantMessage = this.messages.find( - (m) => m.id === this.currentAssistantMessageId, - ) - if (assistantMessage) { - this.events.onStreamEnd?.(assistantMessage) - } + // Emit stream end for the last assistant message + if (lastAssistantMessage) { + this.events.onStreamEnd?.(lastAssistantMessage) } } /** - * Get completed tool calls in API format + * Get completed tool calls in API format (aggregated across all messages) */ private getCompletedToolCalls(): Array { - return Array.from(this.toolCalls.values()) - .filter((tc) => tc.state === 'input-complete') - .map((tc) => ({ - id: tc.id, - type: 'function' as const, - function: { - name: tc.name, - arguments: tc.arguments, - }, - })) + const result: Array = [] + for (const state of this.messageStates.values()) { + for (const tc of state.toolCalls.values()) { + if (tc.state === 'input-complete') { + result.push({ + id: tc.id, + type: 'function' as const, + function: { + name: tc.name, + arguments: tc.arguments, + }, + }) + } + } + } + return result } /** - * Get current result + * Get current result (aggregated across all messages) */ private getResult(): ProcessorResult { const toolCalls = this.getCompletedToolCalls() + let content = '' + let thinking = '' + + for (const state of this.messageStates.values()) { + content += state.totalTextContent + thinking += state.thinkingContent + } + return { - content: this.totalTextContent, - thinking: this.thinkingContent || undefined, + content, + thinking: thinking || undefined, toolCalls: toolCalls.length > 0 ? toolCalls : undefined, finishReason: this.finishReason, } } /** - * Get current processor state + * Get current processor state (aggregated across all messages) */ getState(): ProcessorState { + let content = '' + let thinking = '' + const toolCalls = new Map() + const toolCallOrder: Array = [] + + for (const state of this.messageStates.values()) { + content += state.totalTextContent + thinking += state.thinkingContent + for (const [id, tc] of state.toolCalls) { + toolCalls.set(id, tc) + } + toolCallOrder.push(...state.toolCallOrder) + } + return { - content: this.totalTextContent, - thinking: this.thinkingContent, - toolCalls: new Map(this.toolCalls), - toolCallOrder: [...this.toolCallOrder], + content, + thinking, + toolCalls, + toolCallOrder, finishReason: this.finishReason, done: this.isDone, } @@ -1056,12 +1368,10 @@ export class StreamProcessor { * Reset stream state (but keep messages) */ private resetStreamState(): void { - this.totalTextContent = '' - this.currentSegmentText = '' - this.lastEmittedText = '' - this.thinkingContent = '' - this.toolCalls.clear() - this.toolCallOrder = [] + this.messageStates.clear() + this.activeMessageIds.clear() + this.toolCallToMessage.clear() + this.pendingManualMessageId = null this.finishReason = null this.hasError = false this.isDone = false @@ -1074,7 +1384,6 @@ export class StreamProcessor { reset(): void { this.resetStreamState() this.messages = [] - this.currentAssistantMessageId = null } /** diff --git a/packages/typescript/ai/src/activities/chat/stream/types.ts b/packages/typescript/ai/src/activities/chat/stream/types.ts index 2a323507c..c1806238f 100644 --- a/packages/typescript/ai/src/activities/chat/stream/types.ts +++ b/packages/typescript/ai/src/activities/chat/stream/types.ts @@ -45,6 +45,24 @@ export interface ChunkStrategy { reset?: () => void } +/** + * Per-message streaming state. + * Tracks the accumulation of text, tool calls, and thinking content + * for a single message in the stream. + */ +export interface MessageStreamState { + id: string + role: 'user' | 'assistant' | 'system' + totalTextContent: string + currentSegmentText: string + lastEmittedText: string + thinkingContent: string + toolCalls: Map + toolCallOrder: Array + hasToolCallsSinceTextStart: boolean + isComplete: boolean +} + /** * Result from processing a stream */ diff --git a/packages/typescript/ai/src/types.ts b/packages/typescript/ai/src/types.ts index 4d7ca6e52..390abfa98 100644 --- a/packages/typescript/ai/src/types.ts +++ b/packages/typescript/ai/src/types.ts @@ -702,6 +702,7 @@ export type AGUIEventType = | 'TOOL_CALL_END' | 'STEP_STARTED' | 'STEP_FINISHED' + | 'MESSAGES_SNAPSHOT' | 'STATE_SNAPSHOT' | 'STATE_DELTA' | 'CUSTOM' @@ -778,8 +779,8 @@ export interface TextMessageStartEvent extends BaseAGUIEvent { type: 'TEXT_MESSAGE_START' /** Unique identifier for this message */ messageId: string - /** Role is always assistant for generated messages */ - role: 'assistant' + /** Role of the message sender */ + role: 'user' | 'assistant' | 'system' | 'tool' } /** @@ -813,6 +814,8 @@ export interface ToolCallStartEvent extends BaseAGUIEvent { toolCallId: string /** Name of the tool being called */ toolName: string + /** ID of the parent message that initiated this tool call */ + parentMessageId?: string /** Index for parallel tool calls */ index?: number } @@ -869,6 +872,19 @@ export interface StepFinishedEvent extends BaseAGUIEvent { content?: string } +/** + * Emitted to provide a snapshot of all messages in a conversation. + * + * Unlike StateSnapshot (which carries arbitrary application state), + * MessagesSnapshot specifically delivers the conversation transcript. + * This is a first-class AG-UI event type. + */ +export interface MessagesSnapshotEvent extends BaseAGUIEvent { + type: 'MESSAGES_SNAPSHOT' + /** Complete array of messages in the conversation */ + messages: Array +} + /** * Emitted to provide a full state snapshot. */ @@ -913,6 +929,7 @@ export type AGUIEvent = | ToolCallEndEvent | StepStartedEvent | StepFinishedEvent + | MessagesSnapshotEvent | StateSnapshotEvent | StateDeltaEvent | CustomEvent diff --git a/packages/typescript/ai/tests/stream-processor.test.ts b/packages/typescript/ai/tests/stream-processor.test.ts index ddb7f8129..033afabc6 100644 --- a/packages/typescript/ai/tests/stream-processor.test.ts +++ b/packages/typescript/ai/tests/stream-processor.test.ts @@ -621,8 +621,8 @@ describe('StreamProcessor', () => { processor.processChunk(ev.textContent('First segment')) processor.processChunk(ev.toolStart('tc-1', 'search')) processor.processChunk(ev.toolEnd('tc-1', 'search', { input: {} })) - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent('Second segment', 'msg-2')) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent('Second segment')) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -649,10 +649,10 @@ describe('StreamProcessor', () => { ev.toolEnd('call_1', 'getWeather', { result: '{"temp":"72F"}' }), ) - // Second adapter stream: more text - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent("It's 72F in NYC.", 'msg-2')) - processor.processChunk(ev.textEnd('msg-2')) + // Second adapter stream: more text (same message) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent("It's 72F in NYC.")) + processor.processChunk(ev.textEnd()) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -685,9 +685,9 @@ describe('StreamProcessor', () => { processor.processChunk(ev.textEnd()) processor.processChunk(ev.toolStart('tc-1', 'tool')) processor.processChunk(ev.toolEnd('tc-1', 'tool')) - processor.processChunk(ev.textStart('msg-2')) - processor.processChunk(ev.textContent('After', 'msg-2')) - processor.processChunk(ev.textEnd('msg-2')) + processor.processChunk(ev.textStart()) + processor.processChunk(ev.textContent('After')) + processor.processChunk(ev.textEnd()) processor.processChunk(ev.runFinished('stop')) processor.finalizeStream() @@ -1798,4 +1798,658 @@ describe('StreamProcessor', () => { expect(state2.toolCallOrder).toEqual(['tc-1']) }) }) + + describe('TEXT_MESSAGE_START', () => { + it('should create a message with correct role and messageId', () => { + const processor = new StreamProcessor() + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('msg-1') + expect(messages[0]?.role).toBe('assistant') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + + it('should create a user message via TEXT_MESSAGE_START', () => { + const processor = new StreamProcessor() + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'user-msg-1', + role: 'user', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'user-msg-1', + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('user-msg-1') + expect(messages[0]?.role).toBe('user') + }) + + it('should emit onStreamStart when a new message arrives', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ events: { onStreamStart } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + }) + }) + + describe('TEXT_MESSAGE_END', () => { + it('should not emit onStreamEnd (that happens in finalizeStream)', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello world', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END means "text segment done", not "message done" + // onStreamEnd fires from finalizeStream(), not TEXT_MESSAGE_END + expect(onStreamEnd).not.toHaveBeenCalled() + + processor.finalizeStream() + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + const endMessage = onStreamEnd.mock.calls[0]![0] as UIMessage + expect(endMessage.id).toBe('msg-1') + expect(endMessage.parts[0]).toEqual({ + type: 'text', + content: 'Hello world', + }) + }) + + it('should emit pending text on TEXT_MESSAGE_END', () => { + const onTextUpdate = vi.fn() + // Use a strategy that never emits during streaming + const processor = new StreamProcessor({ + events: { onTextUpdate }, + chunkStrategy: { + shouldEmit: () => false, + }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + // Text not emitted yet due to strategy + expect(onTextUpdate).not.toHaveBeenCalled() + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END should flush pending text + expect(onTextUpdate).toHaveBeenCalledWith('msg-1', 'Hello') + }) + }) + + describe('interleaved messages', () => { + it('should handle two interleaved assistant messages', () => { + const onMessagesChange = vi.fn() + const processor = new StreamProcessor({ + events: { onMessagesChange }, + }) + + // Start two messages + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-b', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Interleave content + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-a', + delta: 'Hello from A', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-b', + delta: 'Hello from B', + timestamp: Date.now(), + } as StreamChunk) + + // End both + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-b', + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(2) + + expect(messages[0]?.id).toBe('msg-a') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello from A', + }) + + expect(messages[1]?.id).toBe('msg-b') + expect(messages[1]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello from B', + }) + }) + + it('should emit onStreamEnd on finalizeStream (not on TEXT_MESSAGE_END)', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamEnd }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-b', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-a', + delta: 'A', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + // TEXT_MESSAGE_END does not fire onStreamEnd + expect(onStreamEnd).not.toHaveBeenCalled() + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-b', + delta: 'B', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-b', + timestamp: Date.now(), + } as StreamChunk) + + // Still not fired + expect(onStreamEnd).not.toHaveBeenCalled() + + // finalizeStream fires onStreamEnd for the last assistant message + processor.finalizeStream() + expect(onStreamEnd).toHaveBeenCalledTimes(1) + }) + }) + + describe('startAssistantMessage + TEXT_MESSAGE_START dedup', () => { + it('should associate TEXT_MESSAGE_START with pending manual message (different ID)', () => { + const processor = new StreamProcessor() + processor.startAssistantMessage() + + // Server sends TEXT_MESSAGE_START with a different ID + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'server-msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Should have only one message (not two) + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + + // The message should have been updated to the server's ID + expect(messages[0]?.id).toBe('server-msg-1') + expect(messages[0]?.role).toBe('assistant') + + // Content should route to the correct message + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'server-msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + expect(processor.getMessages()[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + + it('should associate TEXT_MESSAGE_START with pending manual message (same ID)', () => { + const processor = new StreamProcessor() + processor.startAssistantMessage('my-msg-id') + + // Server sends TEXT_MESSAGE_START with the same ID + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'my-msg-id', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Should still have only one message + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('my-msg-id') + }) + + it('should work when TEXT_MESSAGE_START arrives without startAssistantMessage', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamStart }, + }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('msg-1') + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello', + }) + }) + }) + + describe('backward compat: ensureAssistantMessage auto-creation', () => { + it('should emit onStreamStart when auto-creating a message from content event', () => { + const onStreamStart = vi.fn() + const processor = new StreamProcessor({ + events: { onStreamStart }, + }) + + // No TEXT_MESSAGE_START or startAssistantMessage — content arrives directly + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'auto-msg', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamStart).toHaveBeenCalledTimes(1) + expect(processor.getMessages()).toHaveLength(1) + expect(processor.getMessages()[0]?.role).toBe('assistant') + }) + }) + + describe('backward compat: startAssistantMessage without TEXT_MESSAGE_START', () => { + it('should still work when only startAssistantMessage is used', () => { + const processor = new StreamProcessor() + const msgId = processor.startAssistantMessage() + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'some-other-id', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'some-other-id', + delta: ' world', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe(msgId) + expect(messages[0]?.parts[0]).toEqual({ + type: 'text', + content: 'Hello world', + }) + }) + }) + + describe('MESSAGES_SNAPSHOT', () => { + it('should hydrate messages and emit onMessagesChange', () => { + const onMessagesChange = vi.fn() + const processor = new StreamProcessor({ + events: { onMessagesChange }, + }) + + const snapshotMessages: Array = [ + { + id: 'snap-1', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + { + id: 'snap-2', + role: 'assistant', + parts: [{ type: 'text', content: 'Hi there!' }], + createdAt: new Date(), + }, + ] + + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: snapshotMessages, + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(2) + expect(messages[0]?.id).toBe('snap-1') + expect(messages[0]?.role).toBe('user') + expect(messages[1]?.id).toBe('snap-2') + expect(messages[1]?.role).toBe('assistant') + expect(onMessagesChange).toHaveBeenCalled() + }) + + it('should replace existing messages (not append)', () => { + const processor = new StreamProcessor() + + // Add an initial message + processor.addUserMessage('First message') + expect(processor.getMessages()).toHaveLength(1) + + // Snapshot replaces all messages + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: [ + { + id: 'snap-1', + role: 'assistant', + parts: [{ type: 'text', content: 'Snapshot content' }], + createdAt: new Date(), + }, + ], + timestamp: Date.now(), + } as StreamChunk) + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + expect(messages[0]?.id).toBe('snap-1') + expect(messages[0]?.role).toBe('assistant') + }) + }) + + describe('per-message tool calls', () => { + it('should route tool calls to the correct message via parentMessageId', () => { + const processor = new StreamProcessor() + + // Create two messages + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-a', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + // Tool call on msg-a + processor.processChunk({ + type: 'TOOL_CALL_START', + toolCallId: 'tc-1', + toolName: 'myTool', + parentMessageId: 'msg-a', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_ARGS', + toolCallId: 'tc-1', + delta: '{"arg": "val"}', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_END', + toolCallId: 'tc-1', + timestamp: Date.now(), + } as StreamChunk) + + processor.finalizeStream() + + const messages = processor.getMessages() + expect(messages).toHaveLength(1) + + const toolCallPart = messages[0]?.parts.find( + (p) => p.type === 'tool-call', + ) + expect(toolCallPart).toBeDefined() + expect(toolCallPart?.type).toBe('tool-call') + if (toolCallPart?.type === 'tool-call') { + expect(toolCallPart.name).toBe('myTool') + expect(toolCallPart.state).toBe('input-complete') + } + }) + }) + + describe('double onStreamEnd guard', () => { + it('should fire onStreamEnd exactly once when RUN_FINISHED arrives before TEXT_MESSAGE_END', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-1', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-1', + delta: 'Hello', + timestamp: Date.now(), + } as StreamChunk) + + // RUN_FINISHED fires first — calls finalizeStream which sets isComplete and fires onStreamEnd + processor.processChunk({ + type: 'RUN_FINISHED', + model: 'test', + timestamp: Date.now(), + finishReason: 'stop', + } as StreamChunk) + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + + // TEXT_MESSAGE_END arrives after — should NOT fire onStreamEnd again + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-1', + timestamp: Date.now(), + } as StreamChunk) + + expect(onStreamEnd).toHaveBeenCalledTimes(1) + }) + }) + + describe('MESSAGES_SNAPSHOT resets transient state', () => { + it('should reset stale state and process subsequent stream events correctly', () => { + const onStreamEnd = vi.fn() + const processor = new StreamProcessor({ events: { onStreamEnd } }) + + // Simulate an active streaming session + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-old', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-old', + delta: 'Old content', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TOOL_CALL_START', + toolCallId: 'tc-old', + toolName: 'oldTool', + parentMessageId: 'msg-old', + timestamp: Date.now(), + } as StreamChunk) + + // MESSAGES_SNAPSHOT replaces everything (e.g., on reconnection) + processor.processChunk({ + type: 'MESSAGES_SNAPSHOT', + messages: [ + { + id: 'snap-user', + role: 'user', + parts: [{ type: 'text', content: 'Hello' }], + createdAt: new Date(), + }, + ], + timestamp: Date.now(), + } as StreamChunk) + + // Verify old messages are replaced + const messagesAfterSnapshot = processor.getMessages() + expect(messagesAfterSnapshot).toHaveLength(1) + expect(messagesAfterSnapshot[0]?.id).toBe('snap-user') + + // New stream events should be processed correctly without stale state + processor.processChunk({ + type: 'TEXT_MESSAGE_START', + messageId: 'msg-new', + role: 'assistant', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_CONTENT', + messageId: 'msg-new', + delta: 'New content', + timestamp: Date.now(), + } as StreamChunk) + + processor.processChunk({ + type: 'TEXT_MESSAGE_END', + messageId: 'msg-new', + timestamp: Date.now(), + } as StreamChunk) + + const finalMessages = processor.getMessages() + expect(finalMessages).toHaveLength(2) + expect(finalMessages[1]?.id).toBe('msg-new') + expect(finalMessages[1]?.parts[0]).toEqual({ + type: 'text', + content: 'New content', + }) + + // onStreamEnd fires from finalizeStream, not TEXT_MESSAGE_END + expect(onStreamEnd).not.toHaveBeenCalled() + processor.finalizeStream() + expect(onStreamEnd).toHaveBeenCalledTimes(1) + expect(onStreamEnd.mock.calls[0]![0].id).toBe('msg-new') + }) + }) })