diff --git a/py/PARITY_AUDIT.md b/py/PARITY_AUDIT.md index 1a2b5ca351..6cbfadac83 100644 --- a/py/PARITY_AUDIT.md +++ b/py/PARITY_AUDIT.md @@ -481,7 +481,7 @@ Full plugin list from the repository README (10 plugins, 33 contributors, 54 rel | Gap ID | SDK | Work Item | Reference | Status | |--------|-----|-----------|-----------|:------:| -| G2 → G1 | Python | Add `middleware` storage to `Action`, then add `use=` to `define_model` | §8b.1 | ⬜ | +| G2 → G1 | Python | Add `middleware` storage to `Action`, then add `use=` to `define_model` | §8b.1 | ✅ Done | | G7 | Python | Wire DAP action discovery into `GET /api/actions` | §8a, §8c.5 | ⏳ Deferred | | G6 → G5 | Python | Pass `span_id` in `on_trace_start`, send `X-Genkit-Span-Id` | §8c.3, §8c.4 | ⬜ | | G3 | Python | Implement `simulate_constrained_generation` middleware | §8b.3, §8f | ⬜ | @@ -496,6 +496,7 @@ Full plugin list from the repository README (10 plugins, 33 contributors, 54 rel | G21 | Python | Add `clientHeader` parameter to `Genkit()` constructor | §8j | ⬜ | | G22 | Python | Add `name` parameter to `Genkit()` constructor | §8j | ⬜ | | G4 | Python | Move `augment_with_context` to define-model time | §8b.2 | ⬜ | +| G38 | Python | Implement `get_model_middleware()` auto-wiring (like JS `getModelMiddleware()`) | §8f.1 | ⬜ | | G9 | Python | Add Pinecone vector store plugin | §5g | ⬜ | | G10 | Python | Add ChromaDB vector store plugin | §5g | ⬜ | | G30 | Python | Add Cloud SQL PG vector store parity | §5g | ⬜ | @@ -854,6 +855,93 @@ export function simulateSystemPrompt(options?: { **Impact**: Models without native system prompt support (e.g., some older or fine-tuned models) get automatic simulation in JS but not in Python. +### 8f.1. Model Middleware vs Plugin Inline Implementation — Impact Analysis + +**Status**: 🔍 Analysis Complete — Action Items Identified + +The JS SDK's `getModelMiddleware()` function (in `model.ts:337-358`) **automatically** +wires built-in middleware into every model at `defineModel()` time based on the model's +`supports` metadata. This means JS plugins do **not** need to implement these concerns +themselves — the framework handles them. + +In contrast, Python plugins currently implement some of these concerns **inline** in their +model runner functions. Now that Python has `define_model(use=[...])`, we can optionally +migrate these plugins to use middleware instead, but more importantly we need to implement +the **auto-wiring** pattern from JS. + +#### JS `getModelMiddleware()` auto-wiring + +```typescript +// js/ai/src/model.ts:337-358 +function getModelMiddleware(options) { + const middleware = options.use || []; // user-supplied + if (!options?.supports?.context) + middleware.push(augmentWithContext()); // auto-add context MW + const constrainedSim = simulateConstrainedGeneration(); + middleware.push((req, next) => { // auto-add constrained MW + if (!options?.supports?.constrained || ...) + return constrainedSim(req, next); + return next(req); + }); + return middleware; +} +``` + +This means in JS, **every** model automatically gets: +1. `augmentWithContext()` — unless `supports.context` is true +2. `simulateConstrainedGeneration()` — unless `supports.constrained` is truthy and compatible + +#### Overlap Analysis: What Python Plugins Do Inline Today + +| Concern | JS (Framework MW) | Python Plugin Inline? | Plugins Affected | Migration Path | +|---------|--------------------|-----------------------|------------------|----------------| +| **System prompt simulation** | `simulateSystemPrompt()` | ❌ Not inline — plugins either support `system_role: true` natively or simply don't handle it | All plugins declare `system_role` in `ModelInfo.supports` | Implement `simulateSystemPrompt()` middleware (G16); auto-wire for models with `system_role=False` | +| **Constrained generation** | `simulateConstrainedGeneration()` | ❌ Not inline — plugins rely on `output.schema` passthrough | google-genai, compat-oai, anthropic, ollama | Implement `simulateConstrainedGeneration()` middleware (G3); auto-wire based on `supports.constrained` | +| **Media download (URL→base64)** | `downloadRequestMedia()` | ⚠️ Some plugins handle this inline | google-genai (partially), compat-oai | Implement `downloadRequestMedia()` middleware (G15); plugins can remove inline handling | +| **Context augmentation** | `augmentWithContext()` | ✅ Already middleware in Python | — | ✅ Already done, but wired in `generate_action()` instead of at define-model time (G4) | +| **Support validation** | `validateSupport()` | ❌ Not implemented | All plugins | Implement `validateSupport()` middleware (G14); auto-wire at define-model time | +| **Retry** | `retry()` | ❌ Not implemented — users must wrap generate() | — | Implement as standalone middleware (G12), not auto-wired | +| **Fallback** | `fallback()` | ❌ Not implemented — users must manually try/catch | — | Implement as standalone middleware (G13), not auto-wired | + +#### Key Architectural Gap: Auto-Wiring (New Gap G38) + +The most important missing piece is **auto-wiring**: Python's `define_model()` should +automatically inject middleware based on the model's `ModelInfo.supports` metadata, just +like JS's `getModelMiddleware()`. This is separate from (and more important than) the +individual middleware implementations. + +**What needs to happen:** + +1. **G38** (New): Implement `get_model_middleware()` helper in Python that mirrors JS + - Read `supports.context` → auto-add `augment_with_context()` (currently done in `generate_action()`, should move to define-model time) + - Read `supports.constrained` → auto-add `simulate_constrained_generation()` when implemented (G3) + - Read `supports.system_role` → auto-add `simulate_system_prompt()` when implemented (G16) + - Prepend auto-wired middleware **after** user-supplied `use=[...]` middleware + +2. **Plugin impact: None for now** + - Since Python plugins don't currently implement these concerns inline, adding the + auto-wiring middleware won't conflict with existing plugin logic + - Plugins only need to ensure their `ModelInfo.supports` metadata is accurate + - No breaking changes required + +3. **Plugin impact: Future** + - Once `downloadRequestMedia()` (G15) is implemented as middleware, plugins that + currently handle media download inline can remove that code + - This is a simplification, not a breaking change + +#### Updated Gap Priority + +| Gap | Depends On | Impact | +|-----|-----------|--------| +| **G38** (New) | G2 ✅, G1 ✅ | Auto-wire middleware at define-model time — critical for parity | +| G3 | G38 | `simulateConstrainedGeneration()` — auto-wired via G38 | +| G16 | G38 | `simulateSystemPrompt()` — auto-wired via G38 | +| G4 | G38 | Move `augmentWithContext()` from generate-time to define-model time | +| G14 | G38 | `validateSupport()` — auto-wired via G38 | +| G15 | G38 | `downloadRequestMedia()` — auto-wired via G38; plugins can remove inline code | +| G12 | G2 ✅ | `retry()` — standalone, not auto-wired | +| G13 | G2 ✅ | `fallback()` — standalone, not auto-wired | + ### 8g. Context Providers — Built-in Helpers **Status**: ⚠️ Minor Gap @@ -1023,8 +1111,8 @@ export interface GenkitOptions { | Gap ID | SDK | Gap | Priority | Primary Files to Touch | Fast Validation | |--------|-----|-----|:--------:|------------------------|-----------------| -| G1 | Python | `define_model(use=[...])` missing | P1 | `py/packages/genkit/src/genkit/ai/_registry.py` | unit: model registration accepts and stores `use` | -| G2 | Python | Action-level middleware storage missing | P1 | `py/packages/genkit/src/genkit/core/action/_action.py` | unit: middleware chain wraps action execution | +| G1 | Python | `define_model(use=[...])` ~~missing~~ **done** | P1 | `py/packages/genkit/src/genkit/ai/_registry.py` | unit: model registration accepts and stores `use` | +| G2 | Python | Action-level middleware storage ~~missing~~ **done** | P1 | `py/packages/genkit/src/genkit/core/action/_action.py` | unit: middleware chain wraps action execution | | G3 | Python | `simulate_constrained_generation` missing | P1 | `py/packages/genkit/src/genkit/blocks/middleware.py` | unit: constrained request on unsupported model rewrites prompt | | G4 | Python | `augment_with_context` lifecycle mismatch | P2 | `py/packages/genkit/src/genkit/blocks/generate.py`, `.../blocks/model.py` | parity test: same middleware ordering as JS | | G5 | Python | `X-Genkit-Span-Id` header missing | P1 | `py/packages/genkit/src/genkit/core/reflection.py` | integration: reflection response exposes span header | @@ -1196,7 +1284,7 @@ Reverse topological sort of the gap DAG yields the following dependency levels. | ID | Gap | Work Item | Files to Touch | Effort | Unblocks | |----|-----|-----------|----------------|:------:|----------| -| **P1.1** | **G2** | Add `middleware` storage to `Action` class; implement `action_with_middleware()` wrapper that chains model-level middleware around `action.run()` | `core/action/_action.py` | L | G1, G12, G13, G15, G19 | +| **P1.1** | **G2** | ~~Add `middleware` storage to `Action` class; implement `action_with_middleware()` wrapper that chains model-level middleware around `action.run()`~~ **Done** — `Action.__init__(middleware=...)`, `Action.middleware` property, `register_action(middleware=...)`, `define_model(use=[...])`, `dispatch()` chains model middleware after call-time middleware | `core/action/_action.py`, `core/registry.py`, `ai/_registry.py`, `blocks/generate.py` | L | ~~G1, G12, G13, G15, G19~~ | | **P1.2** | **G6** | Update `on_trace_start` callback signature to `(trace_id: str, span_id: str)` throughout action system | `core/action/_action.py`, `core/reflection.py`, `core/trace/` | S | G5 | | **P1.3** | **G18** | Add multipart tool support: `define_tool(multipart=True)`, `MultipartToolAction` type `tool.v2`, dual registration for non-multipart tools | `blocks/tools.py`, `blocks/generate.py` | M | — | | **P1.4** | **G20** | Add `context` parameter to `Genkit()` that sets `registry.context` for default action context | `ai/_aio.py` | XS | — | @@ -1437,7 +1525,7 @@ Milestone ▲ P1 infra ▲ Middleware ▲ Full P1 ▲ Client | PR | Scope | Gaps | Contents | Depends On | |----|:-----:|------|----------|:----------:| -| **PR-1a** | Core | G2 | Add `middleware` list to `Action.__init__()`, implement `action_with_middleware()` dispatch wrapper, unit tests for middleware chaining | — | +| **PR-1a** | Core | G2 | ~~Add `middleware` list to `Action.__init__()`, implement `action_with_middleware()` dispatch wrapper, unit tests for middleware chaining~~ **Done** — `Action(middleware=...)`, `Action.middleware`, `register_action(middleware=...)`, `define_model(use=[...])`, `dispatch()` chains model-level mw after call-time mw, 3 tests | — | | **PR-1b** | Core | G6 | Update `on_trace_start` callback signature to `(trace_id, span_id)` across action system + tracing, update all call sites | — | | **PR-1c** | Core | G18 | Multipart tool support: `define_tool(multipart=True)`, `tool.v2` action type, dual registration for non-multipart tools, unit tests | — | | **PR-1d** | Core | G20, G21 | `Genkit(context=..., client_header=...)` constructor params — small additive changes, can combine in one PR | — | @@ -1448,7 +1536,7 @@ Milestone ▲ P1 infra ▲ Middleware ▲ Full P1 ▲ Client | PR | Scope | Gaps | Contents | Depends On | |----|:-----:|------|----------|:----------:| -| **PR-2a** | Core | G1 | Add `use` param to `define_model()`, wire to `action_with_middleware()`, build `get_model_middleware()` helper, tests | PR-1a | +| **PR-2a** | Core | G1 | ~~Add `use` param to `define_model()`, wire to `action_with_middleware()`, build `get_model_middleware()` helper, tests~~ **Done** — merged with PR-1a (G2) | ~~PR-1a~~ | | **PR-2b** | Core | G5 | Emit `X-Genkit-Span-Id` response header in reflection server (small, ~20 lines) | PR-1b | | **PR-2c** | Core | G12 | `retry()` middleware — exponential backoff, jitter, configurable statuses, `on_error` callback, dedicated test suite | PR-1a | | **PR-2d** | Core | G13 | `fallback()` middleware — ordered model list, error status config, `on_error` callback, dedicated test suite | PR-1a | @@ -1534,3 +1622,107 @@ The current `yesudeep/feat/checks-plugin` branch bundles 32 changed files spanni | Estimated calendar time to active P2 closure | ~9 weeks | | Plugins needing test uplift | 13 of 20 | | New test files needed (est.) | ~40–50 across all plugins | + +### 10j. Open PR Merge Order — Dependency Graph + +**Current open PRs** (as of 2026-02-08): + +``` +INDEPENDENT (merge in any order) +════════════════════════════════════════════════════════════════════ + + #4494 fix: RedactedSpan dropped_* overrides + Files: core/trace/adjusting_exporter.py + Conflicts: NONE → merge anytime + + #4504 feat: Google Checks AI Safety plugin + Files: plugins/checks/*, samples/provider-checks-hello/* + Conflicts: NONE (new plugin/sample) → merge anytime + + #4510 feat: Model middleware functions (retry, fallback, etc.) + Files: blocks/middleware.py + Conflicts: NONE → merge anytime + + +LAYER 1 — merge these next (independent of each other) +════════════════════════════════════════════════════════════════════ + + #4495 fix: Prompt recursion fix + Files: blocks/prompt.py, core/registry.py + Conflicts with: #4512, #4516 (core/registry.py) + → merge BEFORE #4512 and #4516, or rebase them after + + #4514 fix: Transfer-Encoding chunked + Files: core/reflection.py + Conflicts with: #4401 (core/reflection.py) + → merge BEFORE #4401 (simpler fix, easier to rebase #4401) + + +LAYER 2 — merge after Layer 1 +════════════════════════════════════════════════════════════════════ + + #4401 feat: Reflection API v2 + Files: ai/_base_async.py, core/constants.py, core/reflection.py + Conflicts with: #4512 (ai/_base_async.py, core/constants.py) + #4514 (core/reflection.py) + → merge AFTER #4514, BEFORE #4512 + + #4513 feat: Multipart tool support (tool.v2) + Files: ai/_registry.py, blocks/generate.py, core/action/types.py + Conflicts with: #4516 (ai/_registry.py, blocks/generate.py) + → merge BEFORE #4516, then rebase #4516 + + +LAYER 3 — merge last (most file overlaps) +════════════════════════════════════════════════════════════════════ + + #4512 feat: Genkit constructor parity (context, name, client_header) + Files: ai/_aio.py, ai/_base_async.py, ai/_runtime.py, + core/constants.py, core/registry.py + Conflicts with: #4401 (ai/_base_async.py, core/constants.py) + #4495 (core/registry.py) + #4516 (core/registry.py) + → merge AFTER #4401 and #4495 + + #4516 feat: Model-level middleware (define_model use=[...]) + Files: ai/_registry.py, blocks/generate.py, + core/action/_action.py, core/registry.py + Conflicts with: #4495 (core/registry.py) + #4512 (core/registry.py) + #4513 (ai/_registry.py, blocks/generate.py) + → merge LAST (after #4495, #4512, #4513) +``` + +**Recommended merge sequence:** + +``` + ┌─ #4494 (RedactedSpan) ──┐ + ├─ #4504 (Checks plugin) ──┤ + ├─ #4510 (MW functions) ──┤── all independent, merge in parallel + ├─ #4495 (Prompt fix) ──┤ + └─ #4514 (Transfer-Enc) ──┘ + │ + ▼ + #4401 (Reflection v2) ← rebase onto #4514 + #4513 (Multipart tools) ← independent of #4401 + │ + ▼ + #4512 (Constructor) ← rebase onto #4401 + #4495 + #4516 (MW storage) ← rebase onto #4513 + #4495 + #4512 +``` + +**File conflict matrix:** + +| File | #4401 | #4494 | #4495 | #4504 | #4510 | #4512 | #4513 | #4514 | #4516 | +|------|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:|:-----:| +| `ai/_base_async.py` | ✏️ | | | | | ✏️ | | | | +| `ai/_registry.py` | | | | | | | ✏️ | | ✏️ | +| `blocks/generate.py` | | | | | | | ✏️ | | ✏️ | +| `blocks/middleware.py` | | | | | ✏️ | | | | | +| `core/action/_action.py` | | | | | | | | | ✏️ | +| `core/action/types.py` | | | | | | | ✏️ | | | +| `core/constants.py` | ✏️ | | | | | ✏️ | | | | +| `core/reflection.py` | ✏️ | | | | | | | ✏️ | | +| `core/registry.py` | | | ✏️ | | | ✏️ | | | ✏️ | +| `core/trace/*` | | ✏️ | | | | | | | | +| `plugins/checks/*` | | | | ✏️ | | | | | | diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index 12f6f57d1a..cfbfbef7e3 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -1008,6 +1008,7 @@ def define_model( metadata: dict[str, object] | None = None, info: ModelInfo | None = None, description: str | None = None, + use: list[ModelMiddleware] | None = None, ) -> Action: """Define a custom model action. @@ -1018,6 +1019,10 @@ def define_model( metadata: Optional metadata for the model. info: Optional ModelInfo for the model. description: Optional description for the model. + use: Optional list of model-level middleware to apply when + this model is invoked. These run after call-time middleware + (passed via ``generate(use=[...])``) in the dispatch chain, + matching JS SDK's ``defineModel({use: [mw]})``. """ # Build model options dict model_options: dict[str, object] = {} @@ -1058,6 +1063,7 @@ def define_model( fn=fn, metadata=model_meta, description=model_description, + middleware=use, ) def define_background_model( diff --git a/py/packages/genkit/src/genkit/blocks/generate.py b/py/packages/genkit/src/genkit/blocks/generate.py index 8fd33f2b0b..2b87e4ab3c 100644 --- a/py/packages/genkit/src/genkit/blocks/generate.py +++ b/py/packages/genkit/src/genkit/blocks/generate.py @@ -212,6 +212,13 @@ def wrapper(chunk: GenerateResponseChunk) -> None: if raw_request.docs and not supports_context: middleware.append(augment_with_context()) + # Build the combined middleware chain: call-time middleware first, + # then model-level middleware (from define_model(use=[...])). + # This matches JS SDK execution order: + # call-time[0..N] → model-level[0..M] → runner + model_middleware: list[ModelMiddleware] = cast(list[ModelMiddleware], model.middleware) + combined_middleware: list[ModelMiddleware] = list(middleware) + model_middleware + async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse: """Dispatches model request, passing it through middleware if present. @@ -223,7 +230,7 @@ async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> G Returns: The generated response. """ - if not middleware or index == len(middleware): + if index == len(combined_middleware): # end of the chain, call the original model action return ( await model.arun( @@ -233,7 +240,7 @@ async def dispatch(index: int, req: GenerateRequest, ctx: ActionRunContext) -> G ) ).response - current_middleware = middleware[index] + current_middleware = combined_middleware[index] async def next_fn( modified_req: GenerateRequest | None = None, diff --git a/py/packages/genkit/src/genkit/blocks/middleware.py b/py/packages/genkit/src/genkit/blocks/middleware.py index 30aa8cf4cf..19d8518451 100644 --- a/py/packages/genkit/src/genkit/blocks/middleware.py +++ b/py/packages/genkit/src/genkit/blocks/middleware.py @@ -14,7 +14,44 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Middleware for the Genkit framework.""" +"""Model middleware for the Genkit framework. + +This module contains **Layer 4 (Model Middleware)** — the primary user-facing +middleware in Genkit. Model middleware intercepts and transforms +``GenerateRequest`` and ``GenerateResponse`` objects in a chain around the +model runner. + +Genkit has four distinct middleware layers, each at a different level: + +- **Layer 1 — ASGI / HTTP Middleware** (Starlette, FastAPI, etc.): + Runs on every HTTP request. Use for CORS, rate limiting, security headers. + +- **Layer 2 — Context Providers** (``Genkit(context=...)``): + Extracts auth/context from the HTTP request before every action. + Use for API key validation, JWT parsing, user session extraction. + +- **Layer 3 — Action Middleware** (``Action.use`` in JS core): + Wraps any action type. Primarily internal framework wiring in Python. + +- **Layer 4 — Model Middleware** (this module): + Runs only on model calls via ``generate()`` or ``prompt.generate()``. + Use for safety guardrails, retry, fallback, constrained generation + simulation, system prompt simulation, and media downloading. + +Model middleware can be applied at three levels: + +- **Call-time**: ``generate(use=[mw])`` — per-request. +- **Model-level**: ``define_model(use=[mw])`` — baked into the model. +- **Auto-wired**: Injected at model definition time based on model + capabilities (e.g., ``augment_with_context`` for models without + native context support). + +Execution order: call-time → model-level → auto-wired → model runner. + +See Also: + ``genkit.blocks.model``: ``ModelMiddleware`` type definition. + ``genkit.blocks.generate``: Middleware dispatch in ``generate_action()``. +""" from genkit.blocks.model import ( ModelMiddleware, diff --git a/py/packages/genkit/src/genkit/core/action/_action.py b/py/packages/genkit/src/genkit/core/action/_action.py index b507d9fe57..9d96dd4a21 100644 --- a/py/packages/genkit/src/genkit/core/action/_action.py +++ b/py/packages/genkit/src/genkit/core/action/_action.py @@ -225,6 +225,7 @@ def __init__( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, + middleware: list[Any] | None = None, ) -> None: """Initialize an Action. @@ -237,12 +238,15 @@ def __init__( description: Optional human-readable description of the action. metadata: Optional dictionary of metadata about the action. span_metadata: Optional dictionary of tracing span metadata. + middleware: Optional list of middleware functions to be applied + when executing this action. """ self._kind: ActionKind = kind self._name: str = name self._metadata: dict[str, object] = metadata if metadata else {} self._description: str | None = description self._is_async: bool = inspect.iscoroutinefunction(fn) + self._middleware: list[Any] = middleware if middleware else [] # Optional matcher function for resource actions self.matches: Callable[[object], bool] | None = None @@ -302,6 +306,15 @@ def output_schema(self, value: dict[str, object]) -> None: def is_async(self) -> bool: return self._is_async + @property + def middleware(self) -> list[Any]: # noqa: ANN401 + """Middleware functions applied at action execution time. + + These are model-level middleware set via ``define_model(use=[...])``. + They run after call-time middleware in the dispatch chain. + """ + return self._middleware + def run( self, input: InputT | None = None, diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index ac7aa75986..d0fd049daf 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -30,7 +30,7 @@ import asyncio import threading from collections.abc import Awaitable, Callable -from typing import cast +from typing import Any, cast from dotpromptz.dotprompt import Dotprompt from pydantic import BaseModel @@ -140,6 +140,7 @@ def register_action( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, + middleware: list[Any] | None = None, ) -> Action[InputT, OutputT, ChunkT]: """Register a new action with the registry. @@ -155,6 +156,8 @@ def register_action( description: Optional human-readable description of the action. metadata: Optional dictionary of metadata about the action. span_metadata: Optional dictionary of tracing span metadata. + middleware: Optional list of middleware functions to apply when + executing this action (e.g., model-level middleware). Returns: The newly created and registered Action instance. @@ -167,6 +170,7 @@ def register_action( description=description, metadata=metadata, span_metadata=span_metadata, + middleware=middleware, ) action_typed = cast(Action[InputT, OutputT, ChunkT], action) with self._lock: diff --git a/py/packages/genkit/tests/genkit/blocks/generate_test.py b/py/packages/genkit/tests/genkit/blocks/generate_test.py index 17fb0a4d55..faef7c15af 100644 --- a/py/packages/genkit/tests/genkit/blocks/generate_test.py +++ b/py/packages/genkit/tests/genkit/blocks/generate_test.py @@ -355,6 +355,158 @@ def collect_chunks(c: GenerateResponseChunk) -> None: ] +@pytest.mark.asyncio +async def test_model_level_middleware_applied() -> None: + """Model-level middleware set via define_model(use=[...]) is applied.""" + ai = Genkit() + + call_log: list[str] = [] + + async def model_mw( + req: GenerateRequest, + ctx: ActionRunContext, + next: ModelMiddlewareNext, + ) -> GenerateResponse: + call_log.append('model_mw') + txt = ''.join(text_from_message(m) for m in req.messages) + return await next( + GenerateRequest( + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text=f'MW({txt}))'))]), + ], + ), + ctx, + ) + + def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse: + call_log.append('runner') + merged = ''.join(text_from_message(m) for m in request.messages) + return GenerateResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=TextPart(text=f'[ECHO]{merged}'))], + ) + ) + + ai.define_model(name='mwEchoModel', fn=echo_fn, use=[model_mw]) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='mwEchoModel', + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text='hello'))]), + ], + ), + ) + + assert call_log == ['model_mw', 'runner'] + assert 'MW(' in response.text + assert 'hello' in response.text + + +@pytest.mark.asyncio +async def test_call_time_middleware_runs_before_model_level() -> None: + """Call-time middleware runs before model-level middleware in dispatch chain. + + This matches the JS SDK execution order: + call-time[0..N] → model-level[0..M] → runner + """ + ai = Genkit() + + call_order: list[str] = [] + + async def call_time_mw( + req: GenerateRequest, + ctx: ActionRunContext, + next: ModelMiddlewareNext, + ) -> GenerateResponse: + call_order.append('call_time') + return await next(req, ctx) + + async def model_mw( + req: GenerateRequest, + ctx: ActionRunContext, + next: ModelMiddlewareNext, + ) -> GenerateResponse: + call_order.append('model_level') + return await next(req, ctx) + + def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse: + call_order.append('runner') + merged = ''.join(text_from_message(m) for m in request.messages) + return GenerateResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=TextPart(text=f'[ECHO]{merged}'))], + ) + ) + + ai.define_model(name='mwOrderModel', fn=echo_fn, use=[model_mw]) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='mwOrderModel', + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text='test'))]), + ], + ), + middleware=[call_time_mw], + ) + + # Verify execution order: call-time → model-level → runner + assert call_order == ['call_time', 'model_level', 'runner'] + assert response.text is not None + + +@pytest.mark.asyncio +async def test_multiple_model_level_middleware_chain() -> None: + """Multiple model-level middleware chain in order.""" + ai = Genkit() + + call_order: list[str] = [] + + async def model_mw_a( + req: GenerateRequest, + ctx: ActionRunContext, + next: ModelMiddlewareNext, + ) -> GenerateResponse: + call_order.append('model_a') + return await next(req, ctx) + + async def model_mw_b( + req: GenerateRequest, + ctx: ActionRunContext, + next: ModelMiddlewareNext, + ) -> GenerateResponse: + call_order.append('model_b') + return await next(req, ctx) + + def echo_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse: + call_order.append('runner') + return GenerateResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=TextPart(text='done'))], + ) + ) + + ai.define_model(name='multiMwModel', fn=echo_fn, use=[model_mw_a, model_mw_b]) + + await generate_action( + ai.registry, + GenerateActionOptions( + model='multiMwModel', + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]), + ], + ), + ) + + assert call_order == ['model_a', 'model_b', 'runner'] + + ########################################################################## # run tests from /tests/specs/generate.yaml ########################################################################## diff --git a/py/samples/framework-middleware-demo/README.md b/py/samples/framework-middleware-demo/README.md index 9a651e289d..2036389e05 100644 --- a/py/samples/framework-middleware-demo/README.md +++ b/py/samples/framework-middleware-demo/README.md @@ -1,8 +1,9 @@ # Middleware Demo -Demonstrates Genkit's middleware system using the `use=` parameter on -`ai.generate()`. Middleware intercepts the request/response pipeline, -enabling logging, retries, request modification, and more. +Demonstrates Genkit's middleware system at two levels: + +1. **Call-time middleware** — attached per `generate()` call via `use=[...]` +2. **Model-level middleware** — baked into a model via `define_model(use=[...])` ## Quick Start @@ -17,29 +18,61 @@ Then open the Dev UI at http://localhost:4000. | Flow | What It Demonstrates | |------|---------------------| -| `logging_demo` | Middleware that logs request metadata and response info | -| `request_modifier_demo` | Middleware that modifies the request before it reaches the model | -| `chained_middleware_demo` | Multiple middleware functions composed in a pipeline | +| `logging_demo` | Call-time middleware that logs request and response metadata | +| `request_modifier_demo` | Call-time middleware that modifies the request before the model sees it | +| `chained_middleware_demo` | Multiple call-time middleware composed in a pipeline | +| `model_level_middleware_demo` | Model-level middleware set via `define_model(use=[...])` | +| `combined_middleware_demo` | Both call-time and model-level middleware running together | ## How Middleware Works +### Call-Time Middleware + +Passed directly to `generate()`. Runs first in the chain. + +```python +response = await ai.generate( + prompt='Hello', + use=[logging_middleware, system_instruction_middleware], +) +``` + +### Model-Level Middleware + +Baked into a model at registration time. Every caller gets it automatically. + +```python +ai.define_model( + name='custom/safe-model', + fn=my_model_fn, + use=[safety_middleware], +) + +# safety_middleware runs automatically — no need for use=[...] +response = await ai.generate(model='custom/safe-model', prompt='Hello') +``` + +### Execution Order + +When both are used, call-time middleware runs first, then model-level: + ``` -ai.generate(prompt=..., use=[middleware_a, middleware_b]) +ai.generate(model='custom/safe-model', prompt=..., use=[call_mw]) | v - middleware_a(req, ctx, next) + call_mw(req, ctx, next) ← call-time middleware | v - middleware_b(req, ctx, next) + safety_mw(req, ctx, next) ← model-level middleware | v - Model (actual API call) + my_model_fn(req, ctx) ← model runner | v - middleware_b returns response + safety_mw returns response | v - middleware_a returns response + call_mw returns response | v Final response to caller diff --git a/py/samples/framework-middleware-demo/src/main.py b/py/samples/framework-middleware-demo/src/main.py index 3da6983bb1..747121217f 100644 --- a/py/samples/framework-middleware-demo/src/main.py +++ b/py/samples/framework-middleware-demo/src/main.py @@ -251,9 +251,161 @@ async def chained_middleware_demo(input: ChainedInput) -> str: return response.text +# ============================================================================ +# PART 2: Model-Level Middleware via define_model(use=[...]) +# ============================================================================ +# Model-level middleware is baked into a model at registration time. +# Every caller of this model gets the middleware automatically, without +# needing to pass use=[...] in generate(). +# +# This is how plugin authors add cross-cutting concerns like safety +# checks, rate limiting, or request augmentation to their models. +# ============================================================================ + + +class ModelLevelInput(BaseModel): + """Input for model-level middleware demo.""" + + prompt: str = Field( + default='Tell me something interesting about Python', + description='Prompt to send to the model with baked-in middleware', + ) + + +class CombinedInput(BaseModel): + """Input for combined call-time + model-level middleware demo.""" + + prompt: str = Field( + default='Write a limerick about coding', + description='Prompt to send through both call-time and model-level middleware', + ) + + +async def safety_prefix_middleware( + req: GenerateRequest, + ctx: ActionRunContext, + next_handler: ModelMiddlewareNext, +) -> GenerateResponse: + """Model-level middleware that prepends a safety instruction. + + This middleware is baked into the model via define_model(use=[...]). + Every generate() call using this model will automatically get + the safety instruction injected, even without passing use=[...]. + + Args: + req: The generation request about to be sent. + ctx: The action execution context. + next_handler: Calls the next middleware or the model. + + Returns: + The generation response. + """ + safety_text = 'You are a helpful, harmless, and honest assistant. Never produce harmful content.' + safety_message = Message( + role=Role.SYSTEM, + content=[Part(root=TextPart(text=safety_text))], + ) + modified_messages = [safety_message, *req.messages] + modified_req = req.model_copy(update={'messages': modified_messages}) + await logger.ainfo('safety_prefix_middleware (model-level): injected safety system message') + return await next_handler(modified_req, ctx) + + +# Register a custom model that wraps Gemini and adds safety middleware. +# The actual model function delegates to the real Gemini model, but +# the safety middleware runs before every request. +def custom_model_fn(request: GenerateRequest, ctx: ActionRunContext) -> GenerateResponse: + """Custom model runner that delegates to the real Gemini model. + + This function demonstrates how define_model() works: you implement + the model runner function, and Genkit handles all the middleware + chaining, tracing, and registry management. + + Args: + request: The generation request (possibly modified by middleware). + ctx: The action execution context. + + Returns: + The generation response from the underlying model. + """ + # Build a response echoing the request for demonstration purposes. + # In a real plugin, you'd call an API here. + merged = ' '.join(str(p.root.text) for m in request.messages for p in m.content if p.root.text) + echo_text = f'[custom-model] Processed request with {len(request.messages)} messages. Content: {merged[:100]}...' + return GenerateResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=TextPart(text=echo_text))], + ), + ) + + +# Register the custom model WITH model-level middleware. +# Every call to generate(model='custom/safe-model') will run +# safety_prefix_middleware automatically. +ai.define_model( + name='custom/safe-model', + fn=custom_model_fn, + use=[safety_prefix_middleware], +) + + +@ai.flow() +async def model_level_middleware_demo(input: ModelLevelInput) -> str: + """Demonstrate model-level middleware set via define_model(use=[...]). + + No call-time middleware is passed -- the safety middleware is baked + into the 'custom/safe-model' model definition. Every caller gets + the middleware automatically. + + Args: + input: Input with prompt text. + + Returns: + The model's response text (with safety middleware applied). + """ + response = await ai.generate( + model='custom/safe-model', + prompt=input.prompt, + ) + return response.text + + +@ai.flow() +async def combined_middleware_demo(input: CombinedInput) -> str: + """Demonstrate call-time + model-level middleware running together. + + The execution order is: + 1. logging_middleware (call-time, from generate(use=[...])) + 2. safety_prefix_middleware (model-level, from define_model(use=[...])) + 3. custom_model_fn (the actual model runner) + + This matches the JS SDK execution order: + call-time[0..N] -> model-level[0..M] -> runner + + Args: + input: Input with prompt text. + + Returns: + The model's response text. + """ + response = await ai.generate( + model='custom/safe-model', + prompt=input.prompt, + use=[logging_middleware], + ) + return response.text + + async def main() -> None: """Main function -- keep alive for Dev UI.""" await logger.ainfo('Middleware demo started. Open http://localhost:4000 to test flows.') + await logger.ainfo('Flows available:') + await logger.ainfo(' - logging_demo: call-time logging middleware') + await logger.ainfo(' - request_modifier_demo: call-time request modification') + await logger.ainfo(' - chained_middleware_demo: multiple call-time middleware') + await logger.ainfo(' - model_level_middleware_demo: model-level middleware via define_model(use=[...])') + await logger.ainfo(' - combined_middleware_demo: call-time + model-level middleware together') while True: await asyncio.sleep(3600)