From c356683b109a707459b11a1d3efa23c0cd3c0d06 Mon Sep 17 00:00:00 2001 From: Mike Willbanks Date: Mon, 17 Nov 2025 20:27:30 +0000 Subject: [PATCH] feat(server): add exports for custom handler and adapter support Expose internal handler types and adapter interfaces. Improve visibility of handler internals required for overrides. Add shared entry point for common server utilities. --- packages/server/package.json | 20 +++ packages/server/src/api/index.ts | 1 + packages/server/src/api/rest/index.ts | 128 ++++++++++---------- packages/server/src/api/rpc/index.ts | 14 +-- packages/server/test/adapter/custom.test.ts | 119 ++++++++++++++++++ packages/server/test/api/custom.test.ts | 62 ++++++++++ packages/server/test/api/rest.test.ts | 75 ++++++++++++ packages/server/test/api/rpc.test.ts | 58 +++++++++ packages/server/tsup.config.ts | 2 + 9 files changed, 408 insertions(+), 71 deletions(-) create mode 100644 packages/server/test/adapter/custom.test.ts create mode 100644 packages/server/test/api/custom.test.ts diff --git a/packages/server/package.json b/packages/server/package.json index 84e47382..2f6dacdb 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -39,6 +39,16 @@ "default": "./dist/api.cjs" } }, + "./common": { + "import": { + "types": "./dist/common.d.ts", + "default": "./dist/common.js" + }, + "require": { + "types": "./dist/common.d.cts", + "default": "./dist/common.cjs" + } + }, "./express": { "import": { "types": "./dist/express.d.ts", @@ -118,6 +128,16 @@ "types": "./dist/tanstack-start.d.cts", "default": "./dist/tanstack-start.cjs" } + }, + "./types": { + "import": { + "types": "./dist/types.d.ts", + "default": "./dist/types.js" + }, + "require": { + "types": "./dist/types.d.cts", + "default": "./dist/types.cjs" + } } }, "dependencies": { diff --git a/packages/server/src/api/index.ts b/packages/server/src/api/index.ts index 09d9700e..e57e81b8 100644 --- a/packages/server/src/api/index.ts +++ b/packages/server/src/api/index.ts @@ -1,2 +1,3 @@ export { RestApiHandler, type RestApiHandlerOptions } from './rest'; export { RPCApiHandler, type RPCApiHandlerOptions } from './rpc'; +export * from './utils'; diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 536a45a5..4df8ee48 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -118,10 +118,10 @@ registerCustomSerializers(); */ export class RestApiHandler implements ApiHandler { // resource serializers - private serializers = new Map(); + protected serializers = new Map(); // error responses - private readonly errors: Record = { + protected readonly errors: Record = { unsupportedModel: { status: 404, title: 'Unsupported model type', @@ -200,10 +200,10 @@ export class RestApiHandler implements ApiHandler(\[[^[\]]+\])+)$/); + protected filterParamPattern = new RegExp(/^filter(?(\[[^[\]]+\])+)$/); // zod schema for payload of creating and updating a resource - private createUpdatePayloadSchema = z + protected createUpdatePayloadSchema = z .object({ data: z.object({ type: z.string(), @@ -225,16 +225,16 @@ export class RestApiHandler implements ApiHandler implements ApiHandler = {}; + protected typeMap: Record = {}; // divider used to separate compound ID fields - private idDivider; + protected idDivider; - private urlPatternMap: Record; - private modelNameMapping: Record; - private reverseModelNameMapping: Record; - private externalIdMapping: Record; + protected urlPatternMap: Record; + protected modelNameMapping: Record; + protected reverseModelNameMapping: Record; + protected externalIdMapping: Record; - constructor(private readonly options: RestApiHandlerOptions) { + constructor(protected readonly options: RestApiHandlerOptions) { this.idDivider = options.idDivider ?? DEFAULT_ID_DIVIDER; const segmentCharset = options.urlSegmentCharset ?? 'a-zA-Z0-9-_~ %'; @@ -283,7 +283,7 @@ export class RestApiHandler implements ApiHandler { + protected buildUrlPatternMap(urlSegmentNameCharset: string): Record { const options = { segmentValueCharset: urlSegmentNameCharset }; const buildPath = (segments: string[]) => { @@ -301,11 +301,11 @@ export class RestApiHandler implements ApiHandler implements ApiHandler { + protected handleGenericError(err: unknown): Response | PromiseLike { return this.makeError('unknownError', err instanceof Error ? `${err.message}\n${err.stack}` : 'Unknown error'); } - private async processSingleRead( + protected async processSingleRead( client: ClientContract, type: string, resourceId: string, @@ -528,7 +528,7 @@ export class RestApiHandler implements ApiHandler, type: string, resourceId: string, @@ -617,7 +617,7 @@ export class RestApiHandler implements ApiHandler, type: string, resourceId: string, @@ -683,7 +683,7 @@ export class RestApiHandler implements ApiHandler, type: string, query: Record | undefined, @@ -785,7 +785,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected buildPartialSelect(type: string, query: Record | undefined) { const selectFieldsQuery = query?.[`fields[${type}]`]; if (!selectFieldsQuery) { return { select: undefined, error: undefined }; @@ -812,11 +812,11 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler, type: string, _query: Record | undefined, @@ -931,7 +931,7 @@ export class RestApiHandler implements ApiHandler, type: string, _query: Record | undefined, @@ -1014,7 +1014,7 @@ export class RestApiHandler implements ApiHandler @@ -1024,7 +1024,7 @@ export class RestApiHandler implements ApiHandler, mode: 'create' | 'update' | 'delete', type: string, @@ -1119,7 +1119,7 @@ export class RestApiHandler implements ApiHandler, type: any, resourceId: string, @@ -1186,7 +1186,7 @@ export class RestApiHandler implements ApiHandler, type: any, resourceId: string): Promise { + protected async processDelete(client: ClientContract, type: any, resourceId: string): Promise { const typeInfo = this.getModelInfo(type); if (!typeInfo) { return this.makeUnsupportedModelError(type); @@ -1203,7 +1203,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler> = {}; for (const model of Object.keys(this.schema.models)) { @@ -1382,7 +1382,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler>) { + protected async serializeItems(model: string, items: unknown, options?: Partial>) { model = lowerCaseFirst(model); const serializer = this.serializers.get(model); if (!serializer) { @@ -1421,7 +1421,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler) { + protected replaceURLSearchParams(url: string, params: Record) { const r = new URL(url); for (const [key, value] of Object.entries(params)) { r.searchParams.set(key, value.toString()); @@ -1486,7 +1486,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler ({ ...acc, [curr.name]: true }), {}); } - private makeIdConnect(idFields: FieldDef[], id: string | number) { + protected makeIdConnect(idFields: FieldDef[], id: string | number) { if (idFields.length === 1) { return { [idFields[0]!.name]: this.coerce(idFields[0]!, id) }; } else { @@ -1535,20 +1535,20 @@ export class RestApiHandler implements ApiHandler idf.name).join(this.idDivider); } - private makeDefaultIdKey(idFields: FieldDef[]) { + protected makeDefaultIdKey(idFields: FieldDef[]) { // TODO: support `@@id` with custom name return idFields.map((idf) => idf.name).join(DEFAULT_ID_DIVIDER); } - private makeCompoundId(idFields: FieldDef[], item: any) { + protected makeCompoundId(idFields: FieldDef[], item: any) { return idFields.map((idf) => item[idf.name]).join(this.idDivider); } - private makeUpsertWhere(matchFields: any[], attributes: any, typeInfo: ModelInfo) { + protected makeUpsertWhere(matchFields: any[], attributes: any, typeInfo: ModelInfo) { const where = matchFields.reduce((acc: any, field: string) => { acc[field] = attributes[field] ?? null; return acc; @@ -1566,7 +1566,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler attr.name === '@json')) { try { @@ -1624,7 +1624,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected makeNormalizedUrl(path: string, query: Record | undefined) { const url = new URL(this.makeLinkUrl(path)); for (const [key, value] of Object.entries(query ?? {})) { if ( @@ -1642,7 +1642,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected getPagination(query: Record | undefined) { if (!query) { return { offset: 0, limit: this.options.pageSize ?? DEFAULT_PAGE_SIZE }; } @@ -1676,7 +1676,7 @@ export class RestApiHandler implements ApiHandler | undefined, ): { filter: any; error: any } { @@ -1780,7 +1780,7 @@ export class RestApiHandler implements ApiHandler | undefined) { + protected buildSort(type: string, query: Record | undefined) { if (!query?.['sort']) { return { sort: undefined, error: undefined }; } @@ -1857,7 +1857,7 @@ export class RestApiHandler implements ApiHandler | undefined, @@ -1917,7 +1917,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler implements ApiHandler { return this.makeError('validationError', err.message, 422); @@ -2036,7 +2036,7 @@ export class RestApiHandler implements ApiHandler implements ApiHandler = { * RPC style API request handler that mirrors the ZenStackClient API */ export class RPCApiHandler implements ApiHandler { - constructor(private readonly options: RPCApiHandlerOptions) {} + constructor(protected readonly options: RPCApiHandlerOptions) {} get schema(): Schema { return this.options.schema; @@ -163,11 +163,11 @@ export class RPCApiHandler implements ApiHandler, model: string) { + protected isValidModel(client: ClientContract, model: string) { return Object.keys(client.$schema.models).some((m) => lowerCaseFirst(m) === lowerCaseFirst(model)); } - private makeBadInputErrorResponse(message: string) { + protected makeBadInputErrorResponse(message: string) { const resp = { status: 400, body: { error: { message } }, @@ -176,7 +176,7 @@ export class RPCApiHandler implements ApiHandler implements ApiHandler implements ApiHandler implements ApiHandler = { + method: string; + path: string; + query?: Record; + body?: unknown; + client: ClientContract; +}; + +class RecordingHandler implements ApiHandler { + constructor( + protected readonly schemaDef: SchemaDef, + protected readonly response: Response, + protected readonly logger?: (...args: any[]) => void, + ) {} + + readonly contexts: Array> = []; + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log() { + return this.logger; + } + + async handleRequest(context: RequestContext): Promise { + this.contexts.push(context); + return this.response; + } +} + +class ThrowingHandler implements ApiHandler { + constructor(protected readonly schemaDef: SchemaDef, protected readonly logger: (...args: any[]) => void) {} + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log() { + return this.logger; + } + + async handleRequest(): Promise { + throw new Error('adapter failure'); + } +} + +function createCustomAdapter( + options: CommonAdapterOptions, +): (request: AdapterRequest) => Promise { + return async (request) => { + const context: RequestContext = { + client: request.client, + method: request.method, + path: request.path, + query: request.query, + requestBody: request.body, + }; + + try { + return await options.apiHandler.handleRequest(context); + } catch (err) { + logInternalError(options.apiHandler.log, err); + throw err; + } + }; +} + +describe('Custom adapter test', () => { + const schema = {} as SchemaDef; + const client = { $schema: schema } as unknown as ClientContract; + + it('delegates to api handler', async () => { + const response: Response = { status: 201, body: { ok: true } }; + const handler = new RecordingHandler(schema, response); + const adapter = createCustomAdapter({ apiHandler: handler }); + + const result = await adapter({ + method: 'get', + path: '/something', + query: { foo: 'bar' }, + body: { value: 1 }, + client, + }); + + expect(result).toEqual(response); + expect(handler.contexts).toHaveLength(1); + const captured = handler.contexts[0]; + expect(captured.method).toBe('get'); + expect(captured.path).toBe('/something'); + expect(captured.query).toEqual({ foo: 'bar' }); + expect(captured.requestBody).toEqual({ value: 1 }); + expect(captured.client).toBe(client); + }); + + it('logs internal error when handler throws', async () => { + const logger = vi.fn(); + const handler = new ThrowingHandler(schema, logger); + const adapter = createCustomAdapter({ apiHandler: handler }); + + await expect( + adapter({ + method: 'post', + path: '/fail', + client, + }), + ).rejects.toThrow('adapter failure'); + expect(logger).toHaveBeenCalledTimes(1); + const call = logger.mock.calls[0]; + expect(call[0]).toBe('error'); + expect(call[1]).toContain('An unhandled error occurred while processing the request: Error: adapter failure'); + }); +}); diff --git a/packages/server/test/api/custom.test.ts b/packages/server/test/api/custom.test.ts new file mode 100644 index 00000000..d51d7b9f --- /dev/null +++ b/packages/server/test/api/custom.test.ts @@ -0,0 +1,62 @@ +import type { ClientContract } from '@zenstackhq/orm'; +import type { SchemaDef } from '@zenstackhq/orm/schema'; +import { Decimal } from 'decimal.js'; +import SuperJSON from 'superjson'; +import { describe, expect, it, vi } from 'vitest'; +import { log, registerCustomSerializers } from '../../src/api/utils'; +import { type ApiHandler, type LogConfig, type RequestContext, type Response } from '../../src/types'; + +class CustomApiHandler implements ApiHandler { + protected readonly handled: Array> = []; + + constructor(protected readonly schemaDef: SchemaDef, protected readonly logger: LogConfig) {} + + get schema(): SchemaDef { + return this.schemaDef; + } + + get log(): LogConfig { + return this.logger; + } + + get contexts(): ReadonlyArray> { + return this.handled; + } + + async handleRequest(context: RequestContext): Promise { + this.handled.push(context); + log(this.logger, 'info', () => `received ${context.method.toUpperCase()} ${context.path}`); + return { status: 202, body: { handled: true } }; + } +} + +describe('Custom API handler test', () => { + const schema = {} as SchemaDef; + const client = { $schema: schema } as unknown as ClientContract; + + it('allows building custom handlers with logging helpers', async () => { + const logger = vi.fn(); + const handler = new CustomApiHandler(schema, logger); + + const response = await handler.handleRequest({ + method: 'post', + path: '/custom', + query: { foo: 'bar' }, + requestBody: { value: 1 }, + client, + }); + + expect(response).toEqual({ status: 202, body: { handled: true } }); + expect(handler.contexts).toHaveLength(1); + expect(handler.contexts[0].query).toEqual({ foo: 'bar' }); + expect(logger).toHaveBeenCalledWith('info', 'received POST /custom', undefined); + }); + + it('provides serialization helpers for custom handlers', () => { + registerCustomSerializers(); + const serialized = SuperJSON.serialize({ value: new Decimal('3.14159') }); + const roundTripped = SuperJSON.deserialize(serialized) as { value: Decimal }; + expect(Decimal.isDecimal(roundTripped.value)).toBe(true); + expect(roundTripped.value.toString()).toBe('3.14159'); + }); +}); diff --git a/packages/server/test/api/rest.test.ts b/packages/server/test/api/rest.test.ts index b40ff604..32a39f8d 100644 --- a/packages/server/test/api/rest.test.ts +++ b/packages/server/test/api/rest.test.ts @@ -3163,4 +3163,79 @@ describe('REST server tests', () => { }); }); }); + + describe('REST server tests - handler extension', () => { + const schema = ` + model Post { + id String @id + title String + } + `; + + class CustomRestApiHandler extends RestApiHandler { + public readonly buildFilterCalls: Array<{ + type: string; + query: Record | undefined; + filter: unknown; + }> = []; + + protected override buildFilter( + type: string, + query: Record | undefined, + ) { + const result = super.buildFilter(type, query); + if (type !== 'post') { + this.buildFilterCalls.push({ type, query, filter: result.filter }); + return result; + } + + const baseFilter = + result.filter && typeof result.filter === 'object' && !Array.isArray(result.filter) + ? { ...(result.filter as Record) } + : {}; + + const modified = { + ...result, + filter: { + ...baseFilter, + title: 'second', + }, + }; + + this.buildFilterCalls.push({ type, query, filter: modified.filter }); + return modified; + } + } + + beforeEach(async () => { + client = await createTestClient(schema); + await client.post.create({ data: { id: 'post-first', title: 'first' } }); + await client.post.create({ data: { id: 'post-second', title: 'second' } }); + }); + + it('allows extending RestApiHandler to customize filtering', async () => { + const customHandler = new CustomRestApiHandler({ + schema: client.$schema, + endpoint: 'http://localhost/api', + }); + + const response = await customHandler.handleRequest({ + method: 'get', + path: '/post', + query: {}, + client, + }); + + expect(customHandler.buildFilterCalls).toHaveLength(1); + expect(customHandler.buildFilterCalls[0].type).toBe('post'); + expect(customHandler.buildFilterCalls[0].filter).toMatchObject({ title: 'second' }); + + expect(response.status).toBe(200); + const body = response.body as { + data: Array<{ attributes: { title: string } }>; + }; + expect(body.data).toHaveLength(1); + expect(body.data[0].attributes.title).toBe('second'); + }); + }); }); diff --git a/packages/server/test/api/rpc.test.ts b/packages/server/test/api/rpc.test.ts index 19e44ca0..2f50bf5d 100644 --- a/packages/server/test/api/rpc.test.ts +++ b/packages/server/test/api/rpc.test.ts @@ -508,6 +508,64 @@ describe('RPC API Handler Tests', () => { expect(r.data).toBeNull(); }); + it('allows extending RPCApiHandler to customize query unmarshalling', async () => { + await rawClient.post.deleteMany(); + await rawClient.user.deleteMany(); + + await rawClient.user.create({ + data: { + id: 'ext-user', + email: 'ext@example.com', + posts: { + create: [ + { id: 'ext-post-1', title: 'first', published: true }, + { id: 'ext-post-2', title: 'second', published: true }, + ], + }, + }, + }); + + class CustomHandler extends RPCApiHandler { + public readonly unmarshalCalls: Array<{ value: string; meta: string | undefined; result: unknown }> = []; + protected override unmarshalQ(value: string, meta: string | undefined) { + const result = super.unmarshalQ(value, meta); + this.unmarshalCalls.push({ value, meta, result }); + const asRecord = (result ?? {}) as Record; + const baseWhere = (asRecord.where ?? {}) as Record; + return { + ...asRecord, + where: { + ...baseWhere, + title: 'second', + }, + }; + } + } + + const handler = new CustomHandler({ schema: client.$schema }); + const callHandler = (args: Parameters[0]) => handler.handleRequest(args); + + const response = await callHandler({ + method: 'get', + path: '/post/findMany', + client: rawClient, + query: { + q: JSON.stringify({ where: {} }), + }, + }); + + expect(handler.unmarshalCalls).toHaveLength(1); + expect(handler.unmarshalCalls[0].value).toBeDefined(); + expect(handler.unmarshalCalls[0].result).toEqual({ where: {} }); + expect(response.status).toBe(200); + const responseBody = response.body as { data: Array<{ title: string }> }; + expect(responseBody.data).toHaveLength(1); + expect(responseBody.data[0].title).toBe('second'); + + await rawClient.post.deleteMany(); + await rawClient.user.deleteMany(); + }); + function makeHandler() { const handler = new RPCApiHandler({ schema: client.$schema }); return async (args: any) => { diff --git a/packages/server/tsup.config.ts b/packages/server/tsup.config.ts index 4c236d2f..70009f51 100644 --- a/packages/server/tsup.config.ts +++ b/packages/server/tsup.config.ts @@ -2,7 +2,9 @@ import { defineConfig } from 'tsup'; export default defineConfig({ entry: { + types: 'src/types.ts', api: 'src/api/index.ts', + common: 'src/adapter/common.ts', express: 'src/adapter/express/index.ts', next: 'src/adapter/next/index.ts', fastify: 'src/adapter/fastify/index.ts',