From fd13f6634b2c48c3e6c31fc88aa2b7086a7de34c Mon Sep 17 00:00:00 2001 From: jsonbailey Date: Wed, 11 Mar 2026 12:06:39 -0500 Subject: [PATCH] feat: Support additional create methods for agent and agent_graph MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit feat!: Rename AIProviderFactory → RunnerFactory, feat!: Rename OpenAIProvider → OpenAIRunnerFactory, feat!: Rename LangChainProvider to LangChainRunnerFactory feat: Add create_model(), create_agent(), create_agent_graph() to AIProvider ABC (non-abstract, default warns) --- .../src/ldai_langchain/__init__.py | 9 +- .../src/ldai_langchain/langchain_provider.py | 88 +++++---- .../tests/test_langchain_provider.py | 44 ++--- .../src/ldai_openai/__init__.py | 6 +- .../src/ldai_openai/openai_provider.py | 75 ++++---- .../tests/test_openai_provider.py | 14 +- packages/sdk/server-ai/src/ldai/client.py | 12 +- .../server-ai/src/ldai/providers/__init__.py | 6 +- .../src/ldai/providers/ai_provider.py | 77 +++++--- .../src/ldai/providers/ai_provider_factory.py | 125 ------------ .../src/ldai/providers/runner_factory.py | 181 ++++++++++++++++++ 11 files changed, 363 insertions(+), 274 deletions(-) delete mode 100644 packages/sdk/server-ai/src/ldai/providers/ai_provider_factory.py create mode 100644 packages/sdk/server-ai/src/ldai/providers/runner_factory.py diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py index 1282648..0e0263b 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py @@ -1,13 +1,10 @@ -"""LaunchDarkly AI SDK - LangChain Provider. +"""LaunchDarkly AI SDK - LangChain Connector.""" -This package provides LangChain integration for the LaunchDarkly Server-Side AI SDK, -""" - -from ldai_langchain.langchain_provider import LangChainProvider +from ldai_langchain.langchain_provider import LangChainRunnerFactory __version__ = "0.1.0" __all__ = [ '__version__', - 'LangChainProvider', + 'LangChainRunnerFactory', ] diff --git a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_provider.py b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_provider.py index f4fa62d..89533b4 100644 --- a/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_provider.py @@ -1,4 +1,4 @@ -"""LangChain implementation of AIProvider for LaunchDarkly AI SDK.""" +"""LangChain connector for LaunchDarkly AI SDK.""" from typing import Any, Dict, List, Optional, Union @@ -11,31 +11,46 @@ from ldai.tracker import TokenUsage -class LangChainProvider(AIProvider): +class LangChainRunnerFactory(AIProvider): """ - LangChain implementation of AIProvider. - - This provider integrates LangChain models with LaunchDarkly's tracking capabilities. + LangChain connector for the LaunchDarkly AI SDK. + + Can be used in two ways: + - Transparently via ExecutorFactory (pass ``default_ai_provider='langchain'`` to + ``create_model()`` / ``create_chat()``). + - Directly for full control: instantiate with a ``BaseChatModel``, then call + ``invoke_model()`` yourself and use the static convenience methods + (``get_ai_metrics_from_response``, ``convert_messages_to_langchain``, + ``map_provider``, ``create_langchain_model``). """ - def __init__(self, llm: BaseChatModel): + def __init__(self, llm: Optional[BaseChatModel] = None): """ - Initialize the LangChain provider. + Initialize the LangChain connector. + + When called with no arguments the connector acts as a per-backend factory + — call ``create_model(config)`` to obtain a configured instance. - :param llm: A LangChain BaseChatModel instance + When called with an explicit ``llm`` the connector is ready to invoke + the model immediately. + + :param llm: A LangChain BaseChatModel instance (optional) """ self._llm = llm - @staticmethod - async def create(ai_config: AIConfigKind) -> 'LangChainProvider': + # --- AIProvider factory methods --- + + def create_model(self, config: AIConfigKind) -> 'LangChainRunnerFactory': """ - Static factory method to create a LangChain AIProvider from an AI configuration. + Create a configured LangChain model connector for the given AI config. - :param ai_config: The LaunchDarkly AI configuration - :return: Configured LangChainProvider instance + :param config: The LaunchDarkly AI configuration + :return: Configured LangChainRunnerFactory ready to invoke the model """ - llm = LangChainProvider.create_langchain_model(ai_config) - return LangChainProvider(llm) + llm = LangChainRunnerFactory.create_langchain_model(config) + return LangChainRunnerFactory(llm) + + # --- Model invocation --- async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: """ @@ -45,9 +60,9 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: :return: ChatResponse containing the model's response and metrics """ try: - langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) + langchain_messages = LangChainRunnerFactory.convert_messages_to_langchain(messages) response: BaseMessage = await self._llm.ainvoke(langchain_messages) - metrics = LangChainProvider.get_ai_metrics_from_response(response) + metrics = LangChainRunnerFactory.get_ai_metrics_from_response(response) content: str = '' if isinstance(response.content, str): @@ -84,7 +99,7 @@ async def invoke_structured_model( :return: StructuredResponse containing the structured data """ try: - langchain_messages = LangChainProvider.convert_messages_to_langchain(messages) + langchain_messages = LangChainRunnerFactory.convert_messages_to_langchain(messages) structured_llm = self._llm.with_structured_output(response_structure) response = await structured_llm.ainvoke(langchain_messages) @@ -122,11 +137,13 @@ async def invoke_structured_model( ), ) - def get_chat_model(self) -> BaseChatModel: + # --- Convenience accessors --- + + def get_chat_model(self) -> Optional[BaseChatModel]: """ Get the underlying LangChain model instance. - :return: The underlying BaseChatModel + :return: The underlying BaseChatModel, or None if not yet configured """ return self._llm @@ -135,9 +152,6 @@ def map_provider(ld_provider_name: str) -> str: """ Map LaunchDarkly provider names to LangChain provider names. - This method enables seamless integration between LaunchDarkly's standardized - provider naming and LangChain's naming conventions. - :param ld_provider_name: LaunchDarkly provider name :return: LangChain-compatible provider name """ @@ -152,25 +166,24 @@ def map_provider(ld_provider_name: str) -> str: @staticmethod def get_ai_metrics_from_response(response: BaseMessage) -> LDAIMetrics: """ - Get AI metrics from a LangChain provider response. - - This method extracts token usage information and success status from LangChain responses - and returns a LaunchDarkly AIMetrics object. + Extract LaunchDarkly AI metrics from a LangChain response. :param response: The response from the LangChain model :return: LDAIMetrics with success status and token usage - Example: - # Use with tracker.track_metrics_of for automatic tracking + Example:: + response = await tracker.track_metrics_of( lambda: llm.ainvoke(messages), - LangChainProvider.get_ai_metrics_from_response + LangChainRunnerFactory.get_ai_metrics_from_response ) """ - # Extract token usage if available usage: Optional[TokenUsage] = None if hasattr(response, 'response_metadata') and response.response_metadata: - token_usage = response.response_metadata.get('tokenUsage') or response.response_metadata.get('token_usage') + token_usage = ( + response.response_metadata.get('tokenUsage') + or response.response_metadata.get('token_usage') + ) if token_usage: usage = TokenUsage( total=token_usage.get('totalTokens', 0) or token_usage.get('total_tokens', 0), @@ -187,9 +200,6 @@ def convert_messages_to_langchain( """ Convert LaunchDarkly messages to LangChain messages. - This helper method enables developers to work directly with LangChain message types - while maintaining compatibility with LaunchDarkly's standardized message format. - :param messages: List of LDMessage objects :return: List of LangChain message objects :raises ValueError: If an unsupported message role is encountered @@ -211,10 +221,7 @@ def convert_messages_to_langchain( @staticmethod def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: """ - Create a LangChain model from an AI configuration. - - This public helper method enables developers to initialize their own LangChain models - using LaunchDarkly AI configurations. + Create a LangChain model from a LaunchDarkly AI configuration. :param ai_config: The LaunchDarkly AI configuration :return: A configured LangChain BaseChatModel @@ -231,6 +238,7 @@ def create_langchain_model(ai_config: AIConfigKind) -> BaseChatModel: return init_chat_model( model_name, - model_provider=LangChainProvider.map_provider(provider), + model_provider=LangChainRunnerFactory.map_provider(provider), **parameters, ) + diff --git a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py index 0c90a43..d74f8c3 100644 --- a/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py +++ b/packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py @@ -7,7 +7,7 @@ from ldai import LDMessage -from ldai_langchain import LangChainProvider +from ldai_langchain import LangChainRunnerFactory class TestConvertMessagesToLangchain: @@ -16,7 +16,7 @@ class TestConvertMessagesToLangchain: def test_converts_system_messages_to_system_message(self): """Should convert system messages to SystemMessage.""" messages = [LDMessage(role='system', content='You are a helpful assistant.')] - result = LangChainProvider.convert_messages_to_langchain(messages) + result = LangChainRunnerFactory.convert_messages_to_langchain(messages) assert len(result) == 1 assert isinstance(result[0], SystemMessage) @@ -25,7 +25,7 @@ def test_converts_system_messages_to_system_message(self): def test_converts_user_messages_to_human_message(self): """Should convert user messages to HumanMessage.""" messages = [LDMessage(role='user', content='Hello, how are you?')] - result = LangChainProvider.convert_messages_to_langchain(messages) + result = LangChainRunnerFactory.convert_messages_to_langchain(messages) assert len(result) == 1 assert isinstance(result[0], HumanMessage) @@ -34,7 +34,7 @@ def test_converts_user_messages_to_human_message(self): def test_converts_assistant_messages_to_ai_message(self): """Should convert assistant messages to AIMessage.""" messages = [LDMessage(role='assistant', content='I am doing well, thank you!')] - result = LangChainProvider.convert_messages_to_langchain(messages) + result = LangChainRunnerFactory.convert_messages_to_langchain(messages) assert len(result) == 1 assert isinstance(result[0], AIMessage) @@ -47,7 +47,7 @@ def test_converts_multiple_messages_in_order(self): LDMessage(role='user', content='What is the weather like?'), LDMessage(role='assistant', content='I cannot check the weather.'), ] - result = LangChainProvider.convert_messages_to_langchain(messages) + result = LangChainRunnerFactory.convert_messages_to_langchain(messages) assert len(result) == 3 assert isinstance(result[0], SystemMessage) @@ -62,11 +62,11 @@ class MockMessage: content = 'Test message' with pytest.raises(ValueError, match='Unsupported message role: unknown'): - LangChainProvider.convert_messages_to_langchain([MockMessage()]) # type: ignore + LangChainRunnerFactory.convert_messages_to_langchain([MockMessage()]) # type: ignore def test_handles_empty_message_array(self): """Should handle empty message array.""" - result = LangChainProvider.convert_messages_to_langchain([]) + result = LangChainRunnerFactory.convert_messages_to_langchain([]) assert len(result) == 0 @@ -84,7 +84,7 @@ def test_creates_metrics_with_success_true_and_token_usage(self): }, } - result = LangChainProvider.get_ai_metrics_from_response(mock_response) + result = LangChainRunnerFactory.get_ai_metrics_from_response(mock_response) assert result.success is True assert result.usage is not None @@ -103,7 +103,7 @@ def test_creates_metrics_with_snake_case_token_usage(self): }, } - result = LangChainProvider.get_ai_metrics_from_response(mock_response) + result = LangChainRunnerFactory.get_ai_metrics_from_response(mock_response) assert result.success is True assert result.usage is not None @@ -115,7 +115,7 @@ def test_creates_metrics_with_success_true_and_no_usage_when_metadata_missing(se """Should create metrics with success=True and no usage when metadata is missing.""" mock_response = AIMessage(content='Test response') - result = LangChainProvider.get_ai_metrics_from_response(mock_response) + result = LangChainRunnerFactory.get_ai_metrics_from_response(mock_response) assert result.success is True assert result.usage is None @@ -126,15 +126,15 @@ class TestMapProvider: def test_maps_gemini_to_google_genai(self): """Should map gemini to google-genai.""" - assert LangChainProvider.map_provider('gemini') == 'google-genai' - assert LangChainProvider.map_provider('Gemini') == 'google-genai' - assert LangChainProvider.map_provider('GEMINI') == 'google-genai' + assert LangChainRunnerFactory.map_provider('gemini') == 'google-genai' + assert LangChainRunnerFactory.map_provider('Gemini') == 'google-genai' + assert LangChainRunnerFactory.map_provider('GEMINI') == 'google-genai' def test_returns_provider_name_unchanged_for_unmapped_providers(self): """Should return provider name unchanged for unmapped providers.""" - assert LangChainProvider.map_provider('openai') == 'openai' - assert LangChainProvider.map_provider('anthropic') == 'anthropic' - assert LangChainProvider.map_provider('unknown') == 'unknown' + assert LangChainRunnerFactory.map_provider('openai') == 'openai' + assert LangChainRunnerFactory.map_provider('anthropic') == 'anthropic' + assert LangChainRunnerFactory.map_provider('unknown') == 'unknown' class TestInvokeModel: @@ -150,7 +150,7 @@ async def test_returns_success_true_for_string_content(self, mock_llm): """Should return success=True for string content.""" mock_response = AIMessage(content='Test response') mock_llm.ainvoke = AsyncMock(return_value=mock_response) - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) messages = [LDMessage(role='user', content='Hello')] result = await provider.invoke_model(messages) @@ -163,7 +163,7 @@ async def test_returns_success_false_for_non_string_content_and_logs_warning(sel """Should return success=False for non-string content and log warning.""" mock_response = AIMessage(content=[{'type': 'image', 'data': 'base64data'}]) mock_llm.ainvoke = AsyncMock(return_value=mock_response) - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) messages = [LDMessage(role='user', content='Hello')] result = await provider.invoke_model(messages) @@ -176,7 +176,7 @@ async def test_returns_success_false_when_model_invocation_throws_error(self, mo """Should return success=False when model invocation throws an error.""" error = Exception('Model invocation failed') mock_llm.ainvoke = AsyncMock(side_effect=error) - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) messages = [LDMessage(role='user', content='Hello')] result = await provider.invoke_model(messages) @@ -201,7 +201,7 @@ async def test_returns_success_true_for_successful_invocation(self, mock_llm): mock_structured_llm = MagicMock() mock_structured_llm.ainvoke = AsyncMock(return_value=mock_response) mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm) - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) messages = [LDMessage(role='user', content='Hello')] response_structure = {'type': 'object', 'properties': {}} @@ -217,7 +217,7 @@ async def test_returns_success_false_when_structured_model_invocation_throws_err mock_structured_llm = MagicMock() mock_structured_llm.ainvoke = AsyncMock(side_effect=error) mock_llm.with_structured_output = MagicMock(return_value=mock_structured_llm) - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) messages = [LDMessage(role='user', content='Hello')] response_structure = {'type': 'object', 'properties': {}} @@ -236,7 +236,7 @@ class TestGetChatModel: def test_returns_underlying_llm(self): """Should return the underlying LLM.""" mock_llm = MagicMock() - provider = LangChainProvider(mock_llm) + provider = LangChainRunnerFactory(mock_llm) assert provider.get_chat_model() is mock_llm diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py index 5d5120f..8f4c6e1 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py @@ -1,5 +1,7 @@ -"""LaunchDarkly AI SDK OpenAI Provider.""" +"""LaunchDarkly AI SDK OpenAI Connector.""" from ldai_openai.openai_provider import OpenAIProvider -__all__ = ['OpenAIProvider'] +__all__ = [ + 'OpenAIProvider', +] diff --git a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_provider.py b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_provider.py index c62cc80..f7013ac 100644 --- a/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_provider.py +++ b/packages/ai-providers/server-ai-openai/src/ldai_openai/openai_provider.py @@ -1,4 +1,4 @@ -"""OpenAI implementation of AIProvider for LaunchDarkly AI SDK.""" +"""OpenAI connector for LaunchDarkly AI SDK.""" import json import os @@ -15,46 +15,60 @@ class OpenAIProvider(AIProvider): """ - OpenAI implementation of AIProvider. + OpenAI connector for the LaunchDarkly AI SDK. - This provider integrates OpenAI's chat completions API with LaunchDarkly's tracking capabilities. + Can be used in two ways: + - Transparently via ExecutorFactory (pass ``default_ai_provider='openai'`` to + ``create_model()`` / ``create_chat()``). + - Directly for full control: instantiate with an ``AsyncOpenAI`` client, + model name, and parameters, then call ``invoke_model()`` yourself. """ def __init__( self, - client: AsyncOpenAI, - model_name: str, - parameters: Dict[str, Any], + client: Optional[AsyncOpenAI] = None, + model_name: str = '', + parameters: Optional[Dict[str, Any]] = None, ): """ - Initialize the OpenAI provider. + Initialize the OpenAI connector. - :param client: An AsyncOpenAI client instance + When called with no arguments the connector reads credentials from the + environment (``OPENAI_API_KEY``) and acts as a per-backend factory — + call ``create_model(config)`` to obtain a configured instance. + + When called with explicit arguments the connector is ready to invoke + the model immediately. + + :param client: An AsyncOpenAI client instance (created from env if omitted) :param model_name: The name of the model to use :param parameters: Additional model parameters """ - self._client = client + self._client = client if client is not None else AsyncOpenAI( + api_key=os.environ.get('OPENAI_API_KEY'), + ) self._model_name = model_name - self._parameters = parameters + self._parameters = parameters or {} - @staticmethod - async def create(ai_config: AIConfigKind) -> 'OpenAIProvider': - """ - Static factory method to create an OpenAI AIProvider from an AI configuration. + # --- AIProvider factory methods --- - :param ai_config: The LaunchDarkly AI configuration - :return: Configured OpenAIProvider instance + def create_model(self, config: AIConfigKind) -> 'OpenAIProvider': """ - client = AsyncOpenAI( - api_key=os.environ.get('OPENAI_API_KEY'), - ) + Create a configured OpenAI model connector for the given AI config. + + Reuses the underlying AsyncOpenAI client so that connection pooling is + preserved across calls. - config_dict = ai_config.to_dict() + :param config: The LaunchDarkly AI configuration + :return: Configured OpenAIProvider ready to invoke the model + """ + config_dict = config.to_dict() model_dict = config_dict.get('model') or {} model_name = model_dict.get('name', '') parameters = model_dict.get('parameters') or {} + return OpenAIProvider(self._client, model_name, parameters) - return OpenAIProvider(client, model_name, parameters) + # --- Model invocation --- async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: """ @@ -64,7 +78,6 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: :return: ChatResponse containing the model's response and metrics """ try: - # Convert LDMessage to OpenAI message format openai_messages: Iterable[ChatCompletionMessageParam] = cast( Iterable[ChatCompletionMessageParam], [{'role': msg.role, 'content': msg.content} for msg in messages] @@ -76,10 +89,8 @@ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: **self._parameters, ) - # Generate metrics early (assumes success by default) metrics = OpenAIProvider.get_ai_metrics_from_response(response) - # Safely extract the first choice content content = '' if response.choices and len(response.choices) > 0: message = response.choices[0].message @@ -115,7 +126,6 @@ async def invoke_structured_model( :return: StructuredResponse containing the structured data """ try: - # Convert LDMessage to OpenAI message format openai_messages: Iterable[ChatCompletionMessageParam] = cast( Iterable[ChatCompletionMessageParam], [{'role': msg.role, 'content': msg.content} for msg in messages] @@ -135,10 +145,8 @@ async def invoke_structured_model( **self._parameters, ) - # Generate metrics early (assumes success by default) metrics = OpenAIProvider.get_ai_metrics_from_response(response) - # Safely extract the first choice content content = '' if response.choices and len(response.choices) > 0: message = response.choices[0].message @@ -178,6 +186,8 @@ async def invoke_structured_model( metrics=LDAIMetrics(success=False, usage=None), ) + # --- Convenience accessors --- + def get_client(self) -> AsyncOpenAI: """ Get the underlying OpenAI client instance. @@ -189,21 +199,18 @@ def get_client(self) -> AsyncOpenAI: @staticmethod def get_ai_metrics_from_response(response: Any) -> LDAIMetrics: """ - Get AI metrics from an OpenAI response. - - This method extracts token usage information and success status from OpenAI responses - and returns a LaunchDarkly AIMetrics object. + Extract LaunchDarkly AI metrics from an OpenAI response. :param response: The response from OpenAI chat completions API :return: LDAIMetrics with success status and token usage - Example: + Example:: + response = await tracker.track_metrics_of( lambda: client.chat.completions.create(config), OpenAIProvider.get_ai_metrics_from_response ) """ - # Extract token usage if available usage: Optional[TokenUsage] = None if hasattr(response, 'usage') and response.usage: usage = TokenUsage( @@ -212,5 +219,5 @@ def get_ai_metrics_from_response(response: Any) -> LDAIMetrics: output=response.usage.completion_tokens or 0, ) - # OpenAI responses that complete successfully are considered successful by default return LDAIMetrics(success=True, usage=usage) + diff --git a/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py b/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py index ff9066b..6e22028 100644 --- a/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py +++ b/packages/ai-providers/server-ai-openai/tests/test_openai_provider.py @@ -277,11 +277,10 @@ def test_returns_underlying_client(self): assert provider.get_client() is mock_client -class TestCreate: - """Tests for create static factory method.""" +class TestCreateModel: + """Tests for create_model instance method.""" - @pytest.mark.asyncio - async def test_creates_provider_with_correct_model_and_parameters(self): + def test_creates_connector_with_correct_model_and_parameters(self): """Should create OpenAIProvider with correct model and parameters.""" mock_ai_config = MagicMock() mock_ai_config.to_dict.return_value = { @@ -299,14 +298,13 @@ async def test_creates_provider_with_correct_model_and_parameters(self): mock_client = MagicMock() mock_openai_class.return_value = mock_client - result = await OpenAIProvider.create(mock_ai_config) + result = OpenAIProvider().create_model(mock_ai_config) assert isinstance(result, OpenAIProvider) assert result._model_name == 'gpt-4' assert result._parameters == {'temperature': 0.7, 'max_tokens': 1000} - @pytest.mark.asyncio - async def test_handles_missing_model_config(self): + def test_handles_missing_model_config(self): """Should handle missing model configuration.""" mock_ai_config = MagicMock() mock_ai_config.to_dict.return_value = {} @@ -315,7 +313,7 @@ async def test_handles_missing_model_config(self): mock_client = MagicMock() mock_openai_class.return_value = mock_client - result = await OpenAIProvider.create(mock_ai_config) + result = OpenAIProvider().create_model(mock_ai_config) assert isinstance(result, OpenAIProvider) assert result._model_name == '' diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index ae79bd9..c0cd175 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -24,7 +24,7 @@ ModelConfig, ProviderConfig, ) -from ldai.providers.ai_provider_factory import AIProviderFactory +from ldai.providers.runner_factory import RunnerFactory from ldai.sdk_info import AI_SDK_LANGUAGE, AI_SDK_NAME, AI_SDK_VERSION from ldai.tracker import AIGraphTracker, LDAIConfigTracker @@ -200,7 +200,7 @@ async def create_judge( :param default: A default value representing a standard AI config result :param variables: Dictionary of values for instruction interpolation. The variables `message_history` and `response_to_evaluate` are reserved for the judge and will be ignored. - :param default_ai_provider: Optional default AI provider to use. + :param default_ai_provider: Optional default AI backend to use. :return: Judge instance or None if disabled/unsupported Example:: @@ -245,7 +245,7 @@ async def create_judge( if not judge_config.enabled or not judge_config.tracker: return None - provider = await AIProviderFactory.create(judge_config, default_ai_provider) + provider = await RunnerFactory.create_model(judge_config, default_ai_provider) if not provider: return None @@ -266,7 +266,7 @@ async def _initialize_judges( :param judge_configs: List of judge configurations :param context: Standard Context used when evaluating flags :param variables: Dictionary of values for instruction interpolation - :param default_ai_provider: Optional default AI provider to use + :param default_ai_provider: Optional default AI backend to use :return: Dictionary of judge instances keyed by their configuration keys """ judges: Dict[str, Judge] = {} @@ -314,7 +314,7 @@ async def create_chat( :param default: A default value representing a standard AI config result. When not provided, a disabled config is used as the fallback. :param variables: Dictionary of values for instruction interpolation - :param default_ai_provider: Optional default AI provider to use + :param default_ai_provider: Optional default AI backend to use :return: Chat instance or None if disabled/unsupported Example:: @@ -346,7 +346,7 @@ async def create_chat( if not config.enabled or not config.tracker: return None - provider = await AIProviderFactory.create(config, default_ai_provider) + provider = await RunnerFactory.create_model(config, default_ai_provider) if not provider: return None diff --git a/packages/sdk/server-ai/src/ldai/providers/__init__.py b/packages/sdk/server-ai/src/ldai/providers/__init__.py index 71efb6c..4cebeea 100644 --- a/packages/sdk/server-ai/src/ldai/providers/__init__.py +++ b/packages/sdk/server-ai/src/ldai/providers/__init__.py @@ -1,9 +1,9 @@ -"""AI Provider interfaces and factory for LaunchDarkly AI SDK.""" +"""AI Connector interfaces and factory for LaunchDarkly AI SDK.""" from ldai.providers.ai_provider import AIProvider -from ldai.providers.ai_provider_factory import AIProviderFactory +from ldai.providers.runner_factory import RunnerFactory __all__ = [ 'AIProvider', - 'AIProviderFactory', + 'RunnerFactory', ] diff --git a/packages/sdk/server-ai/src/ldai/providers/ai_provider.py b/packages/sdk/server-ai/src/ldai/providers/ai_provider.py index 91c8cb9..238b9f4 100644 --- a/packages/sdk/server-ai/src/ldai/providers/ai_provider.py +++ b/packages/sdk/server-ai/src/ldai/providers/ai_provider.py @@ -1,38 +1,37 @@ -"""Abstract base class for AI providers.""" +"""Abstract base class for AI connectors.""" -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from abc import ABC +from typing import Any, Dict, List, Optional from ldai import log -from ldai.models import AIConfigKind, LDMessage +from ldai.models import LDMessage from ldai.providers.types import ChatResponse, StructuredResponse class AIProvider(ABC): """ - Abstract base class for AI providers that implement chat model functionality. + Abstract base class for AI backend connectors. - This class provides the contract that all provider implementations must follow - to integrate with LaunchDarkly's tracking and configuration capabilities. + An AIProvider is a per-backend factory: it is instantiated once per backend + (with no arguments — credentials are read from environment variables) and is + responsible for constructing focused runtime capability objects via + create_model(), create_agent(), and create_agent_graph(). - Following the AICHAT spec recommendation to use base classes with non-abstract methods - for better extensibility and backwards compatibility. + The invoke_model() / invoke_structured_model() methods remain on this base + class for compatibility and will migrate to ModelExecutor in PR 2. """ async def invoke_model(self, messages: List[LDMessage]) -> ChatResponse: """ Invoke the chat model with an array of messages. - This method should convert messages to provider format, invoke the model, - and return a ChatResponse with the result and metrics. - Default implementation takes no action and returns a placeholder response. - Provider implementations should override this method. + Connector implementations should override this method. :param messages: Array of LDMessage objects representing the conversation :return: ChatResponse containing the model's response """ - log.warn('invokeModel not implemented by this provider') + log.warn('invoke_model not implemented by this connector') from ldai.models import LDMessage from ldai.providers.types import LDAIMetrics @@ -50,17 +49,14 @@ async def invoke_structured_model( """ Invoke the chat model with structured output support. - This method should convert messages to provider format, invoke the model with - structured output configuration, and return a structured response. - Default implementation takes no action and returns a placeholder response. - Provider implementations should override this method. + Connector implementations should override this method. :param messages: Array of LDMessage objects representing the conversation :param response_structure: Dictionary of output configurations keyed by output name :return: StructuredResponse containing the structured data """ - log.warn('invokeStructuredModel not implemented by this provider') + log.warn('invoke_structured_model not implemented by this connector') from ldai.providers.types import LDAIMetrics @@ -70,16 +66,41 @@ async def invoke_structured_model( metrics=LDAIMetrics(success=False, usage=None), ) - @staticmethod - @abstractmethod - async def create(ai_config: AIConfigKind) -> 'AIProvider': + def create_model(self, config: Any) -> Optional['AIProvider']: + """ + Create a configured model executor for the given AI config. + + Default implementation warns. Backend connectors should override this method. + + :param config: The LaunchDarkly AI configuration + :return: Configured AIProvider instance, or None if unsupported + """ + log.warn('create_model not implemented by this connector') + return None + + def create_agent(self, config: Any, tools: Any) -> Optional[Any]: """ - Static method that constructs an instance of the provider. + Create a configured agent executor for the given AI config and tool registry. - Each provider implementation must provide their own static create method - that accepts an AIConfigKind and returns a configured instance. + Default implementation warns. Backend connectors should override this method. - :param ai_config: The LaunchDarkly AI configuration - :return: Configured provider instance + :param config: The LaunchDarkly AI agent configuration + :param tools: Tool registry mapping tool names to callables + :return: AgentExecutor instance, or None if unsupported """ - raise NotImplementedError('Provider implementations must override the static create method') + log.warn('create_agent not implemented by this connector') + return None + + def create_agent_graph(self, graph_def: Any, tools: Any) -> Optional[Any]: + """ + Create a configured agent graph executor for the given graph definition and tools. + + Default implementation warns. Backend connectors should override this method. + + :param graph_def: The agent graph definition + :param tools: Tool registry mapping tool names to callables + :return: AgentGraphExecutor instance, or None if unsupported + """ + log.warn('create_agent_graph not implemented by this connector') + return None + diff --git a/packages/sdk/server-ai/src/ldai/providers/ai_provider_factory.py b/packages/sdk/server-ai/src/ldai/providers/ai_provider_factory.py deleted file mode 100644 index 74e55d5..0000000 --- a/packages/sdk/server-ai/src/ldai/providers/ai_provider_factory.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Factory for creating AIProvider instances based on the provider configuration.""" - -from importlib import util -from typing import Any, Dict, List, Optional, Tuple, Type - -from ldai import log -from ldai.models import AIConfigKind -from ldai.providers.ai_provider import AIProvider - -# Supported AI providers -# Multi-provider packages should be last in the list -SUPPORTED_AI_PROVIDERS = ('openai', 'langchain') - - -class AIProviderFactory: - """ - Factory for creating AIProvider instances based on the provider configuration. - """ - - @staticmethod - async def create( - ai_config: AIConfigKind, - default_ai_provider: Optional[str] = None, - ) -> Optional[AIProvider]: - """ - Create an AIProvider instance based on the AI configuration. - - This method attempts to load provider-specific implementations dynamically. - Returns None if the provider is not supported. - - :param ai_config: The AI configuration - :param default_ai_provider: Optional default AI provider to use - :return: AIProvider instance or None if not supported - """ - provider_name = ai_config.provider.name.lower() if ai_config.provider else None - providers_to_try = AIProviderFactory._get_providers_to_try(default_ai_provider, provider_name) - - for provider_type in providers_to_try: - provider = await AIProviderFactory._try_create_provider(provider_type, ai_config) - if provider: - log.debug( - f"Successfully created AIProvider for: {provider_name} " - f"with provider type: {provider_type} for AIConfig: {ai_config.key}" - ) - return provider - - log.warn( - f"Provider is not supported or failed to initialize: {provider_name}" - ) - return None - - @staticmethod - def _get_providers_to_try( - default_ai_provider: Optional[str], - provider_name: Optional[str], - ) -> List[str]: - """ - Determine which providers to try based on default_ai_provider and provider_name. - - :param default_ai_provider: Optional default provider to use - :param provider_name: Optional provider name from config - :return: List of providers to try in order - """ - if default_ai_provider: - return [default_ai_provider] - - providers = [] - - if provider_name and provider_name in SUPPORTED_AI_PROVIDERS: - providers.append(provider_name) - - # Then try multi-provider packages, but avoid duplicates - multi_provider_packages: List[str] = ['langchain'] - for provider in multi_provider_packages: - if provider not in providers: - providers.append(provider) - - return providers - - @staticmethod - async def _try_create_provider( - provider_type: str, - ai_config: AIConfigKind, - ) -> Optional[AIProvider]: - """ - Try to create a provider of the specified type. - - :param provider_type: Type of provider to create - :param ai_config: AI configuration - :return: AIProvider instance or None if creation failed - """ - try: - if provider_type == 'langchain': - AIProviderFactory._pkg_exists('ldai_langchain') - from ldai_langchain import LangChainProvider - return await LangChainProvider.create(ai_config) - - if provider_type == 'openai': - AIProviderFactory._pkg_exists('ldai_openai') - from ldai_openai import OpenAIProvider - return await OpenAIProvider.create(ai_config) - - log.warn( - f"Provider {provider_type} is not supported. " - f"Supported providers are: {SUPPORTED_AI_PROVIDERS}" - ) - - return None - except ImportError as error: - log.warn( - f"Error creating {provider_type} provider: {error}. " - f"Make sure the {provider_type} package is installed." - ) - return None - - @staticmethod - def _pkg_exists(package_name: str) -> None: - """ - Check if a package exists. - - :param package_name: Name of the package to check - :return: None if the package exists, otherwise raises an ImportError - """ - if util.find_spec(package_name) is None: - raise ImportError(f"Package {package_name} not found") diff --git a/packages/sdk/server-ai/src/ldai/providers/runner_factory.py b/packages/sdk/server-ai/src/ldai/providers/runner_factory.py new file mode 100644 index 0000000..8c0c3ed --- /dev/null +++ b/packages/sdk/server-ai/src/ldai/providers/runner_factory.py @@ -0,0 +1,181 @@ +"""Factory for creating AIProvider instances and capability runners.""" + +from importlib import util +from typing import Any, Callable, List, Optional, TypeVar + +from ldai import log +from ldai.models import AIConfigKind +from ldai.providers.ai_provider import AIProvider + +T = TypeVar('T') + +# Supported AI backends. +# Multi-provider packages should be last in the list. +SUPPORTED_AI_PROVIDERS = ('openai', 'langchain') + + +class RunnerFactory: + """ + Sole entry point for capability creation. + + RunnerFactory instantiates the appropriate AIProvider for the configured + backend and delegates runner construction to it. The shared fallback + loop (_with_fallback) tries each candidate backend in order and returns + the first successful result. + """ + + @staticmethod + def _get_ai_adapter(provider_type: str) -> Optional[AIProvider]: + """ + Import and instantiate the AIProvider for the given backend type. + + This is the only place in the SDK that knows about connector package names. + + :param provider_type: Backend identifier, e.g. 'openai' or 'langchain' + :return: AIProvider instance, or None if the package is not installed + """ + try: + if provider_type == 'langchain': + RunnerFactory._pkg_exists('ldai_langchain') + from ldai_langchain import LangChainRunnerFactory + return LangChainRunnerFactory() + + if provider_type == 'openai': + RunnerFactory._pkg_exists('ldai_openai') + from ldai_openai import OpenAIProvider + return OpenAIProvider() + + log.warn( + f"Backend '{provider_type}' is not supported. " + f"Supported backends: {SUPPORTED_AI_PROVIDERS}" + ) + return None + except ImportError as error: + log.warn( + f"Could not load backend '{provider_type}': {error}. " + f"Make sure the corresponding package is installed." + ) + return None + + @staticmethod + def _with_fallback( + providers: List[str], + fn: Callable[[AIProvider], Optional[T]], + ) -> Optional[T]: + """ + Try each backend in order; return the first successful result. + + Shared by all create_* methods so the fallback loop is written once. + + :param providers: Ordered list of backend identifiers to try + :param fn: Callable that receives an AIProvider and returns a result or None + :return: First non-None result, or None if all backends fail + """ + for provider_type in providers: + try: + connector = RunnerFactory._get_ai_adapter(provider_type) + if connector is None: + continue + result = fn(connector) + if result is not None: + log.debug(f"Successfully created capability using backend '{provider_type}'") + return result + except Exception as exc: + log.warn(f"Backend '{provider_type}' failed: {exc}") + + log.warn("All backends failed or are unavailable") + return None + + @staticmethod + def _get_providers_to_try( + default_ai_provider: Optional[str], + provider_name: Optional[str], + ) -> List[str]: + """ + Determine which backends to try, in priority order. + + :param default_ai_provider: Caller-specified override (tried exclusively if set) + :param provider_name: Provider name from the AI config + :return: Ordered list of backend identifiers + """ + if default_ai_provider: + return [default_ai_provider] + + providers: List[str] = [] + + if provider_name and provider_name in SUPPORTED_AI_PROVIDERS: + providers.append(provider_name) + + # Multi-provider packages act as a fallback + for multi in ['langchain']: + if multi not in providers: + providers.append(multi) + + return providers + + # --- Public API --- + + @staticmethod + async def create_model( + config: AIConfigKind, + default_ai_provider: Optional[str] = None, + ) -> Optional[AIProvider]: + """ + Create a model executor for the given AI completion config. + + :param config: LaunchDarkly AI config (completion or judge) + :param default_ai_provider: Optional backend override ('openai', 'langchain', …) + :return: Configured AIProvider that can invoke_model(), or None + """ + provider_name = config.provider.name.lower() if config.provider else None + providers = RunnerFactory._get_providers_to_try(default_ai_provider, provider_name) + return RunnerFactory._with_fallback(providers, lambda p: p.create_model(config)) + + @staticmethod + async def create_agent( + config: Any, + tools: Any, + default_ai_provider: Optional[str] = None, + ) -> Optional[Any]: + """ + Create an agent executor for the given AI agent config and tool registry. + + :param config: LaunchDarkly AI agent config + :param tools: Tool registry mapping tool names to callables + :param default_ai_provider: Optional backend override + :return: AgentExecutor instance, or None + """ + provider_name = config.provider.name.lower() if config.provider else None + providers = RunnerFactory._get_providers_to_try(default_ai_provider, provider_name) + return RunnerFactory._with_fallback(providers, lambda p: p.create_agent(config, tools)) + + @staticmethod + async def create_agent_graph( + graph_def: Any, + tools: Any, + default_ai_provider: Optional[str] = None, + ) -> Optional[Any]: + """ + Create an agent graph executor for the given graph definition and tool registry. + + :param graph_def: AgentGraphDefinition instance + :param tools: Tool registry mapping tool names to callables + :param default_ai_provider: Optional backend override + :return: AgentGraphExecutor instance, or None + """ + provider_name = None + if graph_def.root() and graph_def.root().get_config() and graph_def.root().get_config().provider: + provider_name = graph_def.root().get_config().provider.name.lower() + providers = RunnerFactory._get_providers_to_try(default_ai_provider, provider_name) + return RunnerFactory._with_fallback(providers, lambda p: p.create_agent_graph(graph_def, tools)) + + @staticmethod + def _pkg_exists(package_name: str) -> None: + """ + Raise ImportError if the given package is not importable. + + :param package_name: Name of the package to check + """ + if util.find_spec(package_name) is None: + raise ImportError(f"Package '{package_name}' not found") +