diff --git a/pyproject.toml b/pyproject.toml index 7adcc8c..2414cd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,9 +79,13 @@ langchain = [ "langchain-core", ] +langgraph = [ + "nemo-flow[langchain]", + "langgraph>=1.1.10,<2.0.0", +] + langchain-nvidia = [ - "langchain>=1.2.11,<2.0.0", - "langchain-core", + "nemo-flow[langchain]", "langchain-nvidia-ai-endpoints~=1.0", ] diff --git a/python/nemo_flow/integrations/langchain/callbacks.py b/python/nemo_flow/integrations/langchain/callbacks.py index 42082b7..1b2e343 100644 --- a/python/nemo_flow/integrations/langchain/callbacks.py +++ b/python/nemo_flow/integrations/langchain/callbacks.py @@ -22,6 +22,8 @@ class NemoFlowCallbackHandler(BaseCallbackHandler): """Bridge LangChain chain run IDs to NeMo Flow Agent scopes.""" + run_inline = True + def __init__(self) -> None: super().__init__() self._scope_handles: dict[UUID, typing.Any] = {} @@ -39,9 +41,10 @@ def on_chain_start( ) -> typing.Any: """Push a NeMo Flow Agent scope for a LangChain chain run.""" try: - name: str | None = None + name = kwargs.get("name") + if serialized is not None: - name = serialized.get("name") + name = name or serialized.get("name") if name is None: id_list = serialized.get("id") if isinstance(id_list, list) and len(id_list) > 0: diff --git a/python/nemo_flow/integrations/langgraph/README.md b/python/nemo_flow/integrations/langgraph/README.md new file mode 100644 index 0000000..c54ca15 --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/README.md @@ -0,0 +1,79 @@ + + +# NeMo Flow LangGraph Integration + +This directory contains the `nemo_flow.integrations.langgraph` package, which provides public-API LangGraph integration for NeMo Flow. + +The integration builds on `nemo_flow.integrations.langchain`: `NemoFlowCallbackHandler` inherits the LangChain callback handler, and `NemoFlowMiddleware` is re-exported for LangChain agents used inside LangGraph workflows. + +For an alternate approach refer to [the patch-based integration in `third_party/langchain`](../../../../third_party/README-langgraph.md). + +## Setup + +```bash +uv sync --all-groups --extra langgraph +just build-python +``` + +Installing the `langgraph` extra also installs the LangChain integration dependencies. + +## Usage Example + +```python +from typing_extensions import TypedDict + +import nemo_flow +from langgraph.graph import END, START, StateGraph +from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler + + +class State(TypedDict): + value: int + + +def increment(state: State) -> State: + return {"value": state["value"] + 1} + + +builder = StateGraph(State) +builder.add_node("increment", increment) +builder.add_edge(START, "increment") +builder.add_edge("increment", END) + +graph = builder.compile() + +with nemo_flow.scope.scope("langgraph-request", nemo_flow.ScopeType.Agent): + result = graph.invoke( + {"value": 1}, + config={"callbacks": [NemoFlowCallbackHandler()]}, + ) + +print(result) +``` + +For LangChain agents inside a LangGraph workflow, use `NemoFlowMiddleware` from this package the same way as the LangChain integration and pass the LangGraph `config` into the nested agent call: + +```python +from langchain.agents import create_agent +from langchain_core.runnables import RunnableConfig +from nemo_flow.integrations.langgraph import NemoFlowMiddleware + +agent = create_agent( + model="nvidia:nvidia/nemotron-3-nano-30b-a3b", + tools=[], + middleware=[NemoFlowMiddleware()], +) + + +def agent_node(state: dict, config: RunnableConfig) -> dict: + return agent.invoke({"messages": state["messages"]}, config=config) +``` + +## Validation + +```bash +uv run pytest python/tests/integrations/langgraph +``` diff --git a/python/nemo_flow/integrations/langgraph/__init__.py b/python/nemo_flow/integrations/langgraph/__init__.py new file mode 100644 index 0000000..d95db5b --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/__init__.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""NeMo Flow integrations for LangGraph.""" + +from nemo_flow.integrations.langchain import NemoFlowMiddleware +from nemo_flow.integrations.langgraph.callbacks import NemoFlowCallbackHandler + +__all__ = [ + "NemoFlowCallbackHandler", + "NemoFlowMiddleware", +] diff --git a/python/nemo_flow/integrations/langgraph/callbacks.py b/python/nemo_flow/integrations/langgraph/callbacks.py new file mode 100644 index 0000000..6c069a1 --- /dev/null +++ b/python/nemo_flow/integrations/langgraph/callbacks.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""LangGraph callback handler that reuses the LangChain NeMo Flow integration.""" + +from __future__ import annotations + +import logging +from typing import Any + +from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent + +import nemo_flow +from nemo_flow.integrations.langchain._serialization import _prepare_outputs +from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainNemoFlowCallbackHandler + +_logger = logging.getLogger(__name__) + + +def _json_safe(value: Any) -> nemo_flow.Json: + """Return a conservative JSON-compatible representation for mark payloads.""" + try: + value = _prepare_outputs(value) + except Exception: + pass + + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, list | tuple | set): + return [_json_safe(item) for item in value] + return repr(value) + + +def _interrupt_to_payload(interrupt: Any) -> dict[str, nemo_flow.Json]: + return { + "id": _json_safe(getattr(interrupt, "id", None)), + "value": _json_safe(getattr(interrupt, "value", interrupt)), + } + + +class NemoFlowCallbackHandler(LangChainNemoFlowCallbackHandler, GraphCallbackHandler): + """ + Bridge LangChain and LangGraph runs to NeMo Flow using public callback APIs. + + This handler inherits the existing LangChain callback integration, so normal + runnable scopes from LangGraph and LangChain are recorded by the same code + path. It also implements LangGraph's public lifecycle callback hooks for + interrupt and resume marks. + """ + + def on_interrupt(self, event: GraphInterruptEvent) -> Any: + """Emit a NeMo Flow mark for a LangGraph interrupt lifecycle event.""" + self._emit_graph_mark( + "Graph Interrupt", + { + "run_id": str(event.run_id) if event.run_id is not None else None, + "status": event.status, + "checkpoint_id": event.checkpoint_id, + "checkpoint_ns": list(event.checkpoint_ns), + "interrupts": [_interrupt_to_payload(interrupt) for interrupt in event.interrupts], + }, + ) + return None + + def on_resume(self, event: GraphResumeEvent) -> Any: + """Emit a NeMo Flow mark for a LangGraph resume lifecycle event.""" + self._emit_graph_mark( + "Graph Resume", + { + "run_id": str(event.run_id) if event.run_id is not None else None, + "status": event.status, + "checkpoint_id": event.checkpoint_id, + "checkpoint_ns": list(event.checkpoint_ns), + }, + ) + return None + + def _emit_graph_mark(self, name: str, data: dict[str, Any]) -> None: + try: + nemo_flow.scope.event( + name, + data=_json_safe(data), + metadata={"integration": "langgraph"}, + ) + except Exception: + _logger.debug("NeMo Flow: LangGraph mark emission failed", exc_info=True) + + +__all__ = ["NemoFlowCallbackHandler"] diff --git a/python/tests/integrations/langchain/test_callbacks.py b/python/tests/integrations/langchain/test_callbacks.py index e95732d..4669bbb 100644 --- a/python/tests/integrations/langchain/test_callbacks.py +++ b/python/tests/integrations/langchain/test_callbacks.py @@ -49,6 +49,9 @@ def handler(mock_nemo_flow: MagicMock) -> NemoFlowCallbackHandler: class TestScopeLifecycle: """Verify that chain start/end/error map to scope push/pop.""" + def test_handler_runs_inline_for_async_callback_managers(self, handler: NemoFlowCallbackHandler) -> None: + assert handler.run_inline is True + def test_on_chain_start_pushes_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None: run_id = uuid4() @@ -69,6 +72,20 @@ def test_on_chain_start_pushes_scope(self, handler: NemoFlowCallbackHandler, moc } assert run_id in handler._scope_handles + def test_on_chain_start_uses_callback_name( + self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock + ) -> None: + run_id = uuid4() + + handler.on_chain_start( + None, + {"input": "test"}, + run_id=run_id, + name="LangGraph", + ) + + assert mock_nemo_flow.scope.push.call_args.args[0] == "LangGraph" + def test_on_chain_end_pops_scope(self, handler: NemoFlowCallbackHandler, mock_nemo_flow: MagicMock) -> None: run_id = uuid4() handler.on_chain_start( diff --git a/python/tests/integrations/langgraph/test_langgraph_integration.py b/python/tests/integrations/langgraph/test_langgraph_integration.py new file mode 100644 index 0000000..b0f9031 --- /dev/null +++ b/python/tests/integrations/langgraph/test_langgraph_integration.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the LangGraph NeMo Flow callback integration.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from uuid import uuid4 + +from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent +from langgraph.graph import END, START, StateGraph +from langgraph.types import Interrupt +from typing_extensions import TypedDict + +import nemo_flow +from nemo_flow.integrations.langchain.callbacks import NemoFlowCallbackHandler as LangChainCallbackHandler +from nemo_flow.integrations.langgraph import NemoFlowCallbackHandler + + +class State(TypedDict): + value: int + + +def _build_graph() -> Any: + def increment(state: State) -> State: + return {"value": state["value"] + 1} + + builder = StateGraph(State) + builder.add_node("increment", increment) + builder.add_edge(START, "increment") + builder.add_edge("increment", END) + return builder.compile() + + +def _build_async_graph() -> Any: + async def increment(state: State) -> State: + await asyncio.sleep(0) + return {"value": state["value"] + 1} + + builder = StateGraph(State) + builder.add_node("increment", increment) + builder.add_edge(START, "increment") + builder.add_edge("increment", END) + return builder.compile() + + +def _record_events() -> tuple[list[Any], str]: + events: list[Any] = [] + subscriber_name = f"langgraph-test-{uuid4()}" + nemo_flow.subscribers.register(subscriber_name, events.append) + return events, subscriber_name + + +def test_langgraph_handler_builds_on_langchain_handler() -> None: + handler = NemoFlowCallbackHandler() + + assert isinstance(handler, LangChainCallbackHandler) + assert isinstance(handler, GraphCallbackHandler) + assert handler.run_inline is True + + +def test_graph_invoke_with_callback_config_emits_named_graph_and_node_scopes() -> None: + graph = _build_graph() + events, subscriber_name = _record_events() + + try: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + result = graph.invoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]}) + finally: + nemo_flow.subscribers.deregister(subscriber_name) + + assert result == {"value": 2} + scope_names = [event.name for event in events if event.kind == "scope" and event.scope_category == "start"] + assert scope_names == ["request", "LangGraph", "increment"] + + +def test_graph_ainvoke_with_callback_config_completes() -> None: + graph = _build_async_graph() + + async def run_graph() -> dict[str, int]: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + return await graph.ainvoke({"value": 1}, config={"callbacks": [NemoFlowCallbackHandler()]}) + + assert asyncio.run(run_graph()) == {"value": 2} + + +def test_graph_lifecycle_callbacks_emit_marks() -> None: + handler = NemoFlowCallbackHandler() + events, subscriber_name = _record_events() + run_id = uuid4() + + try: + with nemo_flow.scope.scope("request", nemo_flow.ScopeType.Agent): + handler.on_resume( + GraphResumeEvent( + run_id=run_id, + status="pending", + checkpoint_id="checkpoint-1", + checkpoint_ns=("parent", "child"), + ) + ) + handler.on_interrupt( + GraphInterruptEvent( + run_id=run_id, + status="interrupt_after", + checkpoint_id="checkpoint-2", + checkpoint_ns=("parent",), + interrupts=(Interrupt("needs approval", id="interrupt-1"),), + ) + ) + finally: + nemo_flow.subscribers.deregister(subscriber_name) + + marks = [event for event in events if event.kind == "mark"] + assert [event.name for event in marks] == ["Graph Resume", "Graph Interrupt"] + assert marks[0].data["checkpoint_ns"] == ["parent", "child"] + assert marks[1].data["interrupts"] == [{"id": "interrupt-1", "value": "needs approval"}] + assert marks[1].metadata == {"integration": "langgraph"} diff --git a/uv.lock b/uv.lock index 1777558..c7474a6 100644 --- a/uv.lock +++ b/uv.lock @@ -1182,6 +1182,11 @@ langchain-nvidia = [ { name = "langchain-core" }, { name = "langchain-nvidia-ai-endpoints" }, ] +langgraph = [ + { name = "langchain" }, + { name = "langchain-core" }, + { name = "langgraph" }, +] [package.dev-dependencies] dev = [ @@ -1217,11 +1222,14 @@ test = [ requires-dist = [ { name = "langchain", marker = "extra == 'langchain'", specifier = ">=1.2.11,<2.0.0" }, { name = "langchain", marker = "extra == 'langchain-nvidia'", specifier = ">=1.2.11,<2.0.0" }, + { name = "langchain", marker = "extra == 'langgraph'", specifier = ">=1.2.11,<2.0.0" }, { name = "langchain-core", marker = "extra == 'langchain'" }, { name = "langchain-core", marker = "extra == 'langchain-nvidia'" }, + { name = "langchain-core", marker = "extra == 'langgraph'" }, { name = "langchain-nvidia-ai-endpoints", marker = "extra == 'langchain-nvidia'", specifier = "~=1.0" }, + { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.10,<2.0.0" }, ] -provides-extras = ["langchain", "langchain-nvidia"] +provides-extras = ["langchain", "langchain-nvidia", "langgraph"] [package.metadata.requires-dev] dev = [