Skip to content

Commit db997da

Browse files
committed
llm interface v2 supported in mistralai
1 parent 0710a27 commit db997da

File tree

2 files changed

+398
-24
lines changed

2 files changed

+398
-24
lines changed

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 172 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Any, Iterable, List, Optional, Union, cast
18+
from typing import Any, Iterable, List, Optional, Union, cast, overload
1919

2020
from pydantic import ValidationError
2121

2222
from neo4j_graphrag.exceptions import LLMGenerationError
23-
from neo4j_graphrag.llm.base import LLMInterface
23+
from neo4j_graphrag.llm.base import LLMInterface, LLMInterfaceV2
2424
from neo4j_graphrag.utils.rate_limit import (
2525
RateLimitHandler,
2626
rate_limit_handler,
@@ -37,14 +37,22 @@
3737
from neo4j_graphrag.types import LLMMessage
3838

3939
try:
40-
from mistralai import Messages, Mistral
40+
from mistralai import (
41+
Messages,
42+
UserMessage,
43+
AssistantMessage,
44+
SystemMessage,
45+
Mistral,
46+
)
4147
from mistralai.models.sdkerror import SDKError
4248
except ImportError:
4349
Mistral = None # type: ignore
4450
SDKError = None # type: ignore
4551

4652

53+
# pylint: disable=redefined-builtin, arguments-differ, raise-missing-from, no-else-return
4754
class MistralAILLM(LLMInterface):
55+
4856
def __init__(
4957
self,
5058
model_name: str,
@@ -73,28 +81,67 @@ def __init__(
7381
api_key = os.getenv("MISTRAL_API_KEY", "")
7482
self.client = Mistral(api_key=api_key, **kwargs)
7583

76-
def get_messages(
84+
# overloads for LLMInterface and LLMInterfaceV2 methods
85+
@overload
86+
def invoke(
7787
self,
7888
input: str,
7989
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8090
system_instruction: Optional[str] = None,
81-
) -> list[Messages]:
82-
messages = []
83-
if system_instruction:
84-
messages.append(SystemMessage(content=system_instruction).model_dump())
85-
if message_history:
86-
if isinstance(message_history, MessageHistory):
87-
message_history = message_history.messages
88-
try:
89-
MessageList(messages=cast(list[BaseMessage], message_history))
90-
except ValidationError as e:
91-
raise LLMGenerationError(e.errors()) from e
92-
messages.extend(cast(Iterable[dict[str, Any]], message_history))
93-
messages.append(UserMessage(content=input).model_dump())
94-
return cast(list[Messages], messages)
91+
) -> LLMResponse: ...
9592

96-
@rate_limit_handler
93+
@overload
9794
def invoke(
95+
self,
96+
input: List[LLMMessage],
97+
) -> LLMResponse: ...
98+
99+
@overload
100+
async def ainvoke(
101+
self,
102+
input: str,
103+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
104+
system_instruction: Optional[str] = None,
105+
) -> LLMResponse: ...
106+
107+
@overload
108+
async def ainvoke(
109+
self,
110+
input: List[LLMMessage],
111+
) -> LLMResponse: ...
112+
113+
# switching logics to LLMInterface or LLMInterfaceV2
114+
def invoke(
115+
self,
116+
input: Union[str, List[LLMMessage]],
117+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
118+
system_instruction: Optional[str] = None,
119+
) -> LLMResponse:
120+
if isinstance(input, str):
121+
return self.__legacy_invoke(input, message_history, system_instruction)
122+
elif isinstance(input, list):
123+
return self.__brand_new_invoke(input)
124+
else:
125+
raise ValueError(f"Invalid input type for invoke method - {type(input)}")
126+
127+
async def ainvoke(
128+
self,
129+
input: Union[str, List[LLMMessage]],
130+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
131+
system_instruction: Optional[str] = None,
132+
) -> LLMResponse:
133+
if isinstance(input, str):
134+
return await self.__legacy_ainvoke(
135+
input, message_history, system_instruction
136+
)
137+
elif isinstance(input, list):
138+
return await self.__brand_new_ainvoke(input)
139+
else:
140+
raise ValueError(f"Invalid input type for ainvoke method - {type(input)}")
141+
142+
# implementations
143+
@rate_limit_handler
144+
def __legacy_invoke(
98145
self,
99146
input: str,
100147
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -132,8 +179,40 @@ def invoke(
132179
except SDKError as e:
133180
raise LLMGenerationError(e)
134181

182+
def __brand_new_invoke(
183+
self,
184+
input: List[LLMMessage],
185+
) -> LLMResponse:
186+
"""Sends a text input to the Mistral chat completion model
187+
and returns the response's content.
188+
189+
Args:
190+
input (str): Text sent to the LLM.
191+
192+
Returns:
193+
LLMResponse: The response from MistralAI.
194+
195+
Raises:
196+
LLMGenerationError: If anything goes wrong.
197+
"""
198+
try:
199+
messages = self.get_brand_new_messages(input)
200+
response = self.client.chat.complete(
201+
model=self.model_name,
202+
messages=messages,
203+
**self.model_params,
204+
)
205+
content: str = ""
206+
if response and response.choices:
207+
possible_content = response.choices[0].message.content
208+
if isinstance(possible_content, str):
209+
content = possible_content
210+
return LLMResponse(content=content)
211+
except SDKError as e:
212+
raise LLMGenerationError(e)
213+
135214
@async_rate_limit_handler
136-
async def ainvoke(
215+
async def __legacy_ainvoke(
137216
self,
138217
input: str,
139218
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -171,3 +250,76 @@ async def ainvoke(
171250
return LLMResponse(content=content)
172251
except SDKError as e:
173252
raise LLMGenerationError(e)
253+
254+
async def __brand_new_ainvoke(
255+
self,
256+
input: List[LLMMessage],
257+
) -> LLMResponse:
258+
"""Asynchronously sends a text input to the MistralAI chat
259+
completion model and returns the response's content.
260+
261+
Args:
262+
input (str): Text sent to the LLM.
263+
264+
Returns:
265+
LLMResponse: The response from MistralAI.
266+
267+
Raises:
268+
LLMGenerationError: If anything goes wrong.
269+
"""
270+
try:
271+
messages = self.get_brand_new_messages(input)
272+
response = await self.client.chat.complete_async(
273+
model=self.model_name,
274+
messages=messages,
275+
**self.model_params,
276+
)
277+
content: str = ""
278+
if response and response.choices:
279+
possible_content = response.choices[0].message.content
280+
if isinstance(possible_content, str):
281+
content = possible_content
282+
return LLMResponse(content=content)
283+
except SDKError as e:
284+
raise LLMGenerationError(e)
285+
286+
# subsidiary methods
287+
def get_messages(
288+
self,
289+
input: str,
290+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
291+
system_instruction: Optional[str] = None,
292+
) -> list[Messages]:
293+
"""Constructs the message list for the Mistral chat completion model."""
294+
messages = []
295+
if system_instruction:
296+
messages.append(SystemMessage(content=system_instruction).model_dump())
297+
if message_history:
298+
if isinstance(message_history, MessageHistory):
299+
message_history = message_history.messages
300+
try:
301+
MessageList(messages=cast(list[BaseMessage], message_history))
302+
except ValidationError as e:
303+
raise LLMGenerationError(e.errors()) from e
304+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
305+
messages.append(UserMessage(content=input).model_dump())
306+
return cast(list[Messages], messages)
307+
308+
def get_brand_new_messages(
309+
self,
310+
input: list[LLMMessage],
311+
) -> list[Messages]:
312+
"""Constructs the message list for the Mistral chat completion model."""
313+
messages: list[Messages] = []
314+
for m in input:
315+
if m["role"] == "system":
316+
messages.append(SystemMessage(content=m["content"]))
317+
continue
318+
if m["role"] == "user":
319+
messages.append(UserMessage(content=m["content"]))
320+
continue
321+
if m["role"] == "assistant":
322+
messages.append(AssistantMessage(content=m["content"]))
323+
continue
324+
raise ValueError(f"Unknown role: {m['role']}")
325+
return messages

0 commit comments

Comments
 (0)