diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index 1071f857..b94b968f 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -422,6 +422,151 @@ describe.each(testMatrix())( }); }); + test('defaultCallOptions provides signal when caller omits it', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { subscribable: SubscribableServiceSchema }; + const server = createServer(serverTransport, services); + const abortController = new AbortController(); + const client = createClient( + clientTransport, + serverTransport.clientId, + { defaultCallOptions: { signal: abortController.signal } }, + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + // No signal passed at the call site — comes from defaultCallOptions. + const { resReadable } = client.subscribable.value.subscribe({}); + let result = await readNextResult(resReadable); + expect(result).toStrictEqual({ ok: true, payload: { result: 0 } }); + + abortController.abort(); + result = await readNextResult(resReadable); + expect(result).toStrictEqual({ + ok: false, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + payload: expect.objectContaining({ code: CANCEL_CODE }), + }); + expect(await isReadableDone(resReadable)).toEqual(true); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('caller-supplied signal overrides defaultCallOptions', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { subscribable: SubscribableServiceSchema }; + const server = createServer(serverTransport, services); + const defaultAc = new AbortController(); + const callerAc = new AbortController(); + const client = createClient( + clientTransport, + serverTransport.clientId, + { defaultCallOptions: { signal: defaultAc.signal } }, + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + // Caller signal is the one that should drive cancellation. + const { resReadable } = client.subscribable.value.subscribe( + {}, + { signal: callerAc.signal }, + ); + let result = await readNextResult(resReadable); + expect(result).toStrictEqual({ ok: true, payload: { result: 0 } }); + + // Aborting the default-options signal must NOT cancel — caller wins. + defaultAc.abort(); + const add1 = await client.subscribable.add.rpc({ n: 1 }); + expect(add1).toMatchObject({ ok: true }); + result = await readNextResult(resReadable); + expect(result).toStrictEqual({ ok: true, payload: { result: 1 } }); + + // Aborting the caller signal cancels. + callerAc.abort(); + result = await readNextResult(resReadable); + expect(result).toStrictEqual({ + ok: false, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + payload: expect.objectContaining({ code: CANCEL_CODE }), + }); + expect(await isReadableDone(resReadable)).toEqual(true); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('function-form defaultCallOptions is resolved per call', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + const services = { subscribable: SubscribableServiceSchema }; + const server = createServer(serverTransport, services); + let currentSignal: AbortSignal | undefined; + const client = createClient( + clientTransport, + serverTransport.clientId, + { defaultCallOptions: () => ({ signal: currentSignal }) }, + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + // Each subscribe resolves the getter at call time, so each call + // captures whatever signal is current. + const ac1 = new AbortController(); + currentSignal = ac1.signal; + const sub1 = client.subscribable.value.subscribe({}); + + const ac2 = new AbortController(); + currentSignal = ac2.signal; + const sub2 = client.subscribable.value.subscribe({}); + + let r1 = await readNextResult(sub1.resReadable); + let r2 = await readNextResult(sub2.resReadable); + expect(r1).toStrictEqual({ ok: true, payload: { result: 0 } }); + expect(r2).toStrictEqual({ ok: true, payload: { result: 0 } }); + + // ac1 cancels sub1 only — sub2 keeps streaming. + ac1.abort(); + r1 = await readNextResult(sub1.resReadable); + expect(r1).toStrictEqual({ + ok: false, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + payload: expect.objectContaining({ code: CANCEL_CODE }), + }); + expect(await isReadableDone(sub1.resReadable)).toEqual(true); + + const add1 = await client.subscribable.add.rpc({ n: 1 }); + expect(add1).toMatchObject({ ok: true }); + r2 = await readNextResult(sub2.resReadable); + expect(r2).toStrictEqual({ ok: true, payload: { result: 1 } }); + + ac2.abort(); + r2 = await readNextResult(sub2.resReadable); + expect(r2).toStrictEqual({ + ok: false, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + payload: expect.objectContaining({ code: CANCEL_CODE }), + }); + expect(await isReadableDone(sub2.resReadable)).toEqual(true); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + test('subscription idempotent close', async () => { // setup const clientTransport = getClientTransport('client'); diff --git a/package-lock.json b/package-lock.json index bb735940..dedd385b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.216.0", + "version": "0.216.1", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.216.0", + "version": "0.216.1", "license": "MIT", "dependencies": { "@bufbuild/protobuf": "^2.11.0", diff --git a/package.json b/package.json index adccb671..d899c78b 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.216.0", + "version": "0.216.1", "type": "module", "exports": { ".": { diff --git a/router/client.ts b/router/client.ts index 818a02bc..f3c6fb52 100644 --- a/router/client.ts +++ b/router/client.ts @@ -37,7 +37,7 @@ import { UNEXPECTED_DISCONNECT_CODE, } from './errors'; -interface CallOptions { +export interface CallOptions { signal?: AbortSignal; } @@ -205,6 +205,16 @@ function _createRecursiveProxy( export interface ClientOptions { connectOnInvoke: boolean; eagerlyConnect: boolean; + /** + * Default options merged into every leaf call (`rpc`, `stream`, + * `upload`, `subscribe`). Caller-supplied `options` win field-by-field, + * so a caller can override `signal` while keeping other defaults. + * + * Pass a function form when the default needs to be re-resolved per + * call (e.g. an ambient signal that changes between invocations of + * the same client). + */ + defaultCallOptions?: CallOptions | (() => CallOptions); } const defaultClientOptions: ClientOptions = { @@ -273,6 +283,11 @@ export function createClient>( ); } + const merged = mergeCallOptions( + clientOptions.defaultCallOptions, + callOptions as CallOptions | undefined, + ); + return handleProc( procMethod === 'subscribe' ? 'subscription' : procMethod, transport, @@ -280,11 +295,21 @@ export function createClient>( init, serviceName, procName, - callOptions ? (callOptions as CallOptions).signal : undefined, + merged.signal, ); }, []) as Client; } +function mergeCallOptions( + defaults: ClientOptions['defaultCallOptions'], + caller: CallOptions | undefined, +): CallOptions { + const resolved = typeof defaults === 'function' ? defaults() : defaults ?? {}; + + // Caller fields win: spread defaults first, caller second. + return { ...resolved, ...caller }; +} + type AnyProcReturn = | ReturnType> | ReturnType> diff --git a/router/index.ts b/router/index.ts index e6738093..f748101b 100644 --- a/router/index.ts +++ b/router/index.ts @@ -43,7 +43,7 @@ export { BaseErrorSchemaType, } from './errors'; export { createClient } from './client'; -export type { Client } from './client'; +export type { CallOptions, Client, ClientOptions } from './client'; export { createServer } from './server'; export type { Server,