Skip to content

Commit e92cf26

Browse files
committed
fix(llm): Add async streaming support to ChatNVIDIA provider patch (#1504)
* feat(llm): Add async streaming support to ChatNVIDIA provider Enables stream_async() to work with ChatNVIDIA/NIM models by implementing async streaming decorator and _agenerate method. Prior to this fix, stream_async() would fail with NIM engine configurations. * fix: ensure stream_async background task completes before exit (#1508) Wrap the returned iterator to await the background generation task in a finally block, preventing "Task was destroyed but it is pending" warning. Add overloaded type signatures to provide accurate return types based on the include_generation_metadata parameter.
1 parent d2bfaea commit e92cf26

File tree

3 files changed

+486
-6
lines changed

3 files changed

+486
-6
lines changed

nemoguardrails/llm/providers/_langchain_nvidia_ai_endpoints_patch.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818
from functools import wraps
1919
from typing import Any, Dict, List, Optional
2020

21-
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
22-
from langchain_core.language_models.chat_models import generate_from_stream
21+
from langchain_core.callbacks.manager import (
22+
AsyncCallbackManagerForLLMRun,
23+
CallbackManagerForLLMRun,
24+
)
25+
from langchain_core.language_models.chat_models import (
26+
agenerate_from_stream,
27+
generate_from_stream,
28+
)
2329
from langchain_core.messages import BaseMessage
2430
from langchain_core.outputs import ChatResult
2531
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
@@ -50,6 +56,28 @@ def wrapper(
5056
return wrapper
5157

5258

59+
def async_stream_decorator(func): # pragma: no cover
60+
@wraps(func)
61+
async def wrapper(
62+
self,
63+
messages: List[BaseMessage],
64+
stop: Optional[List[str]] = None,
65+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
66+
stream: Optional[bool] = None,
67+
**kwargs: Any,
68+
) -> ChatResult:
69+
should_stream = stream if stream is not None else self.streaming
70+
if should_stream:
71+
stream_iter = self._astream(
72+
messages, stop=stop, run_manager=run_manager, **kwargs
73+
)
74+
return await agenerate_from_stream(stream_iter)
75+
else:
76+
return await func(self, messages, stop, run_manager, **kwargs)
77+
78+
return wrapper
79+
80+
5381
# NOTE: this needs to have the same name as the original class,
5482
# otherwise, there's a check inside `langchain-nvidia-ai-endpoints` that will fail.
5583
class ChatNVIDIA(ChatNVIDIAOriginal): # pragma: no cover
@@ -108,6 +136,21 @@ def _generate(
108136
**kwargs: Any,
109137
) -> ChatResult:
110138
return super()._generate(
139+
messages=messages,
140+
stop=stop,
141+
run_manager=run_manager,
142+
**kwargs,
143+
)
144+
145+
@async_stream_decorator
146+
async def _agenerate(
147+
self,
148+
messages: List[BaseMessage],
149+
stop: Optional[List[str]] = None,
150+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
151+
**kwargs: Any,
152+
) -> ChatResult:
153+
return await super()._agenerate(
111154
messages=messages, stop=stop, run_manager=run_manager, **kwargs
112155
)
113156

nemoguardrails/rails/llm/llmrails.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
Callable,
3131
Dict,
3232
List,
33+
Literal,
3334
Optional,
3435
Tuple,
3536
Type,
3637
Union,
3738
cast,
39+
overload,
3840
)
3941

4042
from langchain_core.language_models import BaseChatModel
@@ -1255,15 +1257,39 @@ def _validate_streaming_with_output_rails(self) -> None:
12551257
"generate_async() instead of stream_async()."
12561258
)
12571259

1260+
@overload
12581261
def stream_async(
12591262
self,
12601263
prompt: Optional[str] = None,
12611264
messages: Optional[List[dict]] = None,
12621265
options: Optional[Union[dict, GenerationOptions]] = None,
12631266
state: Optional[Union[dict, State]] = None,
1264-
include_generation_metadata: Optional[bool] = False,
1267+
include_generation_metadata: Literal[False] = False,
12651268
generator: Optional[AsyncIterator[str]] = None,
12661269
) -> AsyncIterator[str]:
1270+
...
1271+
1272+
@overload
1273+
def stream_async(
1274+
self,
1275+
prompt: Optional[str] = None,
1276+
messages: Optional[List[dict]] = None,
1277+
options: Optional[Union[dict, GenerationOptions]] = None,
1278+
state: Optional[Union[dict, State]] = None,
1279+
include_generation_metadata: Literal[True] = ...,
1280+
generator: Optional[AsyncIterator[str]] = None,
1281+
) -> AsyncIterator[Union[str, dict]]:
1282+
...
1283+
1284+
def stream_async(
1285+
self,
1286+
prompt: Optional[str] = None,
1287+
messages: Optional[List[dict]] = None,
1288+
options: Optional[Union[dict, GenerationOptions]] = None,
1289+
state: Optional[Union[dict, State]] = None,
1290+
include_generation_metadata: Optional[bool] = False,
1291+
generator: Optional[AsyncIterator[str]] = None,
1292+
) -> AsyncIterator[Union[str, dict]]:
12671293
"""Simplified interface for getting directly the streamed tokens from the LLM."""
12681294

12691295
self._validate_streaming_with_output_rails()
@@ -1328,15 +1354,24 @@ def task_done_callback(task):
13281354
self.config.rails.output.streaming
13291355
and self.config.rails.output.streaming.enabled
13301356
):
1331-
# returns an async generator
1332-
return self._run_output_rails_in_streaming(
1357+
base_iterator = self._run_output_rails_in_streaming(
13331358
streaming_handler=streaming_handler,
13341359
output_rails_streaming_config=self.config.rails.output.streaming,
13351360
messages=messages,
13361361
prompt=prompt,
13371362
)
13381363
else:
1339-
return streaming_handler
1364+
base_iterator = streaming_handler
1365+
1366+
async def wrapped_iterator():
1367+
try:
1368+
async for chunk in base_iterator:
1369+
if chunk is not None:
1370+
yield chunk
1371+
finally:
1372+
await task
1373+
1374+
return wrapped_iterator()
13401375

13411376
def generate(
13421377
self,

0 commit comments

Comments
 (0)