From cd9ccd750407b2a69148beeb3985a95da29a94b9 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Sun, 8 Feb 2026 14:18:17 -0500 Subject: [PATCH 01/10] feat: implemented generate middleware --- js/ai/src/generate.ts | 163 +++++- js/ai/src/generate/action.ts | 127 ++++- js/ai/src/generate/middleware.ts | 221 ++++++++ js/ai/src/generate/resolve-tool-requests.ts | 188 ++++--- js/ai/src/index.ts | 10 + js/ai/src/model-types.ts | 3 + js/ai/src/plugin.ts | 27 + js/ai/tests/generate/action_test.ts | 52 +- js/ai/tests/generate/middleware_test.ts | 578 ++++++++++++++++++++ js/core/src/plugin.ts | 14 + js/core/src/reflection.ts | 16 +- js/core/tests/reflection_test.ts | 102 ++++ js/docs/generate-middleware.md | 85 +++ js/genkit/src/common.ts | 7 + js/genkit/src/genkit.ts | 8 + js/genkit/src/plugin.ts | 31 +- js/genkit/tests/generate_test.ts | 37 ++ 17 files changed, 1531 insertions(+), 138 deletions(-) create mode 100644 js/ai/src/generate/middleware.ts create mode 100644 js/ai/src/plugin.ts create mode 100644 js/ai/tests/generate/middleware_test.ts create mode 100644 js/core/tests/reflection_test.ts create mode 100644 js/docs/generate-middleware.md diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index d5c0abf168..48dc68dd34 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -40,6 +40,13 @@ import { shouldInjectFormatInstructions, } from './generate/action.js'; import { GenerateResponseChunk } from './generate/chunk.js'; +import { + GenerateMiddleware, + generateMiddleware, + GenerateMiddlewareDef, + MiddlewareRef, + resolveMiddleware, +} from './generate/middleware.js'; import { GenerateResponse } from './generate/response.js'; import { Message } from './message.js'; import { @@ -171,7 +178,7 @@ export interface GenerateOptions< */ streamingCallback?: StreamingCallback; /** Middleware to be used with this model call. */ - use?: ModelMiddlewareArgument[]; + use?: (ModelMiddlewareArgument | GenerateMiddleware | MiddlewareRef)[]; /** Additional context (data, like e.g. auth) to be passed down to tools, prompts and other sub actions. */ context?: ActionContext; /** Abort signal for the generate request. */ @@ -361,6 +368,108 @@ function messagesFromOptions(options: GenerateOptions): MessageData[] { /** A GenerationBlockedError is thrown when a generation is blocked. */ export class GenerationBlockedError extends GenerationResponseError {} +/** + * Normalizes a mix of middleware representations into an array of standardized `MiddlewareRef`s. + * Any raw functional middleware or unregistered middleware objects are dynamically registered + * into the provided registry. + * + * @param registry The registry to use for looking up or dynamically registering middleware. + * @param middlewareList An array of middleware functions, instances, or references. + * @returns A promise resolving to an array of normalized `MiddlewareRef` objects. + */ +export async function normalizeMiddleware( + registry: Registry, + middlewareList?: ( + | ModelMiddlewareArgument + | GenerateMiddleware + | MiddlewareRef + )[] +): Promise { + if (!middlewareList || middlewareList.length === 0) { + return []; + } + + const refs: MiddlewareRef[] = []; + + for (let i = 0; i < middlewareList.length; i++) { + const middleware = middlewareList[i]; + + if ( + typeof middleware === 'function' && + (middleware as any).instantiate && + (middleware as any).plugin + ) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Middleware ${(middleware as any).name || 'function'} must be called with () when used in 'use' array.`, + }); + } + + if (typeof middleware === 'function') { + const name = `dynamic-middleware-${i}-${Math.random().toString(36).slice(2)}`; + + const wrappedDef = generateMiddleware( + { name, metadata: { dynamic: true } }, + () => ({ + model: async (req, ctx, next) => { + if (middleware.length === 3) { + return (middleware as any)( + req, + ctx, + async (modifiedReq: any, opts: any) => + next(modifiedReq || req, opts || ctx) + ); + } else { + return (middleware as any)(req, async (modifiedReq: any) => + next(modifiedReq || req, ctx) + ); + } + }, + }) + ); + registry.registerValue('middleware', name, wrappedDef); + refs.push({ name }); + continue; + } + + if ( + typeof middleware === 'object' && + middleware !== null && + 'instantiate' in middleware && + typeof middleware.instantiate === 'function' + ) { + const def = middleware as GenerateMiddleware; + const registered = await registry.lookupValue( + 'middleware', + def.name + ); + if (!registered) { + registry.registerValue('middleware', def.name, def); + } + refs.push({ name: def.name }); + continue; + } + + if ( + typeof middleware === 'object' && + middleware !== null && + 'name' in middleware + ) { + const ref = middleware as MiddlewareRef & { __def?: GenerateMiddleware }; + const registered = await registry.lookupValue( + 'middleware', + ref.name + ); + if (!registered && ref.__def) { + registry.registerValue('middleware', ref.name, ref.__def); + } + refs.push({ name: ref.name, config: ref.config }); + } + } + + return refs; +} + /** * Generate calls a generative model based on the provided prompt and configuration. If * `history` is provided, the generation will include a conversation history in its @@ -386,8 +495,16 @@ export async function generate< }; const resolvedFormat = await resolveFormat(registry, resolvedOptions.output); - registry = maybeRegisterDynamicTools(registry, resolvedOptions); - registry = maybeRegisterDynamicResources(registry, resolvedOptions); + registry = Registry.withParent(registry); + + maybeRegisterDynamicTools(registry, resolvedOptions); + maybeRegisterDynamicResources(registry, resolvedOptions); + + const middlewareRefs = await normalizeMiddleware( + registry, + resolvedOptions.use + ); + resolvedOptions.use = middlewareRefs; // Cast back because `use` can be generic const params = await toGenerateActionOptions(registry, resolvedOptions); @@ -399,10 +516,14 @@ export async function generate< const streamingCallback = stripNoop( resolvedOptions.onChunk ?? resolvedOptions.streamingCallback ) as StreamingCallback; + + const resolvedMiddleware = await resolveMiddleware(registry, middlewareRefs); + maybeRegisterDynamicMiddlewareTools(registry, resolvedMiddleware); + const response = await runWithContext(resolvedOptions.context, () => generateHelper(registry, { rawRequest: params, - middleware: resolvedOptions.use, + middleware: resolvedMiddleware, abortSignal: resolvedOptions.abortSignal, streamingCallback, }) @@ -450,42 +571,39 @@ export async function generateOperation< return operation; } +export function maybeRegisterDynamicMiddlewareTools( + registry: Registry, + middlewares?: GenerateMiddlewareDef[] +) { + middlewares?.forEach((mw) => { + mw.tools?.forEach((t) => { + if (isDynamicTool(t)) { + registry.registerAction('tool', t as Action); + } + }); + }); +} + function maybeRegisterDynamicTools< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, ->(registry: Registry, options: GenerateOptions): Registry { - let hasDynamicTools = false; +>(registry: Registry, options: GenerateOptions) { options?.tools?.forEach((t) => { if (isDynamicTool(t)) { - if (!hasDynamicTools) { - hasDynamicTools = true; - // Create a temporary registry with dynamic tools for the duration of this - // generate request. - registry = Registry.withParent(registry); - } registry.registerAction('tool', t as Action); } }); - return registry; } function maybeRegisterDynamicResources< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, ->(registry: Registry, options: GenerateOptions): Registry { - let hasDynamicResources = false; +>(registry: Registry, options: GenerateOptions) { options?.resources?.forEach((r) => { if (isDynamicResourceAction(r)) { - if (!hasDynamicResources) { - hasDynamicResources = true; - // Create a temporary registry with dynamic tools for the duration of this - // generate request. - registry = Registry.withParent(registry); - } registry.registerAction('resource', r); } }); - return registry; } export async function toGenerateActionOptions< @@ -542,6 +660,7 @@ export async function toGenerateActionOptions< returnToolRequests: options.returnToolRequests, maxTurns: options.maxTurns, stepName: options.stepName, + use: options.use as MiddlewareRef[] | undefined, }; // if config is empty and it was not explicitly passed in, we delete it, don't want {} if (Object.keys(params.config).length === 0 && !options.config) { diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 8d755e3686..72ecc80f4d 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -24,7 +24,7 @@ import { type z, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import type { Registry } from '@genkit-ai/core/registry'; +import { Registry } from '@genkit-ai/core/registry'; import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; import { injectInstructions, @@ -35,6 +35,8 @@ import type { Formatter } from '../formats/types.js'; import { GenerateResponse, GenerationResponseError, + maybeRegisterDynamicMiddlewareTools, + normalizeMiddleware, tagAsPreamble, } from '../generate.js'; import { GenerateResponseChunk } from '../generate/chunk.js'; @@ -43,8 +45,6 @@ import { GenerateResponseChunkSchema, GenerateResponseSchema, MessageData, - ModelMiddlewareArgument, - ModelMiddlewareWithOptions, resolveModel, type GenerateActionOptions, type GenerateActionOutputConfig, @@ -54,7 +54,6 @@ import { type GenerateResponseData, type ModelAction, type ModelInfo, - type ModelMiddleware, type ModelRequest, type Part, type Role, @@ -65,6 +64,7 @@ import { type ResourceAction, } from '../resource.js'; import { resolveTools, toToolDefinition, type ToolAction } from '../tool.js'; +import { GenerateMiddlewareDef, resolveMiddleware } from './middleware.js'; import { assertValidToolNames, resolveResumeOption, @@ -89,15 +89,27 @@ export function defineGenerateAction(registry: Registry): GenerateAction { streamSchema: GenerateResponseChunkSchema, }, async (request, { streamingRequested, sendChunk, context }) => { + let childRegistry = Registry.withParent(registry); + const middlewareRefs = await normalizeMiddleware( + childRegistry, + request.use + ); + request.use = middlewareRefs; // Cast back because `use` can be generic + + const resolvedMiddleware = await resolveMiddleware( + childRegistry, + request.use + ); + maybeRegisterDynamicMiddlewareTools(childRegistry, resolvedMiddleware); + const generateFn = ( sendChunk?: StreamingCallback ) => - generate(registry, { + generateActionImpl(childRegistry, { rawRequest: request, currentTurn: 0, messageIndex: 0, - // Generate util action does not support middleware. Maybe when we add named/registered middleware.... - middleware: [], + middleware: resolvedMiddleware, streamingCallback: sendChunk, context, }); @@ -117,7 +129,7 @@ export async function generateHelper( registry: Registry, options: { rawRequest: GenerateActionOptions; - middleware?: ModelMiddlewareArgument[]; + middleware?: GenerateMiddlewareDef[]; currentTurn?: number; messageIndex?: number; abortSignal?: AbortSignal; @@ -140,7 +152,7 @@ export async function generateHelper( async (metadata) => { metadata.name = options.rawRequest.stepName || 'generate'; metadata.input = options.rawRequest; - const output = await generate(registry, { + const output = await generateActionImpl(registry, { rawRequest: options.rawRequest, middleware: options.middleware, currentTurn, @@ -243,7 +255,65 @@ function applyTransferPreamble( }); } -async function generate( +async function generateActionImpl( + registry: Registry, + args: { + rawRequest: GenerateActionOptions; + middleware: GenerateMiddlewareDef[] | undefined; + currentTurn: number; + messageIndex: number; + abortSignal?: AbortSignal; + streamingCallback?: StreamingCallback; + context?: Record; + } +): Promise { + const { + rawRequest, + middleware, + currentTurn, + messageIndex, + abortSignal, + streamingCallback, + context, + } = args; + + if (currentTurn === 0 && middleware && middleware.length > 0) { + const dispatchGenerate = async ( + index: number, + req: GenerateActionOptions, + ctx: ActionRunOptions + ): Promise => { + if (index === middleware.length) { + return generateActionTurn(registry, { + rawRequest: req, + middleware, + currentTurn, + messageIndex, + abortSignal: ctx.abortSignal, + streamingCallback: ctx.onChunk, + context: ctx.context, + }); + } + const currentMiddleware = middleware[index]; + if (currentMiddleware.generate) { + return currentMiddleware.generate(req, ctx, async (modifiedReq, opts) => + dispatchGenerate(index + 1, modifiedReq || req, opts || ctx) + ); + } else { + return dispatchGenerate(index + 1, req, ctx); + } + }; + return dispatchGenerate(0, rawRequest, { + abortSignal, + onChunk: streamingCallback, + context, + }); + } else { + return generateActionTurn(registry, args); + } +} + +async function generateActionTurn( registry: Registry, { rawRequest, @@ -255,7 +325,7 @@ async function generate( context, }: { rawRequest: GenerateActionOptions; - middleware: ModelMiddlewareArgument[] | undefined; + middleware: GenerateMiddlewareDef[] | undefined; currentTurn: number; messageIndex: number; abortSignal?: AbortSignal; @@ -267,6 +337,11 @@ async function generate( registry, rawRequest ); + + // Append tools supplied by middleware + if (middleware) { + tools.push(...middleware.flatMap((m) => m.tools || [])); + } rawRequest = applyFormat(rawRequest, format); rawRequest = await applyResources(registry, rawRequest, resources); @@ -277,7 +352,7 @@ async function generate( revisedRequest, interruptedResponse, toolMessage: resumedToolMessage, - } = await resolveResumeOption(registry, rawRequest); + } = await resolveResumeOption(registry, rawRequest, middleware || []); // NOTE: in the future we should make it possible to interrupt a restart, but // at the moment it's too complicated because it's not clear how to return a // response that amends history but doesn't generate a new message, so we throw @@ -329,35 +404,32 @@ async function generate( var response: GenerateResponse; const sendChunk = streamingCallback && - (((chunk: GenerateResponseChunkData) => - streamingCallback && - streamingCallback(makeChunk('model', chunk))) as any); - const dispatch = async ( + ((chunk: GenerateResponseChunkData) => + streamingCallback && streamingCallback(makeChunk('model', chunk))); + const dispatchModel = async ( index: number, req: z.infer, actionOpts: ActionRunOptions - ) => { + ): Promise => { if (!middleware || index === middleware.length) { // end of the chain, call the original model action return await model(req, actionOpts); } const currentMiddleware = middleware[index]; - if (currentMiddleware.length === 3) { - return (currentMiddleware as ModelMiddlewareWithOptions)( + if (currentMiddleware.model) { + return currentMiddleware.model( req, actionOpts, async (modifiedReq, opts) => - dispatch(index + 1, modifiedReq || req, opts || actionOpts) + dispatchModel(index + 1, modifiedReq || req, opts || actionOpts) ); } else { - return (currentMiddleware as ModelMiddleware)(req, async (modifiedReq) => - dispatch(index + 1, modifiedReq || req, actionOpts) - ); + return dispatchModel(index + 1, req, actionOpts); } }; - const modelResponse = await dispatch(0, request, { + const modelResponse = await dispatchModel(0, request, { abortSignal, context, onChunk: sendChunk, @@ -405,7 +477,12 @@ async function generate( } const { revisedModelMessage, toolMessage, transferPreamble } = - await resolveToolRequests(registry, rawRequest, generatedMessage); + await resolveToolRequests( + rawRequest, + generatedMessage, + tools, + middleware || [] + ); // if an interrupt message is returned, stop the tool loop and return a response if (revisedModelMessage) { diff --git a/js/ai/src/generate/middleware.ts b/js/ai/src/generate/middleware.ts new file mode 100644 index 0000000000..5d8755201d --- /dev/null +++ b/js/ai/src/generate/middleware.ts @@ -0,0 +1,221 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { ActionRunOptions, GenkitError, z } from '@genkit-ai/core'; +import type { Registry } from '@genkit-ai/core/registry'; +import { toJsonSchema } from '@genkit-ai/core/schema'; +import type { GenerateActionOptions } from '../model-types.js'; +import { + GenerateRequest, + GenerateResponseData, + ToolRequestPart, + ToolResponsePart, +} from '../model.js'; +import { GenkitPluginV2 } from '../plugin.js'; +import { ToolAction } from '../tool.js'; + +/** Descriptor for a registered middleware, returned by reflection API. */ +export const MiddlewareDescSchema = z.object({ + /** Unique name of the middleware. */ + name: z.string(), + /** Human-readable description of what the middleware does. */ + description: z.string().optional(), + /** JSON Schema for the middleware's configuration. */ + configSchema: z.record(z.any()).nullish(), + /** User defined metadata for the middleware. */ + metadata: z.record(z.any()).nullish(), +}); +export type MiddlewareDesc = z.infer; + +/** Reference to a registered middleware with optional configuration. */ +export const MiddlewareRefSchema = z.object({ + /** Name of the registered middleware. */ + name: z.string(), + /** Configuration for the middleware (schema defined by the middleware). */ + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; + +/** + * Defines a Genkit Generate Middleware instance, which can be configured and registered. + * When invoked with an optional configuration, it returns a reference suitable for + * inclusion in a `GenerateOptions.use` array. + */ +export interface GenerateMiddleware + extends MiddlewareDesc { + /** Configures the middleware, returning a MiddlewareRef for usage in `generate({use: [...]})`. */ + (config?: z.infer): MiddlewareRef & { __def: GenerateMiddleware }; + /** The unique name of this middleware. */ + name: string; + /** An optional description of what the middleware does. */ + description?: string; + /** An optional Zod schema for validating the middleware's configuration. */ + configSchema?: C; + /** Metadata describing this middleware. */ + metadata?: Record; + /** + * Factory function that receives the validated configuration and creates + * a `GenerateMiddlewareDef` holding the active hooks. + */ + instantiate: (config?: z.infer) => GenerateMiddlewareDef; + /** + * Optional plugin wrapper exposing this middleware for framework-level registration. + */ + plugin: (config?: z.infer) => GenkitPluginV2; + /** Generates a JSON-compatible representation of the middleware metadata. */ + toJson: () => MiddlewareDesc; +} + +/** + * An instantiated implementation of a Generate Middleware. + * Provides optional hooks to intercept the high-level `generate` action, + * the underlying `model` execution, or individual `tool` calls, as well as + * tools to inject into the execution. + */ +export interface GenerateMiddlewareDef { + /** + * Hook for intercepting the top-level generate action. + * Can be used to inject request parameters, modify the response, or catch errors. + */ + generate?: ( + req: GenerateActionOptions, + ctx: ActionRunOptions, + next: ( + req: GenerateActionOptions, + ctx: ActionRunOptions + ) => Promise + ) => Promise; + /** + * Hook for intercepting the underlying model execution. + * Ideal for model-level caching, retry logic, or prompt/response parsing. + */ + model?: ( + req: GenerateRequest, + ctx: ActionRunOptions, + next: ( + req: GenerateRequest, + ctx: ActionRunOptions + ) => Promise + ) => Promise; + /** + * Hook for intercepting individual tool calls. + * Enables caching tool responses, validating inputs, or overriding tool execution. + */ + tool?: ( + req: ToolRequestPart, + ctx: ActionRunOptions, + next: ( + req: ToolRequestPart, + ctx: ActionRunOptions + ) => Promise<{ + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + }> + ) => Promise<{ + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + }>; + /** + * Tools to statically inject into the generation request whenever this middleware is active. + */ + tools?: ToolAction[]; +} + +export function generateMiddleware< + ConfigSchema extends z.ZodTypeAny = z.ZodTypeAny, +>( + options: { + name: string; + configSchema?: ConfigSchema; + description?: string; + metadata?: Record; + }, + middlewareFn: (config?: z.infer) => GenerateMiddlewareDef +): GenerateMiddleware { + const def = function (config?: z.infer) { + return { + name: options.name, + config, + __def: def, + }; + } as GenerateMiddleware; + + Object.defineProperty(def, 'name', { value: options.name }); + def.configSchema = options.configSchema; + def.description = options.description; + def.metadata = options.metadata; + def.instantiate = middlewareFn; + def.plugin = (pluginConfig?: z.infer) => ({ + name: `middleware:${options.name}`, + version: 'v2', + generateMiddleware: () => { + if (pluginConfig === undefined) { + return [def]; + } + const wrappedDef = function (config?: z.infer) { + return def(config ?? pluginConfig); + } as GenerateMiddleware; + + Object.defineProperty(wrappedDef, 'name', { value: options.name }); + wrappedDef.configSchema = options.configSchema; + wrappedDef.description = options.description; + wrappedDef.metadata = options.metadata; + wrappedDef.instantiate = (reqConfig) => + def.instantiate(reqConfig ?? pluginConfig); + wrappedDef.plugin = def.plugin; + + return [wrappedDef]; + }, + model: (_) => { + throw new Error('Not supported for middleware plugins'); + }, + }); + + def.toJson = () => ({ + name: options.name, + description: options.description, + configSchema: options.configSchema + ? toJsonSchema({ schema: options.configSchema }) + : undefined, + metadata: options.metadata, + }); + + return def; +} + +export async function resolveMiddleware( + registry: Registry, + refs?: MiddlewareRef[] +): Promise { + const result: GenerateMiddlewareDef[] = []; + if (!refs) return result; + for (const ref of refs) { + const def = await registry.lookupValue( + 'middleware', + ref.name + ); + if (!def) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Middleware ${ref.name} not found in registry.`, + }); + } + result.push(def.instantiate(ref.config)); + } + return result; +} diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index c72fb096d4..d00e83038f 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -14,7 +14,12 @@ * limitations under the License. */ -import { GenkitError, stripUndefinedProps, z } from '@genkit-ai/core'; +import { + ActionRunOptions, + GenkitError, + stripUndefinedProps, + z, +} from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import type { Registry } from '@genkit-ai/core/registry'; import type { @@ -81,10 +86,13 @@ export function toPendingOutput( }; } +import { GenerateMiddlewareDef } from './middleware.js'; + export async function resolveToolRequest( rawRequest: GenerateActionOptions, part: ToolRequestPart, toolMap: Record, + middleware: GenerateMiddlewareDef[] = [], runOptions?: ToolRunOptions ): Promise<{ response?: ToolResponsePart; @@ -100,72 +108,106 @@ export async function resolveToolRequest( }); } - // if it's a prompt action, go ahead and render the preamble - if (isPromptAction(tool)) { - const metadata = tool.__action.metadata as Record; - const preamble = { - ...(await tool(part.toolRequest.input)), - model: metadata.prompt?.model, - }; - const response = { - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: `transferred to ${part.toolRequest.name}`, - }, - }; - - return { preamble, response }; - } - - // otherwise, execute the tool and catch interrupts - try { - const output = await tool(part.toolRequest.input, toRunOptions(part)); - if (tool.__action.actionType === 'tool.v2') { - const multipartResponse = output as z.infer< - typeof MultipartToolResponseSchema - >; - const response = stripUndefinedProps({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: multipartResponse.output, - content: multipartResponse.content, - } as ToolResponse, - }); - - return { response }; + const dispatch = async ( + index: number, + req: ToolRequestPart, + ctx: ActionRunOptions + ): Promise<{ + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + }> => { + if (index === middleware.length) { + return executeTool(req, ctx); + } + const currentMiddleware = middleware[index]; + if (currentMiddleware.tool) { + return currentMiddleware.tool(req, ctx, async (modifiedReq, opts) => + dispatch(index + 1, modifiedReq || req, opts || ctx) + ); } else { - const response = stripUndefinedProps({ + return dispatch(index + 1, req, ctx); + } + }; + + const executeTool = async ( + req: ToolRequestPart, + ctx: ActionRunOptions + ): Promise<{ + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + }> => { + // if it's a prompt action, go ahead and render the preamble + if (isPromptAction(tool)) { + const metadata = tool.__action.metadata as Record; + const preamble = { + ...(await tool(req.toolRequest.input)), + model: metadata.prompt?.model, + }; + const response = { toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output, + name: req.toolRequest.name, + ref: req.toolRequest.ref, + output: `transferred to ${req.toolRequest.name}`, }, - }); + }; - return { response }; + return { preamble, response }; } - } catch (e) { - if ( - e instanceof ToolInterruptError || - // There's an inexplicable case when the above type check fails, only in tests. - (e as Error).name === 'ToolInterruptError' - ) { - const ie = e as ToolInterruptError; - logger.debug( - `tool '${toolMap[part.toolRequest?.name].__action.name}' triggered an interrupt${ie.metadata ? `: ${JSON.stringify(ie.metadata)}` : ''}` - ); - const interrupt = { - toolRequest: part.toolRequest, - metadata: { ...part.metadata, interrupt: ie.metadata || true }, - }; - return { interrupt }; + // otherwise, execute the tool and catch interrupts + try { + const output = await tool(req.toolRequest.input, ctx as ToolRunOptions); + if (tool.__action.actionType === 'tool.v2') { + const multipartResponse = output as z.infer< + typeof MultipartToolResponseSchema + >; + const response = stripUndefinedProps({ + toolResponse: { + name: req.toolRequest.name, + ref: req.toolRequest.ref, + output: multipartResponse.output, + content: multipartResponse.content, + } as ToolResponse, + }); + + return { response }; + } else { + const response = stripUndefinedProps({ + toolResponse: { + name: req.toolRequest.name, + ref: req.toolRequest.ref, + output, + }, + }); + + return { response }; + } + } catch (e) { + if ( + e instanceof ToolInterruptError || + // There's an inexplicable case when the above type check fails, only in tests. + (e as Error).name === 'ToolInterruptError' + ) { + const ie = e as ToolInterruptError; + logger.debug( + `tool '${toolMap[req.toolRequest?.name].__action.name}' triggered an interrupt${ie.metadata ? `: ${JSON.stringify(ie.metadata)}` : ''}` + ); + const interrupt = { + toolRequest: req.toolRequest, + metadata: { ...req.metadata, interrupt: ie.metadata || true }, + }; + + return { interrupt }; + } + + throw e; } + }; - throw e; - } + const initialCtx = runOptions ?? toRunOptions(part); + return dispatch(0, part, initialCtx); } /** @@ -174,15 +216,16 @@ export async function resolveToolRequest( * if a prompt tool is called */ export async function resolveToolRequests( - registry: Registry, rawRequest: GenerateActionOptions, - generatedMessage: MessageData + generatedMessage: MessageData, + tools: ToolAction[], + middleware: GenerateMiddlewareDef[] = [] ): Promise<{ revisedModelMessage?: MessageData; toolMessage?: MessageData; transferPreamble?: GenerateActionOptions; }> { - const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); + const toolMap = toToolMap(tools); const responseParts: ToolResponsePart[] = []; let hasInterrupts = false; @@ -200,7 +243,8 @@ export async function resolveToolRequests( const { preamble, response, interrupt } = await resolveToolRequest( rawRequest, part as ToolRequestPart, - toolMap + toolMap, + middleware ); if (preamble) { @@ -268,7 +312,8 @@ function findCorrespondingToolResponse( async function resolveResumedToolRequest( rawRequest: GenerateActionOptions, part: ToolRequestPart, - toolMap: Record + toolMap: Record, + middleware: GenerateMiddlewareDef[] = [] ): Promise<{ toolRequest?: ToolRequestPart; toolResponse?: ToolResponsePart; @@ -320,7 +365,8 @@ async function resolveResumedToolRequest( const { response, interrupt, preamble } = await resolveToolRequest( rawRequest, restartRequest, - toolMap + toolMap, + middleware ); if (preamble) { @@ -357,7 +403,8 @@ async function resolveResumedToolRequest( /** Amends message history to handle `resume` arguments. Returns the amended history. */ export async function resolveResumeOption( registry: Registry, - rawRequest: GenerateActionOptions + rawRequest: GenerateActionOptions, + middleware: GenerateMiddlewareDef[] = [] ): Promise<{ revisedRequest?: GenerateActionOptions; interruptedResponse?: GenerateResponseData; @@ -389,7 +436,8 @@ export async function resolveResumeOption( const resolved = await resolveResumedToolRequest( rawRequest, part, - toolMap + toolMap, + middleware ); if (resolved.interrupt) { interrupted = true; @@ -444,7 +492,8 @@ export async function resolveResumeOption( export async function resolveRestartedTools( registry: Registry, - rawRequest: GenerateActionOptions + rawRequest: GenerateActionOptions, + middleware: GenerateMiddlewareDef[] = [] ): Promise { const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); const lastMessage = rawRequest.messages.at(-1); @@ -459,7 +508,8 @@ export async function resolveRestartedTools( const { response, interrupt } = await resolveToolRequest( rawRequest, p, - toolMap + toolMap, + middleware ); // this means that it interrupted *again* after the restart diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index 7baccf7ee6..b850c6ae6c 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -55,6 +55,15 @@ export { type ResumeOptions, type ToolChoice, } from './generate.js'; +export { + MiddlewareDescSchema, + MiddlewareRefSchema, + generateMiddleware, + type GenerateMiddleware, + type GenerateMiddlewareDef, + type MiddlewareDesc, + type MiddlewareRef, +} from './generate/middleware.js'; export { Message } from './message.js'; export { GenerateResponseChunkSchema, @@ -83,6 +92,7 @@ export { type ToolResponsePart, } from './model.js'; export { type ToolRequest, type ToolResponse } from './parts.js'; +export { type GenkitPluginV2 } from './plugin.js'; export { defineHelper, definePartial, diff --git a/js/ai/src/model-types.ts b/js/ai/src/model-types.ts index cf20755ac8..a149db5a0a 100644 --- a/js/ai/src/model-types.ts +++ b/js/ai/src/model-types.ts @@ -16,6 +16,7 @@ import { OperationSchema, z } from '@genkit-ai/core'; import { DocumentDataSchema } from './document.js'; +import { MiddlewareRefSchema } from './generate/middleware.js'; import { CustomPartSchema, DataPartSchema, @@ -416,5 +417,7 @@ export const GenerateActionOptionsSchema = z.object({ maxTurns: z.number().optional(), /** Custom step name for this generate call to display in trace views. Defaults to "generate". */ stepName: z.string().optional(), + /** Middleware to apply to this generation. */ + use: z.array(MiddlewareRefSchema).optional(), }); export type GenerateActionOptions = z.infer; diff --git a/js/ai/src/plugin.ts b/js/ai/src/plugin.ts new file mode 100644 index 0000000000..6b3a69f1f4 --- /dev/null +++ b/js/ai/src/plugin.ts @@ -0,0 +1,27 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BaseGenkitPluginV2 } from '@genkit-ai/core'; +import { GenerateMiddleware } from './generate/middleware.js'; +import { ModelAction } from './model.js'; + +export interface GenkitPluginV2 extends BaseGenkitPluginV2 { + // Returns a list of generate middleware to be used in `generate({use: [...])`. + generateMiddleware?: () => GenerateMiddleware[]; + + // A shortcut for resolving a model. + model(name: string): Promise; +} diff --git a/js/ai/tests/generate/action_test.ts b/js/ai/tests/generate/action_test.ts index c918f1a2dd..e80449d627 100644 --- a/js/ai/tests/generate/action_test.ts +++ b/js/ai/tests/generate/action_test.ts @@ -25,13 +25,14 @@ import { defineGenerateAction, type GenerateAction, } from '../../src/generate/action.js'; +import { generateMiddleware } from '../../src/generate/middleware.js'; import { GenerateActionOptionsSchema, GenerateResponseChunkSchema, GenerateResponseSchema, type GenerateResponseChunkData, } from '../../src/model.js'; -import { defineTool } from '../../src/tool.js'; +import { defineTool, tool } from '../../src/tool.js'; import { defineProgrammableModel, type ProgrammableModel } from '../helpers.js'; initNodeFeatures(); @@ -113,3 +114,52 @@ describe('spec', () => { }); }); }); + +describe('generateAction middleware injection', () => { + let registry: Registry; + let pm: ProgrammableModel; + + beforeEach(() => { + registry = new Registry(); + defineGenerateAction(registry); + pm = defineProgrammableModel(registry); + }); + + it('supports injecting tools through middleware definitions directly via action route', async () => { + const injectedTool = tool( + { + name: 'injectedTool', + description: 'desc', + inputSchema: z.object({ arg: z.string() }), + }, + async (input) => `Result: ${input.arg}` + ); + + let toolsSeen = false; + pm.handleResponse = async (req) => { + if (req.tools?.find((t) => t.name === 'injectedTool')) { + toolsSeen = true; + } + return { + message: { role: 'model', content: [{ text: 'done' }] }, + finishReason: 'stop', + } as any; + }; + + const dummyMw = generateMiddleware({ name: 'dummyMw' }, () => ({ + tools: [injectedTool], + })); + + const action = await registry.lookupAction('/util/generate'); + await action({ + model: 'programmableModel', + messages: [{ role: 'user', content: [{ text: 'test' }] }], + use: [dummyMw()], + } as any); + + assert.ok( + toolsSeen, + 'Tool was not successfully passed to the model from action generated route.' + ); + }); +}); diff --git a/js/ai/tests/generate/middleware_test.ts b/js/ai/tests/generate/middleware_test.ts new file mode 100644 index 0000000000..706d32795d --- /dev/null +++ b/js/ai/tests/generate/middleware_test.ts @@ -0,0 +1,578 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from '@genkit-ai/core'; +import { initNodeFeatures } from '@genkit-ai/core/node'; +import { Registry } from '@genkit-ai/core/registry'; +import * as assert from 'assert'; +import { beforeEach, describe, it } from 'node:test'; +import { generate, generateStream } from '../../src/generate.js'; +import { generateMiddleware } from '../../src/generate/middleware.js'; +import { defineModel } from '../../src/model.js'; +import { defineTool, tool } from '../../src/tool.js'; + +initNodeFeatures(); + +describe('generateMiddleware', () => { + let registry: Registry; + + beforeEach(() => { + registry = new Registry(); + }); + + it('runs generate and model middleware in the correct order', async () => { + const executionOrder: string[] = []; + + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async (req) => { + executionOrder.push('modelExecution'); + return { + message: { + role: 'model', + content: [{ text: 'response' }], + }, + }; + } + ); + + const testMiddleware = generateMiddleware( + { name: 'testMiddleware' }, + () => ({ + generate: async (req, ctx, next) => { + executionOrder.push('generateBefore'); + const res = await next(req, ctx); + executionOrder.push('generateAfter'); + return res; + }, + model: async (req, ctx, next) => { + executionOrder.push('modelBefore'); + const res = await next(req, ctx); + executionOrder.push('modelAfter'); + return res; + }, + }) + ); + + await generate(registry, { + model: mockModel, + prompt: 'hi', + use: [testMiddleware()], + }); + + assert.deepStrictEqual(executionOrder, [ + 'generateBefore', + 'modelBefore', + 'modelExecution', + 'modelAfter', + 'generateAfter', + ]); + }); + + it('runs tool middleware correctly', async () => { + const executionOrder: string[] = []; + + const mockTool = defineTool( + registry, + { + name: 'mockTool', + description: 'A mock tool', + inputSchema: z.object({}), + outputSchema: z.string(), + }, + async () => { + executionOrder.push('toolExecution'); + return 'tool output'; + } + ); + + let turns = 0; + const mockModel = defineModel( + registry, + { name: 'mockModelWithTool' }, + async (req) => { + executionOrder.push('modelExecution'); + turns++; + if (turns === 1) { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: mockTool.__action.name, + ref: '123', + input: {}, + }, + }, + ], + }, + }; + } else { + return { + message: { + role: 'model', + content: [{ text: 'final response' }], + }, + }; + } + } + ); + + const testMiddleware = generateMiddleware( + { name: 'testMiddleware' }, + () => ({ + generate: async (req, ctx, next) => { + executionOrder.push('generateBefore'); + const res = await next(req, ctx); + executionOrder.push('generateAfter'); + return res; + }, + model: async (req, ctx, next) => { + executionOrder.push('modelBefore'); + const res = await next(req, ctx); + executionOrder.push('modelAfter'); + return res; + }, + tool: async (req, ctx, next) => { + executionOrder.push('toolBefore'); + const res = await next(req, ctx); + executionOrder.push('toolAfter'); + return res; + }, + }) + ); + + await generate(registry, { + model: mockModel, + tools: [mockTool], + prompt: 'hi', + use: [testMiddleware()], + }); + + assert.deepStrictEqual(executionOrder, [ + 'generateBefore', + 'modelBefore', // Turn 1 + 'modelExecution', + 'modelAfter', + 'toolBefore', // Tool execution + 'toolExecution', + 'toolAfter', + 'modelBefore', // Turn 2 + 'modelExecution', + 'modelAfter', + 'generateAfter', + ]); + }); + + it('supports configuration and old-style function middleware', async () => { + let configValue = ''; + + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async (req) => { + return { + message: { + role: 'model', + content: [{ text: 'response' }], + }, + }; + } + ); + + const testMiddleware = generateMiddleware( + { name: 'configMw', configSchema: z.object({ val: z.string() }) }, + (config) => ({ + model: async (req, ctx, next) => { + configValue = config?.val || ''; + return next(req, ctx); + }, + }) + ); + + let oldStyleExecuted = false; + const oldStyleMiddleware = async (req: any, next: any) => { + oldStyleExecuted = true; + return next(req); + }; + + await generate(registry, { + model: mockModel, + prompt: 'test', + use: [testMiddleware({ val: 'test_config' }), oldStyleMiddleware], + }); + + assert.strictEqual(configValue, 'test_config'); + assert.strictEqual(oldStyleExecuted, true); + }); + + it('supports pre-registered middleware (e.g. installed via plugin)', async () => { + let executed = false; + let configValue = ''; + + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async (req) => { + return { + message: { + role: 'model', + content: [{ text: 'response' }], + }, + }; + } + ); + + const preRegisteredMw = generateMiddleware( + { name: 'preRegisteredMw', configSchema: z.object({ val: z.string() }) }, + (config) => ({ + model: async (req, ctx, next) => { + executed = true; + configValue = config?.val || ''; + return next(req, ctx); + }, + }) + ); + + // Act as a plugin registering the middleware + const myPlugin = preRegisteredMw.plugin({ val: 'plugin_config' }); + assert.ok(myPlugin.generateMiddleware); + myPlugin.generateMiddleware().forEach((mw: any) => { + registry.registerValue('middleware', mw.name, mw); + }); + + await generate(registry, { + model: mockModel, + prompt: 'test', + use: [{ name: 'preRegisteredMw' }], + }); + + assert.strictEqual(executed, true); + assert.strictEqual(configValue, 'plugin_config'); + }); + + it('throws an error if a middleware factory is passed without being called', async () => { + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async () => ({ + message: { role: 'model', content: [{ text: 'done' }] }, + finishReason: 'stop', + }) + ); + const streamModifyingMw = generateMiddleware({ name: 'dummy' }, () => ({})); + + await assert.rejects( + async () => { + await generate(registry, { + model: mockModel, + prompt: 'test', + use: [streamModifyingMw as any], + }); + }, + (err: any) => { + assert.strictEqual(err.name, 'GenkitError'); + assert.match(err.message, /must be called with \(\)/); + return true; + } + ); + }); + + it('can intercept and modify the stream from model and generate interceptors', async () => { + const chunkIntercepts: string[] = []; + + const mockStreamingModel = defineModel( + registry, + { name: 'mockStreamingModel' }, + async (req, streamingCallback) => { + if (streamingCallback) { + streamingCallback({ content: [{ text: 'chunk1' }] }); + streamingCallback({ content: [{ text: 'chunk2' }] }); + } + return { + message: { + role: 'model', + content: [{ text: 'chunk1chunk2' }], + }, + finishReason: 'stop', + }; + } + ); + + const streamModifyingMw = generateMiddleware( + { name: 'streamModifier' }, + () => ({ + model: async (req, ctx, next) => { + const originalOnChunk = ctx.onChunk; + let interceptedCtx = ctx; + if (originalOnChunk) { + interceptedCtx = { + ...ctx, + onChunk: (chunk) => { + chunkIntercepts.push(`model_mw: ${chunk.content[0].text}`); + chunk.content[0].text = chunk.content[0].text.toUpperCase(); + originalOnChunk(chunk); + }, + }; + } + return next(req, interceptedCtx); + }, + generate: async (req, ctx, next) => { + const originalOnChunk = ctx.onChunk; + let interceptedCtx = ctx; + if (originalOnChunk) { + interceptedCtx = { + ...ctx, + onChunk: (chunk) => { + chunkIntercepts.push(`gen_mw: ${chunk.content[0].text}`); + chunk.content[0].text = `[${chunk.content[0].text}]`; + originalOnChunk(chunk); + }, + }; + } + const res = await next(req, interceptedCtx); + if (res.message) { + return { + ...res, + message: { + ...res.message, + content: [ + { text: `modified_result: ${res.message.content[0].text}` }, + ], + }, + }; + } + return res; + }, + }) + ); + + let finalChunks: string[] = []; + + const { response, stream } = generateStream(registry, { + model: mockStreamingModel, + prompt: 'test streaming mw', + use: [streamModifyingMw()], + }); + + for await (const chunk of stream) { + finalChunks.push(chunk.text); + } + + const res = await response; + + assert.deepStrictEqual(chunkIntercepts, [ + 'model_mw: chunk1', + 'gen_mw: CHUNK1', + 'model_mw: chunk2', + 'gen_mw: CHUNK2', + ]); + + assert.deepStrictEqual(finalChunks, ['[CHUNK1]', '[CHUNK2]']); + assert.strictEqual(res.text, 'modified_result: chunk1chunk2'); + }); + + it('executes multiple middleware in the correct order', async () => { + const executionOrder: string[] = []; + + const mw1 = generateMiddleware({ name: 'mw1' }, () => ({ + async generate(opts, ctx, next) { + executionOrder.push('mw1:gen:start'); + const res = await next(opts, ctx); + executionOrder.push('mw1:gen:end'); + return res; + }, + async model(req, ctx, next) { + executionOrder.push('mw1:model:start'); + const res = await next(req, ctx); + executionOrder.push('mw1:model:end'); + return res; + }, + })); + + const mw2 = generateMiddleware({ name: 'mw2' }, () => ({ + async generate(opts, ctx, next) { + executionOrder.push('mw2:gen:start'); + const res = await next(opts, ctx); + executionOrder.push('mw2:gen:end'); + return res; + }, + async model(req, ctx, next) { + executionOrder.push('mw2:model:start'); + const res = await next(req, ctx); + executionOrder.push('mw2:model:end'); + return res; + }, + })); + + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async () => ({ + message: { role: 'model', content: [{ text: 'done' }] }, + finishReason: 'stop', + }) + ); + + await generate(registry, { + model: mockModel, + prompt: 'test multiple', + use: [mw1(), mw2()], + }); + + // The entire 'generate' layer runs before we ever descend to the 'model' level + assert.deepStrictEqual(executionOrder, [ + 'mw1:gen:start', + 'mw2:gen:start', + 'mw1:model:start', + 'mw2:model:start', + 'mw2:model:end', + 'mw1:model:end', + 'mw2:gen:end', + 'mw1:gen:end', + ]); + }); + + it('supports a combination of new middleware and old-style functional middleware', async () => { + const executionOrder: string[] = []; + + const newMw = generateMiddleware({ name: 'newMw' }, () => ({ + async generate(opts, ctx, next) { + executionOrder.push('newMw:gen:start'); + const res = await next(opts, ctx); + executionOrder.push('newMw:gen:end'); + return res; + }, + async model(req, ctx, next) { + executionOrder.push('newMw:model:start'); + const res = await next(req, ctx); + executionOrder.push('newMw:model:end'); + return res; + }, + })); + + const oldMw1 = async (req: any, next: any) => { + executionOrder.push('oldMw1:model:start'); + const res = await next(); // Validating 0-argument backwards-compatibility + executionOrder.push('oldMw1:model:end'); + return res; + }; + + const oldMw2 = async (req: any, ctx: any, next: any) => { + executionOrder.push('oldMw2:model:start'); + const res = await next(req, ctx); + executionOrder.push('oldMw2:model:end'); + return res; + }; + + const mockModel = defineModel( + registry, + { name: 'mockModel' }, + async () => ({ + message: { role: 'model', content: [{ text: 'done' }] }, + finishReason: 'stop', + }) + ); + + await generate(registry, { + model: mockModel, + prompt: 'test mixed', + use: [oldMw1, newMw(), oldMw2], + }); + + assert.deepStrictEqual(executionOrder, [ + 'newMw:gen:start', // Generate level ALWAYS runs first across full array + 'oldMw1:model:start', + 'newMw:model:start', + 'oldMw2:model:start', + 'oldMw2:model:end', + 'newMw:model:end', + 'oldMw1:model:end', + 'newMw:gen:end', + ]); + }); + + it('injects tools from new-style generateMiddleware and executes tool requests', async () => { + let toolExecutionCount = 0; + + const injectedTool = tool( + { + name: 'injectedTool', + description: 'injected tool description', + inputSchema: z.object({ arg: z.string() }), + outputSchema: z.string(), + }, + async (input) => { + toolExecutionCount++; + return `Result: ${input.arg}`; + } + ); + + const toolMiddleware = generateMiddleware({ name: 'toolMw' }, () => ({ + tools: [injectedTool], + })); + + let callCount = 0; + const mockToolModel = defineModel( + registry, + { name: 'mockToolModel' }, + async (req) => { + callCount++; + // Assert that the tools sent to the model include the injected tool + assert.ok(req.tools?.find((t) => t.name === 'injectedTool')); + + if (callCount === 1) { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'injectedTool', + ref: 'call_1', + input: { arg: 'hello' }, + }, + }, + ], + }, + finishReason: 'stop', + }; + } else { + assert.strictEqual(req.messages[2].role, 'tool'); + const toolData = req.messages[2].content[0].toolResponse; + assert.strictEqual(toolData?.name, 'injectedTool'); + assert.strictEqual(toolData?.output, 'Result: hello'); + + return { + message: { role: 'model', content: [{ text: 'final response' }] }, + finishReason: 'stop', + }; + } + } + ); + + const result = await generate(registry, { + model: mockToolModel, + prompt: 'test tools', + use: [toolMiddleware()], + }); + + assert.strictEqual(result.text, 'final response'); + assert.strictEqual(toolExecutionCount, 1); + }); +}); diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index 3e06d6e47c..2c255fd156 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -16,6 +16,7 @@ import type { z } from 'zod'; import type { Action, ActionMetadata } from './action.js'; +import type { BackgroundAction } from './background-action.js'; import type { ActionType } from './registry.js'; export interface Provider { @@ -48,3 +49,16 @@ export interface InitializedPlugin { } export type Plugin = (...args: T) => PluginProvider; + +export type ResolvableAction = Action | BackgroundAction; + +export interface BaseGenkitPluginV2 { + version: 'v2'; + name: string; + init?: () => ResolvableAction[] | Promise; + resolve?: ( + actionType: ActionType, + name: string + ) => ResolvableAction | undefined | Promise; + list?: () => ActionMetadata[] | Promise; +} diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index a63959bf13..8ca1da3892 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -167,7 +167,7 @@ export class ReflectionServer { response.status(400).send('Query parameter "type" is required.'); return; } - if (type !== 'defaultModel') { + if (type !== 'defaultModel' && type !== 'middleware') { response .status(400) .send( @@ -175,8 +175,18 @@ export class ReflectionServer { ); return; } - const values = await this.registry.listValues(type as string); - response.send(values); + const values = Object.values( + await this.registry.listValues(type as string) + ); + + response.send( + values.map((v: any) => { + if (typeof v.toJson === 'function') { + return v.toJson(); + } + return v; + }) + ); } catch (err) { const { message, stack } = err as Error; next({ message, stack }); diff --git a/js/core/tests/reflection_test.ts b/js/core/tests/reflection_test.ts new file mode 100644 index 0000000000..ddc58b54f6 --- /dev/null +++ b/js/core/tests/reflection_test.ts @@ -0,0 +1,102 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import getPort from 'get-port'; +import * as http from 'http'; +import { afterEach, beforeEach, describe, it } from 'node:test'; +import { ReflectionServer } from '../src/reflection.js'; +import { Registry } from '../src/registry.js'; + +describe('ReflectionServer API', () => { + let registry: Registry; + let server: ReflectionServer; + let port: number; + + beforeEach(async () => { + registry = new Registry(); + server = new ReflectionServer(registry, { port: await getPort() }); + await server.start(); + port = (server as any).server.address().port; + }); + + afterEach(async () => { + await server.stop(); + }); + + async function fetchApi(path: string) { + return new Promise<{ status: number; body: any }>((resolve, reject) => { + http + .get(`http://localhost:${port}${path}`, (res) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk; + }); + res.on('end', () => { + try { + resolve({ + status: res.statusCode || 200, + body: JSON.parse(data), + }); + } catch (e) { + resolve({ + status: res.statusCode || 200, + body: data, + }); + } + }); + }) + .on('error', reject); + }); + } + + it('rejects missing type parameter for /api/values', async () => { + const res = await fetchApi('/api/values'); + assert.strictEqual(res.status, 400); + assert.strictEqual(res.body, 'Query parameter "type" is required.'); + }); + + it('rejects unsupported type parameter for /api/values', async () => { + const res = await fetchApi('/api/values?type=foo'); + assert.strictEqual(res.status, 400); + assert.match(res.body, /is not supported/); + }); + + it('returns defaultModel values', async () => { + registry.registerValue('defaultModel', 'testModel', 'my-model'); + const res = await fetchApi('/api/values?type=defaultModel'); + assert.strictEqual(res.status, 200); + assert.deepStrictEqual(res.body, ['my-model']); + }); + + it('returns middleware values mapped via toJson if available', async () => { + registry.registerValue('middleware', 'mw1', { + name: 'mw1', + __def: {}, + toJson: () => ({ name: 'mw1', description: 'test mw1' }), + }); + registry.registerValue('middleware', 'mw2', { + name: 'mw2', // No toJson + }); + + const res = await fetchApi('/api/values?type=middleware'); + assert.strictEqual(res.status, 200); + assert.deepStrictEqual(res.body, [ + { name: 'mw1', description: 'test mw1' }, + { name: 'mw2' }, + ]); + }); +}); diff --git a/js/docs/generate-middleware.md b/js/docs/generate-middleware.md new file mode 100644 index 0000000000..a95c095f50 --- /dev/null +++ b/js/docs/generate-middleware.md @@ -0,0 +1,85 @@ +# Generate Middleware + +Middleware in Genkit JS allows you to intercept, inspect, and modify the execution of models and tools during a `generate` call. +This is useful for implementing cross-cutting concerns like logging, telemetry, caching, and retry logic. + +## Defining Middleware + +Genkit provides a `generateMiddleware` helper to create configurable middleware that can be distributed as plugins. + +Middleware hooks into different stages of generation: + +- `generate`: Wraps the entire generation process (including the tool loop). Called for each tool call iteration. +- `model`: Wraps the call to the model implementation. Called for each model call. +- `tool`: Wraps the execution of independent tool calls. Called once per tool request. + +```typescript +import { generateMiddleware } from '@genkit-ai/ai'; +import { z } from 'zod'; + +export const myLogger = generateMiddleware( + { + name: 'myLogger', + configSchema: z.object({ + verbose: z.boolean().optional(), + }), + }, + (config) => ({ + async generate(options, ctx, next) { + if (config?.verbose) { + console.log( + 'Generate started with options:', + JSON.stringify(options, null, 2) + ); + } + const result = await next(options, ctx); + if (config?.verbose) { + console.log('Generate finished:', result); + } + return result; + }, + async model(request, ctx, next) { + console.log('Model called:', request); + return next(request, ctx); + }, + async tool(request, ctx, next) { + console.log('Tool called:', request.toolRequest.name); + return next(request, ctx); + }, + // Inject additional tools into the generation + tools: [ + // myCustomTool + ], + }) +); +``` + +## Usage + +You can use the defined middleware directly in your `generate` calls: + +```typescript +import { generate } from '@genkit-ai/ai'; +import { myLogger } from './my-logger'; + +await generate({ + model: 'googleai/gemini-1.5-flash', + prompt: 'Hello', + use: [myLogger({ verbose: true })], +}); +``` + +## Registering as a Plugin + +If you want to register the middleware globally or make it available via the registry (e.g. for inspection tools), you can use the `.plugin()` method: + +```typescript +import { genkit } from 'genkit'; +import { myLogger } from './my-logger'; + +const ai = genkit({ + plugins: [ + myLogger.plugin(), // Can pass default config here + ], +}); +``` diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index dc59372f39..07f52dc2b1 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -28,6 +28,8 @@ export { LlmStatsSchema, Message, MessageSchema, + MiddlewareDescSchema, + MiddlewareRefSchema, ModelRequestSchema, ModelResponseSchema, PartSchema, @@ -39,6 +41,7 @@ export { embedderActionMetadata, embedderRef, evaluatorRef, + generateMiddleware, indexerRef, modelActionMetadata, modelRef, @@ -59,6 +62,8 @@ export { type EvaluatorParams, type EvaluatorReference, type ExecutablePrompt, + type GenerateMiddleware, + type GenerateMiddlewareDef, type GenerateOptions, type GenerateRequest, type GenerateRequestData, @@ -77,6 +82,8 @@ export { type LlmStats, type MediaPart, type MessageData, + type MiddlewareDesc, + type MiddlewareRef, type ModelArgument, type ModelReference, type ModelRequest, diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 5053f6dbfd..118e475a2b 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -977,6 +977,14 @@ export class Genkit implements HasRegistry { resolvedActions?.forEach((resolvedAction) => { registerActionV2(activeRegistry, resolvedAction, plugin); }); + const definedMiddleware = plugin.generateMiddleware?.(); + definedMiddleware?.forEach((middleware) => { + activeRegistry.registerValue( + 'middleware', + middleware.name, + middleware + ); + }); }, async resolver(action: ActionType, target: string) { if (!plugin.resolve) return; diff --git a/js/genkit/src/plugin.ts b/js/genkit/src/plugin.ts index 85d5966569..f44ad8c16c 100644 --- a/js/genkit/src/plugin.ts +++ b/js/genkit/src/plugin.ts @@ -14,15 +14,17 @@ * limitations under the License. */ +import type { GenerateMiddleware } from '@genkit-ai/ai'; +import { type GenkitPluginV2 } from '@genkit-ai/ai'; import { type ModelAction } from '@genkit-ai/ai/model'; import { GenkitError, - type Action, type ActionMetadata, - type BackgroundAction, + type ResolvableAction, } from '@genkit-ai/core'; import type { Genkit } from './genkit.js'; import type { ActionType } from './registry.js'; + export { embedder, embedderActionMetadata } from '@genkit-ai/ai/embedder'; export { evaluator } from '@genkit-ai/ai/evaluator'; export { @@ -32,6 +34,8 @@ export { } from '@genkit-ai/ai/model'; export { reranker } from '@genkit-ai/ai/reranker'; export { indexer, retriever } from '@genkit-ai/ai/retriever'; +export { type GenkitPluginV2, type ResolvableAction }; + export interface PluginProvider { name: string; initializer: () => void | Promise; @@ -39,22 +43,6 @@ export interface PluginProvider { listActions?: () => Promise; } -export type ResolvableAction = Action | BackgroundAction; - -export interface GenkitPluginV2 { - version: 'v2'; - name: string; - init?: () => ResolvableAction[] | Promise; - resolve?: ( - actionType: ActionType, - name: string - ) => ResolvableAction | undefined | Promise; - list?: () => ActionMetadata[] | Promise; - - // A shortcut for resolving a model. - model(name: string): Promise; -} - export type GenkitPlugin = (genkit: Genkit) => PluginProvider; export type PluginInit = (genkit: Genkit) => void | Promise; @@ -118,6 +106,13 @@ export class GenkitPluginV2Instance implements Required { return this.plugin.list(); } + generateMiddleware(): GenerateMiddleware[] { + if (!this.plugin.generateMiddleware) { + return []; + } + return this.plugin.generateMiddleware(); + } + resolve( actionType: ActionType, name: string diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index ecdcb783a8..89b29f9900 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -19,6 +19,7 @@ import { ModelAction } from '@genkit-ai/ai/model'; import { Operation, z, type JSONSchema7 } from '@genkit-ai/core'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; +import { generateMiddleware } from '../../ai/src/generate/middleware'; import { modelRef } from '../../ai/src/model'; import { interrupt } from '../../ai/src/tool'; import { @@ -114,6 +115,42 @@ describe('generate', () => { }); }); + it('works with a middleware plugin', async () => { + let middlewareExecuted = false; + const myMiddleware = generateMiddleware( + { name: 'myMiddleware', configSchema: z.string() }, + (config) => { + return { + model: async (req, ctx, next) => { + middlewareExecuted = true; + return { + request: req, + finishReason: 'stop', + message: { + role: 'model', + content: [{ text: `${config}: hi` }], + }, + }; + }, + }; + } + ); + + const aiWithPlugin = genkit({ + model: 'echoModel', + plugins: [myMiddleware.plugin()], + }); + defineEchoModel(aiWithPlugin); + + const response = await aiWithPlugin.generate({ + prompt: 'hi', + use: [myMiddleware('z-prefix')], + }); + + assert.strictEqual(response.text, 'z-prefix: hi'); + assert.strictEqual(middlewareExecuted, true); + }); + it('streams the default model', async () => { const { response, stream } = await ai.generateStream('hi'); From 2c31fa9bc4e055e4df22100ea29ad0b0adc0c033 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Sun, 8 Feb 2026 20:40:58 -0500 Subject: [PATCH 02/10] feat: Apply middleware to all turns of multi-turn generation and update the test to verify multi-turn execution. --- js/ai/src/generate/action.ts | 2 +- js/ai/tests/generate/middleware_test.ts | 62 ++++++++++++++----------- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 72ecc80f4d..e60f61ee05 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -277,7 +277,7 @@ async function generateActionImpl( context, } = args; - if (currentTurn === 0 && middleware && middleware.length > 0) { + if (middleware && middleware.length > 0) { const dispatchGenerate = async ( index: number, req: GenerateActionOptions, diff --git a/js/ai/tests/generate/middleware_test.ts b/js/ai/tests/generate/middleware_test.ts index 706d32795d..b301ed6534 100644 --- a/js/ai/tests/generate/middleware_test.ts +++ b/js/ai/tests/generate/middleware_test.ts @@ -135,26 +135,30 @@ describe('generateMiddleware', () => { const testMiddleware = generateMiddleware( { name: 'testMiddleware' }, - () => ({ - generate: async (req, ctx, next) => { - executionOrder.push('generateBefore'); - const res = await next(req, ctx); - executionOrder.push('generateAfter'); - return res; - }, - model: async (req, ctx, next) => { - executionOrder.push('modelBefore'); - const res = await next(req, ctx); - executionOrder.push('modelAfter'); - return res; - }, - tool: async (req, ctx, next) => { - executionOrder.push('toolBefore'); - const res = await next(req, ctx); - executionOrder.push('toolAfter'); - return res; - }, - }) + () => { + let turnCount = 0; + return { + generate: async (req, ctx, next) => { + const t = ++turnCount; + executionOrder.push('generateBefore-' + t); + const res = await next(req, ctx); + executionOrder.push('generateAfter-' + t); + return res; + }, + model: async (req, ctx, next) => { + executionOrder.push(`modelBefore-${turnCount}`); + const res = await next(req, ctx); + executionOrder.push(`modelAfter-${turnCount}`); + return res; + }, + tool: async (req, ctx, next) => { + executionOrder.push(`toolBefore-${turnCount}`); + const res = await next(req, ctx); + executionOrder.push(`toolAfter-${turnCount}`); + return res; + }, + }; + } ); await generate(registry, { @@ -165,17 +169,19 @@ describe('generateMiddleware', () => { }); assert.deepStrictEqual(executionOrder, [ - 'generateBefore', - 'modelBefore', // Turn 1 + 'generateBefore-1', + 'modelBefore-1', // Turn 1 'modelExecution', - 'modelAfter', - 'toolBefore', // Tool execution + 'modelAfter-1', + 'toolBefore-1', // Tool execution 'toolExecution', - 'toolAfter', - 'modelBefore', // Turn 2 + 'toolAfter-1', + 'generateBefore-2', + 'modelBefore-2', // Turn 2 'modelExecution', - 'modelAfter', - 'generateAfter', + 'modelAfter-2', + 'generateAfter-2', + 'generateAfter-1', ]); }); From 8ce3844999a6fc7ce49a17791e3f15c32dc3d076 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 9 Feb 2026 19:21:08 -0500 Subject: [PATCH 03/10] feat: Introduce a unified `GenerateAPI` class to centralize AI model interactions including generate, stream, and embed. --- foo.txt | 1 - js/ai/src/generate-api.ts | 343 ++++++++++++++++++ js/ai/src/generate.ts | 2 +- js/ai/src/generate/middleware.ts | 24 +- js/ai/src/index.ts | 8 +- js/ai/src/model-types.ts | 7 +- js/ai/src/model.ts | 11 +- js/ai/src/plugin.ts | 8 +- js/genkit/src/common.ts | 1 + js/genkit/src/genkit.ts | 338 +---------------- js/genkit/src/plugin.ts | 9 +- js/plugins/anthropic/src/index.ts | 18 +- js/plugins/compat-oai/src/openai/index.ts | 4 +- js/plugins/google-genai/src/googleai/index.ts | 4 +- js/plugins/google-genai/src/vertexai/index.ts | 6 +- 15 files changed, 423 insertions(+), 361 deletions(-) delete mode 100644 foo.txt create mode 100644 js/ai/src/generate-api.ts diff --git a/foo.txt b/foo.txt deleted file mode 100644 index eb857a6d0c..0000000000 --- a/foo.txt +++ /dev/null @@ -1 +0,0 @@ -model conformance specs and JS parity — I'll tackle those as a separate pass after the current PR is green. That'll involve reading the conformance KI artifacts and comparing JS plugin implementations against our Python ones to identify gaps. diff --git a/js/ai/src/generate-api.ts b/js/ai/src/generate-api.ts new file mode 100644 index 0000000000..9d12873369 --- /dev/null +++ b/js/ai/src/generate-api.ts @@ -0,0 +1,343 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + getContext, + run, + z, + type ActionContext, + type Operation, +} from '@genkit-ai/core'; +import { type Registry } from '@genkit-ai/core/registry'; +import { checkOperation } from './check-operation.js'; +import { type DocumentData } from './document.js'; +import { + embed, + embedMany, + type EmbedderArgument, + type EmbedderParams, + type Embedding, + type EmbeddingBatch, +} from './embedder.js'; +import { + generate, + generateStream, + type GenerateOptions, + type GenerateResponse, + type GenerateStreamOptions, + type GenerateStreamResponse, +} from './generate.js'; +import { GenerationCommonConfigSchema, type Part } from './model-types.js'; + +/** + * `GenerateAPI` encapsulates model generate APIs. + */ +export class GenerateAPI { + readonly registry: Registry; + + constructor(registry: Registry) { + this.registry = registry; + } + + /** + * Embeds the given `content` using the specified `embedder`. + */ + embed( + params: EmbedderParams + ): Promise { + return embed(this.registry, params); + } + + /** + * A veneer for interacting with embedder models in bulk. + */ + embedMany(params: { + embedder: EmbedderArgument; + content: string[] | DocumentData[]; + metadata?: Record; + options?: z.infer; + }): Promise { + return embedMany(this.registry, params); + } + + /** + * Make a generate call to the default model with a simple text prompt. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: googleAI.model('gemini-2.5-flash'), // default model + * }) + * + * const { text } = await ai.generate('hi'); + * ``` + */ + generate( + strPrompt: string + ): Promise>>; + + /** + * Make a generate call to the default model with a multipart request. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: googleAI.model('gemini-2.5-flash'), // default model + * }) + * + * const { text } = await ai.generate([ + * { media: {url: 'http://....'} }, + * { text: 'describe this image' } + * ]); + * ``` + */ + generate( + parts: Part[] + ): Promise>>; + + /** + * Generate calls a generative model based on the provided prompt and configuration. If + * `messages` is provided, the generation will include a conversation history in its + * request. If `tools` are provided, the generate method will automatically resolve + * tool calls returned from the model unless `returnToolRequests` is set to `true`. + * + * See {@link GenerateOptions} for detailed information about available options. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * }) + * + * const { text } = await ai.generate({ + * system: 'talk like a pirate', + * prompt: [ + * { media: { url: 'http://....' } }, + * { text: 'describe this image' } + * ], + * messages: conversationHistory, + * tools: [ userInfoLookup ], + * model: googleAI.model('gemini-2.5-flash'), + * }); + * ``` + */ + generate< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + opts: + | GenerateOptions + | PromiseLike> + ): Promise>>; + + async generate< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + options: + | string + | Part[] + | GenerateOptions + | PromiseLike> + ): Promise>> { + let resolvedOptions: GenerateOptions; + if (options instanceof Promise) { + resolvedOptions = await options; + } else if (typeof options === 'string' || Array.isArray(options)) { + resolvedOptions = { + prompt: options, + }; + } else { + resolvedOptions = options as GenerateOptions; + } + return generate(this.registry, resolvedOptions); + } + + /** + * Make a streaming generate call to the default model with a simple text prompt. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: googleAI.model('gemini-2.5-flash'), // default model + * }) + * + * const { response, stream } = ai.generateStream('hi'); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` + */ + generateStream( + strPrompt: string + ): GenerateStreamResponse>; + + /** + * Make a streaming generate call to the default model with a multipart request. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * model: googleAI.model('gemini-2.5-flash'), // default model + * }) + * + * const { response, stream } = ai.generateStream([ + * { media: {url: 'http://....'} }, + * { text: 'describe this image' } + * ]); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` + */ + generateStream( + parts: Part[] + ): GenerateStreamResponse>; + + /** + * Streaming generate calls a generative model based on the provided prompt and configuration. If + * `messages` is provided, the generation will include a conversation history in its + * request. If `tools` are provided, the generate method will automatically resolve + * tool calls returned from the model unless `returnToolRequests` is set to `true`. + * + * See {@link GenerateOptions} for detailed information about available options. + * + * ```ts + * const ai = genkit({ + * plugins: [googleAI()], + * }) + * + * const { response, stream } = ai.generateStream({ + * system: 'talk like a pirate', + * prompt: [ + * { media: { url: 'http://....' } }, + * { text: 'describe this image' } + * ], + * messages: conversationHistory, + * tools: [ userInfoLookup ], + * model: googleAI.model('gemini-2.5-flash'), + * }); + * for await (const chunk of stream) { + * console.log(chunk.text); + * } + * console.log((await response).text); + * ``` + */ + generateStream< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + parts: + | GenerateOptions + | PromiseLike> + ): GenerateStreamResponse>; + + generateStream< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, + >( + options: + | string + | Part[] + | GenerateStreamOptions + | PromiseLike> + ): GenerateStreamResponse> { + if (typeof options === 'string' || Array.isArray(options)) { + options = { prompt: options }; + } + return generateStream(this.registry, options); + } + + /** + * Checks the status of of a given operation. Returns a new operation which will contain the updated status. + * + * ```ts + * let operation = await ai.generateOperation({ + * model: googleAI.model('veo-2.0-generate-001'), + * prompt: 'A banana riding a bicycle.', + * }); + * + * while (!operation.done) { + * operation = await ai.checkOperation(operation!); + * await new Promise((resolve) => setTimeout(resolve, 5000)); + * } + * ``` + * + * @param operation + * @returns + */ + checkOperation(operation: Operation): Promise> { + return checkOperation(this.registry, operation); + } + + /** + * A flow step that executes the provided function. Each run step is recorded separately in the trace. + * + * ```ts + * ai.defineFlow('hello', async() => { + * await ai.run('step1', async () => { + * // ... step 1 + * }); + * await ai.run('step2', async () => { + * // ... step 2 + * }); + * return result; + * }) + * ``` + */ + run(name: string, func: () => Promise): Promise; + + /** + * A flow step that executes the provided function. Each run step is recorded separately in the trace. + * + * ```ts + * ai.defineFlow('hello', async() => { + * await ai.run('step1', async () => { + * // ... step 1 + * }); + * await ai.run('step2', async () => { + * // ... step 2 + * }); + * return result; + * }) + */ + run( + name: string, + input: any, + func: (input?: any) => Promise + ): Promise; + + run( + name: string, + funcOrInput: () => Promise | any, + maybeFunc?: (input?: any) => Promise + ): Promise { + if (maybeFunc) { + return run(name, funcOrInput, maybeFunc, this.registry); + } + return run(name, funcOrInput, this.registry); + } + + /** + * Returns current action (or flow) invocation context. Can be used to access things like auth + * data set by HTTP server frameworks. If invoked outside of an action (e.g. flow or tool) will + * return `undefined`. + */ + currentContext(): ActionContext | undefined { + return getContext(); + } +} diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 48dc68dd34..a6e2cb2ff9 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -44,7 +44,6 @@ import { GenerateMiddleware, generateMiddleware, GenerateMiddlewareDef, - MiddlewareRef, resolveMiddleware, } from './generate/middleware.js'; import { GenerateResponse } from './generate/response.js'; @@ -58,6 +57,7 @@ import { type GenerateRequest, type GenerationCommonConfigSchema, type MessageData, + type MiddlewareRef, type ModelArgument, type ModelMiddlewareArgument, type Part, diff --git a/js/ai/src/generate/middleware.ts b/js/ai/src/generate/middleware.ts index 5d8755201d..c34b34c908 100644 --- a/js/ai/src/generate/middleware.ts +++ b/js/ai/src/generate/middleware.ts @@ -17,14 +17,16 @@ import { ActionRunOptions, GenkitError, z } from '@genkit-ai/core'; import type { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; +import { GenerateAPI } from '../generate-api.js'; import type { GenerateActionOptions } from '../model-types.js'; +import { type MiddlewareRef } from '../model-types.js'; import { GenerateRequest, GenerateResponseData, ToolRequestPart, ToolResponsePart, } from '../model.js'; -import { GenkitPluginV2 } from '../plugin.js'; +import { type GenkitPluginV2 } from '../plugin.js'; import { ToolAction } from '../tool.js'; /** Descriptor for a registered middleware, returned by reflection API. */ @@ -40,15 +42,6 @@ export const MiddlewareDescSchema = z.object({ }); export type MiddlewareDesc = z.infer; -/** Reference to a registered middleware with optional configuration. */ -export const MiddlewareRefSchema = z.object({ - /** Name of the registered middleware. */ - name: z.string(), - /** Configuration for the middleware (schema defined by the middleware). */ - config: z.any().optional(), -}); -export type MiddlewareRef = z.infer; - /** * Defines a Genkit Generate Middleware instance, which can be configured and registered. * When invoked with an optional configuration, it returns a reference suitable for @@ -70,7 +63,10 @@ export interface GenerateMiddleware * Factory function that receives the validated configuration and creates * a `GenerateMiddlewareDef` holding the active hooks. */ - instantiate: (config?: z.infer) => GenerateMiddlewareDef; + instantiate: ( + config: z.infer | undefined, + ai: GenerateAPI + ) => GenerateMiddlewareDef; /** * Optional plugin wrapper exposing this middleware for framework-level registration. */ @@ -175,8 +171,8 @@ export function generateMiddleware< wrappedDef.configSchema = options.configSchema; wrappedDef.description = options.description; wrappedDef.metadata = options.metadata; - wrappedDef.instantiate = (reqConfig) => - def.instantiate(reqConfig ?? pluginConfig); + wrappedDef.instantiate = (reqConfig, ai) => + def.instantiate(reqConfig ?? pluginConfig, ai); wrappedDef.plugin = def.plugin; return [wrappedDef]; @@ -215,7 +211,7 @@ export async function resolveMiddleware( message: `Middleware ${ref.name} not found in registry.`, }); } - result.push(def.instantiate(ref.config)); + result.push(def.instantiate(ref.config, new GenerateAPI(registry))); } return result; } diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index b850c6ae6c..d55b7fa68d 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -38,6 +38,7 @@ export { type EvaluatorParams, type EvaluatorReference, } from './evaluator.js'; +export { GenerateAPI } from './generate-api.js'; export { GenerateResponse, GenerateResponseChunk, @@ -57,18 +58,18 @@ export { } from './generate.js'; export { MiddlewareDescSchema, - MiddlewareRefSchema, generateMiddleware, type GenerateMiddleware, type GenerateMiddlewareDef, type MiddlewareDesc, - type MiddlewareRef, } from './generate/middleware.js'; export { Message } from './message.js'; export { GenerateResponseChunkSchema, GenerationCommonConfigSchema, MessageSchema, + MiddlewareRefSchema, + ModelReferenceSchema, ModelRequestSchema, ModelResponseSchema, PartSchema, @@ -82,6 +83,7 @@ export { type GenerationUsage, type MediaPart, type MessageData, + type MiddlewareRef, type ModelArgument, type ModelReference, type ModelRequest, @@ -92,7 +94,7 @@ export { type ToolResponsePart, } from './model.js'; export { type ToolRequest, type ToolResponse } from './parts.js'; -export { type GenkitPluginV2 } from './plugin.js'; +export { type BaseGenkitPluginV2, type GenkitPluginV2 } from './plugin.js'; export { defineHelper, definePartial, diff --git a/js/ai/src/model-types.ts b/js/ai/src/model-types.ts index a149db5a0a..ce6853d27b 100644 --- a/js/ai/src/model-types.ts +++ b/js/ai/src/model-types.ts @@ -16,7 +16,6 @@ import { OperationSchema, z } from '@genkit-ai/core'; import { DocumentDataSchema } from './document.js'; -import { MiddlewareRefSchema } from './generate/middleware.js'; import { CustomPartSchema, DataPartSchema, @@ -28,6 +27,12 @@ import { ToolResponsePartSchema, } from './parts.js'; +export const MiddlewareRefSchema: z.ZodTypeAny = z.object({ + name: z.string(), + config: z.any().optional(), +}); +export type MiddlewareRef = z.infer; + // // IMPORTANT: Please keep type definitions in sync with // genkit-tools/src/types/model.ts diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 442576e3d8..753dba0b68 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -357,7 +357,16 @@ function getModelMiddleware(options: { return middleware; } -export interface ModelReference { +export const ModelReferenceSchema = z.object({ + name: z.string(), + configSchema: z.any().optional(), + info: z.any().optional(), + version: z.string().optional(), + config: z.any().optional(), +}); + +export interface ModelReference + extends z.infer { name: string; configSchema?: CustomOptions; info?: ModelInfo; diff --git a/js/ai/src/plugin.ts b/js/ai/src/plugin.ts index 6b3a69f1f4..babe4829f7 100644 --- a/js/ai/src/plugin.ts +++ b/js/ai/src/plugin.ts @@ -14,9 +14,11 @@ * limitations under the License. */ -import { BaseGenkitPluginV2 } from '@genkit-ai/core'; -import { GenerateMiddleware } from './generate/middleware.js'; -import { ModelAction } from './model.js'; +import { type BaseGenkitPluginV2 } from '@genkit-ai/core'; +import { type GenerateMiddleware } from './generate/middleware.js'; +import { type ModelAction } from './model.js'; + +export { type BaseGenkitPluginV2 }; export interface GenkitPluginV2 extends BaseGenkitPluginV2 { // Returns a list of generate middleware to be used in `generate({use: [...])`. diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index 07f52dc2b1..866f2eaa7f 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -30,6 +30,7 @@ export { MessageSchema, MiddlewareDescSchema, MiddlewareRefSchema, + ModelReferenceSchema, ModelRequestSchema, ModelResponseSchema, PartSchema, diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 118e475a2b..55d5673635 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -15,60 +15,42 @@ */ import { - checkOperation, + GenerateAPI, defineHelper, definePartial, definePrompt, defineTool, - embed, - evaluate, generate, - generateStream, loadPromptFolder, modelRef, prompt, - rerank, - retrieve, type BaseDataPointSchema, type Document, type EmbedderInfo, - type EmbedderParams, - type Embedding, - type EvalResponses, - type EvaluatorParams, type ExecutablePrompt, type GenerateOptions, type GenerateRequest, type GenerateResponse, type GenerateResponseChunk, type GenerateResponseData, - type GenerateStreamOptions, type GenerateStreamResponse, - type GenerationCommonConfigSchema, - type IndexerParams, type ModelArgument, type ModelReference, - type Part, type PromptConfig, type PromptGenerateOptions, - type RankedDocument, - type RerankerParams, type RetrieverAction, type RetrieverInfo, - type RetrieverParams, type ToolAction, type ToolConfig, } from '@genkit-ai/ai'; import { defineEmbedder, - embedMany, type EmbedderAction, - type EmbedderArgument, type EmbedderFn, - type EmbeddingBatch, } from '@genkit-ai/ai/embedder'; import { defineEvaluator, + evaluate, type EvaluatorAction, type EvaluatorFn, } from '@genkit-ai/ai/evaluator'; @@ -84,16 +66,21 @@ import { type ModelAction, } from '@genkit-ai/ai/model'; import { + RankedDocument, + RerankerParams, defineReranker, + rerank, type RerankerFn, type RerankerInfo, } from '@genkit-ai/ai/reranker'; import { + IndexerParams, + RetrieverParams, defineIndexer, defineRetriever, defineSimpleRetriever, index, - type DocumentData, + retrieve, type IndexerAction, type IndexerFn, type RetrieverFn, @@ -108,18 +95,15 @@ import { import { ActionFnArg, GenkitError, - Operation, ReflectionServer, defineDynamicActionProvider, defineFlow, defineJsonSchema, defineSchema, - getContext, isAction, isBackgroundAction, isDevEnv, registerBackgroundAction, - run, setClientHeader, type Action, type ActionContext, @@ -134,13 +118,17 @@ import { } from '@genkit-ai/core'; import { Channel } from '@genkit-ai/core/async'; import type { HasRegistry } from '@genkit-ai/core/registry'; -import type { BaseEvalDataPointSchema } from './evaluator.js'; +import type { + BaseEvalDataPointSchema, + EvalResponses, + EvaluatorParams, +} from './evaluator.js'; import { logger } from './logging.js'; import { - ResolvableAction, isPluginV2, type GenkitPlugin, type GenkitPluginV2, + type ResolvableAction, } from './plugin.js'; import { Registry, type ActionType } from './registry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; @@ -180,11 +168,10 @@ export interface GenkitOptions { * * There may be multiple Genkit instances in a single codebase. */ -export class Genkit implements HasRegistry { +export class Genkit extends GenerateAPI implements HasRegistry { + readonly registry: Registry; /** Developer-configured options. */ readonly options: GenkitOptions; - /** Registry instance that is exclusively modified by this Genkit instance. */ - readonly registry: Registry; /** Reflection server for this registry. May be null if not started. */ private reflectionServer: ReflectionServer | null = null; /** List of flows that have been registered in this instance. */ @@ -195,8 +182,10 @@ export class Genkit implements HasRegistry { } constructor(options?: GenkitOptions) { + const registry = new Registry(); + super(registry); + this.registry = registry; this.options = options || {}; - this.registry = new Registry(); if (this.options.context) { this.registry.context = this.options.context; } @@ -616,27 +605,6 @@ export class Genkit implements HasRegistry { return defineReranker(this.registry, options, runner); } - /** - * Embeds the given `content` using the specified `embedder`. - */ - embed( - params: EmbedderParams - ): Promise { - return embed(this.registry, params); - } - - /** - * A veneer for interacting with embedder models in bulk. - */ - embedMany(params: { - embedder: EmbedderArgument; - content: string[] | DocumentData[]; - metadata?: Record; - options?: z.infer; - }): Promise { - return embedMany(this.registry, params); - } - /** * Evaluates the given `dataset` using the specified `evaluator`. */ @@ -674,274 +642,6 @@ export class Genkit implements HasRegistry { return retrieve(this.registry, params); } - /** - * Make a generate call to the default model with a simple text prompt. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * model: googleAI.model('gemini-2.5-flash'), // default model - * }) - * - * const { text } = await ai.generate('hi'); - * ``` - */ - generate( - strPrompt: string - ): Promise>>; - - /** - * Make a generate call to the default model with a multipart request. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * model: googleAI.model('gemini-2.5-flash'), // default model - * }) - * - * const { text } = await ai.generate([ - * { media: {url: 'http://....'} }, - * { text: 'describe this image' } - * ]); - * ``` - */ - generate( - parts: Part[] - ): Promise>>; - - /** - * Generate calls a generative model based on the provided prompt and configuration. If - * `messages` is provided, the generation will include a conversation history in its - * request. If `tools` are provided, the generate method will automatically resolve - * tool calls returned from the model unless `returnToolRequests` is set to `true`. - * - * See {@link GenerateOptions} for detailed information about available options. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * }) - * - * const { text } = await ai.generate({ - * system: 'talk like a pirate', - * prompt: [ - * { media: { url: 'http://....' } }, - * { text: 'describe this image' } - * ], - * messages: conversationHistory, - * tools: [ userInfoLookup ], - * model: googleAI.model('gemini-2.5-flash'), - * }); - * ``` - */ - generate< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, - >( - opts: - | GenerateOptions - | PromiseLike> - ): Promise>>; - - async generate< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, - >( - options: - | string - | Part[] - | GenerateOptions - | PromiseLike> - ): Promise>> { - let resolvedOptions: GenerateOptions; - if (options instanceof Promise) { - resolvedOptions = await options; - } else if (typeof options === 'string' || Array.isArray(options)) { - resolvedOptions = { - prompt: options, - }; - } else { - resolvedOptions = options as GenerateOptions; - } - return generate(this.registry, resolvedOptions); - } - - /** - * Make a streaming generate call to the default model with a simple text prompt. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * model: googleAI.model('gemini-2.5-flash'), // default model - * }) - * - * const { response, stream } = ai.generateStream('hi'); - * for await (const chunk of stream) { - * console.log(chunk.text); - * } - * console.log((await response).text); - * ``` - */ - generateStream( - strPrompt: string - ): GenerateStreamResponse>; - - /** - * Make a streaming generate call to the default model with a multipart request. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * model: googleAI.model('gemini-2.5-flash'), // default model - * }) - * - * const { response, stream } = ai.generateStream([ - * { media: {url: 'http://....'} }, - * { text: 'describe this image' } - * ]); - * for await (const chunk of stream) { - * console.log(chunk.text); - * } - * console.log((await response).text); - * ``` - */ - generateStream( - parts: Part[] - ): GenerateStreamResponse>; - - /** - * Streaming generate calls a generative model based on the provided prompt and configuration. If - * `messages` is provided, the generation will include a conversation history in its - * request. If `tools` are provided, the generate method will automatically resolve - * tool calls returned from the model unless `returnToolRequests` is set to `true`. - * - * See {@link GenerateOptions} for detailed information about available options. - * - * ```ts - * const ai = genkit({ - * plugins: [googleAI()], - * }) - * - * const { response, stream } = ai.generateStream({ - * system: 'talk like a pirate', - * prompt: [ - * { media: { url: 'http://....' } }, - * { text: 'describe this image' } - * ], - * messages: conversationHistory, - * tools: [ userInfoLookup ], - * model: googleAI.model('gemini-2.5-flash'), - * }); - * for await (const chunk of stream) { - * console.log(chunk.text); - * } - * console.log((await response).text); - * ``` - */ - generateStream< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, - >( - parts: - | GenerateOptions - | PromiseLike> - ): GenerateStreamResponse>; - - generateStream< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, - >( - options: - | string - | Part[] - | GenerateStreamOptions - | PromiseLike> - ): GenerateStreamResponse> { - if (typeof options === 'string' || Array.isArray(options)) { - options = { prompt: options }; - } - return generateStream(this.registry, options); - } - - /** - * Checks the status of of a given operation. Returns a new operation which will contain the updated status. - * - * ```ts - * let operation = await ai.generateOperation({ - * model: googleAI.model('veo-2.0-generate-001'), - * prompt: 'A banana riding a bicycle.', - * }); - * - * while (!operation.done) { - * operation = await ai.checkOperation(operation!); - * await new Promise((resolve) => setTimeout(resolve, 5000)); - * } - * ``` - * - * @param operation - * @returns - */ - checkOperation(operation: Operation): Promise> { - return checkOperation(this.registry, operation); - } - - /** - * A flow step that executes the provided function. Each run step is recorded separately in the trace. - * - * ```ts - * ai.defineFlow('hello', async() => { - * await ai.run('step1', async () => { - * // ... step 1 - * }); - * await ai.run('step2', async () => { - * // ... step 2 - * }); - * return result; - * }) - * ``` - */ - run(name: string, func: () => Promise): Promise; - - /** - * A flow step that executes the provided function. Each run step is recorded separately in the trace. - * - * ```ts - * ai.defineFlow('hello', async() => { - * await ai.run('step1', async () => { - * // ... step 1 - * }); - * await ai.run('step2', async () => { - * // ... step 2 - * }); - * return result; - * }) - */ - run( - name: string, - input: any, - func: (input?: any) => Promise - ): Promise; - - run( - name: string, - funcOrInput: () => Promise | any, - maybeFunc?: (input?: any) => Promise - ): Promise { - if (maybeFunc) { - return run(name, funcOrInput, maybeFunc, this.registry); - } - return run(name, funcOrInput, this.registry); - } - - /** - * Returns current action (or flow) invocation context. Can be used to access things like auth - * data set by HTTP server frameworks. If invoked outside of an action (e.g. flow or tool) will - * return `undefined`. - */ - currentContext(): ActionContext | undefined { - return getContext(); - } - /** * Configures the Genkit instance. */ diff --git a/js/genkit/src/plugin.ts b/js/genkit/src/plugin.ts index f44ad8c16c..ff794d1aab 100644 --- a/js/genkit/src/plugin.ts +++ b/js/genkit/src/plugin.ts @@ -15,7 +15,7 @@ */ import type { GenerateMiddleware } from '@genkit-ai/ai'; -import { type GenkitPluginV2 } from '@genkit-ai/ai'; +import { type BaseGenkitPluginV2, type GenkitPluginV2 } from '@genkit-ai/ai'; import { type ModelAction } from '@genkit-ai/ai/model'; import { GenkitError, @@ -34,7 +34,12 @@ export { } from '@genkit-ai/ai/model'; export { reranker } from '@genkit-ai/ai/reranker'; export { indexer, retriever } from '@genkit-ai/ai/retriever'; -export { type GenkitPluginV2, type ResolvableAction }; +export { + type BaseGenkitPluginV2, + type GenerateMiddleware, + type GenkitPluginV2, + type ResolvableAction, +}; export interface PluginProvider { name: string; diff --git a/js/plugins/anthropic/src/index.ts b/js/plugins/anthropic/src/index.ts index a48947bde5..1de13548e7 100644 --- a/js/plugins/anthropic/src/index.ts +++ b/js/plugins/anthropic/src/index.ts @@ -19,24 +19,24 @@ import Anthropic from '@anthropic-ai/sdk'; import { genkitPluginV2, type GenkitPluginV2 } from 'genkit/plugin'; import type { Part } from 'genkit'; -import { ActionMetadata, ModelReference, z } from 'genkit'; -import { ModelAction } from 'genkit/model'; -import { ActionType } from 'genkit/registry'; +import { z, type ActionMetadata, type ModelReference } from 'genkit'; +import { type ModelAction } from 'genkit/model'; +import { type ActionType } from 'genkit/registry'; import { listActions } from './list.js'; import { - AnthropicConfigSchemaType, - ClaudeConfig, - ClaudeModelName, KNOWN_CLAUDE_MODELS, - KnownClaudeModels, claudeModel, claudeModelReference, + type AnthropicConfigSchemaType, + type ClaudeConfig, + type ClaudeModelName, + type KnownClaudeModels, } from './models.js'; import { - InternalPluginOptions, - PluginOptions, __testClient, type AnthropicDocumentOptions, + type InternalPluginOptions, + type PluginOptions, } from './types.js'; // Re-export types and utilities for consumers diff --git a/js/plugins/compat-oai/src/openai/index.ts b/js/plugins/compat-oai/src/openai/index.ts index 38d6f80ced..702d82c4ad 100644 --- a/js/plugins/compat-oai/src/openai/index.ts +++ b/js/plugins/compat-oai/src/openai/index.ts @@ -24,8 +24,8 @@ import { ModelReference, z, } from 'genkit'; -import { ResolvableAction, type GenkitPluginV2 } from 'genkit/plugin'; -import { ActionType } from 'genkit/registry'; +import { type GenkitPluginV2, type ResolvableAction } from 'genkit/plugin'; +import { type ActionType } from 'genkit/registry'; import OpenAI from 'openai'; import { defineCompatOpenAISpeechModel, diff --git a/js/plugins/google-genai/src/googleai/index.ts b/js/plugins/google-genai/src/googleai/index.ts index 27dd7dde30..2e7defe59e 100644 --- a/js/plugins/google-genai/src/googleai/index.ts +++ b/js/plugins/google-genai/src/googleai/index.ts @@ -17,9 +17,9 @@ import { ActionMetadata, EmbedderReference, ModelReference, z } from 'genkit'; import { logger } from 'genkit/logging'; import { - GenkitPluginV2, - ResolvableAction, genkitPluginV2, + type GenkitPluginV2, + type ResolvableAction, } from 'genkit/plugin'; import { ActionType } from 'genkit/registry'; import { extractErrMsg } from '../common/utils.js'; diff --git a/js/plugins/google-genai/src/vertexai/index.ts b/js/plugins/google-genai/src/vertexai/index.ts index 488d4a94f7..98d425a94b 100644 --- a/js/plugins/google-genai/src/vertexai/index.ts +++ b/js/plugins/google-genai/src/vertexai/index.ts @@ -22,11 +22,11 @@ import { EmbedderReference, ModelReference, z } from 'genkit'; import { - GenkitPluginV2, - ResolvableAction, genkitPluginV2, + type GenkitPluginV2, + type ResolvableAction, } from 'genkit/plugin'; -import { ActionType } from 'genkit/registry'; +import { type ActionType } from 'genkit/registry'; import { listModels } from './client.js'; import * as embedder from './embedder.js'; From c327325fa21f6230a287eaf784c656e39b6f938c Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 9 Feb 2026 19:27:34 -0500 Subject: [PATCH 04/10] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- js/ai/src/generate/action.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index e60f61ee05..5ff0763b07 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -405,7 +405,7 @@ async function generateActionTurn( const sendChunk = streamingCallback && ((chunk: GenerateResponseChunkData) => - streamingCallback && streamingCallback(makeChunk('model', chunk))); + streamingCallback(makeChunk('model', chunk))); const dispatchModel = async ( index: number, req: z.infer, From 781f5746063034d35e89b0465fbeb885cac991fc Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 9 Feb 2026 19:28:49 -0500 Subject: [PATCH 05/10] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- js/core/src/reflection.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index 8ca1da3892..b1a6c3b4cc 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -171,7 +171,7 @@ export class ReflectionServer { response .status(400) .send( - `'type' ${type} is not supported. Only 'defaultModel' is supported` + `'type' ${type} is not supported. Only 'defaultModel' and 'middleware' are supported` ); return; } From 87c56c1a6787d6e60d1904d357182b4606b75867 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 9 Feb 2026 21:57:02 -0500 Subject: [PATCH 06/10] feat: pass the `ai` object to generate middleware functions. --- js/ai/src/generate/middleware.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/js/ai/src/generate/middleware.ts b/js/ai/src/generate/middleware.ts index c34b34c908..ab5316a208 100644 --- a/js/ai/src/generate/middleware.ts +++ b/js/ai/src/generate/middleware.ts @@ -141,7 +141,10 @@ export function generateMiddleware< description?: string; metadata?: Record; }, - middlewareFn: (config?: z.infer) => GenerateMiddlewareDef + middlewareFn: ( + config: z.infer | undefined, + ai: GenerateAPI + ) => GenerateMiddlewareDef ): GenerateMiddleware { const def = function (config?: z.infer) { return { From 943d3ba3a6e1d4ca0158761f3b015cebc11658a0 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Feb 2026 19:12:38 -0500 Subject: [PATCH 07/10] fix(ai): handle void middleware returns and missing tool messages Updates the generation loop and middleware logic to safely handle cases where tool execution yields no result. - Updates `generateActionTurn` to verify `toolMessage` existence before streaming chunks or adding to the message history, preventing runtime errors on undefined values. - Updates middleware types and `resolveToolRequest` to allow `void` return values. - Ensures `resolveToolRequests` returns an empty object if no response parts or transfer preamble are generated, rather than constructing a malformed tool message. --- js/ai/src/generate/action.ts | 19 ++++++++++----- js/ai/src/generate/middleware.ts | 26 +++++++++++++-------- js/ai/src/generate/resolve-tool-requests.ts | 19 ++++++++++----- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 5ff0763b07..ae54d80cf5 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -495,16 +495,23 @@ async function generateActionTurn( } // if the loop will continue, stream out the tool response message... - streamingCallback?.( - makeChunk('tool', { - content: toolMessage!.content, - }) - ); + if (toolMessage) { + streamingCallback?.( + makeChunk('tool', { + content: toolMessage.content, + }) + ); + } + + const messages = [...rawRequest.messages, generatedMessage.toJSON()]; + if (toolMessage) { + messages.push(toolMessage); + } let nextRequest = { ...rawRequest, - messages: [...rawRequest.messages, generatedMessage.toJSON(), toolMessage!], + messages, }; nextRequest = applyTransferPreamble(nextRequest, transferPreamble); diff --git a/js/ai/src/generate/middleware.ts b/js/ai/src/generate/middleware.ts index ab5316a208..e9b1874330 100644 --- a/js/ai/src/generate/middleware.ts +++ b/js/ai/src/generate/middleware.ts @@ -116,16 +116,22 @@ export interface GenerateMiddlewareDef { next: ( req: ToolRequestPart, ctx: ActionRunOptions - ) => Promise<{ - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - }> - ) => Promise<{ - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - }>; + ) => Promise< + | { + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + } + | void + > + ) => Promise< + | { + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + } + | void + >; /** * Tools to statically inject into the generation request whenever this middleware is active. */ diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index d00e83038f..cffeae064f 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -112,11 +112,14 @@ export async function resolveToolRequest( index: number, req: ToolRequestPart, ctx: ActionRunOptions - ): Promise<{ - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - }> => { + ): Promise< + | { + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; + } + | void + > => { if (index === middleware.length) { return executeTool(req, ctx); } @@ -207,7 +210,7 @@ export async function resolveToolRequest( }; const initialCtx = runOptions ?? toRunOptions(part); - return dispatch(0, part, initialCtx); + return (await dispatch(0, part, initialCtx)) || {}; } /** @@ -279,6 +282,10 @@ export async function resolveToolRequests( return { revisedModelMessage }; } + if (responseParts.length === 0 && !transferPreamble) { + return {}; + } + return { toolMessage: { role: 'tool', content: responseParts }, transferPreamble, From cbb3eb69f0aa1e8a822f242b0560b80e808ccfa8 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Feb 2026 19:30:46 -0500 Subject: [PATCH 08/10] refactor(ai): simplify tool middleware and prompt handling logic - Simplify `GenerateMiddlewareDef` to return `ToolResponsePart | void` directly, removing the complex wrapper object that included `interrupt` and `preamble`. - Move `isPromptAction` check in `resolveToolRequest` to execute before the middleware chain, ensuring prompt actions are handled immediately without passing through tool middleware. - Update `executeTool` implementation and return types to align with the simplified middleware signature. --- js/ai/src/generate/middleware.ts | 18 +- js/ai/src/generate/resolve-tool-requests.ts | 190 +++++++++----------- js/ai/tests/generate/middleware_test.ts | 63 ++++++- 3 files changed, 152 insertions(+), 119 deletions(-) diff --git a/js/ai/src/generate/middleware.ts b/js/ai/src/generate/middleware.ts index e9b1874330..1d2bcda653 100644 --- a/js/ai/src/generate/middleware.ts +++ b/js/ai/src/generate/middleware.ts @@ -116,22 +116,8 @@ export interface GenerateMiddlewareDef { next: ( req: ToolRequestPart, ctx: ActionRunOptions - ) => Promise< - | { - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - } - | void - > - ) => Promise< - | { - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - } - | void - >; + ) => Promise + ) => Promise; /** * Tools to statically inject into the generation request whenever this middleware is active. */ diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index cffeae064f..e41152fef2 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -108,18 +108,29 @@ export async function resolveToolRequest( }); } + // if it's a prompt action, go ahead and render the preamble + if (isPromptAction(tool)) { + const metadata = tool.__action.metadata as Record; + const preamble = { + ...(await tool(part.toolRequest.input)), + model: metadata.prompt?.model, + }; + const response = { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: `transferred to ${part.toolRequest.name}`, + }, + }; + + return { preamble, response }; + } + const dispatch = async ( index: number, req: ToolRequestPart, ctx: ActionRunOptions - ): Promise< - | { - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - } - | void - > => { + ): Promise => { if (index === middleware.length) { return executeTool(req, ctx); } @@ -136,81 +147,35 @@ export async function resolveToolRequest( const executeTool = async ( req: ToolRequestPart, ctx: ActionRunOptions - ): Promise<{ - response?: ToolResponsePart; - interrupt?: ToolRequestPart; - preamble?: GenerateActionOptions; - }> => { - // if it's a prompt action, go ahead and render the preamble - if (isPromptAction(tool)) { - const metadata = tool.__action.metadata as Record; - const preamble = { - ...(await tool(req.toolRequest.input)), - model: metadata.prompt?.model, - }; - const response = { + ): Promise => { + // execute the tool and catch interrupts + const output = await tool(req.toolRequest.input, ctx as ToolRunOptions); + if (tool.__action.actionType === 'tool.v2') { + const multipartResponse = output as z.infer< + typeof MultipartToolResponseSchema + >; + return stripUndefinedProps({ + toolResponse: { + name: req.toolRequest.name, + ref: req.toolRequest.ref, + output: multipartResponse.output, + content: multipartResponse.content, + } as ToolResponse, + }); + } else { + return stripUndefinedProps({ toolResponse: { name: req.toolRequest.name, ref: req.toolRequest.ref, - output: `transferred to ${req.toolRequest.name}`, + output, }, - }; - - return { preamble, response }; - } - - // otherwise, execute the tool and catch interrupts - try { - const output = await tool(req.toolRequest.input, ctx as ToolRunOptions); - if (tool.__action.actionType === 'tool.v2') { - const multipartResponse = output as z.infer< - typeof MultipartToolResponseSchema - >; - const response = stripUndefinedProps({ - toolResponse: { - name: req.toolRequest.name, - ref: req.toolRequest.ref, - output: multipartResponse.output, - content: multipartResponse.content, - } as ToolResponse, - }); - - return { response }; - } else { - const response = stripUndefinedProps({ - toolResponse: { - name: req.toolRequest.name, - ref: req.toolRequest.ref, - output, - }, - }); - - return { response }; - } - } catch (e) { - if ( - e instanceof ToolInterruptError || - // There's an inexplicable case when the above type check fails, only in tests. - (e as Error).name === 'ToolInterruptError' - ) { - const ie = e as ToolInterruptError; - logger.debug( - `tool '${toolMap[req.toolRequest?.name].__action.name}' triggered an interrupt${ie.metadata ? `: ${JSON.stringify(ie.metadata)}` : ''}` - ); - const interrupt = { - toolRequest: req.toolRequest, - metadata: { ...req.metadata, interrupt: ie.metadata || true }, - }; - - return { interrupt }; - } - - throw e; + }); } }; const initialCtx = runOptions ?? toRunOptions(part); - return (await dispatch(0, part, initialCtx)) || {}; + const dispatchResult = await dispatch(0, part, initialCtx); + return dispatchResult ? { response: dispatchResult } : {}; } /** @@ -243,37 +208,58 @@ export async function resolveToolRequests( revisedModelMessage.content.map(async (part, i) => { if (!part.toolRequest) return; // skip non-tool-request parts - const { preamble, response, interrupt } = await resolveToolRequest( - rawRequest, - part as ToolRequestPart, - toolMap, - middleware - ); + try { + const { preamble, response, interrupt } = await resolveToolRequest( + rawRequest, + part as ToolRequestPart, + toolMap, + middleware + ); - if (preamble) { - if (transferPreamble) { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: `Model attempted to transfer to multiple prompt tools.`, - }); - } + if (preamble) { + if (transferPreamble) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Model attempted to transfer to multiple prompt tools.`, + }); + } - transferPreamble = preamble; - } + transferPreamble = preamble; + } - // this happens for preamble or normal tools - if (response) { - responseParts.push(response!); - revisedModelMessage.content.splice( - i, - 1, - toPendingOutput(part, response) - ); - } + // this happens for preamble or normal tools + if (response) { + responseParts.push(response!); + revisedModelMessage.content.splice( + i, + 1, + toPendingOutput(part, response) + ); + } - if (interrupt) { - revisedModelMessage.content.splice(i, 1, interrupt); - hasInterrupts = true; + if (interrupt) { + revisedModelMessage.content.splice(i, 1, interrupt); + hasInterrupts = true; + } + } catch (e) { + if ( + e instanceof ToolInterruptError || + // There's an inexplicable case when the above type check fails, only in tests. + (e as Error).name === 'ToolInterruptError' + ) { + const ie = e as ToolInterruptError; + logger.debug( + `tool '${toolMap[part.toolRequest?.name].__action.name}' triggered an interrupt${ie.metadata ? `: ${JSON.stringify(ie.metadata)}` : ''}` + ); + const interrupt = { + toolRequest: part.toolRequest, + metadata: { ...part.metadata, interrupt: ie.metadata || true }, + }; + revisedModelMessage.content.splice(i, 1, interrupt); + hasInterrupts = true; + } else { + throw e; + } } }) ); diff --git a/js/ai/tests/generate/middleware_test.ts b/js/ai/tests/generate/middleware_test.ts index b301ed6534..750a2f49b1 100644 --- a/js/ai/tests/generate/middleware_test.ts +++ b/js/ai/tests/generate/middleware_test.ts @@ -22,7 +22,7 @@ import { beforeEach, describe, it } from 'node:test'; import { generate, generateStream } from '../../src/generate.js'; import { generateMiddleware } from '../../src/generate/middleware.js'; import { defineModel } from '../../src/model.js'; -import { defineTool, tool } from '../../src/tool.js'; +import { ToolInterruptError, defineTool, tool } from '../../src/tool.js'; initNodeFeatures(); @@ -581,4 +581,65 @@ describe('generateMiddleware', () => { assert.strictEqual(result.text, 'final response'); assert.strictEqual(toolExecutionCount, 1); }); + + it('handles ToolInterruptError from middleware', async () => { + const mockTool = defineTool( + registry, + { + name: 'interruptTool', + description: 'interrupts', + inputSchema: z.object({}), + outputSchema: z.string(), + }, + async () => { + return 'foo'; + } + ); + + const interruptMiddleware = generateMiddleware( + { name: 'interruptMw' }, + () => ({ + tool: async (req, ctx, next) => { + throw new ToolInterruptError({ some: 'metadata' }); + }, + }) + ); + + const mockModel = defineModel( + registry, + { name: 'mockModelWithTool' }, + async (req) => { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: mockTool.__action.name, + ref: '123', + input: {}, + }, + }, + ], + }, + }; + } + ); + + const result = await generate(registry, { + model: mockModel, + prompt: 'hi', + tools: ['interruptTool'], + use: [interruptMiddleware()], + }); + + assert.strictEqual(result.finishReason, 'interrupted'); + const interruptPart = result.message?.content.find( + (p) => p.metadata?.interrupt + ); + assert.ok(interruptPart); + assert.deepStrictEqual(interruptPart.metadata?.interrupt, { + some: 'metadata', + }); + }); }); From 577916fe3c818bb251eab3521fa5ed21fab16a99 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 11 Feb 2026 15:42:39 -0500 Subject: [PATCH 09/10] fixed resume middleware --- js/ai/src/generate/action.ts | 48 +++-- js/ai/src/generate/resolve-tool-requests.ts | 3 +- js/ai/tests/generate/middleware_test.ts | 197 ++++++++++++++++++++ 3 files changed, 236 insertions(+), 12 deletions(-) diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index ae54d80cf5..78a6fe728f 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -352,16 +352,45 @@ async function generateActionTurn( revisedRequest, interruptedResponse, toolMessage: resumedToolMessage, - } = await resolveResumeOption(registry, rawRequest, middleware || []); + } = await resolveResumeOption(registry, rawRequest, tools, middleware || []); // NOTE: in the future we should make it possible to interrupt a restart, but // at the moment it's too complicated because it's not clear how to return a // response that amends history but doesn't generate a new message, so we throw - if (interruptedResponse) { - throw new GenkitError({ - status: 'FAILED_PRECONDITION', - message: - 'One or more tools triggered an interrupt during a restarted execution.', - detail: { message: interruptedResponse.message }, + if (revisedRequest && revisedRequest !== rawRequest) { + if (interruptedResponse) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + 'One or more tools triggered an interrupt during a restarted execution.', + detail: { message: interruptedResponse.message }, + }); + } + + if (resumedToolMessage && streamingCallback) { + streamingCallback( + new GenerateResponseChunk( + { + role: 'tool', + content: resumedToolMessage.content, + }, + { + index: messageIndex, + role: 'tool', + previousChunks: [], + parser: format?.handler(rawRequest.output?.jsonSchema).parseChunk, + } + ) + ); + } + + return await generateHelper(registry, { + rawRequest: revisedRequest, + middleware, + currentTurn, + messageIndex: messageIndex + (resumedToolMessage ? 1 : 0), + abortSignal, + streamingCallback, + context, }); } rawRequest = revisedRequest!; @@ -396,10 +425,7 @@ async function generateActionTurn( }); }; - // if resolving the 'resume' option above generated a tool message, stream it. - if (resumedToolMessage && streamingCallback) { - streamingCallback(makeChunk('tool', resumedToolMessage)); - } + var response: GenerateResponse; const sendChunk = diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index e41152fef2..433bbed16e 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -397,6 +397,7 @@ async function resolveResumedToolRequest( export async function resolveResumeOption( registry: Registry, rawRequest: GenerateActionOptions, + tools: ToolAction[], middleware: GenerateMiddlewareDef[] = [] ): Promise<{ revisedRequest?: GenerateActionOptions; @@ -404,7 +405,7 @@ export async function resolveResumeOption( toolMessage?: MessageData; }> { if (!rawRequest.resume) return { revisedRequest: rawRequest }; // no-op if no resume option - const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); + const toolMap = toToolMap(tools); const messages = rawRequest.messages; const lastMessage = messages.at(-1); diff --git a/js/ai/tests/generate/middleware_test.ts b/js/ai/tests/generate/middleware_test.ts index 750a2f49b1..2e44f34f78 100644 --- a/js/ai/tests/generate/middleware_test.ts +++ b/js/ai/tests/generate/middleware_test.ts @@ -642,4 +642,201 @@ describe('generateMiddleware', () => { some: 'metadata', }); }); + + it('resumes tool execution with modified metadata after interrupt', async () => { + const mockTool = defineTool( + registry, + { + name: 'interruptTool', + description: 'interrupts', + inputSchema: z.object({}), + outputSchema: z.string(), + }, + async () => { + return 'tool output'; + } + ); + + let middlewareRunCount = 0; + const interruptMiddleware = generateMiddleware( + { name: 'interruptMw' }, + () => ({ + tool: async (req, ctx, next) => { + middlewareRunCount++; + if (req.metadata?.['approved'] === true) { + return next(req, ctx); + } + throw new ToolInterruptError({ some: 'metadata' }); + }, + }) + ); + + let callCount = 0; + const mockModel = defineModel( + registry, + { name: 'mockModelWithTool' }, + async (req) => { + callCount++; + if (callCount === 1) { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: mockTool.__action.name, + ref: '123', + input: {}, + }, + }, + ], + }, + }; + } else { + return { + message: { + role: 'model', + content: [{ text: 'final response' }], + }, + }; + } + } + ); + + const result = await generate(registry, { + model: mockModel, + prompt: 'hi', + tools: ['interruptTool'], + use: [interruptMiddleware()], + }); + + assert.strictEqual(result.finishReason, 'interrupted'); + const interruptPart = result.interrupts[0]; + assert.ok(interruptPart); + assert.strictEqual(middlewareRunCount, 1); + + // Modify metadata + if (interruptPart.metadata) { + interruptPart.metadata = { ...interruptPart.metadata, approved: true }; + } + + const result2 = await generate(registry, { + model: mockModel, + messages: result.messages, + tools: ['interruptTool'], + use: [interruptMiddleware()], + resume: { + restart: [interruptPart], + }, + }); + + assert.strictEqual(result2.text, 'final response'); + // Middleware should have run again + assert.strictEqual(middlewareRunCount, 2); + }); + + it('re-runs generate middleware after resuming tool execution', async () => { + const mockTool = defineTool( + registry, + { + name: 'interruptTool', + description: 'interrupts', + inputSchema: z.object({}), + outputSchema: z.string(), + }, + async () => { + return 'tool output'; + } + ); + + let generateMiddlewareCallCount = 0; + let seenToolResponseInGenerate = false; + + const testMiddleware = generateMiddleware( + { name: 'testMw' }, + () => ({ + generate: async (req, ctx, next) => { + generateMiddlewareCallCount++; + const lastMsg = req.messages[req.messages.length - 1]; + if (lastMsg?.role === 'tool') { + seenToolResponseInGenerate = true; + } + return next(req, ctx); + }, + tool: async (req, ctx, next) => { + if (req.metadata?.['approved'] === true) { + return next(req, ctx); + } + throw new ToolInterruptError({ some: 'metadata' }); + }, + }) + ); + + let callCount = 0; + const mockModel = defineModel( + registry, + { name: 'mockModelWithTool2' }, + async (req) => { + callCount++; + if (callCount === 1) { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: mockTool.__action.name, + ref: '123', + input: {}, + }, + }, + ], + }, + }; + } else { + return { + message: { + role: 'model', + content: [{ text: 'final response' }], + }, + }; + } + } + ); + + const result = await generate(registry, { + model: mockModel, + prompt: 'hi', + tools: ['interruptTool'], + use: [testMiddleware()], + }); + + assert.strictEqual(result.finishReason, 'interrupted'); + const interruptPart = result.interrupts[0]; + assert.ok(interruptPart); + + // Modify metadata + if (interruptPart.metadata) { + interruptPart.metadata = { ...interruptPart.metadata, approved: true }; + } + + generateMiddlewareCallCount = 0; // Reset + seenToolResponseInGenerate = false; + + await generate(registry, { + model: mockModel, + messages: result.messages, + tools: ['interruptTool'], + use: [testMiddleware()], + resume: { + restart: [interruptPart], + }, + }); + + assert.ok( + seenToolResponseInGenerate, + 'Generate middleware should see the tool response' + ); + assert.strictEqual(generateMiddlewareCallCount, 2); + }); }); From 2f57f05fcfb262748eb4cf140de162ef1214c8fb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 11 Feb 2026 21:25:31 -0500 Subject: [PATCH 10/10] allow dotprompt to use middleware --- js/ai/src/prompt.ts | 4 +++- js/genkit/src/genkit.ts | 15 +++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 1c489ad189..e497530b02 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -48,6 +48,7 @@ import { import { Message } from './message.js'; import { GenerateActionOptionsSchema, + MiddlewareRef, type GenerateActionOptions, type GenerateRequest, type GenerateRequestSchema, @@ -127,7 +128,7 @@ export interface PromptConfig< metadata?: Record; tools?: ToolArgument[]; toolChoice?: ToolChoice; - use?: ModelMiddleware[]; + use?: (ModelMiddleware | MiddlewareRef)[]; context?: ActionContext; } @@ -861,6 +862,7 @@ function loadPrompt( maxTurns: promptMetadata.raw?.['maxTurns'], toolChoice: promptMetadata.raw?.['toolChoice'], returnToolRequests: promptMetadata.raw?.['returnToolRequests'], + use: promptMetadata.raw?.['use'], messages: parsedPrompt.template, }; }) diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 55d5673635..6f4b2998f5 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -668,6 +668,13 @@ export class Genkit extends GenerateAPI implements HasRegistry { plugins.forEach((plugin) => { if (isPluginV2(plugin)) { logger.debug(`Registering v2 plugin ${plugin.name}...`); + plugin.generateMiddleware?.()?.forEach((middleware) => { + activeRegistry.registerValue( + 'middleware', + middleware.name, + middleware + ); + }); activeRegistry.registerPluginProvider(plugin.name, { name: plugin.name, async initializer() { @@ -677,14 +684,6 @@ export class Genkit extends GenerateAPI implements HasRegistry { resolvedActions?.forEach((resolvedAction) => { registerActionV2(activeRegistry, resolvedAction, plugin); }); - const definedMiddleware = plugin.generateMiddleware?.(); - definedMiddleware?.forEach((middleware) => { - activeRegistry.registerValue( - 'middleware', - middleware.name, - middleware - ); - }); }, async resolver(action: ActionType, target: string) { if (!plugin.resolve) return;