From b651518e637cb33d32fadd6fdc43f884c04a577d Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 14:19:54 -0700 Subject: [PATCH 1/7] Add dependency extra for langgraph, remove redundant dependency for langchain Signed-off-by: David Gardner --- pyproject.toml | 8 ++++++-- uv.lock | 10 +++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7adcc8c..2414cd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,9 +79,13 @@ langchain = [ "langchain-core", ] +langgraph = [ + "nemo-flow[langchain]", + "langgraph>=1.1.10,<2.0.0", +] + langchain-nvidia = [ - "langchain>=1.2.11,<2.0.0", - "langchain-core", + "nemo-flow[langchain]", "langchain-nvidia-ai-endpoints~=1.0", ] diff --git a/uv.lock b/uv.lock index 1777558..c7474a6 100644 --- a/uv.lock +++ b/uv.lock @@ -1182,6 +1182,11 @@ langchain-nvidia = [ { name = "langchain-core" }, { name = "langchain-nvidia-ai-endpoints" }, ] +langgraph = [ + { name = "langchain" }, + { name = "langchain-core" }, + { name = "langgraph" }, +] [package.dev-dependencies] dev = [ @@ -1217,11 +1222,14 @@ test = [ requires-dist = [ { name = "langchain", marker = "extra == 'langchain'", specifier = ">=1.2.11,<2.0.0" }, { name = "langchain", marker = "extra == 'langchain-nvidia'", specifier = ">=1.2.11,<2.0.0" }, + { name = "langchain", marker = "extra == 'langgraph'", specifier = ">=1.2.11,<2.0.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, { name = "langchain-core", marker = "extra == 'langchain-nvidia'" }, + { name = "langchain-core", marker = "extra == 'langgraph'" }, { name = "langchain-nvidia-ai-endpoints", marker = "extra == 'langchain-nvidia'", specifier = "~=1.0" }, + { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.10,<2.0.0" }, ] -provides-extras = ["langchain", "langchain-nvidia"] +provides-extras = ["langchain", "langchain-nvidia", "langgraph"] [package.metadata.requires-dev] dev = [ From eee3a2853d6ea4852bda6109b151edf9ae9909ac Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 14:29:05 -0700 Subject: [PATCH 2/7] First pass at LangGraph integration Signed-off-by: David Gardner --- .../integrations/langchain/callbacks.py | 7 +- .../integrations/langgraph/README.md | 75 ++++++ .../integrations/langgraph/__init__.py | 16 ++ .../integrations/langgraph/callbacks.py | 90 +++++++ .../nemo_flow/integrations/langgraph/graph.py | 238 ++++++++++++++++++ .../integrations/langchain/test_callbacks.py | 17 ++ .../langgraph/test_langgraph_integration.py | 174 +++++++++++++ 7 files changed, 615 insertions(+), 2 deletions(-) create mode 100644 python/nemo_flow/integrations/langgraph/README.md create mode 100644 python/nemo_flow/integrations/langgraph/__init__.py create mode 100644 python/nemo_flow/integrations/langgraph/callbacks.py create mode 100644 python/nemo_flow/integrations/langgraph/graph.py create mode 100644 python/tests/integrations/langgraph/test_langgraph_integration.py diff --git a/python/nemo_flow/integrations/langchain/callbacks.py b/python/nemo_flow/integrations/langchain/callbacks.py index 42082b7..1b2e343 100644 --- a/python/nemo_flow/integrations/langchain/callbacks.py +++ b/python/nemo_flow/integrations/langchain/callbacks.py @@ -22,6 +22,8 @@ class NemoFlowCallbackHandler(BaseCallbackHandler): """Bridge LangChain chain run IDs to NeMo Flow Agent scopes.""" + run_inline = True + def __init__(self) -> None: super().__init__() self._scope_handles: dict[UUID, typing.Any] = {} @@ -39,9 +41,10 @@ def on_chain_start( ) -> typing.Any: """Push a NeMo Flow Agent scope for a LangChain chain run.""" try: - name: str | None = None + name = kwargs.get("name") + if serialized is not None: - name = serialized.get("name") + name = name or serialized.get("name") if name is None: id_list = serialized.get("id") if isinstance(id_list, list) and len(id_list) > 0: diff --git a/python/nemo_flow/integrations/langgraph/README.md b/python/nemo_flow/integrations/langgraph/README.md new file mode 100644 index 0000000..5a15649 --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/README.md @@ -0,0 +1,75 @@ + + +# NeMo Flow LangGraph Integration + +This directory contains the `nemo_flow.integrations.langgraph` package, which provides public-API LangGraph integration for NeMo Flow. + +The integration builds on `nemo_flow.integrations.langchain`: `NemoFlowCallbackHandler` inherits the LangChain callback handler, and `NemoFlowMiddleware` is re-exported for LangChain agents used inside LangGraph workflows. + +## Setup + +```bash +uv sync --all-groups --extra langgraph +just build-python +``` + +Installing the `langgraph` extra also installs the LangChain integration dependencies. + +## Usage Example + +```python +from typing_extensions import TypedDict + +import nemo_flow +from langgraph.graph import END, START, StateGraph +from nemo_flow.integrations.langgraph import instrument_graph + + +class State(TypedDict): + value: int + + +def increment(state: State) -> State: + return {"value": state["value"] + 1} + + +builder = StateGraph(State) +builder.add_node("increment", increment) +builder.add_edge(START, "increment") +builder.add_edge("increment", END) + +graph = instrument_graph(builder.compile()) + +with nemo_flow.scope.scope("langgraph-request", nemo_flow.ScopeType.Agent): + result = graph.invoke({"value": 1}) + +print(result) +``` + +For LangChain agents inside a LangGraph workflow, use `NemoFlowMiddleware` from this package the same way as the LangChain integration: + +```python +from langchain.agents import create_agent +from nemo_flow.integrations.langgraph import NemoFlowMiddleware + +agent = create_agent( + model="nvidia:nvidia/nemotron-3-nano-30b-a3b", + tools=[], + middleware=[NemoFlowMiddleware()], +) +``` + +## Public API Coverage + +The public callback path records LangGraph graph and node runnable scopes through LangChain callbacks. LangGraph resume and interrupt lifecycle callbacks are emitted as NeMo Flow marks. When using `instrument_graph(...).stream(..., stream_mode="tasks")`, `stream_mode="checkpoints"`, or `stream_mode="debug"`, public stream events are also emitted as marks. + +The patch-based integration in `patches/langgraph/0001-add-nemo-flow-integration.patch` can observe lower-level scheduler details such as internal supersteps, edge writes, and per-branch scope-stack isolation. Those details are not exposed by LangGraph's public callback API, so this package intentionally does not rely on them. + +## Validation + +```bash +uv run pytest python/tests/integrations/langgraph +``` diff --git a/python/nemo_flow/integrations/langgraph/__init__.py b/python/nemo_flow/integrations/langgraph/__init__.py new file mode 100644 index 0000000..8eb8b5d --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NeMo Flow integrations for LangGraph.""" + +from nemo_flow.integrations.langchain import NemoFlowMiddleware +from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler +from nemo_flow.integrations.langgraph.graph import NemoFlowGraph, instrument_graph, with_nemo_flow_callbacks + +__all__ = [ + "NemoFlowCallbackHandler", + "NemoFlowGraph", + "NemoFlowMiddleware", + "instrument_graph", + "with_nemo_flow_callbacks", +] diff --git a/python/nemo_flow/integrations/langgraph/callbacks.py b/python/nemo_flow/integrations/langgraph/callbacks.py new file mode 100644 index 0000000..1fbed19 --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/callbacks.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""LangGraph callback handler that reuses the LangChain NeMo Flow integration.""" + +from __future__ import annotations + +import logging +from typing import Any + +from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent + +import nemo_flow +from nemo_flow.integrations.langchain._serialization import _prepare_outputs +from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainNemoFlowCallbackHandler + +_logger = logging.getLogger(__name__) + + +def _json_safe(value: Any) -> nemo_flow.Json: + """Return a conservative JSON-compatible representation for mark payloads.""" + try: + value = _prepare_outputs(value) + except Exception: + pass + + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, list | tuple | set): + return [_json_safe(item) for item in value] + return repr(value) + + +def _interrupt_to_payload(interrupt: Any) -> dict[str, nemo_flow.Json]: + return { + "id": _json_safe(getattr(interrupt, "id", None)), + "value": _json_safe(getattr(interrupt, "value", interrupt)), + } + + +class NemoFlowCallbackHandler(LangChainNemoFlowCallbackHandler, GraphCallbackHandler): + """Bridge LangGraph runs to NeMo Flow using public callback APIs. + + This handler inherits the existing LangChain callback integration, so normal + runnable scopes from LangGraph and LangChain are recorded by the same code + path. It also implements LangGraph's public lifecycle callback hooks for + interrupt and resume marks. + """ + + def on_interrupt(self, event: GraphInterruptEvent) -> Any: + """Emit a NeMo Flow mark for a LangGraph interrupt lifecycle event.""" + self._emit_graph_mark( + "Graph Interrupt", + { + "run_id": str(event.run_id) if event.run_id is not None else None, + "status": event.status, + "checkpoint_id": event.checkpoint_id, + "checkpoint_ns": list(event.checkpoint_ns), + "interrupts": [_interrupt_to_payload(interrupt) for interrupt in event.interrupts], + }, + ) + return None + + def on_resume(self, event: GraphResumeEvent) -> Any: + """Emit a NeMo Flow mark for a LangGraph resume lifecycle event.""" + self._emit_graph_mark( + "Graph Resume", + { + "run_id": str(event.run_id) if event.run_id is not None else None, + "status": event.status, + "checkpoint_id": event.checkpoint_id, + "checkpoint_ns": list(event.checkpoint_ns), + }, + ) + return None + + def _emit_graph_mark(self, name: str, data: dict[str, Any]) -> None: + try: + nemo_flow.scope.event( + name, + data=_json_safe(data), + metadata={"integration": "langgraph"}, + ) + except Exception: + _logger.debug("NeMo Flow: LangGraph mark emission failed", exc_info=True) + + +__all__ = ["NemoFlowCallbackHandler"] diff --git a/python/nemo_flow/integrations/langgraph/graph.py b/python/nemo_flow/integrations/langgraph/graph.py new file mode 100644 index 0000000..1674179 --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/graph.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Helpers for applying NeMo Flow callbacks to compiled LangGraph graphs.""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator, Iterator, Sequence +from typing import Any + +from langchain_core.callbacks import BaseCallbackManager +from langchain_core.callbacks.base import BaseCallbackHandler +from langchain_core.runnables import RunnableConfig + +import nemo_flow +from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler + +_logger = logging.getLogger(__name__) + + +def with_nemo_flow_callbacks( + config: RunnableConfig | None = None, + *, + callback_handler: NemoFlowCallbackHandler | None = None, +) -> RunnableConfig: + """Return a LangChain runnable config with a LangGraph NeMo Flow callback. + + The returned config is a shallow copy. Existing callbacks are preserved and + a handler is added only when a NeMo Flow LangGraph handler is not already + present. + """ + next_config: RunnableConfig = dict(config or {}) + next_config["callbacks"] = _append_callback( + next_config.get("callbacks"), + callback_handler or NemoFlowCallbackHandler(), + ) + return next_config + + +def instrument_graph( + graph: Any, + *, + callback_handler: NemoFlowCallbackHandler | None = None, +) -> "NemoFlowGraph": + """Wrap a compiled LangGraph graph so invocations include NeMo Flow callbacks.""" + return NemoFlowGraph(graph, callback_handler=callback_handler) + + +class NemoFlowGraph: + """Thin proxy that injects NeMo Flow callbacks into graph invocations.""" + + def __init__( + self, + graph: Any, + *, + callback_handler: NemoFlowCallbackHandler | None = None, + ) -> None: + self._graph = graph + self._callback_handler = callback_handler + + def __getattr__(self, name: str) -> Any: + return getattr(self._graph, name) + + def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> "NemoFlowGraph": + """Return an instrumented copy of the graph with updated LangChain config.""" + return NemoFlowGraph( + self._graph.with_config(config, **kwargs), + callback_handler=self._callback_handler, + ) + + def invoke( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Any: + """Invoke the wrapped graph with NeMo Flow callbacks installed.""" + return self._graph.invoke( + input, + self._config(config), + **kwargs, + ) + + async def ainvoke( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Any: + """Asynchronously invoke the wrapped graph with NeMo Flow callbacks installed.""" + return await self._graph.ainvoke( + input, + self._config(config), + **kwargs, + ) + + def stream( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> Iterator[Any]: + """Stream from the wrapped graph and mark public task/checkpoint events.""" + for chunk in self._graph.stream( + input, + self._config(config), + **kwargs, + ): + _emit_public_stream_mark(chunk) + yield chunk + + async def astream( + self, + input: Any, + config: RunnableConfig | None = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Asynchronously stream from the wrapped graph and mark public events.""" + async for chunk in self._graph.astream( + input, + self._config(config), + **kwargs, + ): + _emit_public_stream_mark(chunk) + yield chunk + + def _config(self, config: RunnableConfig | None) -> RunnableConfig: + return with_nemo_flow_callbacks(config, callback_handler=self._callback_handler) + + +def _append_callback(callbacks: Any, handler: BaseCallbackHandler) -> Any: + if callbacks is None: + return [handler] + + if isinstance(callbacks, NemoFlowCallbackHandler): + return callbacks + + if isinstance(callbacks, BaseCallbackManager): + if _manager_has_nemo_flow_handler(callbacks): + return callbacks + manager = callbacks.copy() + manager.add_handler(handler, inherit=True) + return manager + + if isinstance(callbacks, BaseCallbackHandler): + if isinstance(callbacks, NemoFlowCallbackHandler): + return callbacks + return [callbacks, handler] + + if isinstance(callbacks, Sequence) and not isinstance(callbacks, str | bytes): + callback_list = list(callbacks) + if any(isinstance(callback, NemoFlowCallbackHandler) for callback in callback_list): + return callback_list + return [*callback_list, handler] + + return callbacks + + +def _manager_has_nemo_flow_handler(manager: BaseCallbackManager) -> bool: + handlers = [*manager.handlers, *manager.inheritable_handlers] + return any(isinstance(handler, NemoFlowCallbackHandler) for handler in handlers) + + +def _emit_public_stream_mark(chunk: Any) -> None: + mode, payload = _stream_mode_payload(chunk) + if mode == "tasks": + _emit_task_mark(payload) + elif mode == "checkpoints": + _emit_mark("Checkpoint Save", payload) + elif mode == "debug": + _emit_debug_mark(payload) + + +def _stream_mode_payload(chunk: Any) -> tuple[str | None, Any]: + if isinstance(chunk, dict) and isinstance(chunk.get("type"), str) and "data" in chunk: + return chunk["type"], chunk["data"] + if isinstance(chunk, tuple): + if len(chunk) == 2 and isinstance(chunk[0], str): + return chunk[0], chunk[1] + if len(chunk) == 3 and isinstance(chunk[1], str): + return chunk[1], chunk[2] + if isinstance(chunk, dict): + if "type" in chunk and "payload" in chunk: + return "debug", chunk + if {"id", "name", "triggers"}.issubset(chunk): + return "tasks", chunk + if {"id", "name", "result"}.issubset(chunk) or {"id", "name", "error"}.issubset(chunk): + return "tasks", chunk + if {"config", "metadata", "values", "next", "tasks"}.issubset(chunk): + return "checkpoints", chunk + return None, None + + +def _emit_debug_mark(payload: Any) -> None: + if not isinstance(payload, dict): + return + event_type = payload.get("type") + event_payload = payload.get("payload") + if event_type == "task": + _emit_mark("Graph Task Start", event_payload) + elif event_type == "task_result": + _emit_mark("Graph Task End", event_payload) + elif event_type == "checkpoint": + _emit_mark("Checkpoint Save", event_payload) + + +def _emit_task_mark(payload: Any) -> None: + if not isinstance(payload, dict): + return + if "triggers" in payload: + _emit_mark("Graph Task Start", payload) + elif "result" in payload or "error" in payload: + _emit_mark("Graph Task End", payload) + + +def _emit_mark(name: str, payload: Any) -> None: + try: + nemo_flow.scope.event( + name, + data=_json_safe(payload), + metadata={"integration": "langgraph"}, + ) + except Exception: + _logger.debug("NeMo Flow: LangGraph stream mark emission failed", exc_info=True) + + +def _json_safe(value: Any) -> nemo_flow.Json: + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, list | tuple | set): + return [_json_safe(item) for item in value] + return repr(value) + + +__all__ = ["NemoFlowGraph", "instrument_graph", "with_nemo_flow_callbacks"] diff --git a/python/tests/integrations/langchain/test_callbacks.py b/python/tests/integrations/langchain/test_callbacks.py index e95732d..4669bbb 100644 --- a/python/tests/integrations/langchain/test_callbacks.py +++ b/python/tests/integrations/langchain/test_callbacks.py @@ -49,6 +49,9 @@ def handler(mock_nemo_flow: MagicMock) -> NemoFlowCallbackHandler: class TestScopeLifecycle: """Verify that chain start/end/error map to scope push/pop.""" + def test_handler_runs_inline_for_async_callback_managers(self, handler: NemoFlowCallbackHandler) -> None: + assert handler.run_inline is True + def test_on_chain_start_pushes_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None: run_id = uuid4() @@ -69,6 +72,20 @@ def test_on_chain_start_pushes_scope(self, handler: NemoFlowCallbackHandler, moc } assert run_id in handler._scope_handles + def test_on_chain_start_uses_callback_name( + self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock + ) -> None: + run_id = uuid4() + + handler.on_chain_start( + None, + {"input": "test"}, + run_id=run_id, + name="LangGraph", + ) + + assert mock_nemo_flow.scope.push.call_args.args[0] == "LangGraph" + def test_on_chain_end_pops_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None: run_id = uuid4() handler.on_chain_start( diff --git a/python/tests/integrations/langgraph/test_langgraph_integration.py b/python/tests/integrations/langgraph/test_langgraph_integration.py new file mode 100644 index 0000000..54717a9 --- /dev/null +++ b/python/tests/integrations/langgraph/test_langgraph_integration.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the LangGraph NeMo Flow callback integration.""" + +from __future__ import annotations + +import asyncio +import tomllib +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from langchain_core.callbacks import CallbackManager +from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent +from langgraph.graph import END, START, StateGraph +from langgraph.types import Interrupt +from typing_extensions import TypedDict + +import nemo_flow +from nemo_flow.integrations.langchain import NemoFlowMiddleware +from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainCallbackHandler +from nemo_flow.integrations.langgraph import ( + NemoFlowCallbackHandler, + instrument_graph, + with_nemo_flow_callbacks, +) + + +class State(TypedDict): + value: int + + +def _build_graph() -> Any: + def increment(state: State) -> State: + return {"value": state["value"] + 1} + + builder = StateGraph(State) + builder.add_node("increment", increment) + builder.add_edge(START, "increment") + builder.add_edge("increment", END) + return builder.compile() + + +def _build_async_graph() -> Any: + async def increment(state: State) -> State: + await asyncio.sleep(0) + return {"value": state["value"] + 1} + + builder = StateGraph(State) + builder.add_node("increment", increment) + builder.add_edge(START, "increment") + builder.add_edge("increment", END) + return builder.compile() + + +def _record_events() -> tuple[list[Any], str]: + events: list[Any] = [] + subscriber_name = f"langgraph-test-{uuid4()}" + nemo_flow.subscribers.register(subscriber_name, events.append) + return events, subscriber_name + + +def test_langgraph_handler_builds_on_langchain_handler() -> None: + handler = NemoFlowCallbackHandler() + + assert isinstance(handler, LangChainCallbackHandler) + assert isinstance(handler, GraphCallbackHandler) + assert handler.run_inline is True + + +def test_with_nemo_flow_callbacks_preserves_existing_callbacks() -> None: + class ExistingHandler(GraphCallbackHandler): + pass + + existing = ExistingHandler() + config = with_nemo_flow_callbacks({"callbacks": [existing]}) + + callbacks = config["callbacks"] + assert callbacks[0] is existing + assert isinstance(callbacks[1], NemoFlowCallbackHandler) + + +def test_with_nemo_flow_callbacks_handles_callback_managers() -> None: + manager = CallbackManager([]) + config = with_nemo_flow_callbacks({"callbacks": manager}) + + callbacks = config["callbacks"] + assert callbacks is not manager + assert any(isinstance(handler, NemoFlowCallbackHandler) for handler in callbacks.handlers) + + +def test_instrumented_graph_invoke_emits_named_graph_and_node_scopes() -> None: + graph = instrument_graph(_build_graph()) + events, subscriber_name = _record_events() + + try: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + result = graph.invoke({"value": 1}) + finally: + nemo_flow.subscribers.deregister(subscriber_name) + + assert result == {"value": 2} + scope_names = [event.name for event in events if event.kind == "scope" and event.scope_category == "start"] + assert scope_names == ["request", "LangGraph", "increment"] + + +def test_instrumented_graph_ainvoke_completes_with_inline_callbacks() -> None: + graph = instrument_graph(_build_async_graph()) + + async def run_graph() -> dict[str, int]: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + return await graph.ainvoke({"value": 1}) + + assert asyncio.run(run_graph()) == {"value": 2} + + +def test_graph_lifecycle_callbacks_emit_marks() -> None: + handler = NemoFlowCallbackHandler() + events, subscriber_name = _record_events() + run_id = uuid4() + + try: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + handler.on_resume( + GraphResumeEvent( + run_id=run_id, + status="pending", + checkpoint_id="checkpoint-1", + checkpoint_ns=("parent", "child"), + ) + ) + handler.on_interrupt( + GraphInterruptEvent( + run_id=run_id, + status="interrupt_after", + checkpoint_id="checkpoint-2", + checkpoint_ns=("parent",), + interrupts=(Interrupt("needs approval", id="interrupt-1"),), + ) + ) + finally: + nemo_flow.subscribers.deregister(subscriber_name) + + marks = [event for event in events if event.kind == "mark"] + assert [event.name for event in marks] == ["Graph Resume", "Graph Interrupt"] + assert marks[0].data["checkpoint_ns"] == ["parent", "child"] + assert marks[1].data["interrupts"] == [{"id": "interrupt-1", "value": "needs approval"}] + assert marks[1].metadata == {"integration": "langgraph"} + + +def test_instrumented_graph_stream_emits_public_task_marks() -> None: + graph = instrument_graph(_build_graph()) + events, subscriber_name = _record_events() + + try: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + chunks = list(graph.stream({"value": 1}, stream_mode="tasks")) + finally: + nemo_flow.subscribers.deregister(subscriber_name) + + assert len(chunks) == 2 + marks = [event.name for event in events if event.kind == "mark"] + assert marks == ["Graph Task Start", "Graph Task End"] + + +def test_langgraph_extra_includes_langchain_integration_dependencies() -> None: + pyproject = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8")) + extra = pyproject["project"]["optional-dependencies"]["langgraph"] + + assert "langchain-core" in extra + assert any(requirement.startswith("langchain>=") for requirement in extra) + assert any(requirement.startswith("langgraph>=") for requirement in extra) + assert NemoFlowMiddleware is not None From f894208254bf8e0dbcdab288cb12f3a1287d86d4 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 14:32:11 -0700 Subject: [PATCH 3/7] Remove uneeded test for pyproject.toml entries Signed-off-by: David Gardner --- .../langgraph/test_langgraph_integration.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/python/tests/integrations/langgraph/test_langgraph_integration.py b/python/tests/integrations/langgraph/test_langgraph_integration.py index 54717a9..10bb07d 100644 --- a/python/tests/integrations/langgraph/test_langgraph_integration.py +++ b/python/tests/integrations/langgraph/test_langgraph_integration.py @@ -162,13 +162,3 @@ def test_instrumented_graph_stream_emits_public_task_marks() -> None: assert len(chunks) == 2 marks = [event.name for event in events if event.kind == "mark"] assert marks == ["Graph Task Start", "Graph Task End"] - - -def test_langgraph_extra_includes_langchain_integration_dependencies() -> None: - pyproject = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8")) - extra = pyproject["project"]["optional-dependencies"]["langgraph"] - - assert "langchain-core" in extra - assert any(requirement.startswith("langchain>=") for requirement in extra) - assert any(requirement.startswith("langgraph>=") for requirement in extra) - assert NemoFlowMiddleware is not None From c37d2e376e1df7122b9a5e175d2947f45744fa81 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 15:50:16 -0700 Subject: [PATCH 4/7] Lingting Signed-off-by: David Gardner --- .../tests/integrations/langgraph/test_langgraph_integration.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tests/integrations/langgraph/test_langgraph_integration.py b/python/tests/integrations/langgraph/test_langgraph_integration.py index 10bb07d..145d125 100644 --- a/python/tests/integrations/langgraph/test_langgraph_integration.py +++ b/python/tests/integrations/langgraph/test_langgraph_integration.py @@ -6,8 +6,6 @@ from __future__ import annotations import asyncio -import tomllib -from pathlib import Path from typing import Any from uuid import uuid4 @@ -18,7 +16,6 @@ from typing_extensions import TypedDict import nemo_flow -from nemo_flow.integrations.langchain import NemoFlowMiddleware from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainCallbackHandler from nemo_flow.integrations.langgraph import ( NemoFlowCallbackHandler, From 437fc0c2f82331bda070b76d51f8107582ab963a Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 16:13:54 -0700 Subject: [PATCH 5/7] Remove unnecessary graph wrapper Signed-off-by: David Gardner --- .../integrations/langgraph/README.md | 18 +- .../integrations/langgraph/__init__.py | 4 - .../nemo_flow/integrations/langgraph/graph.py | 238 ------------------ .../langgraph/test_langgraph_integration.py | 55 +--- 4 files changed, 20 insertions(+), 295 deletions(-) delete mode 100644 python/nemo_flow/integrations/langgraph/graph.py diff --git a/python/nemo_flow/integrations/langgraph/README.md b/python/nemo_flow/integrations/langgraph/README.md index 5a15649..f8bcc7e 100644 --- a/python/nemo_flow/integrations/langgraph/README.md +++ b/python/nemo_flow/integrations/langgraph/README.md @@ -25,7 +25,7 @@ from typing_extensions import TypedDict import nemo_flow from langgraph.graph import END, START, StateGraph -from nemo_flow.integrations.langgraph import instrument_graph +from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler class State(TypedDict): @@ -41,18 +41,22 @@ builder.add_node("increment", increment) builder.add_edge(START, "increment") builder.add_edge("increment", END) -graph = instrument_graph(builder.compile()) +graph = builder.compile() with nemo_flow.scope.scope("langgraph-request", nemo_flow.ScopeType.Agent): - result = graph.invoke({"value": 1}) + result = graph.invoke( + {"value": 1}, + config={"callbacks": [NemoFlowCallbackHandler()]}, + ) print(result) ``` -For LangChain agents inside a LangGraph workflow, use `NemoFlowMiddleware` from this package the same way as the LangChain integration: +For LangChain agents inside a LangGraph workflow, use `NemoFlowMiddleware` from this package the same way as the LangChain integration and pass the LangGraph `config` into the nested agent call: ```python from langchain.agents import create_agent +from langchain_core.runnables import RunnableConfig from nemo_flow.integrations.langgraph import NemoFlowMiddleware agent = create_agent( @@ -60,11 +64,15 @@ agent = create_agent( tools=[], middleware=[NemoFlowMiddleware()], ) + + +def agent_node(state: dict, config: RunnableConfig) -> dict: + return agent.invoke({"messages": state["messages"]}, config=config) ``` ## Public API Coverage -The public callback path records LangGraph graph and node runnable scopes through LangChain callbacks. LangGraph resume and interrupt lifecycle callbacks are emitted as NeMo Flow marks. When using `instrument_graph(...).stream(..., stream_mode="tasks")`, `stream_mode="checkpoints"`, or `stream_mode="debug"`, public stream events are also emitted as marks. +The public callback path records LangGraph graph and node runnable scopes through LangChain callbacks. LangGraph resume and interrupt lifecycle callbacks are emitted as NeMo Flow marks when LangGraph exposes those events through its public callback API. The patch-based integration in `patches/langgraph/0001-add-nemo-flow-integration.patch` can observe lower-level scheduler details such as internal supersteps, edge writes, and per-branch scope-stack isolation. Those details are not exposed by LangGraph's public callback API, so this package intentionally does not rely on them. diff --git a/python/nemo_flow/integrations/langgraph/__init__.py b/python/nemo_flow/integrations/langgraph/__init__.py index 8eb8b5d..d95db5b 100644 --- a/python/nemo_flow/integrations/langgraph/__init__.py +++ b/python/nemo_flow/integrations/langgraph/__init__.py @@ -5,12 +5,8 @@ from nemo_flow.integrations.langchain import NemoFlowMiddleware from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler -from nemo_flow.integrations.langgraph.graph import NemoFlowGraph, instrument_graph, with_nemo_flow_callbacks __all__ = [ "NemoFlowCallbackHandler", - "NemoFlowGraph", "NemoFlowMiddleware", - "instrument_graph", - "with_nemo_flow_callbacks", ] diff --git a/python/nemo_flow/integrations/langgraph/graph.py b/python/nemo_flow/integrations/langgraph/graph.py deleted file mode 100644 index 1674179..0000000 --- a/python/nemo_flow/integrations/langgraph/graph.py +++ /dev/null @@ -1,238 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Helpers for applying NeMo Flow callbacks to compiled LangGraph graphs.""" - -from __future__ import annotations - -import logging -from collections.abc import AsyncIterator, Iterator, Sequence -from typing import Any - -from langchain_core.callbacks import BaseCallbackManager -from langchain_core.callbacks.base import BaseCallbackHandler -from langchain_core.runnables import RunnableConfig - -import nemo_flow -from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler - -_logger = logging.getLogger(__name__) - - -def with_nemo_flow_callbacks( - config: RunnableConfig | None = None, - *, - callback_handler: NemoFlowCallbackHandler | None = None, -) -> RunnableConfig: - """Return a LangChain runnable config with a LangGraph NeMo Flow callback. - - The returned config is a shallow copy. Existing callbacks are preserved and - a handler is added only when a NeMo Flow LangGraph handler is not already - present. - """ - next_config: RunnableConfig = dict(config or {}) - next_config["callbacks"] = _append_callback( - next_config.get("callbacks"), - callback_handler or NemoFlowCallbackHandler(), - ) - return next_config - - -def instrument_graph( - graph: Any, - *, - callback_handler: NemoFlowCallbackHandler | None = None, -) -> "NemoFlowGraph": - """Wrap a compiled LangGraph graph so invocations include NeMo Flow callbacks.""" - return NemoFlowGraph(graph, callback_handler=callback_handler) - - -class NemoFlowGraph: - """Thin proxy that injects NeMo Flow callbacks into graph invocations.""" - - def __init__( - self, - graph: Any, - *, - callback_handler: NemoFlowCallbackHandler | None = None, - ) -> None: - self._graph = graph - self._callback_handler = callback_handler - - def __getattr__(self, name: str) -> Any: - return getattr(self._graph, name) - - def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> "NemoFlowGraph": - """Return an instrumented copy of the graph with updated LangChain config.""" - return NemoFlowGraph( - self._graph.with_config(config, **kwargs), - callback_handler=self._callback_handler, - ) - - def invoke( - self, - input: Any, - config: RunnableConfig | None = None, - **kwargs: Any, - ) -> Any: - """Invoke the wrapped graph with NeMo Flow callbacks installed.""" - return self._graph.invoke( - input, - self._config(config), - **kwargs, - ) - - async def ainvoke( - self, - input: Any, - config: RunnableConfig | None = None, - **kwargs: Any, - ) -> Any: - """Asynchronously invoke the wrapped graph with NeMo Flow callbacks installed.""" - return await self._graph.ainvoke( - input, - self._config(config), - **kwargs, - ) - - def stream( - self, - input: Any, - config: RunnableConfig | None = None, - **kwargs: Any, - ) -> Iterator[Any]: - """Stream from the wrapped graph and mark public task/checkpoint events.""" - for chunk in self._graph.stream( - input, - self._config(config), - **kwargs, - ): - _emit_public_stream_mark(chunk) - yield chunk - - async def astream( - self, - input: Any, - config: RunnableConfig | None = None, - **kwargs: Any, - ) -> AsyncIterator[Any]: - """Asynchronously stream from the wrapped graph and mark public events.""" - async for chunk in self._graph.astream( - input, - self._config(config), - **kwargs, - ): - _emit_public_stream_mark(chunk) - yield chunk - - def _config(self, config: RunnableConfig | None) -> RunnableConfig: - return with_nemo_flow_callbacks(config, callback_handler=self._callback_handler) - - -def _append_callback(callbacks: Any, handler: BaseCallbackHandler) -> Any: - if callbacks is None: - return [handler] - - if isinstance(callbacks, NemoFlowCallbackHandler): - return callbacks - - if isinstance(callbacks, BaseCallbackManager): - if _manager_has_nemo_flow_handler(callbacks): - return callbacks - manager = callbacks.copy() - manager.add_handler(handler, inherit=True) - return manager - - if isinstance(callbacks, BaseCallbackHandler): - if isinstance(callbacks, NemoFlowCallbackHandler): - return callbacks - return [callbacks, handler] - - if isinstance(callbacks, Sequence) and not isinstance(callbacks, str | bytes): - callback_list = list(callbacks) - if any(isinstance(callback, NemoFlowCallbackHandler) for callback in callback_list): - return callback_list - return [*callback_list, handler] - - return callbacks - - -def _manager_has_nemo_flow_handler(manager: BaseCallbackManager) -> bool: - handlers = [*manager.handlers, *manager.inheritable_handlers] - return any(isinstance(handler, NemoFlowCallbackHandler) for handler in handlers) - - -def _emit_public_stream_mark(chunk: Any) -> None: - mode, payload = _stream_mode_payload(chunk) - if mode == "tasks": - _emit_task_mark(payload) - elif mode == "checkpoints": - _emit_mark("Checkpoint Save", payload) - elif mode == "debug": - _emit_debug_mark(payload) - - -def _stream_mode_payload(chunk: Any) -> tuple[str | None, Any]: - if isinstance(chunk, dict) and isinstance(chunk.get("type"), str) and "data" in chunk: - return chunk["type"], chunk["data"] - if isinstance(chunk, tuple): - if len(chunk) == 2 and isinstance(chunk[0], str): - return chunk[0], chunk[1] - if len(chunk) == 3 and isinstance(chunk[1], str): - return chunk[1], chunk[2] - if isinstance(chunk, dict): - if "type" in chunk and "payload" in chunk: - return "debug", chunk - if {"id", "name", "triggers"}.issubset(chunk): - return "tasks", chunk - if {"id", "name", "result"}.issubset(chunk) or {"id", "name", "error"}.issubset(chunk): - return "tasks", chunk - if {"config", "metadata", "values", "next", "tasks"}.issubset(chunk): - return "checkpoints", chunk - return None, None - - -def _emit_debug_mark(payload: Any) -> None: - if not isinstance(payload, dict): - return - event_type = payload.get("type") - event_payload = payload.get("payload") - if event_type == "task": - _emit_mark("Graph Task Start", event_payload) - elif event_type == "task_result": - _emit_mark("Graph Task End", event_payload) - elif event_type == "checkpoint": - _emit_mark("Checkpoint Save", event_payload) - - -def _emit_task_mark(payload: Any) -> None: - if not isinstance(payload, dict): - return - if "triggers" in payload: - _emit_mark("Graph Task Start", payload) - elif "result" in payload or "error" in payload: - _emit_mark("Graph Task End", payload) - - -def _emit_mark(name: str, payload: Any) -> None: - try: - nemo_flow.scope.event( - name, - data=_json_safe(payload), - metadata={"integration": "langgraph"}, - ) - except Exception: - _logger.debug("NeMo Flow: LangGraph stream mark emission failed", exc_info=True) - - -def _json_safe(value: Any) -> nemo_flow.Json: - if value is None or isinstance(value, str | int | float | bool): - return value - if isinstance(value, dict): - return {str(key): _json_safe(item) for key, item in value.items()} - if isinstance(value, list | tuple | set): - return [_json_safe(item) for item in value] - return repr(value) - - -__all__ = ["NemoFlowGraph", "instrument_graph", "with_nemo_flow_callbacks"] diff --git a/python/tests/integrations/langgraph/test_langgraph_integration.py b/python/tests/integrations/langgraph/test_langgraph_integration.py index 145d125..b0f9031 100644 --- a/python/tests/integrations/langgraph/test_langgraph_integration.py +++ b/python/tests/integrations/langgraph/test_langgraph_integration.py @@ -9,7 +9,6 @@ from typing import Any from uuid import uuid4 -from langchain_core.callbacks import CallbackManager from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent from langgraph.graph import END, START, StateGraph from langgraph.types import Interrupt @@ -17,11 +16,7 @@ import nemo_flow from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainCallbackHandler -from nemo_flow.integrations.langgraph import ( - NemoFlowCallbackHandler, - instrument_graph, - with_nemo_flow_callbacks, -) +from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler class State(TypedDict): @@ -66,34 +61,13 @@ def test_langgraph_handler_builds_on_langchain_handler() -> None: assert handler.run_inline is True -def test_with_nemo_flow_callbacks_preserves_existing_callbacks() -> None: - class ExistingHandler(GraphCallbackHandler): - pass - - existing = ExistingHandler() - config = with_nemo_flow_callbacks({"callbacks": [existing]}) - - callbacks = config["callbacks"] - assert callbacks[0] is existing - assert isinstance(callbacks[1], NemoFlowCallbackHandler) - - -def test_with_nemo_flow_callbacks_handles_callback_managers() -> None: - manager = CallbackManager([]) - config = with_nemo_flow_callbacks({"callbacks": manager}) - - callbacks = config["callbacks"] - assert callbacks is not manager - assert any(isinstance(handler, NemoFlowCallbackHandler) for handler in callbacks.handlers) - - -def test_instrumented_graph_invoke_emits_named_graph_and_node_scopes() -> None: - graph = instrument_graph(_build_graph()) +def test_graph_invoke_with_callback_config_emits_named_graph_and_node_scopes() -> None: + graph = _build_graph() events, subscriber_name = _record_events() try: with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): - result = graph.invoke({"value": 1}) + result = graph.invoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]}) finally: nemo_flow.subscribers.deregister(subscriber_name) @@ -102,12 +76,12 @@ def test_instrumented_graph_invoke_emits_named_graph_and_node_scopes() -> None: assert scope_names == ["request", "LangGraph", "increment"] -def test_instrumented_graph_ainvoke_completes_with_inline_callbacks() -> None: - graph = instrument_graph(_build_async_graph()) +def test_graph_ainvoke_with_callback_config_completes() -> None: + graph = _build_async_graph() async def run_graph() -> dict[str, int]: with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): - return await graph.ainvoke({"value": 1}) + return await graph.ainvoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]}) assert asyncio.run(run_graph()) == {"value": 2} @@ -144,18 +118,3 @@ def test_graph_lifecycle_callbacks_emit_marks() -> None: assert marks[0].data["checkpoint_ns"] == ["parent", "child"] assert marks[1].data["interrupts"] == [{"id": "interrupt-1", "value": "needs approval"}] assert marks[1].metadata == {"integration": "langgraph"} - - -def test_instrumented_graph_stream_emits_public_task_marks() -> None: - graph = instrument_graph(_build_graph()) - events, subscriber_name = _record_events() - - try: - with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): - chunks = list(graph.stream({"value": 1}, stream_mode="tasks")) - finally: - nemo_flow.subscribers.deregister(subscriber_name) - - assert len(chunks) == 2 - marks = [event.name for event in events if event.kind == "mark"] - assert marks == ["Graph Task Start", "Graph Task End"] From cef4ef855f2a6162861752ee5cf810a2940e03c1 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 16:38:45 -0700 Subject: [PATCH 6/7] Update docstring Signed-off-by: David Gardner --- python/nemo_flow/integrations/langgraph/callbacks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/nemo_flow/integrations/langgraph/callbacks.py b/python/nemo_flow/integrations/langgraph/callbacks.py index 1fbed19..6c069a1 100644 --- a/python/nemo_flow/integrations/langgraph/callbacks.py +++ b/python/nemo_flow/integrations/langgraph/callbacks.py @@ -41,7 +41,8 @@ def _interrupt_to_payload(interrupt: Any) -> dict[str, nemo_flow.Json]: class NemoFlowCallbackHandler(LangChainNemoFlowCallbackHandler, GraphCallbackHandler): - """Bridge LangGraph runs to NeMo Flow using public callback APIs. + """ + Bridge LangChain and LangGraph runs to NeMo Flow using public callback APIs. This handler inherits the existing LangChain callback integration, so normal runnable scopes from LangGraph and LangChain are recorded by the same code From ea88adc9cf65b1e6e91e12c4c2cbc97193fc302f Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 11 May 2026 16:46:58 -0700 Subject: [PATCH 7/7] Cleanup wording Signed-off-by: David Gardner --- python/nemo_flow/integrations/langgraph/README.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/nemo_flow/integrations/langgraph/README.md b/python/nemo_flow/integrations/langgraph/README.md index f8bcc7e..c54ca15 100644 --- a/python/nemo_flow/integrations/langgraph/README.md +++ b/python/nemo_flow/integrations/langgraph/README.md @@ -9,6 +9,8 @@ This directory contains the `nemo_flow.integrations.langgraph` package, which pr The integration builds on `nemo_flow.integrations.langchain`: `NemoFlowCallbackHandler` inherits the LangChain callback handler, and `NemoFlowMiddleware` is re-exported for LangChain agents used inside LangGraph workflows. +For an alternate approach refer to [the patch-based integration in `third_party/langchain`](../../../../third_party/README-langgraph.md). + ## Setup ```bash @@ -70,12 +72,6 @@ def agent_node(state: dict, config: RunnableConfig) -> dict: return agent.invoke({"messages": state["messages"]}, config=config) ``` -## Public API Coverage - -The public callback path records LangGraph graph and node runnable scopes through LangChain callbacks. LangGraph resume and interrupt lifecycle callbacks are emitted as NeMo Flow marks when LangGraph exposes those events through its public callback API. - -The patch-based integration in `patches/langgraph/0001-add-nemo-flow-integration.patch` can observe lower-level scheduler details such as internal supersteps, edge writes, and per-branch scope-stack isolation. Those details are not exposed by LangGraph's public callback API, so this package intentionally does not rely on them. - ## Validation ```bash