From f4b209b4a168295e42277e266a20693b03b9cb1a Mon Sep 17 00:00:00 2001 From: Bri <34875062+Monkatraz@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:06:24 -0700 Subject: [PATCH 1/2] add uncaughts --- __tests__/cancellation.test.ts | 236 +++++++++++++++++++++++++++++++++ router/context.ts | 11 +- router/errors.ts | 11 ++ router/server.ts | 14 +- 4 files changed, 268 insertions(+), 4 deletions(-) diff --git a/__tests__/cancellation.test.ts b/__tests__/cancellation.test.ts index 97b4f796..a411b357 100644 --- a/__tests__/cancellation.test.ts +++ b/__tests__/cancellation.test.ts @@ -790,6 +790,242 @@ describe.each(testMatrix())( }, ); +describe.each(testMatrix())('handler explicit uncaught error cancellation ($transport.name transport, $codec.name codec)', + async ({ transport, codec }) => { + const opts = { codec: codec.codec }; + + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + beforeEach(async () => { + const setup = await transport.setup({ client: opts, server: opts }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + describe('e2e', () => { + test('rpc', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const handler = makeMockHandler('rpc'); + const services = { + service: ServiceSchema.define({ + rpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const resP = client.service.rpc.rpc({}); + + await waitFor(() => { + expect(handler).toHaveBeenCalledTimes(1); + }); + + const [{ ctx }] = handler.mock.calls[0]; + const onRequestFinished = vi.fn(); + ctx.signal.addEventListener('abort', onRequestFinished); + + const err = ctx.uncaught(new Error('test')); + + expect(err).toEqual( + Err({ + code: UNCAUGHT_ERROR_CODE, + message: 'test', + }), + ); + + await waitFor(() => { + expect(onRequestFinished).toHaveBeenCalled(); + }); + await expect(resP).resolves.toEqual(err); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('stream', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const handler = makeMockHandler('stream'); + const services = { + service: ServiceSchema.define({ + stream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, resReadable } = client.service.stream.stream({}); + + await waitFor(() => { + expect(handler).toHaveBeenCalledTimes(1); + }); + + const [{ ctx, reqReadable, resWritable }] = handler.mock.calls[0]; + + const err = ctx.uncaught(new Error('test')); + + expect(err).toEqual( + Err({ + code: UNCAUGHT_ERROR_CODE, + message: 'test', + }), + ); + + expect(await reqReadable.collect()).toEqual([err]); + expect(resWritable.isWritable()).toEqual(false); + + expect(await resReadable.collect()).toEqual([err]); + expect(reqWritable.isWritable()).toEqual(false); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('upload', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const handler = makeMockHandler('upload'); + const services = { + service: ServiceSchema.define({ + upload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, finalize } = client.service.upload.upload({}); + + await waitFor(() => { + expect(handler).toHaveBeenCalledTimes(1); + }); + + const [{ ctx, reqReadable }] = handler.mock.calls[0]; + + const err = ctx.uncaught(new Error('test')); + + expect(err).toEqual( + Err({ + code: UNCAUGHT_ERROR_CODE, + message: 'test', + }), + ); + + expect(await finalize()).toEqual(err); + expect(reqWritable.isWritable()).toEqual(false); + expect(await reqReadable.collect()).toEqual([err]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('subscribe', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const handler = makeMockHandler('subscription'); + const services = { + service: ServiceSchema.define({ + subscribe: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { resReadable } = client.service.subscribe.subscribe({}); + + await waitFor(() => { + expect(handler).toHaveBeenCalledTimes(1); + }); + + const [{ ctx, resWritable }] = handler.mock.calls[0]; + + const err = ctx.uncaught(new Error('test')); + + expect(err).toEqual( + Err({ + code: UNCAUGHT_ERROR_CODE, + message: 'test', + }), + ); + + expect(await resReadable.collect()).toEqual([err]); + expect(resWritable.isWritable()).toEqual(false); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + }); + }, +); + const createRejectable = () => { let reject: (reason: Error) => void; const promise = new Promise((_res, rej) => { diff --git a/router/context.ts b/router/context.ts index 922c6574..0cdc5f12 100644 --- a/router/context.ts +++ b/router/context.ts @@ -2,7 +2,7 @@ import { Span } from '@opentelemetry/api'; import { TransportClientId } from '../transport/message'; import { SessionId } from '../transport/sessionStateMachine/common'; import { ErrResult } from './result'; -import { CancelErrorSchema } from './errors'; +import { CancelErrorSchema, UncaughtErrorSchema } from './errors'; import { Static } from '@sinclair/typebox'; /** @@ -40,6 +40,15 @@ export type ProcedureHandlerContext = * the river documentation to understand the difference between the two concepts. */ cancel: (message?: string) => ErrResult>; + /** + * This emits an uncaught error in the same way that throwing an error in a handler + * would. You should minimize the amount of work you do after calling this function + * as this will start a cleanup of the entire procedure call. + * + * You'll typically want to use this for streaming procedures, as in e.g. an RPC + * you can just throw instead. + */ + uncaught: (err?: unknown) => ErrResult>; /** * This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) * triggered when the procedure invocation is done. This signal tracks the invocation/request finishing diff --git a/router/errors.ts b/router/errors.ts index 836d9411..83fdcee4 100644 --- a/router/errors.ts +++ b/router/errors.ts @@ -72,6 +72,16 @@ export function castTypeboxValueErrors( return result; } +/** + * A schema for unexpected errors in handlers + */ +export const UncaughtErrorSchema = Type.Object({ + code: Type.Literal(UNCAUGHT_ERROR_CODE), + message: Type.String(), +}); + +export const UncaughtResultSchema = ErrResultSchema(UncaughtErrorSchema); + /** * A schema for cancel payloads sent from the client */ @@ -88,6 +98,7 @@ export const CancelResultSchema = ErrResultSchema(CancelErrorSchema); * on the client). */ export const ReaderErrorSchema = Type.Union([ + UncaughtErrorSchema, Type.Object({ code: Type.Literal(UNCAUGHT_ERROR_CODE), message: Type.String(), diff --git a/router/server.ts b/router/server.ts index 87341e44..17003ac5 100644 --- a/router/server.ts +++ b/router/server.ts @@ -10,6 +10,7 @@ import { ValidationErrors, castTypeboxValueErrors, CancelResultSchema, + UncaughtResultSchema, } from './errors'; import { AnyService, @@ -550,7 +551,7 @@ class RiverServer< }, }); - const onHandlerError = (err: unknown, span: Span) => { + const onHandlerError = (err: unknown, span: Span): Static => { const errorMsg = coerceErrorString(err); span.recordException(err instanceof Error ? err : new Error(errorMsg)); @@ -571,10 +572,14 @@ class RiverServer< }, ); - onServerCancel({ + const res = Err({ code: UNCAUGHT_ERROR_CODE, message: errorMsg, }); + + onServerCancel(res.payload); + + return res; }; // if the init message has a close flag then we know this stream @@ -603,6 +608,9 @@ class RiverServer< return Err(errRes); }, + uncaught: (err?: unknown) => { + return onHandlerError(err, span); + }, signal: finishedController.signal, }; @@ -1039,7 +1047,7 @@ function getStreamCloseBackwardsCompat(protocolVersion: ProtocolVersion) { export interface MiddlewareContext extends Readonly< - Omit, 'cancel'> + Omit, 'cancel' | 'uncaught'> > { readonly streamId: StreamId; readonly procedureName: string; From 5a201047845ccfca2b4020fb9678151987e2f2c8 Mon Sep 17 00:00:00 2001 From: Bri <34875062+Monkatraz@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:10:42 -0700 Subject: [PATCH 2/2] remove duplicate --- router/errors.ts | 4 ---- 1 file changed, 4 deletions(-) diff --git a/router/errors.ts b/router/errors.ts index 83fdcee4..7b13e6eb 100644 --- a/router/errors.ts +++ b/router/errors.ts @@ -99,10 +99,6 @@ export const CancelResultSchema = ErrResultSchema(CancelErrorSchema); */ export const ReaderErrorSchema = Type.Union([ UncaughtErrorSchema, - Type.Object({ - code: Type.Literal(UNCAUGHT_ERROR_CODE), - message: Type.String(), - }), Type.Object({ code: Type.Literal(UNEXPECTED_DISCONNECT_CODE), message: Type.String(),