Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""LaunchDarkly AI SDK - LangChain Provider.
"""LaunchDarkly AI SDK - LangChain Connector."""

This package provides LangChain integration for the LaunchDarkly Server-Side AI SDK,
"""

from ldai_langchain.langchain_provider import LangChainProvider
from ldai_langchain.langchain_provider import LangChainConnector

__version__ = "0.1.0"

__all__ = [
'__version__',
'LangChainProvider',
'LangChainConnector',
]
Original file line number Diff line number Diff line change
@@ -1,41 +1,56 @@
"""LangChain implementation of AIProvider for LaunchDarkly AI SDK."""
"""LangChain connector for LaunchDarkly AI SDK."""

from typing import Any, Dict, List, Optional, Union

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from ldai import LDMessage, log
from ldai.models import AIConfigKind
from ldai.providers import AIProvider
from ldai.providers import AIConnector
from ldai.providers.types import ChatResponse, LDAIMetrics, StructuredResponse
from ldai.tracker import TokenUsage


class LangChainProvider(AIProvider):
class LangChainConnector(AIConnector):
"""
LangChain implementation of AIProvider.

This provider integrates LangChain models with LaunchDarkly's tracking capabilities.
LangChain connector for the LaunchDarkly AI SDK.

Can be used in two ways:
- Transparently via ExecutorFactory (pass ``default_connector='langchain'`` to
``create_model()`` / ``create_chat()``).
- Directly for full control: instantiate with a ``BaseChatModel``, then call
``invoke_model()`` yourself and use the static convenience methods
(``get_ai_metrics_from_response``, ``convert_messages_to_langchain``,
``map_provider``, ``create_langchain_model``).
"""

def __init__(self, llm: BaseChatModel):
def __init__(self, llm: Optional[BaseChatModel] = None):
"""
Initialize the LangChain provider.
Initialize the LangChain connector.

When called with no arguments the connector acts as a per-backend factory
— call ``create_model(config)`` to obtain a configured instance.

:param llm: A LangChain BaseChatModel instance
When called with an explicit ``llm`` the connector is ready to invoke
the model immediately.

:param llm: A LangChain BaseChatModel instance (optional)
"""
self._llm = llm

@staticmethod
async def create(ai_config: AIConfigKind) -> 'LangChainProvider':
# --- AIConnector factory methods ---

def create_model(self, config: AIConfigKind) -> 'LangChainConnector':
"""
Static factory method to create a LangChain AIProvider from an AI configuration.
Create a configured LangChain model connector for the given AI config.

:param ai_config: The LaunchDarkly AI configuration
:return: Configured LangChainProvider instance
:param config: The LaunchDarkly AI configuration
:return: Configured LangChainConnector ready to invoke the model
"""
llm = LangChainProvider.create_langchain_model(ai_config)
return LangChainProvider(llm)
llm = LangChainConnector.create_langchain_model(config)
return LangChainConnector(llm)

# --- Model invocation ---

async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
"""
Expand All @@ -45,9 +60,9 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse:
:return: ChatResponse containing the model's response and metrics
"""
try:
langchain_messages = LangChainProvider.convert_messages_to_langchain(messages)
langchain_messages = LangChainConnector.convert_messages_to_langchain(messages)
response: BaseMessage = await self._llm.ainvoke(langchain_messages)
metrics = LangChainProvider.get_ai_metrics_from_response(response)
metrics = LangChainConnector.get_ai_metrics_from_response(response)

content: str = ''
if isinstance(response.content, str):
Expand Down Expand Up @@ -84,7 +99,7 @@ async def invoke_structured_model(
:return: StructuredResponse containing the structured data
"""
try:
langchain_messages = LangChainProvider.convert_messages_to_langchain(messages)
langchain_messages = LangChainConnector.convert_messages_to_langchain(messages)
structured_llm = self._llm.with_structured_output(response_structure)
response = await structured_llm.ainvoke(langchain_messages)

Expand Down Expand Up @@ -122,11 +137,13 @@ async def invoke_structured_model(
),
)

def get_chat_model(self) -> BaseChatModel:
# --- Convenience accessors ---

def get_chat_model(self) -> Optional[BaseChatModel]:
"""
Get the underlying LangChain model instance.

:return: The underlying BaseChatModel
:return: The underlying BaseChatModel, or None if not yet configured
"""
return self._llm

Expand All @@ -135,9 +152,6 @@ def map_provider(ld_provider_name: str) -> str:
"""
Map LaunchDarkly provider names to LangChain provider names.

This method enables seamless integration between LaunchDarkly's standardized
provider naming and LangChain's naming conventions.

:param ld_provider_name: LaunchDarkly provider name
:return: LangChain-compatible provider name
"""
Expand All @@ -152,25 +166,24 @@ def map_provider(ld_provider_name: str) -> str:
@staticmethod
def get_ai_metrics_from_response(response: BaseMessage) -> LDAIMetrics:
"""
Get AI metrics from a LangChain provider response.

This method extracts token usage information and success status from LangChain responses
and returns a LaunchDarkly AIMetrics object.
Extract LaunchDarkly AI metrics from a LangChain response.

:param response: The response from the LangChain model
:return: LDAIMetrics with success status and token usage

Example:
# Use with tracker.track_metrics_of for automatic tracking
Example::

response = await tracker.track_metrics_of(
lambda: llm.ainvoke(messages),
LangChainProvider.get_ai_metrics_from_response
LangChainConnector.get_ai_metrics_from_response
)
"""
# Extract token usage if available
usage: Optional[TokenUsage] = None
if hasattr(response, 'response_metadata') and response.response_metadata:
token_usage = response.response_metadata.get('tokenUsage') or response.response_metadata.get('token_usage')
token_usage = (
response.response_metadata.get('tokenUsage')
or response.response_metadata.get('token_usage')
)
if token_usage:
usage = TokenUsage(
total=token_usage.get('totalTokens', 0) or token_usage.get('total_tokens', 0),
Expand All @@ -187,9 +200,6 @@ def convert_messages_to_langchain(
"""
Convert LaunchDarkly messages to LangChain messages.

This helper method enables developers to work directly with LangChain message types
while maintaining compatibility with LaunchDarkly's standardized message format.

:param messages: List of LDMessage objects
:return: List of LangChain message objects
:raises ValueError: If an unsupported message role is encountered
Expand All @@ -211,10 +221,7 @@ def convert_messages_to_langchain(
@staticmethod
def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel:
"""
Create a LangChain model from an AI configuration.

This public helper method enables developers to initialize their own LangChain models
using LaunchDarkly AI configurations.
Create a LangChain model from a LaunchDarkly AI configuration.

:param ai_config: The LaunchDarkly AI configuration
:return: A configured LangChain BaseChatModel
Expand All @@ -231,6 +238,7 @@ def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel:

return init_chat_model(
model_name,
model_provider=LangChainProvider.map_provider(provider),
model_provider=LangChainConnector.map_provider(provider),
**parameters,
)

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ldai import LDMessage

from ldai_langchain import LangChainProvider
from ldai_langchain import LangChainConnector


class TestConvertMessagesToLangchain:
Expand All @@ -16,7 +16,7 @@ class TestConvertMessagesToLangchain:
def test_converts_system_messages_to_system_message(self):
"""Should convert system messages to SystemMessage."""
messages = [LDMessage(role='system', content='You are a helpful assistant.')]
result = LangChainProvider.convert_messages_to_langchain(messages)
result = LangChainConnector.convert_messages_to_langchain(messages)

assert len(result) == 1
assert isinstance(result[0], SystemMessage)
Expand All @@ -25,7 +25,7 @@ def test_converts_system_messages_to_system_message(self):
def test_converts_user_messages_to_human_message(self):
"""Should convert user messages to HumanMessage."""
messages = [LDMessage(role='user', content='Hello, how are you?')]
result = LangChainProvider.convert_messages_to_langchain(messages)
result = LangChainConnector.convert_messages_to_langchain(messages)

assert len(result) == 1
assert isinstance(result[0], HumanMessage)
Expand All @@ -34,7 +34,7 @@ def test_converts_user_messages_to_human_message(self):
def test_converts_assistant_messages_to_ai_message(self):
"""Should convert assistant messages to AIMessage."""
messages = [LDMessage(role='assistant', content='I am doing well, thank you!')]
result = LangChainProvider.convert_messages_to_langchain(messages)
result = LangChainConnector.convert_messages_to_langchain(messages)

assert len(result) == 1
assert isinstance(result[0], AIMessage)
Expand All @@ -47,7 +47,7 @@ def test_converts_multiple_messages_in_order(self):
LDMessage(role='user', content='What is the weather like?'),
LDMessage(role='assistant', content='I cannot check the weather.'),
]
result = LangChainProvider.convert_messages_to_langchain(messages)
result = LangChainConnector.convert_messages_to_langchain(messages)

assert len(result) == 3
assert isinstance(result[0], SystemMessage)
Expand All @@ -62,11 +62,11 @@ class MockMessage:
content = 'Test message'

with pytest.raises(ValueError, match='Unsupported message role: unknown'):
LangChainProvider.convert_messages_to_langchain([MockMessage()]) # type: ignore
LangChainConnector.convert_messages_to_langchain([MockMessage()]) # type: ignore

def test_handles_empty_message_array(self):
"""Should handle empty message array."""
result = LangChainProvider.convert_messages_to_langchain([])
result = LangChainConnector.convert_messages_to_langchain([])
assert len(result) == 0


Expand All @@ -84,7 +84,7 @@ def test_creates_metrics_with_success_true_and_token_usage(self):
},
}

result = LangChainProvider.get_ai_metrics_from_response(mock_response)
result = LangChainConnector.get_ai_metrics_from_response(mock_response)

assert result.success is True
assert result.usage is not None
Expand All @@ -103,7 +103,7 @@ def test_creates_metrics_with_snake_case_token_usage(self):
},
}

result = LangChainProvider.get_ai_metrics_from_response(mock_response)
result = LangChainConnector.get_ai_metrics_from_response(mock_response)

assert result.success is True
assert result.usage is not None
Expand All @@ -115,7 +115,7 @@ def test_creates_metrics_with_success_true_and_no_usage_when_metadata_missing(se
"""Should create metrics with success=True and no usage when metadata is missing."""
mock_response = AIMessage(content='Test response')

result = LangChainProvider.get_ai_metrics_from_response(mock_response)
result = LangChainConnector.get_ai_metrics_from_response(mock_response)

assert result.success is True
assert result.usage is None
Expand All @@ -126,15 +126,15 @@ class TestMapProvider:

def test_maps_gemini_to_google_genai(self):
"""Should map gemini to google-genai."""
assert LangChainProvider.map_provider('gemini') == 'google-genai'
assert LangChainProvider.map_provider('Gemini') == 'google-genai'
assert LangChainProvider.map_provider('GEMINI') == 'google-genai'
assert LangChainConnector.map_provider('gemini') == 'google-genai'
assert LangChainConnector.map_provider('Gemini') == 'google-genai'
assert LangChainConnector.map_provider('GEMINI') == 'google-genai'

def test_returns_provider_name_unchanged_for_unmapped_providers(self):
"""Should return provider name unchanged for unmapped providers."""
assert LangChainProvider.map_provider('openai') == 'openai'
assert LangChainProvider.map_provider('anthropic') == 'anthropic'
assert LangChainProvider.map_provider('unknown') == 'unknown'
assert LangChainConnector.map_provider('openai') == 'openai'
assert LangChainConnector.map_provider('anthropic') == 'anthropic'
assert LangChainConnector.map_provider('unknown') == 'unknown'


class TestInvokeModel:
Expand All @@ -150,7 +150,7 @@ async def test_returns_success_true_for_string_content(self, mock_llm):
"""Should return success=True for string content."""
mock_response = AIMessage(content='Test response')
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)
Expand All @@ -163,7 +163,7 @@ async def test_returns_success_false_for_non_string_content_and_logs_warning(sel
"""Should return success=False for non-string content and log warning."""
mock_response = AIMessage(content=[{'type': 'image', 'data': 'base64data'}])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)
Expand All @@ -176,7 +176,7 @@ async def test_returns_success_false_when_model_invocation_throws_error(self, mo
"""Should return success=False when model invocation throws an error."""
error = Exception('Model invocation failed')
mock_llm.ainvoke = AsyncMock(side_effect=error)
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
result = await provider.invoke_model(messages)
Expand All @@ -201,7 +201,7 @@ async def test_returns_success_true_for_successful_invocation(self, mock_llm):
mock_structured_llm = MagicMock()
mock_structured_llm.ainvoke = AsyncMock(return_value=mock_response)
mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm)
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
response_structure = {'type': 'object', 'properties': {}}
Expand All @@ -217,7 +217,7 @@ async def test_returns_success_false_when_structured_model_invocation_throws_err
mock_structured_llm = MagicMock()
mock_structured_llm.ainvoke = AsyncMock(side_effect=error)
mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm)
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

messages = [LDMessage(role='user', content='Hello')]
response_structure = {'type': 'object', 'properties': {}}
Expand All @@ -236,7 +236,7 @@ class TestGetChatModel:
def test_returns_underlying_llm(self):
"""Should return the underlying LLM."""
mock_llm = MagicMock()
provider = LangChainProvider(mock_llm)
provider = LangChainConnector(mock_llm)

assert provider.get_chat_model() is mock_llm

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""LaunchDarkly AI SDK OpenAI Provider."""
"""LaunchDarkly AI SDK OpenAI Connector."""

from ldai_openai.openai_provider import OpenAIProvider
from ldai_openai.openai_provider import OpenAIConnector

__all__ = ['OpenAIProvider']
__all__ = [
'OpenAIConnector',
]
Loading
Loading