diff --git a/src/oss/langchain/middleware.mdx b/src/oss/langchain/middleware.mdx index 929b7ff5b8..bf186c6cb5 100644 --- a/src/oss/langchain/middleware.mdx +++ b/src/oss/langchain/middleware.mdx @@ -299,6 +299,8 @@ const agent = createAgent({ **Important:** Human-in-the-loop middleware requires a [checkpointer](/oss/langgraph/persistence#checkpoints) to maintain state across interruptions. + This middleware uses the `after_model` hook to interrupt execution after the model generates tool calls but before they execute. This ensures that on resume, tools don't re-execute. See [Using interrupts in middleware](#using-interrupts-in-middleware) for details on which hooks support interrupts. + See the [human-in-the-loop documentation](/oss/langchain/human-in-the-loop) for complete examples and integration patterns. @@ -1997,6 +1999,181 @@ const conditionalMiddleware = createMiddleware({ ``` ::: +### Using interrupts in middleware + +Middleware can use @[interrupts](/oss/langgraph/interrupts) to pause execution and wait for external input. However, **not all hooks are suitable for interrupts** due to how resumption works. + +**Safe to interrupt:** +- `before_agent` - Interrupt before the agent loop starts +- `after_agent` - Interrupt after the agent loop ends +- `before_model` - Interrupt before the model call +- `after_model` - Interrupt after the model generates tool calls but before execution (used by [`HumanInTheLoopMiddleware`](/oss/langchain/human-in-the-loop)) + +**Avoid interrupting in:** +- `wrap_model_call` +- `wrap_tool_call` + +#### Example: Interrupting for tool call validation + +If you need to validate tool calls before they execute, use `after_model` to inspect the model's response. This example follows the same pattern as [`HumanInTheLoopMiddleware`](/oss/langchain/human-in-the-loop): + +:::python +```python expandable +from langchain.agents.middleware import ( + AgentMiddleware, + AgentState, + ActionRequest, + ReviewConfig, + HITLRequest, + HITLResponse, +) +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.types import interrupt +from langgraph.runtime import Runtime +from typing import Any + +class ToolValidationMiddleware(AgentMiddleware): + def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + """Validate tool calls before execution.""" + # Find the last AI message with tool calls + last_ai_msg = next( + (msg for msg in reversed(state["messages"]) if isinstance(msg, AIMessage)), + None + ) + + if not last_ai_msg or not last_ai_msg.tool_calls: + return None + + # Filter tool calls that need validation + tool_calls_to_validate = [ + tc for tc in last_ai_msg.tool_calls + if tc["name"] == "dangerous_tool" + ] + + if not tool_calls_to_validate: + return None + + # Create HITL request + hitl_request: HITLRequest = { + "action_requests": [ + { + "name": tc["name"], + "args": tc["args"], + "description": f"Tool execution requires approval: {tc['name']}" + } + for tc in tool_calls_to_validate + ], + "review_configs": [ + { + "action_name": tc["name"], + "allowed_decisions": ["approve", "reject"] + } + for tc in tool_calls_to_validate + ] + } + + # Interrupt for human approval + hitl_response: HITLResponse = interrupt(hitl_request) + + # Process decisions + revised_tool_calls = [ + tc for tc in last_ai_msg.tool_calls + if tc["name"] != "dangerous_tool" + ] + tool_messages = [] + + for i, decision in enumerate(hitl_response["decisions"]): + tool_call = tool_calls_to_validate[i] + if decision["type"] == "approve": + revised_tool_calls.append(tool_call) + elif decision["type"] == "reject": + tool_messages.append( + ToolMessage( + content=decision.get("message", "Tool call rejected"), + name=tool_call["name"], + tool_call_id=tool_call["id"], + status="error" + ) + ) + + last_ai_msg.tool_calls = revised_tool_calls + return {"messages": [last_ai_msg, *tool_messages]} +``` +::: + +:::js +```typescript expandable +import { createMiddleware, AIMessage, ToolMessage } from "langchain"; +import { interrupt } from "@langchain/langgraph"; + +const toolValidationMiddleware = createMiddleware({ + name: "ToolValidationMiddleware", + afterModel: (state) => { + // Find the last AI message with tool calls + const lastAiMsg = state.messages + .slice() + .reverse() + .find((msg) => msg instanceof AIMessage); + + if (!lastAiMsg?.tool_calls?.length) { + return; + } + + // Filter tool calls that need validation + const toolCallsToValidate = lastAiMsg.tool_calls.filter( + (tc) => tc.name === "dangerous_tool" + ); + + if (toolCallsToValidate.length === 0) { + return; + } + + // Interrupt for human approval + const hitlResponse = interrupt({ + actionRequests: toolCallsToValidate.map((tc) => ({ + name: tc.name, + args: tc.args, + description: `Tool execution requires approval: ${tc.name}`, + })), + reviewConfigs: toolCallsToValidate.map((tc) => ({ + actionName: tc.name, + allowedDecisions: ["approve", "reject"], + })), + }) as { + decisions: { type: "approve" | "reject"; message?: string }[]; + }; + + // Process decisions + const revisedToolCalls = lastAiMsg.tool_calls.filter( + (tc) => tc.name !== "dangerous_tool" + ); + const toolMessages = []; + + for (let i = 0; i < hitlResponse.decisions.length; i++) { + const decision = hitlResponse.decisions[i]; + const toolCall = toolCallsToValidate[i]; + + if (decision.type === "approve") { + revisedToolCalls.push(toolCall); + } else if (decision.type === "reject") { + toolMessages.push( + new ToolMessage({ + content: decision.message || "Tool call rejected", + name: toolCall.name, + tool_call_id: toolCall.id, + status: "error", + }) + ); + } + } + + lastAiMsg.tool_calls = revisedToolCalls; + return { messages: [lastAiMsg, ...toolMessages] }; + }, +}); +``` +::: + ### Best practices 1. Keep middleware focused - each should do one thing well