Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
self.passthrough_bot_output_key = output_key
self.verbose = verbose
self.config: Optional[RunnableConfig] = None
self._current_config: Optional[RunnableConfig] = None
self._current_kwargs: dict = {}

# We override the config passthrough.
config.passthrough = passthrough
Expand All @@ -74,7 +76,10 @@ async def passthrough_fn(context: dict, events: List[dict]):
# First, we fetch the input from the context
_input = context.get("passthrough_input")
async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke)
_output = await async_wrapped_invoke(_input, self.config, **self.kwargs)

# Pass the config and kwargs that were captured in the invoke method
# This ensures that callbacks (like Langfuse tracing) are properly propagated
_output = await async_wrapped_invoke(_input, self._current_config, **self._current_kwargs)

# If the output is a string, we consider it to be the output text
if isinstance(_output, str):
Expand Down Expand Up @@ -188,8 +193,12 @@ def invoke(
) -> Output:
"""Invoke this runnable synchronously."""
input_messages = self._transform_input_to_rails_format(input)
# Store config and kwargs for use in passthrough function
# This ensures callbacks are properly passed to the underlying runnable
self.config = config
self.kwargs = kwargs
self._current_config = config
self._current_kwargs = kwargs
res = self.rails.generate(
messages=input_messages, options=GenerationOptions(output_vars=True)
)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,32 @@ def log(x):
print(result)
assert "LOL" not in result["output"]
assert "can't respond" in result["output"]


def test_runnable_config_callback_passthrough():
"""Test that RunnableConfig with callbacks is properly passed to passthrough runnable."""
config_received = []

class CallbackTestRunnable(Runnable):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# Capture the config to verify callbacks were passed
config_received.append(config)
return {"output": "Test response"}

# Create a mock callback for testing
mock_callbacks = ["mock_callback"]
test_config = RunnableConfig(callbacks=mock_callbacks)

rails_config = RailsConfig.from_content(config={"models": []})
runnable_with_rails = RunnableRails(
rails_config, passthrough=True, runnable=CallbackTestRunnable()
)

# Invoke with the config containing callbacks
result = runnable_with_rails.invoke("test input", config=test_config)

# Verify that the config with callbacks was passed through
assert len(config_received) == 1
assert config_received[0] is not None
assert config_received[0].get("callbacks") == mock_callbacks
assert result == {"output": "Test response"}
Loading