Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""

import asyncio
import time
from contextvars import ContextVar
from typing import Annotated, Any, Dict, List, Set, Tuple

from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
from ldai.providers.types import LDAIMetrics
from ldai.providers import AgentGraphRunner, ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics

from ldai_langchain.langchain_helper import (
build_structured_tools,
Expand All @@ -18,9 +16,6 @@
)
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler

# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')


def _make_handoff_tool(child_key: str, description: str) -> Any:
"""
Expand Down Expand Up @@ -65,9 +60,10 @@ class LangGraphAgentGraphRunner(AgentGraphRunner):

AgentGraphRunner implementation for LangGraph.

Compiles and runs the agent graph with LangGraph and automatically records
graph- and node-level AI metric data to the LaunchDarkly trackers on the
graph definition and each node.
Compiles and runs the agent graph with LangGraph and collects graph- and
node-level metrics via a LangChain callback handler. Tracking events are
emitted by the managed layer (:class:`~ldai.ManagedAgentGraph`) from the
returned :class:`~ldai.providers.types.AgentGraphRunnerResult`.

Requires ``langgraph`` to be installed.
"""
Expand Down Expand Up @@ -181,26 +177,6 @@ async def invoke(state: WorkflowState) -> dict:
if node_instructions:
msgs = [SystemMessage(content=node_instructions)] + msgs
response = await bound_model.ainvoke(msgs)

node_obj = self._graph.get_node(nk)
if node_obj is not None:
input_text = '\r\n'.join(
m.content if isinstance(m.content, str) else str(m.content)
for m in msgs
) if msgs else ''
output_text = (
response.content if hasattr(response, 'content') else str(response)
)
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
run_tasks = _run_eval_tasks.get(None)
if run_tasks is not None:
run_tasks.setdefault(nk, []).append(task)
else:
log.warning(
f"LangGraphAgentGraphRunner: eval task for node '{nk}' "
"has no run context; judge results will not be tracked"
)

return {'messages': [response]}

invoke.__name__ = nk
Expand Down Expand Up @@ -298,20 +274,18 @@ def route(state: WorkflowState) -> str:
compiled = agent_builder.compile()
return compiled, fn_name_to_config_key, node_keys

async def run(self, input: Any) -> AgentGraphResult:
async def run(self, input: Any) -> AgentGraphRunnerResult:
"""
Run the agent graph with the given input.

Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
it, and invokes it. Uses a LangChain callback handler to collect
per-node metrics, then flushes them to LaunchDarkly trackers.
per-node metrics. Graph-level tracking events are emitted by the
managed layer from the returned GraphMetrics.

:param input: The string prompt to send to the agent graph
:return: AgentGraphResult with the final output and metrics
:return: AgentGraphRunnerResult with the final content and GraphMetrics
"""
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
token = _run_eval_tasks.set(pending_eval_tasks)
tracker = self._graph.create_tracker()
start_ns = time.perf_counter_ns()

try:
Expand All @@ -325,24 +299,23 @@ async def run(self, input: Any) -> AgentGraphResult:
config={'callbacks': [handler], 'recursion_limit': 25},
)

duration = (time.perf_counter_ns() - start_ns) // 1_000_000
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
messages = result.get('messages', [])
output = extract_last_message_content(messages)
total_usage = sum_token_usage_from_messages(messages)

# Flush per-node metrics to LD trackers; eval results are tracked
# internally and intentionally not exposed on AgentGraphResult here
# — judge dispatch is the managed layer's responsibility.
await handler.flush(self._graph, pending_eval_tasks)

tracker.track_path(handler.path)
tracker.track_duration(duration)
tracker.track_invocation_success()
tracker.track_total_tokens(sum_token_usage_from_messages(messages))
node_metrics = handler.node_metrics

return AgentGraphResult(
output=output,
return AgentGraphRunnerResult(
content=output,
raw=result,
metrics=LDAIMetrics(success=True),
metrics=GraphMetrics(
success=True,
path=handler.path,
duration_ms=duration_ms,
usage=total_usage if (total_usage is not None and total_usage.total > 0) else None,
node_metrics=node_metrics,
),
)

except Exception as exc:
Expand All @@ -353,13 +326,12 @@ async def run(self, input: Any) -> AgentGraphResult:
)
else:
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
tracker.track_duration(duration)
tracker.track_invocation_failure()
return AgentGraphResult(
output='',
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
return AgentGraphRunnerResult(
content='',
raw=None,
metrics=LDAIMetrics(success=False),
metrics=GraphMetrics(
success=False,
duration_ms=duration_ms,
),
)
finally:
_run_eval_tasks.reset(token)
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import ChatGeneration, LLMResult
from ldai.agent_graph import AgentGraphDefinition
from ldai.providers.types import JudgeResult
from ldai.providers.types import LDAIMetrics
from ldai.tracker import TokenUsage

from ldai_langchain.langchain_helper import get_ai_usage_from_response
Expand All @@ -20,8 +19,10 @@ class LDMetricsCallbackHandler(BaseCallbackHandler):

LangChain callback handler that collects per-node metrics during a LangGraph run.

Records token usage, tool calls, and duration for each agent node in the graph,
then flushes them to LaunchDarkly trackers after the run completes via ``flush()``.
Records token usage, tool calls, and duration for each agent node in the graph.
Each node's :class:`~ldai.providers.types.LDAIMetrics` is built incrementally
as callbacks fire. Access the ``node_metrics`` property after the run completes
to retrieve the accumulated per-node metrics.
"""

def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
Expand All @@ -39,14 +40,10 @@ def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):

# run_id -> node_key for active chain runs
self._run_to_node: Dict[UUID, str] = {}
# accumulated token usage per node
self._node_tokens: Dict[str, TokenUsage] = {}
# tool config keys called per node
self._node_tool_calls: Dict[str, List[str]] = {}
# start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes
self._node_start_ns: Dict[UUID, int] = {}
# accumulated duration (ms) per node
self._node_duration_ms: Dict[str, int] = {}
# per-node metrics, built incrementally as callbacks fire
self._node_metrics: Dict[str, LDAIMetrics] = {}
# execution path in order (deduplicated)
self._path: List[str] = []
self._path_set: Set[str] = set()
Expand All @@ -61,19 +58,9 @@ def path(self) -> List[str]:
return list(self._path)

@property
def node_tokens(self) -> Dict[str, TokenUsage]:
"""Accumulated token usage per node key."""
return dict(self._node_tokens)

@property
def node_tool_calls(self) -> Dict[str, List[str]]:
"""Tool config keys called per node key."""
return {k: list(v) for k, v in self._node_tool_calls.items()}

@property
def node_durations_ms(self) -> Dict[str, int]:
"""Accumulated duration in milliseconds per node key."""
return dict(self._node_duration_ms)
def node_metrics(self) -> Dict[str, LDAIMetrics]:
"""Per-node metrics keyed by node key."""
return dict(self._node_metrics)

# ------------------------------------------------------------------
# Callbacks
Expand Down Expand Up @@ -101,10 +88,10 @@ def on_chain_start(
if name not in self._path_set:
self._path.append(name)
self._path_set.add(name)
self._node_metrics[name] = LDAIMetrics(success=False)
elif name.endswith('__tools'):
stripped = name[: -len('__tools')]
if stripped in self._node_keys:
# Attribute tool events to the owning agent node
self._run_to_node[run_id] = stripped

def on_chain_end(
Expand All @@ -121,9 +108,10 @@ def on_chain_end(
start_ns = self._node_start_ns.pop(run_id, None)
if start_ns is not None:
elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
self._node_duration_ms[node_key] = (
self._node_duration_ms.get(node_key, 0) + elapsed_ms
)
metrics = self._node_metrics.get(node_key)
if metrics is not None:
metrics.success = True
metrics.duration_ms = (metrics.duration_ms or 0) + elapsed_ms

def on_llm_end(
self,
Expand Down Expand Up @@ -151,11 +139,14 @@ def on_llm_end(
if usage is None:
return

existing = self._node_tokens.get(node_key)
metrics = self._node_metrics.get(node_key)
if metrics is None:
return
existing = metrics.usage
if existing is None:
self._node_tokens[node_key] = usage
metrics.usage = usage
else:
self._node_tokens[node_key] = TokenUsage(
metrics.usage = TokenUsage(
total=existing.total + usage.total,
input=existing.input + usage.input,
output=existing.output + usage.output,
Expand All @@ -179,64 +170,11 @@ def on_tool_end(

config_key = self._fn_name_to_config_key.get(name)
if config_key is None:
# Tool is not a registered functional tool (e.g. a handoff tool) — skip tracking.
return
if node_key not in self._node_tool_calls:
self._node_tool_calls[node_key] = []
self._node_tool_calls[node_key].append(config_key)

# ------------------------------------------------------------------
# Flush
# ------------------------------------------------------------------

async def flush(
self, graph: AgentGraphDefinition, eval_tasks=None
) -> List[JudgeResult]:
"""
Emit all collected per-node metrics to the LaunchDarkly trackers.

Call this once after the graph run completes.

:param graph: The AgentGraphDefinition whose nodes hold the LD config trackers.
:param eval_tasks: Optional dict mapping node key to a list of awaitables that
return judge evaluation results. Multiple tasks arise when a node is visited
more than once (e.g. in a graph with cycles).
:return: All judge results collected across all nodes.
"""
node_trackers: Dict[str, Any] = {}
all_eval_results: List[JudgeResult] = []
for node_key in self._path:
if node_key in node_trackers:
continue
node = graph.get_node(node_key)
if not node:
continue
config_tracker = node.get_config().create_tracker()
if not config_tracker:
continue
node_trackers[node_key] = config_tracker

usage = self._node_tokens.get(node_key)
if usage:
config_tracker.track_tokens(usage)

duration = self._node_duration_ms.get(node_key)
if duration is not None:
config_tracker.track_duration(duration)

config_tracker.track_success()

for tool_key in self._node_tool_calls.get(node_key, []):
config_tracker.track_tool_call(tool_key)

if not eval_tasks:
continue

for eval_task in eval_tasks.get(node_key, []):
results = await eval_task
all_eval_results.extend(results)
for r in results:
if r.success:
config_tracker.track_judge_result(r)

return all_eval_results
metrics = self._node_metrics.get(node_key)
if metrics is None:
return
if metrics.tool_calls is None:
metrics.tool_calls = [config_key]
else:
metrics.tool_calls.append(config_key)
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Tests for LangChain Provider."""

import pytest
from unittest.mock import AsyncMock, MagicMock

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from ldai import LDMessage
from ldai.evaluator import Evaluator

Expand Down Expand Up @@ -404,6 +403,7 @@ class TestCreateAgent:
def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
"""Should create LangChainAgentRunner wrapping a compiled graph."""
from unittest.mock import patch

from ldai_langchain import LangChainAgentRunner

mock_ai_config = MagicMock()
Expand Down Expand Up @@ -436,6 +436,7 @@ def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
def test_creates_agent_runner_with_no_tools(self):
"""Should create LangChainAgentRunner with no tool definitions."""
from unittest.mock import patch

from ldai_langchain import LangChainAgentRunner

mock_ai_config = MagicMock()
Expand Down Expand Up @@ -522,6 +523,7 @@ class TestBuildTools:

def test_registers_sync_callable_as_structured_tool_func(self):
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig

from ldai_langchain.langchain_helper import build_structured_tools

def sync_tool(x: str = '') -> str:
Expand All @@ -546,6 +548,7 @@ def sync_tool(x: str = '') -> str:

def test_registers_async_callable_as_structured_tool_coroutine(self):
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig

from ldai_langchain.langchain_helper import build_structured_tools

async def async_tool(x: str = '') -> str:
Expand Down
Loading
Loading