diff --git a/src/skillspector/llm_utils.py b/src/skillspector/llm_utils.py index 1e03fc1..6e336ab 100644 --- a/src/skillspector/llm_utils.py +++ b/src/skillspector/llm_utils.py @@ -29,13 +29,32 @@ from __future__ import annotations -import os - from langchain_openai import ChatOpenAI -from skillspector.constants import MODEL_CONFIG from skillspector.model_info import get_max_input_tokens, get_max_output_tokens -from skillspector.providers import resolve_provider_credentials +from skillspector.providers import get_metadata_provider, resolve_provider_credentials +from skillspector.providers.openai import OpenAIProvider + + +def _resolve_llm_client_config() -> tuple[str, str | None, str]: + """Return ``(api_key, base_url, default_model)`` for the resolved endpoint.""" + creds = resolve_provider_credentials() + if creds is not None: + api_key, base_url = creds + return api_key, base_url, get_metadata_provider().resolve_model() + + openai_provider = OpenAIProvider() + openai_creds = openai_provider.resolve_credentials() + if openai_creds is not None: + api_key, base_url = openai_creds + return api_key, base_url, openai_provider.resolve_model() + + raise ValueError( + "No LLM API key configured. Set the credential env var for the " + "active provider, or set OPENAI_API_KEY (and optionally " + "OPENAI_BASE_URL) to use a standard OpenAI-compatible endpoint. " + "Use --no-llm to skip LLM analysis and run static checks only." + ) def _resolve_llm_credentials() -> tuple[str, str | None]: @@ -47,20 +66,7 @@ def _resolve_llm_credentials() -> tuple[str, str | None]: Raises: ValueError: when no API key can be resolved from any source. """ - creds = resolve_provider_credentials() - if creds is not None: - return creds - - resolved_key = os.environ.get("OPENAI_API_KEY", "").strip() - if not resolved_key: - raise ValueError( - "No LLM API key configured. Set the credential env var for the " - "active provider, or set OPENAI_API_KEY (and optionally " - "OPENAI_BASE_URL) to use a standard OpenAI-compatible endpoint. " - "Use --no-llm to skip LLM analysis and run static checks only." - ) - - resolved_base = os.environ.get("OPENAI_BASE_URL", "").strip() or None + resolved_key, resolved_base, _ = _resolve_llm_client_config() return resolved_key, resolved_base @@ -84,8 +90,8 @@ def get_chat_model(model: str | None = None) -> ChatOpenAI: Raises: ValueError: when no API key is configured (see ``is_llm_available``). """ - resolved_key, resolved_base = _resolve_llm_credentials() - model = model or MODEL_CONFIG["default"] + resolved_key, resolved_base, default_model = _resolve_llm_client_config() + model = model or default_model return ChatOpenAI( model=model, diff --git a/tests/unit/test_llm_utils.py b/tests/unit/test_llm_utils.py index 97a46c1..04d3fc8 100644 --- a/tests/unit/test_llm_utils.py +++ b/tests/unit/test_llm_utils.py @@ -23,14 +23,25 @@ from __future__ import annotations import pytest +from langchain_openai import ChatOpenAI -from skillspector.llm_utils import _resolve_llm_credentials, is_llm_available +from skillspector.llm_utils import ( + _resolve_llm_credentials, + get_chat_model, + is_llm_available, +) from skillspector.providers import resolve_provider_credentials +from skillspector.providers.nv_build import NvBuildProvider +from skillspector.providers.openai import OpenAIProvider _LLM_ENV_VARS = ( "OPENAI_API_KEY", "OPENAI_BASE_URL", "NVIDIA_INFERENCE_KEY", + "NVIDIA_INFERENCE_METADATA_KEY", + "ANTHROPIC_API_KEY", + "SKILLSPECTOR_MODEL", + "SKILLSPECTOR_PROVIDER", ) @@ -100,3 +111,37 @@ def test_returns_false_with_message_when_no_credentials(self) -> None: assert ok is False assert msg is not None assert "API key" in msg + + +class TestGetChatModel: + def test_openai_fallback_uses_openai_default_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only") + + llm = get_chat_model() + + assert _chat_model_name(llm) == OpenAIProvider.DEFAULT_MODEL + + def test_explicit_model_still_overrides_openai_fallback( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only") + + llm = get_chat_model(model="custom/model") + + assert _chat_model_name(llm) == "custom/model" + + def test_provider_credentials_use_provider_default_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("NVIDIA_INFERENCE_KEY", "nvapi-test") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai") + + llm = get_chat_model() + + assert _chat_model_name(llm) == NvBuildProvider.DEFAULT_MODEL + + +def _chat_model_name(llm: ChatOpenAI) -> str: + return str(getattr(llm, "model_name", None) or getattr(llm, "model", None))