Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ langchain = [
"langchain-core",
]

langgraph = [
"nemo-flow[langchain]",
"langgraph>=1.1.10,<2.0.0",
]

langchain-nvidia = [
"langchain>=1.2.11,<2.0.0",
"langchain-core",
"nemo-flow[langchain]",
"langchain-nvidia-ai-endpoints~=1.0",
]

Expand Down
7 changes: 5 additions & 2 deletions python/nemo_flow/integrations/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
class NemoFlowCallbackHandler(BaseCallbackHandler):
"""Bridge LangChain chain run IDs to NeMo Flow Agent scopes."""

run_inline = True

def __init__(self) -> None:
super().__init__()
self._scope_handles: dict[UUID, typing.Any] = {}
Expand All @@ -39,9 +41,10 @@ def on_chain_start(
) -> typing.Any:
"""Push a NeMo Flow Agent scope for a LangChain chain run."""
try:
name: str | None = None
name = kwargs.get("name")

if serialized is not None:
name = serialized.get("name")
name = name or serialized.get("name")
if name is None:
id_list = serialized.get("id")
if isinstance(id_list, list) and len(id_list) > 0:
Expand Down
79 changes: 79 additions & 0 deletions python/nemo_flow/integrations/langgraph/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
<!--
SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->

# NeMo Flow LangGraph Integration

This directory contains the `nemo_flow.integrations.langgraph` package, which provides public-API LangGraph integration for NeMo Flow.

The integration builds on `nemo_flow.integrations.langchain`: `NemoFlowCallbackHandler` inherits the LangChain callback handler, and `NemoFlowMiddleware` is re-exported for LangChain agents used inside LangGraph workflows.

For an alternate approach refer to [the patch-based integration in `third_party/langchain`](../../../../third_party/README-langgraph.md).

## Setup

```bash
uv sync --all-groups --extra langgraph
just build-python
```

Installing the `langgraph` extra also installs the LangChain integration dependencies.

## Usage Example

```python
from typing_extensions import TypedDict

import nemo_flow
from langgraph.graph import END, START, StateGraph
from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler


class State(TypedDict):
value: int


def increment(state: State) -> State:
return {"value": state["value"] + 1}


builder = StateGraph(State)
builder.add_node("increment", increment)
builder.add_edge(START, "increment")
builder.add_edge("increment", END)

graph = builder.compile()

with nemo_flow.scope.scope("langgraph-request", nemo_flow.ScopeType.Agent):
result = graph.invoke(
{"value": 1},
config={"callbacks": [NemoFlowCallbackHandler()]},
)

print(result)
```

For LangChain agents inside a LangGraph workflow, use `NemoFlowMiddleware` from this package the same way as the LangChain integration and pass the LangGraph `config` into the nested agent call:

```python
from langchain.agents import create_agent
from langchain_core.runnables import RunnableConfig
from nemo_flow.integrations.langgraph import NemoFlowMiddleware

agent = create_agent(
model="nvidia:nvidia/nemotron-3-nano-30b-a3b",
tools=[],
middleware=[NemoFlowMiddleware()],
)


def agent_node(state: dict, config: RunnableConfig) -> dict:
return agent.invoke({"messages": state["messages"]}, config=config)
```

## Validation

```bash
uv run pytest python/tests/integrations/langgraph
```
12 changes: 12 additions & 0 deletions python/nemo_flow/integrations/langgraph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""NeMo Flow integrations for LangGraph."""

from nemo_flow.integrations.langchain import NemoFlowMiddleware
from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler

__all__ = [
"NemoFlowCallbackHandler",
"NemoFlowMiddleware",
]
91 changes: 91 additions & 0 deletions python/nemo_flow/integrations/langgraph/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""LangGraph callback handler that reuses the LangChain NeMo Flow integration."""

from __future__ import annotations

import logging
from typing import Any

from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent

import nemo_flow
from nemo_flow.integrations.langchain._serialization import _prepare_outputs
from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainNemoFlowCallbackHandler

_logger = logging.getLogger(__name__)


def _json_safe(value: Any) -> nemo_flow.Json:
"""Return a conservative JSON-compatible representation for mark payloads."""
try:
value = _prepare_outputs(value)
except Exception:
pass

if value is None or isinstance(value, str | int | float | bool):
return value
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in value.items()}
if isinstance(value, list | tuple | set):
return [_json_safe(item) for item in value]
return repr(value)


def _interrupt_to_payload(interrupt: Any) -> dict[str, nemo_flow.Json]:
return {
"id": _json_safe(getattr(interrupt, "id", None)),
"value": _json_safe(getattr(interrupt, "value", interrupt)),
}


class NemoFlowCallbackHandler(LangChainNemoFlowCallbackHandler, GraphCallbackHandler):
"""
Bridge LangChain and LangGraph runs to NeMo Flow using public callback APIs.

This handler inherits the existing LangChain callback integration, so normal
runnable scopes from LangGraph and LangChain are recorded by the same code
path. It also implements LangGraph's public lifecycle callback hooks for
interrupt and resume marks.
"""

def on_interrupt(self, event: GraphInterruptEvent) -> Any:
"""Emit a NeMo Flow mark for a LangGraph interrupt lifecycle event."""
self._emit_graph_mark(
"Graph Interrupt",
{
"run_id": str(event.run_id) if event.run_id is not None else None,
"status": event.status,
"checkpoint_id": event.checkpoint_id,
"checkpoint_ns": list(event.checkpoint_ns),
"interrupts": [_interrupt_to_payload(interrupt) for interrupt in event.interrupts],
},
)
return None

def on_resume(self, event: GraphResumeEvent) -> Any:
"""Emit a NeMo Flow mark for a LangGraph resume lifecycle event."""
self._emit_graph_mark(
"Graph Resume",
{
"run_id": str(event.run_id) if event.run_id is not None else None,
"status": event.status,
"checkpoint_id": event.checkpoint_id,
"checkpoint_ns": list(event.checkpoint_ns),
},
)
return None

def _emit_graph_mark(self, name: str, data: dict[str, Any]) -> None:
try:
nemo_flow.scope.event(
name,
data=_json_safe(data),
metadata={"integration": "langgraph"},
)
except Exception:
_logger.debug("NeMo Flow: LangGraph mark emission failed", exc_info=True)


__all__ = ["NemoFlowCallbackHandler"]
17 changes: 17 additions & 0 deletions python/tests/integrations/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
class TestScopeLifecycle:
"""Verify that chain start/end/error map to scope push/pop."""

def test_handler_runs_inline_for_async_callback_managers(self, handler: NemoFlowCallbackHandler) -> None:
assert handler.run_inline is True

def test_on_chain_start_pushes_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None:
run_id = uuid4()

Expand All @@ -69,6 +72,20 @@
}
assert run_id in handler._scope_handles

def test_on_chain_start_uses_callback_name(
self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock
) -> None:
run_id = uuid4()

handler.on_chain_start(
None,

Check failure on line 81 in python/tests/integrations/langchain/test_callbacks.py

View workflow job for this annotation

GitHub Actions / Check / Run

Argument to bound method `on_chain_start` is incorrect
{"input": "test"},
run_id=run_id,
name="LangGraph",
)

assert mock_nemo_flow.scope.push.call_args.args[0] == "LangGraph"

def test_on_chain_end_pops_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None:
run_id = uuid4()
handler.on_chain_start(
Expand Down
120 changes: 120 additions & 0 deletions python/tests/integrations/langgraph/test_langgraph_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Tests for the LangGraph NeMo Flow callback integration."""

from __future__ import annotations

import asyncio
from typing import Any
from uuid import uuid4

from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent
from langgraph.graph import END, START, StateGraph
from langgraph.types import Interrupt
from typing_extensions import TypedDict

import nemo_flow
from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainCallbackHandler
from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler


class State(TypedDict):
value: int


def _build_graph() -> Any:
def increment(state: State) -> State:
return {"value": state["value"] + 1}

builder = StateGraph(State)
builder.add_node("increment", increment)
builder.add_edge(START, "increment")
builder.add_edge("increment", END)
return builder.compile()


def _build_async_graph() -> Any:
async def increment(state: State) -> State:
await asyncio.sleep(0)
return {"value": state["value"] + 1}

builder = StateGraph(State)
builder.add_node("increment", increment)
builder.add_edge(START, "increment")
builder.add_edge("increment", END)
return builder.compile()


def _record_events() -> tuple[list[Any], str]:
events: list[Any] = []
subscriber_name = f"langgraph-test-{uuid4()}"
nemo_flow.subscribers.register(subscriber_name, events.append)
return events, subscriber_name


def test_langgraph_handler_builds_on_langchain_handler() -> None:
handler = NemoFlowCallbackHandler()

assert isinstance(handler, LangChainCallbackHandler)
assert isinstance(handler, GraphCallbackHandler)
assert handler.run_inline is True


def test_graph_invoke_with_callback_config_emits_named_graph_and_node_scopes() -> None:
graph = _build_graph()
events, subscriber_name = _record_events()

try:
with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent):
result = graph.invoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]})
finally:
nemo_flow.subscribers.deregister(subscriber_name)

assert result == {"value": 2}
scope_names = [event.name for event in events if event.kind == "scope" and event.scope_category == "start"]
assert scope_names == ["request", "LangGraph", "increment"]


def test_graph_ainvoke_with_callback_config_completes() -> None:
graph = _build_async_graph()

async def run_graph() -> dict[str, int]:
with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent):
return await graph.ainvoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]})

assert asyncio.run(run_graph()) == {"value": 2}


def test_graph_lifecycle_callbacks_emit_marks() -> None:
handler = NemoFlowCallbackHandler()
events, subscriber_name = _record_events()
run_id = uuid4()

try:
with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent):
handler.on_resume(
GraphResumeEvent(
run_id=run_id,
status="pending",
checkpoint_id="checkpoint-1",
checkpoint_ns=("parent", "child"),
)
)
handler.on_interrupt(
GraphInterruptEvent(
run_id=run_id,
status="interrupt_after",
checkpoint_id="checkpoint-2",
checkpoint_ns=("parent",),
interrupts=(Interrupt("needs approval", id="interrupt-1"),),
)
)
finally:
nemo_flow.subscribers.deregister(subscriber_name)

marks = [event for event in events if event.kind == "mark"]
assert [event.name for event in marks] == ["Graph Resume", "Graph Interrupt"]
assert marks[0].data["checkpoint_ns"] == ["parent", "child"]
assert marks[1].data["interrupts"] == [{"id": "interrupt-1", "value": "needs approval"}]
assert marks[1].metadata == {"integration": "langgraph"}
Loading
Loading