Skip to content

Commit 0710a27

Browse files
committed
llminterfacev2 support added to ollama
1 parent bdae480 commit 0710a27

File tree

2 files changed

+395
-23
lines changed

2 files changed

+395
-23
lines changed

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 156 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,36 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from __future__ import annotations
1615

16+
# built-in dependencies
17+
from __future__ import annotations
1718
import warnings
18-
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Iterable,
23+
List,
24+
Optional,
25+
Sequence,
26+
Union,
27+
cast,
28+
overload,
29+
)
1930

31+
# 3rd-party dependencies
2032
from pydantic import ValidationError
2133

34+
# project dependencies
2235
from neo4j_graphrag.exceptions import LLMGenerationError
2336
from neo4j_graphrag.message_history import MessageHistory
2437
from neo4j_graphrag.types import LLMMessage
25-
26-
from .base import LLMInterface
2738
from neo4j_graphrag.utils.rate_limit import (
2839
RateLimitHandler,
2940
rate_limit_handler,
3041
async_rate_limit_handler,
3142
)
43+
44+
from .base import LLMInterface, LLMInterfaceV2
3245
from .types import (
3346
BaseMessage,
3447
LLMResponse,
@@ -40,8 +53,12 @@
4053
if TYPE_CHECKING:
4154
from ollama import Message
4255

56+
# pylint: disable=redefined-builtin, arguments-differ, raise-missing-from, no-else-return
57+
58+
59+
class OllamaLLM(LLMInterface, LLMInterfaceV2):
60+
"""LLM wrapper for Ollama models."""
4361

44-
class OllamaLLM(LLMInterface):
4562
def __init__(
4663
self,
4764
model_name: str,
@@ -78,28 +95,66 @@ def __init__(
7895
)
7996
self.model_params = {"options": self.model_params}
8097

81-
def get_messages(
98+
# overloads for LLMInterface and LLMInterfaceV2 methods
99+
@overload
100+
def invoke(
82101
self,
83102
input: str,
84103
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
85104
system_instruction: Optional[str] = None,
86-
) -> Sequence[Message]:
87-
messages = []
88-
if system_instruction:
89-
messages.append(SystemMessage(content=system_instruction).model_dump())
90-
if message_history:
91-
if isinstance(message_history, MessageHistory):
92-
message_history = message_history.messages
93-
try:
94-
MessageList(messages=cast(list[BaseMessage], message_history))
95-
except ValidationError as e:
96-
raise LLMGenerationError(e.errors()) from e
97-
messages.extend(cast(Iterable[dict[str, Any]], message_history))
98-
messages.append(UserMessage(content=input).model_dump())
99-
return messages # type: ignore
105+
) -> LLMResponse: ...
100106

101-
@rate_limit_handler
107+
@overload
102108
def invoke(
109+
self,
110+
input: List[LLMMessage],
111+
) -> LLMResponse: ...
112+
113+
@overload
114+
async def ainvoke(
115+
self,
116+
input: str,
117+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
118+
system_instruction: Optional[str] = None,
119+
) -> LLMResponse: ...
120+
121+
@overload
122+
async def ainvoke(
123+
self,
124+
input: List[LLMMessage],
125+
) -> LLMResponse: ...
126+
127+
# switching logics to LLMInterface or LLMInterfaceV2
128+
def invoke(
129+
self,
130+
input: Union[str, List[LLMMessage]],
131+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
132+
system_instruction: Optional[str] = None,
133+
) -> LLMResponse:
134+
if isinstance(input, str):
135+
return self.__legacy_invoke(input, message_history, system_instruction)
136+
elif isinstance(input, list):
137+
return self.__brand_new_invoke(input)
138+
else:
139+
raise ValueError(f"Invalid input type for invoke method - {type(input)}")
140+
141+
async def ainvoke(
142+
self,
143+
input: Union[str, List[LLMMessage]],
144+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
145+
system_instruction: Optional[str] = None,
146+
) -> LLMResponse:
147+
if isinstance(input, str):
148+
return await self.__legacy_ainvoke(
149+
input, message_history, system_instruction
150+
)
151+
elif isinstance(input, list):
152+
return await self.__brand_new_ainvoke(input)
153+
else:
154+
raise ValueError(f"Invalid input type for ainvoke method - {type(input)}")
155+
156+
@rate_limit_handler
157+
def __legacy_invoke(
103158
self,
104159
input: str,
105160
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -129,8 +184,31 @@ def invoke(
129184
except self.ollama.ResponseError as e:
130185
raise LLMGenerationError(e)
131186

187+
def __brand_new_invoke(
188+
self,
189+
input: List[LLMMessage],
190+
) -> LLMResponse:
191+
"""Sends text to the LLM and returns a response.
192+
193+
Args:
194+
input (str): The text to send to the LLM.
195+
196+
Returns:
197+
LLMResponse: The response from the LLM.
198+
"""
199+
try:
200+
response = self.client.chat(
201+
model=self.model_name,
202+
messages=self.get_brand_new_messages(input),
203+
**self.model_params,
204+
)
205+
content = response.message.content or ""
206+
return LLMResponse(content=content)
207+
except self.ollama.ResponseError as e:
208+
raise LLMGenerationError(e)
209+
132210
@async_rate_limit_handler
133-
async def ainvoke(
211+
async def __legacy_ainvoke(
134212
self,
135213
input: str,
136214
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -163,3 +241,59 @@ async def ainvoke(
163241
return LLMResponse(content=content)
164242
except self.ollama.ResponseError as e:
165243
raise LLMGenerationError(e)
244+
245+
async def __brand_new_ainvoke(
246+
self,
247+
input: List[LLMMessage],
248+
) -> LLMResponse:
249+
"""Asynchronously sends a text input to the OpenAI chat
250+
completion model and returns the response's content.
251+
252+
Args:
253+
input (str): Text sent to the LLM.
254+
255+
Returns:
256+
LLMResponse: The response from OpenAI.
257+
258+
Raises:
259+
LLMGenerationError: If anything goes wrong.
260+
"""
261+
try:
262+
response = await self.async_client.chat(
263+
model=self.model_name,
264+
messages=self.get_brand_new_messages(input),
265+
options=self.model_params,
266+
)
267+
content = response.message.content or ""
268+
return LLMResponse(content=content)
269+
except self.ollama.ResponseError as e:
270+
raise LLMGenerationError(e)
271+
272+
# subsdiary methods
273+
def get_messages(
274+
self,
275+
input: str,
276+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
277+
system_instruction: Optional[str] = None,
278+
) -> Sequence[Message]:
279+
"""Constructs the message list for the Ollama chat API."""
280+
messages = []
281+
if system_instruction:
282+
messages.append(SystemMessage(content=system_instruction).model_dump())
283+
if message_history:
284+
if isinstance(message_history, MessageHistory):
285+
message_history = message_history.messages
286+
try:
287+
MessageList(messages=cast(list[BaseMessage], message_history))
288+
except ValidationError as e:
289+
raise LLMGenerationError(e.errors()) from e
290+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
291+
messages.append(UserMessage(content=input).model_dump())
292+
return messages # type: ignore
293+
294+
def get_brand_new_messages(
295+
self,
296+
input: list[LLMMessage],
297+
) -> Sequence[Message]:
298+
"""Constructs the message list for the Ollama chat API."""
299+
return [self.ollama.Message(**i) for i in input]

0 commit comments

Comments
 (0)