Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions nemoguardrails/llm/cache/lfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def append(self, node: LFUNode) -> None:
"""Add node to the end of the list (before tail)."""
node.prev = self.tail.prev
node.next = self.tail
self.tail.prev.next = node
if self.tail.prev is not None:
self.tail.prev.next = node
self.tail.prev = node
self.size += 1

Expand All @@ -67,8 +68,10 @@ def pop(self, node: Optional[LFUNode] = None) -> Optional[LFUNode]:
node = self.head.next

# Remove node from the list
node.prev.next = node.next
node.next.prev = node.prev
if node is not None and node.prev is not None:
node.prev.next = node.next
if node is not None and node.next is not None:
node.next.prev = node.prev
self.size -= 1

return node
Expand Down Expand Up @@ -121,6 +124,7 @@ def __init__(
"evictions": 0,
"puts": 0,
"updates": 0,
"hit_rate": 0.0,
}

def _update_node_freq(self, node: LFUNode) -> None:
Expand Down Expand Up @@ -288,6 +292,7 @@ def reset_stats(self) -> None:
"evictions": 0,
"puts": 0,
"updates": 0,
"hit_rate": 0.0,
}

def _check_and_log_stats(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions nemoguardrails/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_messages(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

bot_lines = []
bot_lines: list[str] = []
for i, line in enumerate(lines):
if line.startswith('user "'):
# If we have bot lines in the buffer, we first add a bot message.
Expand Down Expand Up @@ -181,8 +181,8 @@ def to_messages_v2(colang_history: str) -> List[dict]:
# a message from the user, and the rest gets translated to messages from the assistant.
lines = colang_history.split("\n")

user_lines = []
bot_lines = []
user_lines: list[str] = []
bot_lines: list[str] = []
for line in lines:
if line.startswith("user action:"):
if len(bot_lines) > 0:
Expand Down Expand Up @@ -275,7 +275,7 @@ def verbose_v1(colang_history: str) -> str:
return "\n".join(lines)


def to_chat_messages(events: List[dict]) -> str:
def to_chat_messages(events: List[dict]) -> List[dict]:
"""Filter that turns an array of events into a sequence of user/assistant messages.

Properly handles multimodal content by preserving the structure when the content
Expand Down
23 changes: 12 additions & 11 deletions nemoguardrails/llm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Type, Union
from typing import List, Optional, Type

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM, BaseLLM
from langchain_core.language_models.llms import LLM


def get_llm_instance_wrapper(
llm_instance: Union[LLM, BaseLLM], llm_type: str
) -> Type[LLM]:
def get_llm_instance_wrapper(llm_instance: LLM, llm_type: str) -> Type[LLM]:
"""Wraps an LLM instance in a class that can be registered with LLMRails.

This is useful to create specific types of LLMs using a generic LLM provider
Expand All @@ -47,7 +45,7 @@ def model_kwargs(self):
These are needed to allow changes to the arguments of the LLM calls.
"""
if hasattr(llm_instance, "model_kwargs"):
return llm_instance.model_kwargs
return llm_instance.model_kwargs # type: ignore[attr-defined] (We check in line above)
return {}

@property
Expand All @@ -66,26 +64,29 @@ def _modify_instance_kwargs(self):
"""

if hasattr(llm_instance, "model_kwargs"):
if isinstance(llm_instance.model_kwargs, dict):
llm_instance.model_kwargs["temperature"] = self.temperature
llm_instance.model_kwargs["streaming"] = self.streaming
model_kwargs = getattr(llm_instance, "model_kwargs")
if isinstance(model_kwargs, dict):
model_kwargs["temperature"] = self.temperature
model_kwargs["streaming"] = self.streaming

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return llm_instance._call(prompt, stop, run_manager)
return llm_instance._call(prompt, stop, run_manager, **kwargs)

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs,
) -> str:
self._modify_instance_kwargs()
return await llm_instance._acall(prompt, stop, run_manager)
return await llm_instance._acall(prompt, stop, run_manager, **kwargs)

return WrapperLLM
9 changes: 6 additions & 3 deletions nemoguardrails/llm/models/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM

from .langchain_initializer import ModelInitializationError, init_langchain_model
from nemoguardrails.llm.models.langchain_initializer import (
ModelInitializationError,
init_langchain_model,
)


# later we can easily conver it to a class
# later we can easily convert it to a class
def init_llm_model(
model_name: Optional[str],
model_name: str,
provider_name: str,
mode: Literal["chat", "text"],
kwargs: Dict[str, Any],
Expand Down
15 changes: 8 additions & 7 deletions nemoguardrails/llm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import logging
import warnings
from typing import Dict, Type
from typing import Any, Dict, Type

from langchain.base_language import BaseLanguageModel

Expand All @@ -61,18 +61,18 @@ def __init__(self, llm: BaseLanguageModel, **kwargs):
warnings.warn(_DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
self.llm = llm
self.altered_params = kwargs
self.original_params = {}
self.original_params: dict[str, Any] = {}

def __enter__(self):
# Here we can access and modify the global language model parameters.
self.original_params = {}
for param, value in self.altered_params.items():
if hasattr(self.llm, param):
self.original_params[param] = getattr(self.llm, param)
setattr(self.llm, param, value)

elif hasattr(self.llm, "model_kwargs"):
if param not in self.llm.model_kwargs:
model_kwargs = getattr(self.llm, "model_kwargs", {})
if param not in model_kwargs:
log.warning(
"Parameter %s does not exist for %s. Passing to model_kwargs",
param,
Expand All @@ -81,9 +81,10 @@ def __enter__(self):

self.original_params[param] = None
else:
self.original_params[param] = self.llm.model_kwargs[param]
self.original_params[param] = model_kwargs[param]

self.llm.model_kwargs[param] = value
model_kwargs[param] = value
setattr(self.llm, "model_kwargs", model_kwargs)

else:
log.warning(
Expand All @@ -92,7 +93,7 @@ def __enter__(self):
self.llm.__class__.__name__,
)

def __exit__(self, type, value, traceback):
def __exit__(self, exc_type, value, traceback):
# Restore original parameters when exiting the context
for param, value in self.original_params.items():
if hasattr(self.llm, param):
Expand Down
39 changes: 30 additions & 9 deletions nemoguardrails/llm/providers/huggingface/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import Any, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.schema.output import GenerationChunk
from langchain_community.llms import HuggingFacePipeline

# Import HuggingFacePipeline with fallbacks for different LangChain versions
HuggingFacePipeline = None # type: ignore[assignment]

try:
from langchain_community.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Fallback for older versions of langchain
try:
from langchain.llms import (
HuggingFacePipeline, # type: ignore[attr-defined,no-redef]
)
except ImportError:
# Create a dummy class if HuggingFacePipeline is not available
class HuggingFacePipeline: # type: ignore[misc,no-redef]
def __init__(self, *args, **kwargs):
raise ImportError("HuggingFacePipeline is not available")


class HuggingFacePipelineCompatible(HuggingFacePipeline):
Expand All @@ -47,12 +66,13 @@ def _call(
)

# Streaming for NeMo Guardrails is not supported in sync calls.
if self.model_kwargs and self.model_kwargs.get("streaming"):
raise Exception(
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
raise NotImplementedError(
"Streaming mode not supported for HuggingFacePipeline in NeMo Guardrails!"
)

llm_result = self._generate(
llm_result = self._generate( # type: ignore[attr-defined]
[prompt],
stop=stop,
run_manager=run_manager,
Expand All @@ -78,11 +98,12 @@ async def _acall(
)

# Handle streaming, if the flag is set
if self.model_kwargs and self.model_kwargs.get("streaming"):
model_kwargs = getattr(self, "model_kwargs", {})
if model_kwargs and model_kwargs.get("streaming"):
# Retrieve the streamer object, needs to be set in model_kwargs
streamer = self.model_kwargs.get("streamer")
streamer = model_kwargs.get("streamer")
if not streamer:
raise Exception(
raise ValueError(
"Cannot stream, please add HuggingFace streamer object to model_kwargs!"
)

Expand All @@ -99,7 +120,7 @@ async def _acall(
run_manager=run_manager,
**kwargs,
)
loop.create_task(self._agenerate(**generation_kwargs))
loop.create_task(getattr(self, "_agenerate")(**generation_kwargs))

# And start waiting for the chunks to come in.
completion = ""
Expand All @@ -111,7 +132,7 @@ async def _acall(

return completion

llm_result = await self._agenerate(
llm_result = await getattr(self, "_agenerate")(
[prompt],
stop=stop,
run_manager=run_manager,
Expand Down
26 changes: 22 additions & 4 deletions nemoguardrails/llm/providers/huggingface/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,27 @@
# limitations under the License.

import asyncio
from typing import TYPE_CHECKING, Optional

from transformers.generation.streamers import TextStreamer
TRANSFORMERS_AVAILABLE = True
try:
from transformers.generation.streamers import ( # type: ignore[import-untyped]
TextStreamer,
)
except ImportError:
# Fallback if transformers is not available
TRANSFORMERS_AVAILABLE = False

class TextStreamer: # type: ignore[no-redef]
def __init__(self, *args, **kwargs):
pass

class AsyncTextIteratorStreamer(TextStreamer):

if TYPE_CHECKING:
from transformers import AutoTokenizer # type: ignore[import-untyped]


class AsyncTextIteratorStreamer(TextStreamer): # type: ignore[misc]
"""
Simple async implementation for HuggingFace Transformers streamers.

Expand All @@ -30,12 +46,14 @@ def __init__(
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = asyncio.Queue()
self.text_queue: asyncio.Queue[str] = asyncio.Queue()
self.stop_signal = None
self.loop = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
if self.loop is None:
return
if len(text) > 0:
asyncio.run_coroutine_threadsafe(self.text_queue.put(text), self.loop)

Expand Down
21 changes: 16 additions & 5 deletions nemoguardrails/llm/taskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__(self, config: RailsConfig):
def _get_general_instructions(self):
"""Helper to extract the general instructions."""
text = ""
if self.config.instructions is None:
return text
for instruction in self.config.instructions:
if instruction.type == "general":
text = instruction.content
Expand Down Expand Up @@ -266,7 +268,9 @@ def render_task_prompt(
task_prompt = self._render_string(
prompt.content, context=context, events=events
)
while len(task_prompt) > prompt.max_length:
while (
prompt.max_length is not None and len(task_prompt) > prompt.max_length
):
if not events:
raise Exception(
f"Prompt exceeds max length of {prompt.max_length} characters even without history"
Expand All @@ -288,20 +292,27 @@ def render_task_prompt(

return task_prompt
else:
if prompt.messages is None:
return []
task_messages = self._render_messages(
prompt.messages, context=context, events=events
)
task_prompt_length = self._get_messages_text_length(task_messages)
while task_prompt_length > prompt.max_length:
while (
prompt.max_length is not None and task_prompt_length > prompt.max_length
):
if not events:
raise Exception(
f"Prompt exceeds max length of {prompt.max_length} characters even without history"
)
# Remove events from the beginning of the history until the prompt fits.
events = events[1:]
task_messages = self._render_messages(
prompt.messages, context=context, events=events
)
if prompt.messages is not None:
task_messages = self._render_messages(
prompt.messages, context=context, events=events
)
else:
task_messages = []
task_prompt_length = self._get_messages_text_length(task_messages)
return task_messages

Expand Down
Loading