Skip to content

Commit 053f70f

Browse files
committed
llm interface v2 support added to anthropic
1 parent 8c324ab commit 053f70f

File tree

2 files changed

+367
-18
lines changed

2 files changed

+367
-18
lines changed

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 152 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast
16+
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast, overload
1717

1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
21-
from neo4j_graphrag.llm.base import LLMInterface
21+
from neo4j_graphrag.llm.base import LLMInterface, LLMInterfaceV2
2222
from neo4j_graphrag.utils.rate_limit import (
2323
RateLimitHandler,
2424
rate_limit_handler,
@@ -35,9 +35,11 @@
3535

3636
if TYPE_CHECKING:
3737
from anthropic.types.message_param import MessageParam
38+
from anthropic import NotGiven
3839

3940

40-
class AnthropicLLM(LLMInterface):
41+
# pylint: disable=redefined-builtin, arguments-differ, raise-missing-from, no-else-return
42+
class AnthropicLLM(LLMInterface, LLMInterfaceV2):
4143
"""Interface for large language models on Anthropic
4244
4345
Args:
@@ -82,25 +84,67 @@ def __init__(
8284
self.client = anthropic.Anthropic(**kwargs)
8385
self.async_client = anthropic.AsyncAnthropic(**kwargs)
8486

85-
def get_messages(
87+
# overloads for LLMInterface and LLMInterfaceV2 methods
88+
@overload
89+
def invoke(
8690
self,
8791
input: str,
8892
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
89-
) -> Iterable[MessageParam]:
90-
messages: list[dict[str, str]] = []
91-
if message_history:
92-
if isinstance(message_history, MessageHistory):
93-
message_history = message_history.messages
94-
try:
95-
MessageList(messages=cast(list[BaseMessage], message_history))
96-
except ValidationError as e:
97-
raise LLMGenerationError(e.errors()) from e
98-
messages.extend(cast(Iterable[dict[str, Any]], message_history))
99-
messages.append(UserMessage(content=input).model_dump())
100-
return messages # type: ignore
93+
system_instruction: Optional[str] = None,
94+
) -> LLMResponse: ...
10195

102-
@rate_limit_handler
96+
@overload
97+
def invoke(
98+
self,
99+
input: List[LLMMessage],
100+
) -> LLMResponse: ...
101+
102+
@overload
103+
async def ainvoke(
104+
self,
105+
input: str,
106+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
107+
system_instruction: Optional[str] = None,
108+
) -> LLMResponse: ...
109+
110+
@overload
111+
async def ainvoke(
112+
self,
113+
input: List[LLMMessage],
114+
) -> LLMResponse: ...
115+
116+
# switching logics to LLMInterface or LLMInterfaceV2
103117
def invoke(
118+
self,
119+
input: Union[str, List[LLMMessage]],
120+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
121+
system_instruction: Optional[str] = None,
122+
) -> LLMResponse:
123+
if isinstance(input, str):
124+
return self.__legacy_invoke(input, message_history, system_instruction)
125+
elif isinstance(input, list):
126+
return self.__brand_new_invoke(input)
127+
else:
128+
raise ValueError(f"Invalid input type for invoke method - {type(input)}")
129+
130+
async def ainvoke(
131+
self,
132+
input: Union[str, List[LLMMessage]],
133+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
134+
system_instruction: Optional[str] = None,
135+
) -> LLMResponse:
136+
if isinstance(input, str):
137+
return await self.__legacy_ainvoke(
138+
input, message_history, system_instruction
139+
)
140+
elif isinstance(input, list):
141+
return await self.__brand_new_ainvoke(input)
142+
else:
143+
raise ValueError(f"Invalid input type for ainvoke method - {type(input)}")
144+
145+
# implementaions
146+
@rate_limit_handler
147+
def __legacy_invoke(
104148
self,
105149
input: str,
106150
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -136,8 +180,29 @@ def invoke(
136180
except self.anthropic.APIError as e:
137181
raise LLMGenerationError(e)
138182

183+
def __brand_new_invoke(
184+
self,
185+
input: List[LLMMessage],
186+
) -> LLMResponse:
187+
try:
188+
system_instruction, messages = self.get_brand_new_messages(input)
189+
response = self.client.messages.create(
190+
model=self.model_name,
191+
system=system_instruction,
192+
messages=messages,
193+
**self.model_params,
194+
)
195+
response_content = response.content
196+
if response_content and len(response_content) > 0:
197+
text = response_content[0].text
198+
else:
199+
raise LLMGenerationError("LLM returned empty response.")
200+
return LLMResponse(content=text)
201+
except self.anthropic.APIError as e:
202+
raise LLMGenerationError(e)
203+
139204
@async_rate_limit_handler
140-
async def ainvoke(
205+
async def __legacy_ainvoke(
141206
self,
142207
input: str,
143208
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -172,3 +237,72 @@ async def ainvoke(
172237
return LLMResponse(content=text)
173238
except self.anthropic.APIError as e:
174239
raise LLMGenerationError(e)
240+
241+
async def __brand_new_ainvoke(
242+
self,
243+
input: List[LLMMessage],
244+
) -> LLMResponse:
245+
"""Asynchronously sends text to the LLM and returns a response.
246+
247+
Args:
248+
input (str): The text to send to the LLM.
249+
250+
Returns:
251+
LLMResponse: The response from the LLM.
252+
"""
253+
try:
254+
system_instruction, messages = self.get_brand_new_messages(input)
255+
response = await self.async_client.messages.create(
256+
model=self.model_name,
257+
system=system_instruction,
258+
messages=messages,
259+
**self.model_params,
260+
)
261+
response_content = response.content
262+
if response_content and len(response_content) > 0:
263+
text = response_content[0].text
264+
else:
265+
raise LLMGenerationError("LLM returned empty response.")
266+
return LLMResponse(content=text)
267+
except self.anthropic.APIError as e:
268+
raise LLMGenerationError(e)
269+
270+
# subsidiary methods
271+
def get_messages(
272+
self,
273+
input: str,
274+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
275+
) -> Iterable[MessageParam]:
276+
"""Constructs the message list for the LLM from the input and message history."""
277+
messages: list[dict[str, str]] = []
278+
if message_history:
279+
if isinstance(message_history, MessageHistory):
280+
message_history = message_history.messages
281+
try:
282+
MessageList(messages=cast(list[BaseMessage], message_history))
283+
except ValidationError as e:
284+
raise LLMGenerationError(e.errors()) from e
285+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
286+
messages.append(UserMessage(content=input).model_dump())
287+
return messages # type: ignore
288+
289+
def get_brand_new_messages(
290+
self,
291+
input: list[LLMMessage],
292+
) -> tuple[Union[str, NotGiven], Iterable[MessageParam]]:
293+
"""Constructs the message list for the LLM from the input."""
294+
messages: list[MessageParam] = []
295+
system_instruction: Union[str, NotGiven] = self.anthropic.NOT_GIVEN
296+
for i in input:
297+
if i["role"] == "system":
298+
system_instruction = i["content"]
299+
else:
300+
if i["role"] not in ("user", "assistant"):
301+
raise ValueError(f"Unknown role: {i['role']}")
302+
messages.append(
303+
self.anthropic.types.MessageParam(
304+
role=i["role"],
305+
content=i["content"],
306+
)
307+
)
308+
return system_instruction, messages

0 commit comments

Comments
 (0)