Skip to content

Commit 8c324ab

Browse files
committed
llm interface v2 support added to cohere
1 parent db997da commit 8c324ab

File tree

2 files changed

+256
-21
lines changed

2 files changed

+256
-21
lines changed

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 145 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from __future__ import annotations
1615

17-
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast
16+
# built-in dependencies
17+
from __future__ import annotations
18+
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, cast, overload
1819

20+
# 3rd party dependencies
1921
from pydantic import ValidationError
2022

23+
# project dependencies
2124
from neo4j_graphrag.exceptions import LLMGenerationError
22-
from neo4j_graphrag.llm.base import LLMInterface
25+
from neo4j_graphrag.llm.base import LLMInterface, LLMInterfaceV2
2326
from neo4j_graphrag.utils.rate_limit import (
2427
RateLimitHandler,
2528
rate_limit_handler,
@@ -39,7 +42,8 @@
3942
from cohere import ChatMessages
4043

4144

42-
class CohereLLM(LLMInterface):
45+
# pylint: disable=redefined-builtin, arguments-differ, raise-missing-from, no-else-return
46+
class CohereLLM(LLMInterface, LLMInterfaceV2):
4347
"""Interface for large language models on the Cohere platform
4448
4549
Args:
@@ -82,28 +86,67 @@ def __init__(
8286
self.client = cohere.ClientV2(**kwargs)
8387
self.async_client = cohere.AsyncClientV2(**kwargs)
8488

85-
def get_messages(
89+
# overloads for LLMInterface and LLMInterfaceV2 methods
90+
@overload
91+
def invoke(
8692
self,
8793
input: str,
8894
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
8995
system_instruction: Optional[str] = None,
90-
) -> ChatMessages:
91-
messages = []
92-
if system_instruction:
93-
messages.append(SystemMessage(content=system_instruction).model_dump())
94-
if message_history:
95-
if isinstance(message_history, MessageHistory):
96-
message_history = message_history.messages
97-
try:
98-
MessageList(messages=cast(list[BaseMessage], message_history))
99-
except ValidationError as e:
100-
raise LLMGenerationError(e.errors()) from e
101-
messages.extend(cast(Iterable[dict[str, Any]], message_history))
102-
messages.append(UserMessage(content=input).model_dump())
103-
return messages # type: ignore
96+
) -> LLMResponse: ...
10497

105-
@rate_limit_handler
98+
@overload
99+
def invoke(
100+
self,
101+
input: List[LLMMessage],
102+
) -> LLMResponse: ...
103+
104+
@overload
105+
async def ainvoke(
106+
self,
107+
input: str,
108+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
109+
system_instruction: Optional[str] = None,
110+
) -> LLMResponse: ...
111+
112+
@overload
113+
async def ainvoke(
114+
self,
115+
input: List[LLMMessage],
116+
) -> LLMResponse: ...
117+
118+
# switching logics to LLMInterface or LLMInterfaceV2
106119
def invoke(
120+
self,
121+
input: Union[str, List[LLMMessage]],
122+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
123+
system_instruction: Optional[str] = None,
124+
) -> LLMResponse:
125+
if isinstance(input, str):
126+
return self.__legacy_invoke(input, message_history, system_instruction)
127+
elif isinstance(input, list):
128+
return self.__brand_new_invoke(input)
129+
else:
130+
raise ValueError(f"Invalid input type for invoke method - {type(input)}")
131+
132+
async def ainvoke(
133+
self,
134+
input: Union[str, List[LLMMessage]],
135+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
136+
system_instruction: Optional[str] = None,
137+
) -> LLMResponse:
138+
if isinstance(input, str):
139+
return await self.__legacy_ainvoke(
140+
input, message_history, system_instruction
141+
)
142+
elif isinstance(input, list):
143+
return await self.__brand_new_ainvoke(input)
144+
else:
145+
raise ValueError(f"Invalid input type for ainvoke method - {type(input)}")
146+
147+
# implementations
148+
@rate_limit_handler
149+
def __legacy_invoke(
107150
self,
108151
input: str,
109152
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -134,8 +177,32 @@ def invoke(
134177
content=res.message.content[0].text if res.message.content else "",
135178
)
136179

180+
def __brand_new_invoke(
181+
self,
182+
input: List[LLMMessage],
183+
) -> LLMResponse:
184+
"""Sends text to the LLM and returns a response.
185+
186+
Args:
187+
input (str): The text to send to the LLM.
188+
189+
Returns:
190+
LLMResponse: The response from the LLM.
191+
"""
192+
try:
193+
messages = self.get_brand_new_messages(input)
194+
res = self.client.chat(
195+
messages=messages,
196+
model=self.model_name,
197+
)
198+
except self.cohere_api_error as e:
199+
raise LLMGenerationError("Error calling cohere") from e
200+
return LLMResponse(
201+
content=res.message.content[0].text if res.message.content else "",
202+
)
203+
137204
@async_rate_limit_handler
138-
async def ainvoke(
205+
async def __legacy_ainvoke(
139206
self,
140207
input: str,
141208
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
@@ -165,3 +232,60 @@ async def ainvoke(
165232
return LLMResponse(
166233
content=res.message.content[0].text if res.message.content else "",
167234
)
235+
236+
async def __brand_new_ainvoke(
237+
self,
238+
input: List[LLMMessage],
239+
) -> LLMResponse:
240+
try:
241+
messages = self.get_brand_new_messages(input)
242+
res = await self.async_client.chat(
243+
messages=messages,
244+
model=self.model_name,
245+
)
246+
except self.cohere_api_error as e:
247+
raise LLMGenerationError("Error calling cohere") from e
248+
return LLMResponse(
249+
content=res.message.content[0].text if res.message.content else "",
250+
)
251+
252+
# subsdiary methods
253+
def get_messages(
254+
self,
255+
input: str,
256+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
257+
system_instruction: Optional[str] = None,
258+
) -> ChatMessages:
259+
"""Converts input and message history to ChatMessages for Cohere."""
260+
messages = []
261+
if system_instruction:
262+
messages.append(SystemMessage(content=system_instruction).model_dump())
263+
if message_history:
264+
if isinstance(message_history, MessageHistory):
265+
message_history = message_history.messages
266+
try:
267+
MessageList(messages=cast(list[BaseMessage], message_history))
268+
except ValidationError as e:
269+
raise LLMGenerationError(e.errors()) from e
270+
messages.extend(cast(Iterable[dict[str, Any]], message_history))
271+
messages.append(UserMessage(content=input).model_dump())
272+
return messages # type: ignore
273+
274+
def get_brand_new_messages(
275+
self,
276+
input: list[LLMMessage],
277+
) -> ChatMessages:
278+
"""Converts a list of LLMMessage to ChatMessages for Cohere."""
279+
messages: ChatMessages = []
280+
for i in input:
281+
if i["role"] == "system":
282+
messages.append(self.cohere.SystemChatMessageV2(content=i["content"]))
283+
elif i["role"] == "user":
284+
messages.append(self.cohere.UserChatMessageV2(content=i["content"]))
285+
elif i["role"] == "assistant":
286+
messages.append(
287+
self.cohere.AssistantChatMessageV2(content=i["content"])
288+
)
289+
else:
290+
raise ValueError(f"Unknown role: {i['role']}")
291+
return messages

tests/unit/llm/test_cohere_llm.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,114 @@ async def test_cohere_llm_failed_async(mock_cohere: Mock) -> None:
152152
with pytest.raises(LLMGenerationError) as excinfo:
153153
await llm.ainvoke("my text")
154154
assert "ApiError" in str(excinfo)
155+
156+
157+
# V2 Interface Tests
158+
159+
160+
def test_cohere_llm_invoke_v2_happy_path(mock_cohere: Mock) -> None:
161+
"""Test V2 interface invoke method with List[LLMMessage] input."""
162+
chat_response_mock = MagicMock()
163+
chat_response_mock.message.content = [MagicMock(text="cohere v2 response text")]
164+
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
165+
166+
# Mock Cohere message types
167+
mock_cohere.SystemChatMessageV2 = MagicMock()
168+
mock_cohere.UserChatMessageV2 = MagicMock()
169+
mock_cohere.AssistantChatMessageV2 = MagicMock()
170+
171+
messages = [
172+
{"role": "system", "content": "You are a helpful assistant."},
173+
{"role": "user", "content": "What is the capital of France?"},
174+
]
175+
176+
llm = CohereLLM(model_name="something")
177+
response = llm.invoke(messages)
178+
179+
assert isinstance(response, LLMResponse)
180+
assert response.content == "cohere v2 response text"
181+
182+
# Verify the client was called correctly
183+
mock_cohere.ClientV2.return_value.chat.assert_called_once()
184+
call_args = mock_cohere.ClientV2.return_value.chat.call_args[1]
185+
assert call_args["model"] == "something"
186+
187+
188+
@pytest.mark.asyncio
189+
async def test_cohere_llm_ainvoke_v2_happy_path(mock_cohere: Mock) -> None:
190+
"""Test V2 interface async invoke method with List[LLMMessage] input."""
191+
chat_response_mock = MagicMock()
192+
chat_response_mock.message.content = [
193+
MagicMock(text="cohere v2 async response text")
194+
]
195+
mock_cohere.AsyncClientV2.return_value.chat = AsyncMock(
196+
return_value=chat_response_mock
197+
)
198+
199+
# Mock Cohere message types
200+
mock_cohere.SystemChatMessageV2 = MagicMock()
201+
mock_cohere.UserChatMessageV2 = MagicMock()
202+
mock_cohere.AssistantChatMessageV2 = MagicMock()
203+
204+
messages = [
205+
{"role": "system", "content": "You are a helpful assistant."},
206+
{"role": "user", "content": "What is the capital of France?"},
207+
]
208+
209+
llm = CohereLLM(model_name="something")
210+
response = await llm.ainvoke(messages)
211+
212+
assert isinstance(response, LLMResponse)
213+
assert response.content == "cohere v2 async response text"
214+
215+
# Verify the async client was called correctly
216+
mock_cohere.AsyncClientV2.return_value.chat.assert_awaited_once()
217+
218+
219+
def test_cohere_llm_invoke_v2_validation_error(mock_cohere: Mock) -> None:
220+
"""Test V2 interface invoke with invalid message role raises error."""
221+
chat_response_mock = MagicMock()
222+
chat_response_mock.message.content = [MagicMock(text="should not get here")]
223+
mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock
224+
225+
messages = [
226+
{"role": "invalid_role", "content": "This should fail."},
227+
]
228+
229+
llm = CohereLLM(model_name="something")
230+
231+
with pytest.raises(ValueError) as exc_info:
232+
llm.invoke(messages)
233+
assert "Unknown role: invalid_role" in str(exc_info.value)
234+
235+
236+
def test_cohere_llm_get_brand_new_messages_all_roles(mock_cohere: Mock) -> None:
237+
"""Test get_brand_new_messages method handles all message roles correctly."""
238+
# Mock Cohere message types
239+
mock_system_msg = MagicMock()
240+
mock_user_msg = MagicMock()
241+
mock_assistant_msg = MagicMock()
242+
243+
mock_cohere.SystemChatMessageV2.return_value = mock_system_msg
244+
mock_cohere.UserChatMessageV2.return_value = mock_user_msg
245+
mock_cohere.AssistantChatMessageV2.return_value = mock_assistant_msg
246+
247+
messages = [
248+
{"role": "system", "content": "You are a helpful assistant."},
249+
{"role": "user", "content": "Hello"},
250+
{"role": "assistant", "content": "Hi there!"},
251+
{"role": "user", "content": "How are you?"},
252+
]
253+
254+
llm = CohereLLM(model_name="something")
255+
result_messages = llm.get_brand_new_messages(messages)
256+
257+
# Verify the correct number of messages are returned
258+
assert len(result_messages) == 4
259+
260+
# Verify the correct Cohere message constructors were called
261+
mock_cohere.SystemChatMessageV2.assert_called_once_with(
262+
content="You are a helpful assistant."
263+
)
264+
assert mock_cohere.UserChatMessageV2.call_count == 2
265+
mock_cohere.AssistantChatMessageV2.assert_called_once_with(content="Hi there!")

0 commit comments

Comments
 (0)