Skip to content
204 changes: 198 additions & 6 deletions py/PARITY_AUDIT.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions py/packages/genkit/src/genkit/ai/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ def define_model(
metadata: dict[str, object] | None = None,
info: ModelInfo | None = None,
description: str | None = None,
use: list[ModelMiddleware] | None = None,
) -> Action:
"""Define a custom model action.

Expand All @@ -1018,6 +1019,10 @@ def define_model(
metadata: Optional metadata for the model.
info: Optional ModelInfo for the model.
description: Optional description for the model.
use: Optional list of model-level middleware to apply when
this model is invoked. These run after call-time middleware
(passed via ``generate(use=[...])``) in the dispatch chain,
matching JS SDK's ``defineModel({use: [mw]})``.
"""
# Build model options dict
model_options: dict[str, object] = {}
Expand Down Expand Up @@ -1058,6 +1063,7 @@ def define_model(
fn=fn,
metadata=model_meta,
description=model_description,
middleware=use,
)

def define_background_model(
Expand Down
11 changes: 9 additions & 2 deletions py/packages/genkit/src/genkit/blocks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def wrapper(chunk: GenerateResponseChunk) -> None:
if raw_request.docs and not supports_context:
middleware.append(augment_with_context())

# Build the combined middleware chain: call-time middleware first,
# then model-level middleware (from define_model(use=[...])).
# This matches JS SDK execution order:
# call-time[0..N] → model-level[0..M] → runner
model_middleware: list[ModelMiddleware] = cast(list[ModelMiddleware], model.middleware)
combined_middleware: list[ModelMiddleware] = list(middleware) + model_middleware

async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
"""Dispatches model request, passing it through middleware if present.

Expand All @@ -223,7 +230,7 @@ async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> G
Returns:
The generated response.
"""
if not middleware or index == len(middleware):
if index == len(combined_middleware):
# end of the chain, call the original model action
return (
await model.arun(
Expand All @@ -233,7 +240,7 @@ async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> G
)
).response

current_middleware = middleware[index]
current_middleware = combined_middleware[index]

async def next_fn(
modified_req: GenerateRequest | None = None,
Expand Down
39 changes: 38 additions & 1 deletion py/packages/genkit/src/genkit/blocks/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,44 @@
#
# SPDX-License-Identifier: Apache-2.0

"""Middleware for the Genkit framework."""
"""Model middleware for the Genkit framework.

This module contains **Layer 4 (Model Middleware)** — the primary user-facing
middleware in Genkit. Model middleware intercepts and transforms
``GenerateRequest`` and ``GenerateResponse`` objects in a chain around the
model runner.

Genkit has four distinct middleware layers, each at a different level:

- **Layer 1 — ASGI / HTTP Middleware** (Starlette, FastAPI, etc.):
Runs on every HTTP request. Use for CORS, rate limiting, security headers.

- **Layer 2 — Context Providers** (``Genkit(context=...)``):
Extracts auth/context from the HTTP request before every action.
Use for API key validation, JWT parsing, user session extraction.

- **Layer 3 — Action Middleware** (``Action.use`` in JS core):
Wraps any action type. Primarily internal framework wiring in Python.

- **Layer 4 — Model Middleware** (this module):
Runs only on model calls via ``generate()`` or ``prompt.generate()``.
Use for safety guardrails, retry, fallback, constrained generation
simulation, system prompt simulation, and media downloading.

Model middleware can be applied at three levels:

- **Call-time**: ``generate(use=[mw])`` — per-request.
- **Model-level**: ``define_model(use=[mw])`` — baked into the model.
- **Auto-wired**: Injected at model definition time based on model
capabilities (e.g., ``augment_with_context`` for models without
native context support).

Execution order: call-time → model-level → auto-wired → model runner.

See Also:
``genkit.blocks.model``: ``ModelMiddleware`` type definition.
``genkit.blocks.generate``: Middleware dispatch in ``generate_action()``.
"""

from genkit.blocks.model import (
ModelMiddleware,
Expand Down
13 changes: 13 additions & 0 deletions py/packages/genkit/src/genkit/core/action/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
description: str | None = None,
metadata: dict[str, object] | None = None,
span_metadata: dict[str, SpanAttributeValue] | None = None,
middleware: list[Any] | None = None,
) -> None:
"""Initialize an Action.

Expand All @@ -237,12 +238,15 @@ def __init__(
description: Optional human-readable description of the action.
metadata: Optional dictionary of metadata about the action.
span_metadata: Optional dictionary of tracing span metadata.
middleware: Optional list of middleware functions to be applied
when executing this action.
"""
self._kind: ActionKind = kind
self._name: str = name
self._metadata: dict[str, object] = metadata if metadata else {}
self._description: str | None = description
self._is_async: bool = inspect.iscoroutinefunction(fn)
self._middleware: list[Any] = middleware if middleware else []
# Optional matcher function for resource actions
self.matches: Callable[[object], bool] | None = None

Expand Down Expand Up @@ -302,6 +306,15 @@ def output_schema(self, value: dict[str, object]) -> None:
def is_async(self) -> bool:
return self._is_async

@property
def middleware(self) -> list[Any]: # noqa: ANN401
"""Middleware functions applied at action execution time.

These are model-level middleware set via ``define_model(use=[...])``.
They run after call-time middleware in the dispatch chain.
"""
return self._middleware

def run(
self,
input: InputT | None = None,
Expand Down
6 changes: 5 additions & 1 deletion py/packages/genkit/src/genkit/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import asyncio
import threading
from collections.abc import Awaitable, Callable
from typing import cast
from typing import Any, cast

from dotpromptz.dotprompt import Dotprompt
from pydantic import BaseModel
Expand Down Expand Up @@ -140,6 +140,7 @@ def register_action(
description: str | None = None,
metadata: dict[str, object] | None = None,
span_metadata: dict[str, SpanAttributeValue] | None = None,
middleware: list[Any] | None = None,
) -> Action[InputT, OutputT, ChunkT]:
"""Register a new action with the registry.

Expand All @@ -155,6 +156,8 @@ def register_action(
description: Optional human-readable description of the action.
metadata: Optional dictionary of metadata about the action.
span_metadata: Optional dictionary of tracing span metadata.
middleware: Optional list of middleware functions to apply when
executing this action (e.g., model-level middleware).

Returns:
The newly created and registered Action instance.
Expand All @@ -167,6 +170,7 @@ def register_action(
description=description,
metadata=metadata,
span_metadata=span_metadata,
middleware=middleware,
)
action_typed = cast(Action[InputT, OutputT, ChunkT], action)
with self._lock:
Expand Down
152 changes: 152 additions & 0 deletions py/packages/genkit/tests/genkit/blocks/generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,158 @@ def collect_chunks(c: GenerateResponseChunk) -> None:
]


@pytest.mark.asyncio
async def test_model_level_middleware_applied() -> None:
"""Model-level middleware set via define_model(use=[...]) is applied."""
ai = Genkit()

call_log: list[str] = []

async def model_mw(
req: GenerateRequest,
ctx: ActionRunContext,
next: ModelMiddlewareNext,
) -> GenerateResponse:
call_log.append('model_mw')
txt = ''.join(text_from_message(m) for m in req.messages)
return await next(
GenerateRequest(
messages=[
Message(role=Role.USER, content=[Part(root=TextPart(text=f'MW({txt}))'))]),
],
),
ctx,
)

def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
call_log.append('runner')
merged = ''.join(text_from_message(m) for m in request.messages)
return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[Part(root=TextPart(text=f'[ECHO]{merged}'))],
)
)

ai.define_model(name='mwEchoModel', fn=echo_fn, use=[model_mw])

response = await generate_action(
ai.registry,
GenerateActionOptions(
model='mwEchoModel',
messages=[
Message(role=Role.USER, content=[Part(root=TextPart(text='hello'))]),
],
),
)

assert call_log == ['model_mw', 'runner']
assert 'MW(' in response.text
assert 'hello' in response.text


@pytest.mark.asyncio
async def test_call_time_middleware_runs_before_model_level() -> None:
"""Call-time middleware runs before model-level middleware in dispatch chain.

This matches the JS SDK execution order:
call-time[0..N] → model-level[0..M] → runner
"""
ai = Genkit()

call_order: list[str] = []

async def call_time_mw(
req: GenerateRequest,
ctx: ActionRunContext,
next: ModelMiddlewareNext,
) -> GenerateResponse:
call_order.append('call_time')
return await next(req, ctx)

async def model_mw(
req: GenerateRequest,
ctx: ActionRunContext,
next: ModelMiddlewareNext,
) -> GenerateResponse:
call_order.append('model_level')
return await next(req, ctx)

def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
call_order.append('runner')
merged = ''.join(text_from_message(m) for m in request.messages)
return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[Part(root=TextPart(text=f'[ECHO]{merged}'))],
)
)

ai.define_model(name='mwOrderModel', fn=echo_fn, use=[model_mw])

response = await generate_action(
ai.registry,
GenerateActionOptions(
model='mwOrderModel',
messages=[
Message(role=Role.USER, content=[Part(root=TextPart(text='test'))]),
],
),
middleware=[call_time_mw],
)

# Verify execution order: call-time → model-level → runner
assert call_order == ['call_time', 'model_level', 'runner']
assert response.text is not None


@pytest.mark.asyncio
async def test_multiple_model_level_middleware_chain() -> None:
"""Multiple model-level middleware chain in order."""
ai = Genkit()

call_order: list[str] = []

async def model_mw_a(
req: GenerateRequest,
ctx: ActionRunContext,
next: ModelMiddlewareNext,
) -> GenerateResponse:
call_order.append('model_a')
return await next(req, ctx)

async def model_mw_b(
req: GenerateRequest,
ctx: ActionRunContext,
next: ModelMiddlewareNext,
) -> GenerateResponse:
call_order.append('model_b')
return await next(req, ctx)

def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse:
call_order.append('runner')
return GenerateResponse(
message=Message(
role=Role.MODEL,
content=[Part(root=TextPart(text='done'))],
)
)

ai.define_model(name='multiMwModel', fn=echo_fn, use=[model_mw_a, model_mw_b])

await generate_action(
ai.registry,
GenerateActionOptions(
model='multiMwModel',
messages=[
Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]),
],
),
)

assert call_order == ['model_a', 'model_b', 'runner']


##########################################################################
# run tests from /tests/specs/generate.yaml
##########################################################################
Expand Down
Loading
Loading