Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "uipath-langchain"
version = "0.1.41"
version = "0.1.42"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
Expand Down
24 changes: 19 additions & 5 deletions src/uipath_langchain/agent/guardrails/guardrail_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from uipath_langchain.agent.guardrails.types import ExecutionStage
from uipath_langchain.agent.guardrails.utils import (
_extract_tool_input_data,
_extract_tool_args_from_message,
_extract_tool_output_data,
_extract_tools_args_from_message,
get_message_content,
)
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
Expand Down Expand Up @@ -188,7 +189,11 @@ def create_llm_guardrail_node(
def _payload_generator(state: AgentGuardrailsGraphState) -> str:
if not state.messages:
return ""
return get_message_content(state.messages[-1])
match execution_stage:
case ExecutionStage.PRE_EXECUTION:
return get_message_content(state.messages[-1])
case ExecutionStage.POST_EXECUTION:
return json.dumps(_extract_tools_args_from_message(state.messages[-1]))

return _create_guardrail_node(
guardrail,
Expand Down Expand Up @@ -273,16 +278,25 @@ def _payload_generator(state: AgentGuardrailsGraphState) -> str:
return ""

if execution_stage == ExecutionStage.PRE_EXECUTION:
# Extract tool args as dict and convert to JSON string
args_dict = _extract_tool_input_data(state, tool_name, execution_stage)
last_message = state.messages[-1]
args_dict = _extract_tool_args_from_message(last_message, tool_name)
if args_dict:
return json.dumps(args_dict)

return get_message_content(state.messages[-1])

# Create closures for input/output data extraction (for deterministic guardrails)
def _input_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
return _extract_tool_input_data(state, tool_name, execution_stage)
if execution_stage == ExecutionStage.PRE_EXECUTION:
if len(state.messages) < 1:
return {}
message = state.messages[-1]
else: # POST_EXECUTION
if len(state.messages) < 2:
return {}
message = state.messages[-2]

return _extract_tool_args_from_message(message, tool_name)

def _output_data_extractor(state: AgentGuardrailsGraphState) -> dict[str, Any]:
return _extract_tool_output_data(state)
Expand Down
58 changes: 27 additions & 31 deletions src/uipath_langchain/agent/guardrails/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
ToolMessage,
)

from uipath_langchain.agent.guardrails.types import ExecutionStage
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState
from uipath_langchain.agent.tools.utils import sanitize_tool_name

Expand Down Expand Up @@ -61,45 +60,42 @@ def _extract_tool_args_from_message(
return parsed
except json.JSONDecodeError:
logger.warning(
"Failed to parse tool args as JSON for tool '%s': %s",
tool_name,
args[:100] if len(args) > 100 else args,
"Failed to parse tool args as JSON for tool '%s'", tool_name
)
return {}

return {}


def _extract_tool_input_data(
state: AgentGuardrailsGraphState, tool_name: str, execution_stage: ExecutionStage
) -> dict[str, Any]:
"""Extract tool call arguments as dict for deterministic guardrails.
def _extract_tools_args_from_message(message: AnyMessage) -> list[dict[str, Any]]:
if not isinstance(message, AIMessage):
return []

Args:
state: The current agent graph state.
tool_name: Name of the tool to extract arguments from.
execution_stage: PRE_EXECUTION or POST_EXECUTION.
if not message.tool_calls:
return []

Returns:
Dict containing tool call arguments, or empty dict if not found.
- For PRE_EXECUTION: extracts from last message
- For POST_EXECUTION: extracts from second-to-last message
"""
if not state.messages:
return {}
result: list[dict[str, Any]] = []

# For PRE_EXECUTION, look at last message
# For POST_EXECUTION, look at second-to-last message (before the ToolMessage)
if execution_stage == ExecutionStage.PRE_EXECUTION:
if len(state.messages) < 1:
return {}
message = state.messages[-1]
else: # POST_EXECUTION
if len(state.messages) < 2:
return {}
message = state.messages[-2]

return _extract_tool_args_from_message(message, tool_name)
for tool_call in message.tool_calls:
args = (
tool_call.get("args")
if isinstance(tool_call, dict)
else getattr(tool_call, "args", None)
)
if args is not None:
# Args should already be a dict
if isinstance(args, dict):
result.append(args)
# If it's a JSON string, parse it
elif isinstance(args, str):
try:
parsed = json.loads(args)
if isinstance(parsed, dict):
result.append(parsed)
except json.JSONDecodeError:
logger.warning("Failed to parse tool args as JSON")

return result


def _extract_tool_output_data(state: AgentGuardrailsGraphState) -> dict[str, Any]:
Expand Down
153 changes: 153 additions & 0 deletions tests/agent/guardrails/test_extraction_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Tests for guardrail utility functions."""

import json

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from uipath_langchain.agent.guardrails.utils import (
_extract_tool_args_from_message,
_extract_tool_output_data,
_extract_tools_args_from_message,
get_message_content,
)
from uipath_langchain.agent.react.types import AgentGuardrailsGraphState


class TestExtractToolArgsFromMessage:
"""Tests for _extract_tool_args_from_message function."""

def test_extracts_args_from_matching_tool(self):
"""Should extract args from matching tool call."""
message = AIMessage(
content="",
tool_calls=[
{
"name": "test_tool",
"args": {"param1": "value1", "param2": 123},
"id": "call_1",
}
],
)
result = _extract_tool_args_from_message(message, "test_tool")
assert result == {"param1": "value1", "param2": 123}

def test_returns_empty_dict_for_non_matching_tool(self):
"""Should return empty dict when tool name doesn't match."""
message = AIMessage(
content="",
tool_calls=[
{"name": "other_tool", "args": {"data": "value"}, "id": "call_1"}
],
)
result = _extract_tool_args_from_message(message, "test_tool")
assert result == {}

def test_returns_empty_dict_for_non_ai_message(self):
"""Should return empty dict when message is not AIMessage."""
message = HumanMessage(content="Test message")
result = _extract_tool_args_from_message(message, "test_tool")
assert result == {}

def test_returns_first_matching_tool_when_multiple(self):
"""Should return args from first matching tool call."""
message = AIMessage(
content="",
tool_calls=[
{"name": "test_tool", "args": {"first": "call"}, "id": "call_1"},
{"name": "test_tool", "args": {"second": "call"}, "id": "call_2"},
],
)
result = _extract_tool_args_from_message(message, "test_tool")
assert result == {"first": "call"}


class TestExtractToolsArgsFromMessage:
"""Tests for _extract_tools_args_from_message function."""

def test_extracts_args_from_all_tool_calls(self):
"""Should extract args from all tool calls."""
message = AIMessage(
content="",
tool_calls=[
{"name": "tool1", "args": {"arg1": "val1"}, "id": "call_1"},
{"name": "tool2", "args": {"arg2": "val2"}, "id": "call_2"},
{"name": "tool3", "args": {"arg3": "val3"}, "id": "call_3"},
],
)
result = _extract_tools_args_from_message(message)
assert result == [{"arg1": "val1"}, {"arg2": "val2"}, {"arg3": "val3"}]

def test_returns_empty_list_for_non_ai_message(self):
"""Should return empty list when message is not AIMessage."""
message = HumanMessage(content="Test message")
result = _extract_tools_args_from_message(message)
assert result == []

def test_returns_empty_list_when_no_tool_calls(self):
"""Should return empty list when AIMessage has no tool calls."""
message = AIMessage(content="Test response")
result = _extract_tools_args_from_message(message)
assert result == []


class TestExtractToolOutputData:
"""Tests for _extract_tool_output_data function."""

def test_extracts_json_dict_content(self):
"""Should parse and return dict when content is JSON string."""
json_content = json.dumps({"result": "success", "data": {"value": 42}})
state = AgentGuardrailsGraphState(
messages=[ToolMessage(content=json_content, tool_call_id="call_1")]
)
result = _extract_tool_output_data(state)
assert result == {"result": "success", "data": {"value": 42}}

def test_wraps_non_json_string_in_output_field(self):
"""Should wrap non-JSON string content in 'output' field."""
state = AgentGuardrailsGraphState(
messages=[ToolMessage(content="Plain text result", tool_call_id="call_1")]
)
result = _extract_tool_output_data(state)
assert result == {"output": "Plain text result"}

def test_returns_empty_dict_for_empty_messages(self):
"""Should return empty dict when state has no messages."""
state = AgentGuardrailsGraphState(messages=[])
result = _extract_tool_output_data(state)
assert result == {}

def test_returns_empty_dict_for_non_tool_message(self):
"""Should return empty dict when last message is not ToolMessage."""
state = AgentGuardrailsGraphState(
messages=[AIMessage(content="Not a tool message")]
)
result = _extract_tool_output_data(state)
assert result == {}


class TestGetMessageContent:
"""Tests for get_message_content function."""

def test_extracts_string_content_from_human_message(self):
"""Should extract string content from HumanMessage."""
message = HumanMessage(content="Hello from human")
result = get_message_content(message)
assert result == "Hello from human"

def test_extracts_content_from_ai_message(self):
"""Should extract content from AIMessage."""
message = AIMessage(content="AI response")
result = get_message_content(message)
assert result == "AI response"

def test_extracts_content_from_tool_message(self):
"""Should extract content from ToolMessage."""
message = ToolMessage(content="Tool result", tool_call_id="call_1")
result = get_message_content(message)
assert result == "Tool result"

def test_handles_empty_content(self):
"""Should handle empty content string."""
message = AIMessage(content="")
result = get_message_content(message)
assert result == ""
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading