diff --git a/test/backends/test_ollama_unit.py b/test/backends/test_ollama_unit.py new file mode 100644 index 000000000..9ade604b7 --- /dev/null +++ b/test/backends/test_ollama_unit.py @@ -0,0 +1,151 @@ +"""Unit tests for Ollama backend pure-logic helpers — no Ollama server required. + +Covers _simplify_and_merge, _make_backend_specific_and_remove, and +chat_response_delta_merge. +""" + +from unittest.mock import MagicMock, patch + +import ollama +import pytest + +from mellea.backends import ModelOption +from mellea.backends.ollama import OllamaModelBackend, chat_response_delta_merge +from mellea.core import ModelOutputThunk + + +def _make_backend(model_options: dict | None = None) -> OllamaModelBackend: + """Return an OllamaModelBackend with all network calls patched.""" + with ( + patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True), + patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True), + patch("mellea.backends.ollama.ollama.Client", return_value=MagicMock()), + patch("mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock()), + ): + return OllamaModelBackend(model_id="granite3.3:8b", model_options=model_options) + + +@pytest.fixture +def backend(): + """Return an OllamaModelBackend with no pre-set model options.""" + return _make_backend() + + +# --- Map consistency --- + + +def test_from_mellea_keys_are_subset_of_to_mellea_values(backend): + """Every key in from_mellea must appear as a value in to_mellea (maps agree).""" + to_values = set(backend.to_mellea_model_opts_map.values()) + from_keys = set(backend.from_mellea_model_opts_map.keys()) + assert from_keys <= to_values, ( + f"from_mellea has keys absent from to_mellea values: {from_keys - to_values}" + ) + + +# --- _simplify_and_merge --- + + +def test_simplify_and_merge_none_returns_empty_dict(backend): + result = backend._simplify_and_merge(None) + assert result == {} + + +def test_simplify_and_merge_all_to_mellea_entries(backend): + """Every to_mellea entry remaps to its ModelOption via _simplify_and_merge.""" + for backend_key, mellea_key in backend.to_mellea_model_opts_map.items(): + result = backend._simplify_and_merge({backend_key: 42}) + assert mellea_key in result, f"{backend_key!r} did not produce {mellea_key!r}" + assert result[mellea_key] == 42 + + +def test_simplify_and_merge_remaps_num_predict(backend): + """Hardcoded anchor: the most critical mapping for generation length.""" + result = backend._simplify_and_merge({"num_predict": 128}) + assert ModelOption.MAX_NEW_TOKENS in result + assert result[ModelOption.MAX_NEW_TOKENS] == 128 + + +def test_simplify_and_merge_per_call_overrides_backend(): + # Backend sets num_predict=128; per-call value of 256 must win. + b = _make_backend(model_options={"num_predict": 128}) + result = b._simplify_and_merge({"num_predict": 256}) + assert result[ModelOption.MAX_NEW_TOKENS] == 256 + + +# --- _make_backend_specific_and_remove --- + + +def test_make_backend_specific_all_from_mellea_entries(backend): + """Every from_mellea entry remaps to its backend key via _make_backend_specific_and_remove.""" + for mellea_key, backend_key in backend.from_mellea_model_opts_map.items(): + result = backend._make_backend_specific_and_remove({mellea_key: 42}) + assert backend_key in result, f"{mellea_key!r} did not produce {backend_key!r}" + assert result[backend_key] == 42 + + +def test_make_backend_specific_remaps_max_new_tokens(backend): + """Hardcoded anchor: the most critical mapping for generation length.""" + opts = {ModelOption.MAX_NEW_TOKENS: 64} + result = backend._make_backend_specific_and_remove(opts) + assert "num_predict" in result + assert result["num_predict"] == 64 + + +def test_make_backend_specific_removes_sentinel_keys(backend): + opts = {ModelOption.MAX_NEW_TOKENS: 32, ModelOption.SYSTEM_PROMPT: "sys"} + result = backend._make_backend_specific_and_remove(opts) + # Sentinel keys not in from_mellea_model_opts_map should be removed + assert ModelOption.SYSTEM_PROMPT not in result + + +# --- chat_response_delta_merge --- + + +def _make_delta( + content: str, + role: str = "assistant", + done: bool = False, + thinking: str | None = None, +) -> ollama.ChatResponse: + msg = ollama.Message(role=role, content=content, thinking=thinking) + return ollama.ChatResponse(model="test", created_at=None, message=msg, done=done) + + +def test_delta_merge_first_sets_chat_response(): + mot = ModelOutputThunk(value=None) + delta = _make_delta("Hello") + chat_response_delta_merge(mot, delta) + assert mot._meta["chat_response"] is delta + + +def test_delta_merge_second_appends_content(): + mot = ModelOutputThunk(value=None) + chat_response_delta_merge(mot, _make_delta("Hello")) + chat_response_delta_merge(mot, _make_delta(" world")) + assert mot._meta["chat_response"].message.content == "Hello world" + + +def test_delta_merge_done_propagated(): + mot = ModelOutputThunk(value=None) + chat_response_delta_merge(mot, _make_delta("partial", done=False)) + chat_response_delta_merge(mot, _make_delta("", done=True)) + assert mot._meta["chat_response"].done is True + + +def test_delta_merge_role_set_from_first_delta(): + mot = ModelOutputThunk(value=None) + chat_response_delta_merge(mot, _make_delta("hi", role="assistant")) + chat_response_delta_merge(mot, _make_delta(" there", role="")) + assert mot._meta["chat_response"].message.role == "assistant" + + +def test_delta_merge_thinking_concatenated(): + mot = ModelOutputThunk(value=None) + chat_response_delta_merge(mot, _make_delta("reply", thinking="step 1")) + chat_response_delta_merge(mot, _make_delta("", thinking=" step 2")) + assert mot._meta["chat_response"].message.thinking == "step 1 step 2" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_openai_unit.py b/test/backends/test_openai_unit.py new file mode 100644 index 000000000..09524df8c --- /dev/null +++ b/test/backends/test_openai_unit.py @@ -0,0 +1,172 @@ +"""Unit tests for OpenAI backend pure-logic helpers — no API calls required. + +Covers filter_openai_client_kwargs, filter_chat_completions_kwargs, +_simplify_and_merge, and _make_backend_specific_and_remove. +""" + +import pytest + +from mellea.backends import ModelOption +from mellea.backends.openai import OpenAIBackend + + +def _make_backend(model_options: dict | None = None) -> OpenAIBackend: + """Return an OpenAIBackend with a fake API key.""" + return OpenAIBackend( + model_id="gpt-4o", + api_key="fake-key", + base_url="http://localhost:9999/v1", + model_options=model_options, + ) + + +@pytest.fixture +def backend(): + """Return an OpenAIBackend with no pre-set model options.""" + return _make_backend() + + +# --- filter_openai_client_kwargs --- + + +def test_filter_openai_client_kwargs_removes_unknown(): + result = OpenAIBackend.filter_openai_client_kwargs( + api_key="sk-test", unknown_param="x" + ) + assert "api_key" in result + assert "unknown_param" not in result + + +def test_filter_openai_client_kwargs_known_params(): + result = OpenAIBackend.filter_openai_client_kwargs( + api_key="sk-test", base_url="http://localhost", timeout=30 + ) + assert "api_key" in result + assert "base_url" in result + + +def test_filter_openai_client_kwargs_empty(): + result = OpenAIBackend.filter_openai_client_kwargs() + assert result == {} + + +# --- filter_chat_completions_kwargs --- + + +def test_filter_chat_completions_keeps_valid_params(backend): + result = backend.filter_chat_completions_kwargs( + {"model": "gpt-4o", "temperature": 0.7, "unknown_option": True} + ) + assert "model" in result + assert "temperature" in result + assert "unknown_option" not in result + + +def test_filter_chat_completions_empty(backend): + result = backend.filter_chat_completions_kwargs({}) + assert result == {} + + +def test_filter_chat_completions_max_tokens(backend): + result = backend.filter_chat_completions_kwargs({"max_completion_tokens": 100}) + assert "max_completion_tokens" in result + + +# --- Map consistency --- + + +@pytest.mark.parametrize("context", ["chats", "completions"]) +def test_from_mellea_keys_are_subset_of_to_mellea_values(backend, context): + """Every key in from_mellea must appear as a value in to_mellea (maps agree).""" + to_map = getattr(backend, f"to_mellea_model_opts_map_{context}") + from_map = getattr(backend, f"from_mellea_model_opts_map_{context}") + to_values = set(to_map.values()) + from_keys = set(from_map.keys()) + assert from_keys <= to_values, ( + f"from_mellea_{context} has keys absent from to_mellea values: {from_keys - to_values}" + ) + + +# --- _simplify_and_merge --- + + +def test_simplify_and_merge_none_returns_empty_dict(backend): + result = backend._simplify_and_merge(None, is_chat_context=True) + assert result == {} + + +@pytest.mark.parametrize("context", ["chats", "completions"]) +def test_simplify_and_merge_all_to_mellea_entries(backend, context): + """Every to_mellea entry remaps to its ModelOption via _simplify_and_merge.""" + is_chat = context == "chats" + to_map = getattr(backend, f"to_mellea_model_opts_map_{context}") + for backend_key, mellea_key in to_map.items(): + result = backend._simplify_and_merge({backend_key: 42}, is_chat_context=is_chat) + assert mellea_key in result, f"{backend_key!r} did not produce {mellea_key!r}" + assert result[mellea_key] == 42 + + +def test_simplify_and_merge_remaps_max_completion_tokens(backend): + """Hardcoded anchor: the critical chat API mapping for generation length.""" + result = backend._simplify_and_merge( + {"max_completion_tokens": 256}, is_chat_context=True + ) + assert ModelOption.MAX_NEW_TOKENS in result + assert result[ModelOption.MAX_NEW_TOKENS] == 256 + + +def test_simplify_and_merge_completions_remaps_max_tokens(backend): + """Hardcoded anchor: completions API uses a different key for the same sentinel.""" + result = backend._simplify_and_merge({"max_tokens": 100}, is_chat_context=False) + assert ModelOption.MAX_NEW_TOKENS in result + assert result[ModelOption.MAX_NEW_TOKENS] == 100 + + +def test_simplify_and_merge_per_call_overrides_backend(): + # Backend sets max_completion_tokens=128; per-call value of 512 must win. + b = _make_backend(model_options={"max_completion_tokens": 128}) + result = b._simplify_and_merge({"max_completion_tokens": 512}, is_chat_context=True) + assert result[ModelOption.MAX_NEW_TOKENS] == 512 + + +# --- _make_backend_specific_and_remove --- + + +@pytest.mark.parametrize("context", ["chats", "completions"]) +def test_make_backend_specific_all_from_mellea_entries(backend, context): + """Every from_mellea entry remaps to its backend key via _make_backend_specific_and_remove.""" + is_chat = context == "chats" + from_map = getattr(backend, f"from_mellea_model_opts_map_{context}") + for mellea_key, backend_key in from_map.items(): + result = backend._make_backend_specific_and_remove( + {mellea_key: 42}, is_chat_context=is_chat + ) + assert backend_key in result, f"{mellea_key!r} did not produce {backend_key!r}" + assert result[backend_key] == 42 + + +def test_make_backend_specific_chat_remaps_max_new_tokens(backend): + """Hardcoded anchor: chat API maps MAX_NEW_TOKENS → max_completion_tokens.""" + opts = {ModelOption.MAX_NEW_TOKENS: 200} + result = backend._make_backend_specific_and_remove(opts, is_chat_context=True) + assert "max_completion_tokens" in result + assert result["max_completion_tokens"] == 200 + + +def test_make_backend_specific_completions_remaps_max_new_tokens(backend): + """Hardcoded anchor: completions API maps MAX_NEW_TOKENS → max_tokens.""" + opts = {ModelOption.MAX_NEW_TOKENS: 100} + result = backend._make_backend_specific_and_remove(opts, is_chat_context=False) + assert "max_tokens" in result + assert result["max_tokens"] == 100 + + +def test_make_backend_specific_unknown_mellea_keys_removed(backend): + opts = {ModelOption.TOOLS: ["tool1"], ModelOption.SYSTEM_PROMPT: "sys"} + result = backend._make_backend_specific_and_remove(opts, is_chat_context=True) + # SYSTEM_PROMPT has no from_mellea mapping — should be removed + assert ModelOption.SYSTEM_PROMPT not in result + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/backends/test_utils.py b/test/backends/test_utils.py new file mode 100644 index 000000000..764c8c9c7 --- /dev/null +++ b/test/backends/test_utils.py @@ -0,0 +1,164 @@ +"""Unit tests for backends/utils.py — get_value accessor and to_tool_calls parser.""" + +from dataclasses import dataclass + +import pytest + +from mellea.backends.tools import MelleaTool +from mellea.backends.utils import get_value, to_tool_calls +from mellea.core import ModelToolCall + +# --- get_value --- + + +def test_get_value_dict_present(): + assert get_value({"a": 1, "b": 2}, "a") == 1 + + +def test_get_value_dict_missing(): + assert get_value({"a": 1}, "missing") is None + + +def test_get_value_object_attribute(): + obj = type("Obj", (), {"x": "hello"})() + assert get_value(obj, "x") == "hello" + + +def test_get_value_object_missing_attribute(): + obj = type("Obj", (), {})() + assert get_value(obj, "nonexistent") is None + + +def test_get_value_dict_none_value(): + # Explicitly stored None should come back as None (same as get()) + assert get_value({"k": None}, "k") is None + + +@dataclass +class _DC: + score: float + label: str + + +def test_get_value_dataclass(): + dc = _DC(score=0.9, label="positive") + assert get_value(dc, "score") == 0.9 + assert get_value(dc, "label") == "positive" + + +# --- to_tool_calls --- + + +def _make_tool_registry() -> dict: + def add(x: int, y: int) -> int: + """Add two integers.""" + return x + y + + def greet(name: str) -> str: + """Greet a person.""" + return f"Hello, {name}!" + + return { + "add": MelleaTool.from_callable(add), + "greet": MelleaTool.from_callable(greet), + } + + +def _tool_call_json(name: str, args: dict) -> str: + import json + + return json.dumps([{"name": name, "arguments": args}]) + + +def test_to_tool_calls_single_call(): + registry = _make_tool_registry() + raw = _tool_call_json("add", {"x": 3, "y": 4}) + result = to_tool_calls(registry, raw) + assert result is not None + assert "add" in result + mtc = result["add"] + assert isinstance(mtc, ModelToolCall) + assert mtc.name == "add" + assert mtc.args == {"x": 3, "y": 4} + + +def test_to_tool_calls_returns_none_when_no_calls(): + registry = _make_tool_registry() + result = to_tool_calls(registry, "no tool call here") + assert result is None + + +def test_to_tool_calls_unknown_tool_skipped(): + registry = _make_tool_registry() + raw = _tool_call_json("nonexistent_fn", {"arg": "val"}) + # Unknown tool is skipped — result should be None (empty dict → None) + result = to_tool_calls(registry, raw) + assert result is None + + +def test_to_tool_calls_empty_params_cleared(): + """When the tool has no parameters, hallucinated args should be stripped.""" + + def noop() -> str: + """Does nothing.""" + return "done" + + registry = {"noop": MelleaTool.from_callable(noop)} + raw = _tool_call_json("noop", {"hallucinated": "arg"}) + result = to_tool_calls(registry, raw) + assert result is not None + assert result["noop"].args == {} + + +def test_to_tool_calls_string_arg_coerced_to_int(): + """validate_tool_arguments coerces strings to int when strict=False.""" + registry = _make_tool_registry() + raw = _tool_call_json("add", {"x": "5", "y": "10"}) + result = to_tool_calls(registry, raw) + assert result is not None + assert result["add"].args["x"] == 5 + assert result["add"].args["y"] == 10 + + +# --- to_chat --- + + +def test_to_chat_basic_message(): + from mellea.backends.utils import to_chat + from mellea.formatters.template_formatter import TemplateFormatter as ChatFormatter + from mellea.stdlib.components import Message + from mellea.stdlib.context import ChatContext + + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + action = Message("user", "next question") + formatter = ChatFormatter(model_id="test") + + result = to_chat(action, ctx, formatter, system_prompt=None) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["role"] == "user" + assert result[0]["content"] == "hello" + assert result[1]["role"] == "user" + assert result[1]["content"] == "next question" + + +def test_to_chat_with_system_prompt(): + from mellea.backends.utils import to_chat + from mellea.formatters.template_formatter import TemplateFormatter as ChatFormatter + from mellea.stdlib.components import Message + from mellea.stdlib.context import ChatContext + + ctx = ChatContext() + ctx = ctx.add(Message("user", "hi")) + action = Message("user", "q") + formatter = ChatFormatter(model_id="test") + + result = to_chat(action, ctx, formatter, system_prompt="You are helpful.") + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are helpful." + assert len(result) == 3 # system + user context + user action + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/core/test_base.py b/test/core/test_base.py index cada2f42e..ff89148aa 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -1,8 +1,11 @@ +import base64 +import io from typing import Any import pytest +from PIL import Image as PILImage -from mellea.core import CBlock, Component, ModelOutputThunk +from mellea.core import CBlock, Component, ImageBlock, ModelOutputThunk from mellea.stdlib.components import Message @@ -66,5 +69,93 @@ def __init__(self, msg: Message) -> None: assert result.parsed_repr.content == "result value" +# --- CBlock edge cases --- + + +def test_cblock_non_string_value_raises(): + with pytest.raises(TypeError, match="should always be a string or None"): + CBlock(value=42) # type: ignore + + +def test_cblock_none_value_allowed(): + cb = CBlock(value=None) + assert str(cb) == "" + + +def test_cblock_value_setter(): + cb = CBlock(value="old") + cb.value = "new" + assert cb.value == "new" + + +# --- ImageBlock.is_valid_base64_png --- + + +def _make_png_b64() -> str: + img = PILImage.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode() + + +def test_image_block_valid_png(): + b64 = _make_png_b64() + assert ImageBlock.is_valid_base64_png(b64) is True + + +def test_image_block_invalid_base64_returns_false(): + assert ImageBlock.is_valid_base64_png("not-base64!!!") is False + + +def test_image_block_valid_base64_but_not_png(): + # Base64-encoded JPEG magic bytes + jpg_magic = base64.b64encode(b"\xff\xd8\xff" + b"\x00" * 20).decode() + assert ImageBlock.is_valid_base64_png(jpg_magic) is False + + +def test_image_block_data_uri_prefix_stripped(): + b64 = _make_png_b64() + data_uri = f"data:image/png;base64,{b64}" + assert ImageBlock.is_valid_base64_png(data_uri) is True + + +def test_image_block_invalid_value_raises(): + with pytest.raises(AssertionError, match="Invalid base64"): + ImageBlock(value="not-a-png") + + +# --- ModelOutputThunk._copy_from --- + + +def test_mot_copy_from_copies_underlying_value(): + a = ModelOutputThunk(value=None) + b = ModelOutputThunk(value="copied") + a._copy_from(b) + # _copy_from copies _underlying_value (not _computed), so check raw field + assert a._underlying_value == "copied" + + +def test_mot_copy_from_copies_meta(): + a = ModelOutputThunk(value=None) + b = ModelOutputThunk(value="x", meta={"key": "val"}) + a._copy_from(b) + assert a._meta["key"] == "val" + + +def test_mot_copy_from_copies_tool_calls(): + a = ModelOutputThunk(value=None) + b = ModelOutputThunk(value="x", tool_calls={"fn": None}) + a._copy_from(b) + assert a.tool_calls == {"fn": None} + + +def test_mot_copy_from_copies_usage(): + a = ModelOutputThunk(value=None) + b = ModelOutputThunk(value="x") + b.usage = {"prompt_tokens": 10} + a._copy_from(b) + assert a.usage == {"prompt_tokens": 10} + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/core/test_requirement_helpers.py b/test/core/test_requirement_helpers.py new file mode 100644 index 000000000..9ae3f71b5 --- /dev/null +++ b/test/core/test_requirement_helpers.py @@ -0,0 +1,91 @@ +"""Unit tests for core/requirement.py pure helpers — ValidationResult, default_output_to_bool.""" + +import pytest + +from mellea.core import CBlock, ModelOutputThunk +from mellea.core.requirement import ValidationResult, default_output_to_bool + +# --- ValidationResult --- + + +def test_validation_result_pass(): + r = ValidationResult(result=True) + assert r.as_bool() is True + assert bool(r) is True + + +def test_validation_result_fail(): + r = ValidationResult(result=False) + assert r.as_bool() is False + assert bool(r) is False + + +def test_validation_result_reason(): + r = ValidationResult(result=True, reason="looks good") + assert r.reason == "looks good" + + +def test_validation_result_score(): + r = ValidationResult(result=True, score=0.95) + assert r.score == pytest.approx(0.95) + + +def test_validation_result_thunk(): + mot = ModelOutputThunk(value="x") + r = ValidationResult(result=True, thunk=mot) + assert r.thunk is mot + + +def test_validation_result_context(): + from mellea.stdlib.context import SimpleContext + + ctx = SimpleContext() + r = ValidationResult(result=True, context=ctx) + assert r.context is ctx + + +def test_validation_result_defaults_none(): + r = ValidationResult(result=False) + assert r.reason is None + assert r.score is None + assert r.thunk is None + assert r.context is None + + +# --- default_output_to_bool --- + + +def test_yes_exact_passes(): + assert default_output_to_bool(CBlock("yes")) is True + + +def test_yes_uppercase_passes(): + assert default_output_to_bool(CBlock("YES")) is True + + +def test_y_passes(): + assert default_output_to_bool(CBlock("y")) is True + + +def test_yes_in_sentence(): + assert default_output_to_bool(CBlock("Yes, it meets the requirement.")) is True + + +def test_no_fails(): + assert default_output_to_bool(CBlock("no")) is False + + +def test_empty_string_fails(): + assert default_output_to_bool(CBlock("")) is False + + +def test_random_text_fails(): + assert default_output_to_bool(CBlock("the output looks reasonable")) is False + + +def test_plain_string_yes(): + assert default_output_to_bool("YES") is True # type: ignore + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/formatters/granite/base/__init__.py b/test/formatters/granite/base/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/formatters/granite/base/test_base_util.py b/test/formatters/granite/base/test_base_util.py new file mode 100644 index 000000000..0a78b3e37 --- /dev/null +++ b/test/formatters/granite/base/test_base_util.py @@ -0,0 +1,61 @@ +"""Unit tests for formatters/granite/base/util.py pure helpers.""" + +import pytest + +from mellea.formatters.granite.base.util import find_substring_in_text + +# --- find_substring_in_text --- + + +def test_find_single_match(): + result = find_substring_in_text("hello", "say hello world") + assert len(result) == 1 + assert result[0]["begin_idx"] == 4 + assert result[0]["end_idx"] == 9 + + +def test_find_multiple_matches(): + result = find_substring_in_text("ab", "ababab") + assert len(result) == 3 + # Verify positions are non-overlapping + assert result[0]["begin_idx"] == 0 + assert result[1]["begin_idx"] == 2 + assert result[2]["begin_idx"] == 4 + + +def test_find_no_match_returns_empty(): + result = find_substring_in_text("xyz", "hello world") + assert result == [] + + +def test_find_empty_text_returns_empty(): + result = find_substring_in_text("hello", "") + assert result == [] + + +def test_find_at_start(): + result = find_substring_in_text("the", "the quick fox") + assert result[0]["begin_idx"] == 0 + + +def test_find_at_end(): + result = find_substring_in_text("fox", "the quick fox") + assert result[-1]["end_idx"] == len("the quick fox") + + +def test_find_full_text_match(): + result = find_substring_in_text("exact", "exact") + assert len(result) == 1 + assert result[0]["begin_idx"] == 0 + assert result[0]["end_idx"] == 5 + + +def test_find_special_regex_chars_escaped(): + # Dots in the substring should be treated literally + result = find_substring_in_text("a.b", "a.b and axb") + assert len(result) == 1 + assert result[0]["begin_idx"] == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/formatters/granite/test_granite32_output.py b/test/formatters/granite/test_granite32_output.py index a5d0aa89f..356588628 100644 --- a/test/formatters/granite/test_granite32_output.py +++ b/test/formatters/granite/test_granite32_output.py @@ -29,6 +29,7 @@ Granite3Controls, Granite3Kwargs, ) +from test.predicates import require_nltk_data # --------------------------------------------------------------------------- # _parse_citations_text @@ -285,6 +286,7 @@ def test_invalid_tool_call_falls_through(self): assert result.tool_calls == [] assert isinstance(result.content, str) + @require_nltk_data() def test_citations_and_hallucinations_pipeline(self): proc = Granite32OutputProcessor() model_output = ( diff --git a/test/formatters/granite/test_granite33_output.py b/test/formatters/granite/test_granite33_output.py index 8ce4e8827..a283d09ae 100644 --- a/test/formatters/granite/test_granite33_output.py +++ b/test/formatters/granite/test_granite33_output.py @@ -31,6 +31,7 @@ Granite3Controls, Granite3Kwargs, ) +from test.predicates import require_nltk_data # --------------------------------------------------------------------------- # _parse_citations_text @@ -256,6 +257,7 @@ def test_tool_call_parsing(self): assert len(result.tool_calls) == 1 assert result.tool_calls[0].name == "search" + @require_nltk_data() def test_raw_content_set_when_different(self): proc = Granite33OutputProcessor() model_output = ( diff --git a/test/predicates.py b/test/predicates.py index 2c00f1865..26f34c507 100644 --- a/test/predicates.py +++ b/test/predicates.py @@ -156,6 +156,52 @@ def test_watsonx_generate(): ... return pytest.mark.skipif(False, reason="") +# --------------------------------------------------------------------------- +# NLTK data +# --------------------------------------------------------------------------- + + +def _nltk_data_available() -> tuple[bool, str]: + """Check whether nltk is installed *and* punkt_tab data is downloaded. + + Returns a (available, reason) tuple so the skip message is specific: + - nltk not installed → "nltk not installed — install mellea[formatters]" + - punkt_tab missing → "NLTK punkt_tab data not downloaded — run: python -m nltk.downloader punkt_tab" + - both ok → (True, "") + """ + try: + import nltk + except ImportError: + return False, "nltk not installed — install mellea[formatters]" + + try: + import nltk.data + + nltk.data.find("tokenizers/punkt_tab") + except LookupError: + return ( + False, + "NLTK punkt_tab data not downloaded — run: python -m nltk.downloader punkt_tab", + ) + + return True, "" + + +def require_nltk_data(): + """Skip unless nltk is installed and punkt_tab tokenizer data is available. + + Distinguishes between the two failure modes so the skip reason is actionable:: + + @require_nltk_data() + def test_citation_spans(): ... + + # Module-level (skips all tests in the file): + pytestmark = [require_nltk_data()] + """ + available, reason = _nltk_data_available() + return pytest.mark.skipif(not available, reason=reason) + + # --------------------------------------------------------------------------- # Optional dependencies # --------------------------------------------------------------------------- diff --git a/test/stdlib/components/test_chat.py b/test/stdlib/components/test_chat.py index 66ebb9fc2..2319aafff 100644 --- a/test/stdlib/components/test_chat.py +++ b/test/stdlib/components/test_chat.py @@ -1,7 +1,10 @@ import pytest +from mellea.core import CBlock, ModelOutputThunk, TemplateRepresentation from mellea.helpers import messages_to_docs from mellea.stdlib.components import Document, Message +from mellea.stdlib.components.chat import ToolMessage, as_chat_history +from mellea.stdlib.context import ChatContext def test_message_with_docs(): @@ -22,5 +25,231 @@ def test_message_with_docs(): assert tr.args["documents"] +# --- Message init --- + + +def test_message_basic_fields(): + msg = Message("user", "hello") + assert msg.role == "user" + assert msg.content == "hello" + assert msg._images is None + assert msg._docs is None + + +def test_message_content_block_created(): + msg = Message("assistant", "response") + assert isinstance(msg._content_cblock, CBlock) + assert msg._content_cblock.value == "response" + + +def test_message_repr(): + msg = Message("user", "hi there") + r = repr(msg) + assert 'role="user"' in r + assert 'content="hi there"' in r + + +# --- Message images property --- + + +def test_message_images_none(): + msg = Message("user", "text") + assert msg.images is None + + +# --- Message parts() --- + + +def test_message_parts_no_docs_no_images(): + msg = Message("user", "text") + parts = msg.parts() + assert len(parts) == 1 + assert parts[0] is msg._content_cblock + + +def test_message_parts_with_docs(): + doc = Document("text", "title") + msg = Message("user", "hi", documents=[doc]) + parts = msg.parts() + assert doc in parts + + +# --- Message format_for_llm --- + + +def test_message_format_for_llm_structure(): + msg = Message("user", "hello") + tr = msg.format_for_llm() + assert isinstance(tr, TemplateRepresentation) + assert tr.args["role"] == "user" + assert tr.args["content"] is msg._content_cblock + assert tr.args["images"] is None + assert tr.args["documents"] is None + + +# --- Message._parse — no tool calls --- + + +def test_parse_plain_value_no_meta(): + msg = Message("user", "original") + mot = ModelOutputThunk(value="model response") + result = msg._parse(mot) + assert isinstance(result, Message) + assert result.role == "assistant" + assert result.content == "model response" + + +def test_parse_ollama_chat_response(): + msg = Message("user", "q") + mot = ModelOutputThunk(value="v") + fake_response = type( + "Resp", + (), + { + "message": type( + "Msg", (), {"role": "assistant", "content": "ollama answer"} + )() + }, + )() + mot._meta["chat_response"] = fake_response + result = msg._parse(mot) + assert result.role == "assistant" + assert result.content == "ollama answer" + + +def test_parse_openai_chat_response(): + msg = Message("user", "q") + mot = ModelOutputThunk(value="v") + mot._meta["oai_chat_response"] = { + "choices": [{"message": {"role": "assistant", "content": "openai answer"}}] + } + result = msg._parse(mot) + assert result.role == "assistant" + assert result.content == "openai answer" + + +# --- Message._parse — with tool calls --- + + +def test_parse_tool_calls_ollama(): + msg = Message("user", "q") + mot = ModelOutputThunk(value="v", tool_calls={"some_fn": None}) + fake_calls = [{"name": "some_fn"}] + fake_response = type( + "Resp", + (), + {"message": type("Msg", (), {"role": "assistant", "tool_calls": fake_calls})()}, + )() + mot._meta["chat_response"] = fake_response + result = msg._parse(mot) + assert result.role == "assistant" + assert "some_fn" in result.content + + +def test_parse_tool_calls_openai(): + msg = Message("user", "q") + mot = ModelOutputThunk(value="v", tool_calls={"fn": None}) + mot._meta["oai_chat_response"] = { + "choices": [ + { + "message": { + "role": "assistant", + "tool_calls": [{"function": {"name": "fn"}}], + } + } + ] + } + result = msg._parse(mot) + assert result.role == "assistant" + + +def test_parse_tool_calls_fallback_uses_value(): + """No chat_response or oai_chat_response — falls back to computed.value.""" + msg = Message("user", "q") + mot = ModelOutputThunk(value="fn()", tool_calls={"fn": None}) + result = msg._parse(mot) + assert result.role == "assistant" + assert result.content == "fn()" + + +# --- ToolMessage --- + + +def test_tool_message_fields(): + from mellea.core import ModelToolCall + + fake_tool = type("T", (), {"as_json_tool": {}})() + mtc = ModelToolCall("my_tool", fake_tool, {"x": 1}) + tm = ToolMessage( + role="tool", + content='{"result": 42}', + tool_output=42, + name="my_tool", + args={"x": 1}, + tool=mtc, + ) + assert tm.role == "tool" + assert tm.name == "my_tool" + assert tm.arguments == {"x": 1} + + +def test_tool_message_format_for_llm_includes_name(): + from mellea.core import ModelToolCall + + fake_tool = type("T", (), {"as_json_tool": {}})() + mtc = ModelToolCall("my_tool", fake_tool, {}) + tm = ToolMessage( + role="tool", + content="output", + tool_output="output", + name="my_tool", + args={}, + tool=mtc, + ) + tr = tm.format_for_llm() + assert isinstance(tr, TemplateRepresentation) + assert tr.args["name"] == "my_tool" + + +def test_tool_message_repr(): + from mellea.core import ModelToolCall + + fake_tool = type("T", (), {"as_json_tool": {}})() + mtc = ModelToolCall("fn", fake_tool, {}) + tm = ToolMessage("tool", "out", "out", "fn", {}, mtc) + r = repr(tm) + assert 'name="fn"' in r + + +# --- as_chat_history --- + + +def test_as_chat_history_messages_only(): + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + ctx = ctx.add(Message("assistant", "hi")) + history = as_chat_history(ctx) + assert len(history) == 2 + assert history[0].role == "user" + assert history[1].role == "assistant" + + +def test_as_chat_history_empty(): + ctx = ChatContext() + history = as_chat_history(ctx) + assert history == [] + + +def test_as_chat_history_with_parsed_mot(): + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + mot = ModelOutputThunk(value="reply") + mot.parsed_repr = Message("assistant", "reply") + ctx = ctx.add(mot) + history = as_chat_history(ctx) + assert len(history) == 2 + assert history[1].content == "reply" + + if __name__ == "__main__": - pytest.main([__file__]) + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/components/test_genstub_unit.py b/test/stdlib/components/test_genstub_unit.py new file mode 100644 index 000000000..20d046814 --- /dev/null +++ b/test/stdlib/components/test_genstub_unit.py @@ -0,0 +1,326 @@ +"""Unit tests for genstub pure-logic helpers — no backend, no LLM required. + +Covers describe_function, get_argument, bind_function_arguments, +create_response_format, GenerativeStub.format_for_llm, and @generative routing. +""" + +from typing import Literal + +import pytest + +from mellea import generative +from mellea.core import TemplateRepresentation, ValidationResult +from mellea.stdlib.components.genstub import ( + ArgPreconditionRequirement, + Arguments, + AsyncGenerativeStub, + Function, + PreconditionException, + SyncGenerativeStub, + bind_function_arguments, + create_response_format, + describe_function, + get_argument, +) +from mellea.stdlib.requirements.requirement import reqify + +# --- describe_function --- + + +def test_describe_function_name(): + def greet(name: str) -> str: + """Say hello.""" + return f"Hello {name}" + + result = describe_function(greet) + assert result["name"] == "greet" + + +def test_describe_function_signature_includes_params(): + def add(x: int, y: int) -> int: + return x + y + + result = describe_function(add) + assert "x" in result["signature"] + assert "y" in result["signature"] + + +def test_describe_function_docstring(): + def noop() -> None: + """Does nothing.""" + + result = describe_function(noop) + assert result["docstring"] == "Does nothing." + + +def test_describe_function_no_docstring(): + def bare(): + pass + + result = describe_function(bare) + assert result["docstring"] is None + + +# --- get_argument --- + + +def test_get_argument_string_value_quoted(): + def fn(name: str) -> None: + pass + + arg = get_argument(fn, "name", "Alice") + assert arg._argument_dict["value"] == '"Alice"' + assert arg._argument_dict["name"] == "name" + + +def test_get_argument_int_value_not_quoted(): + def fn(count: int) -> None: + pass + + arg = get_argument(fn, "count", 42) + assert arg._argument_dict["value"] == 42 + assert "int" in str(arg._argument_dict["annotation"]) + + +def test_get_argument_no_annotation_falls_back_to_runtime_type(): + # No annotation on kwargs — should fall back to type(val) + def fn(**kwargs) -> None: + pass + + arg = get_argument(fn, "x", 3.14) + assert "float" in str(arg._argument_dict["annotation"]) + + +# --- bind_function_arguments --- + + +def test_bind_function_arguments_basic(): + def fn(x: int, y: int) -> int: + return x + y + + result = bind_function_arguments(fn, x=1, y=2) + assert result == {"x": 1, "y": 2} + + +def test_bind_function_arguments_with_defaults(): + def fn(x: int, y: int = 10) -> int: + return x + y + + result = bind_function_arguments(fn, x=5) + assert result == {"x": 5, "y": 10} + + +def test_bind_function_arguments_missing_required_raises(): + def fn(x: int, y: int) -> int: + return x + y + + with pytest.raises(TypeError, match="missing required parameter"): + bind_function_arguments(fn, x=1) + + +def test_bind_function_arguments_no_params(): + def fn() -> str: + return "hi" + + result = bind_function_arguments(fn) + assert result == {} + + +# --- create_response_format --- + + +def test_create_response_format_class_name_derived_from_func(): + def get_sentiment() -> str: ... + + model = create_response_format(get_sentiment) + assert "GetSentiment" in model.__name__ + + +def test_create_response_format_result_field_accessible(): + def score_text() -> float: ... + + model = create_response_format(score_text) + instance = model(result=0.9) + assert instance.result == 0.9 + + +def test_create_response_format_literal_type(): + def classify() -> Literal["pos", "neg"]: ... + + model = create_response_format(classify) + instance = model(result="pos") + assert instance.result == "pos" + + +# --- GenerativeStub.format_for_llm --- + + +def test_generative_stub_format_for_llm_returns_template_repr(): + @generative + def summarise(text: str) -> str: + """Summarise the given text.""" + + result = summarise.format_for_llm() + assert isinstance(result, TemplateRepresentation) + + +def test_generative_stub_format_for_llm_includes_function_name(): + @generative + def my_function(x: int) -> int: ... + + result = my_function.format_for_llm() + assert result.args["function"]["name"] == "my_function" + + +def test_generative_stub_format_for_llm_includes_docstring(): + @generative + def documented() -> str: + """This is the docstring.""" + + result = documented.format_for_llm() + assert result.args["function"]["docstring"] == "This is the docstring." + + +def test_generative_stub_format_for_llm_no_args_until_called(): + @generative + def fn() -> str: ... + + result = fn.format_for_llm() + assert result.args["arguments"] is None + + +# --- @generative decorator routing --- + + +def test_generative_sync_function_returns_sync_stub(): + @generative + def sync_fn() -> str: ... + + assert isinstance(sync_fn, SyncGenerativeStub) + + +def test_generative_async_function_returns_async_stub(): + @generative + async def async_fn() -> str: ... + + assert isinstance(async_fn, AsyncGenerativeStub) + + +def test_generative_disallowed_param_name_raises(): + with pytest.raises(ValueError, match="disallowed parameter names"): + + @generative + def fn(backend: str) -> str: ... + + +# --- Arguments (CBlock subclass rendering bound args) --- + + +def test_arguments_renders_text(): + def fn(name: str, count: int) -> None: + pass + + args = [get_argument(fn, "name", "Alice"), get_argument(fn, "count", 3)] + block = Arguments(args) + assert "name" in block.value + assert "count" in block.value + + +def test_arguments_stores_meta_by_name(): + def fn(x: int) -> None: + pass + + args = [get_argument(fn, "x", 5)] + block = Arguments(args) + assert "x" in block._meta + + +def test_arguments_empty_list(): + block = Arguments([]) + assert block.value == "" + + +# --- Function (wraps callable with metadata) --- + + +def test_function_stores_callable(): + def greet(name: str) -> str: + """Say hi.""" + return f"hi {name}" + + f = Function(greet) + assert f._func is greet + assert f._function_dict["name"] == "greet" + assert f._function_dict["docstring"] == "Say hi." + + +# --- ArgPreconditionRequirement (requirement wrapper) --- + + +def test_arg_precondition_delegates_description(): + req = reqify("must be non-empty") + wrapper = ArgPreconditionRequirement(req) + assert wrapper.description == req.description + + +def test_arg_precondition_copy(): + from copy import copy + + req = reqify("be valid") + wrapper = ArgPreconditionRequirement(req) + copied = copy(wrapper) + assert isinstance(copied, ArgPreconditionRequirement) + assert copied.req is req + + +def test_arg_precondition_deepcopy(): + from copy import deepcopy + + req = reqify("be clean") + wrapper = ArgPreconditionRequirement(req) + cloned = deepcopy(wrapper) + assert isinstance(cloned, ArgPreconditionRequirement) + assert cloned.description == req.description + + +# --- PreconditionException --- + + +def test_precondition_exception_message(): + vr = ValidationResult(result=False, reason="failed check") + exc = PreconditionException("precondition failed", [vr]) + assert "precondition failed" in str(exc) + assert exc.validation == [vr] + + +# --- GenerativeStub._parse --- + + +def test_genstub_parse_json_to_result(): + import json + + from mellea.core import ModelOutputThunk + + @generative + def classify(text: str) -> str: ... + + mot = ModelOutputThunk(value=json.dumps({"result": "positive"})) + parsed = classify._parse(mot) + assert parsed == "positive" + + +def test_genstub_parse_int_result(): + import json + + from mellea.core import ModelOutputThunk + + @generative + def compute(x: int) -> int: ... + + mot = ModelOutputThunk(value=json.dumps({"result": 42})) + parsed = compute._parse(mot) + assert parsed == 42 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/components/test_instruction.py b/test/stdlib/components/test_instruction.py new file mode 100644 index 000000000..aef2b6984 --- /dev/null +++ b/test/stdlib/components/test_instruction.py @@ -0,0 +1,264 @@ +"""Unit tests for the Instruction component — init, jinja rendering, copy/repair, parts, format.""" + +import pytest + +from mellea.core import CBlock, ModelOutputThunk, Requirement, TemplateRepresentation +from mellea.stdlib.components.instruction import Instruction + +# --- basic init --- + + +def test_init_minimal(): + ins = Instruction(description="summarise the text") + assert ins._description is not None + assert str(ins._description) == "summarise the text" + assert ins._requirements == [] + assert ins._icl_examples == [] + assert ins._grounding_context == {} + assert ins._repair_string is None + + +def test_init_no_args(): + ins = Instruction() + assert ins._description is None + assert ins._requirements == [] + + +def test_init_converts_string_description_to_cblock(): + ins = Instruction(description="hello") + assert isinstance(ins._description, CBlock) + + +def test_init_accepts_cblock_description(): + cb = CBlock("already a block") + ins = Instruction(description=cb) + assert ins._description is cb + + +def test_init_string_requirements_converted(): + ins = Instruction(requirements=["must be concise", "must be accurate"]) + assert len(ins._requirements) == 2 + for r in ins._requirements: + assert isinstance(r, Requirement) + + +def test_init_requirement_objects_preserved(): + r = Requirement(description="no profanity") + ins = Instruction(requirements=[r]) + assert ins._requirements[0].description == "no profanity" + + +def test_init_grounding_context_strings_blockified(): + ins = Instruction(grounding_context={"doc1": "some content"}) + assert isinstance(ins._grounding_context["doc1"], CBlock) + + +def test_init_prefix_converted(): + ins = Instruction(prefix="Answer:") + assert isinstance(ins._prefix, CBlock) + + +def test_init_output_prefix_raises(): + """output_prefix is currently unsupported; should raise AssertionError.""" + with pytest.raises( + AssertionError, match="output_prefix is not currently supported" + ): + Instruction(user_variables={"x": "y"}, output_prefix="Result:") + + +# --- apply_user_dict_from_jinja --- + + +def test_jinja_simple_substitution(): + result = Instruction.apply_user_dict_from_jinja( + {"name": "world"}, "Hello {{ name }}!" + ) + assert result == "Hello world!" + + +def test_jinja_multiple_variables(): + result = Instruction.apply_user_dict_from_jinja( + {"a": "foo", "b": "bar"}, "{{ a }} and {{ b }}" + ) + assert result == "foo and bar" + + +def test_jinja_missing_variable_renders_empty(): + result = Instruction.apply_user_dict_from_jinja({}, "Hello {{ name }}!") + assert result == "Hello !" + + +def test_jinja_no_variables(): + result = Instruction.apply_user_dict_from_jinja({}, "plain string") + assert result == "plain string" + + +# --- user_variables applied to fields --- + + +def test_user_variables_applied_to_description(): + ins = Instruction( + description="Task: {{ task }}", user_variables={"task": "translate"} + ) + assert str(ins._description) == "Task: translate" + + +def test_user_variables_applied_to_prefix(): + ins = Instruction( + prefix="{{ prefix_word }}:", user_variables={"prefix_word": "Answer"} + ) + assert str(ins._prefix) == "Answer:" + + +def test_user_variables_applied_to_requirements(): + ins = Instruction( + requirements=["must be in {{ lang }}"], user_variables={"lang": "French"} + ) + assert ins._requirements[0].description == "must be in French" + + +def test_user_variables_applied_to_icl_examples(): + ins = Instruction(icl_examples=["Example: {{ ex }}"], user_variables={"ex": "blue"}) + assert str(ins._icl_examples[0]) == "Example: blue" + + +def test_user_variables_applied_to_grounding_context(): + ins = Instruction( + grounding_context={"doc": "See {{ ref }}"}, user_variables={"ref": "section 3"} + ) + assert str(ins._grounding_context["doc"]) == "See section 3" + + +def test_user_variables_description_must_be_string(): + with pytest.raises(AssertionError, match="description must be a string"): + Instruction(description=CBlock("not a string"), user_variables={"x": "y"}) + + +def test_user_variables_requirement_object_description_rendered(): + r = Requirement(description="must be in {{ lang }}") + ins = Instruction(requirements=[r], user_variables={"lang": "Spanish"}) + assert ins._requirements[0].description == "must be in Spanish" + + +# --- parts() --- + + +def test_parts_includes_description(): + ins = Instruction(description="do something") + parts = ins.parts() + assert ins._description in parts + + +def test_parts_includes_requirements(): + r = Requirement(description="be concise") + ins = Instruction(description="task", requirements=[r]) + assert r in ins.parts() + + +def test_parts_includes_grounding_context_values(): + ins = Instruction(grounding_context={"doc": "content"}) + parts = ins.parts() + assert ins._grounding_context["doc"] in parts + + +def test_parts_empty_instruction(): + ins = Instruction() + # No description, no requirements, no grounding context + assert ins.parts() == [] + + +def test_parts_includes_icl_examples(): + ins = Instruction(icl_examples=["example 1"]) + parts = ins.parts() + assert len(parts) == 1 + + +# --- format_for_llm --- + + +def test_format_for_llm_returns_template_representation(): + ins = Instruction(description="do something") + result = ins.format_for_llm() + assert isinstance(result, TemplateRepresentation) + + +def test_format_for_llm_args_structure(): + ins = Instruction(description="task", requirements=["req 1"], icl_examples=["ex 1"]) + result = ins.format_for_llm() + assert "description" in result.args + assert "requirements" in result.args + assert "icl_examples" in result.args + assert "grounding_context" in result.args + assert "repair" in result.args + + +def test_format_for_llm_check_only_req_excluded(): + r = Requirement(description="internal check", check_only=True) + ins = Instruction(requirements=[r]) + result = ins.format_for_llm() + assert r.description not in result.args["requirements"] + + +def test_format_for_llm_repair_is_none_by_default(): + ins = Instruction(description="task") + result = ins.format_for_llm() + assert result.args["repair"] is None + + +# --- copy_and_repair --- + + +def test_copy_and_repair_sets_repair_string(): + ins = Instruction(description="task", requirements=["be brief"]) + repaired = ins.copy_and_repair("requirement 'be brief' not met") + assert repaired._repair_string == "requirement 'be brief' not met" + + +def test_copy_and_repair_does_not_mutate_original(): + ins = Instruction(description="task") + _ = ins.copy_and_repair("failed") + assert ins._repair_string is None + + +def test_copy_and_repair_deep_copy(): + ins = Instruction(description="task", requirements=["be brief"]) + repaired = ins.copy_and_repair("reason") + # Mutating the copy's requirements should not affect the original + repaired._requirements.append(Requirement(description="new")) + assert len(ins._requirements) == 1 + + +def test_copy_and_repair_format_includes_repair(): + ins = Instruction(description="task") + repaired = ins.copy_and_repair("please fix this") + result = repaired.format_for_llm() + assert result.args["repair"] == "please fix this" + + +# --- _parse --- + + +def test_parse_returns_value(): + ins = Instruction(description="x") + mot = ModelOutputThunk(value="answer") + assert ins._parse(mot) == "answer" + + +def test_parse_none_returns_empty_string(): + ins = Instruction(description="x") + mot = ModelOutputThunk(value=None) + assert ins._parse(mot) == "" + + +# --- requirements property --- + + +def test_requirements_property(): + ins = Instruction(requirements=["be brief", "be accurate"]) + reqs = ins.requirements + assert len(reqs) == 2 + assert all(isinstance(r, Requirement) for r in reqs) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/components/test_mobject.py b/test/stdlib/components/test_mobject.py new file mode 100644 index 000000000..05166a4be --- /dev/null +++ b/test/stdlib/components/test_mobject.py @@ -0,0 +1,213 @@ +"""Unit tests for Query, Transform, and MObject — no docling, no backend required.""" + +import pytest + +from mellea.core import ModelOutputThunk, TemplateRepresentation +from mellea.stdlib.components.mobject import MObject, Query, Transform + +# --- helpers --- + + +class _SimpleComponent(MObject): + """Minimal MObject subclass for testing.""" + + def __init__(self, content: str = "hello") -> None: + super().__init__() + self._content = content + + def content_as_string(self) -> str: + return self._content + + def format_for_llm(self) -> str: + return self._content + + def parts(self): + return [] + + def _parse(self, computed): + return computed.value or "" + + +# --- Query --- + + +def test_query_parts_returns_wrapped_object(): + obj = _SimpleComponent("doc text") + q = Query(obj, "what is this?") + parts = q.parts() + assert len(parts) == 1 + assert parts[0] is obj + + +def test_query_format_for_llm_returns_template_repr(): + obj = _SimpleComponent("text") + q = Query(obj, "summarise") + result = q.format_for_llm() + assert isinstance(result, TemplateRepresentation) + + +def test_query_format_for_llm_query_field(): + obj = _SimpleComponent("text") + q = Query(obj, "what colour?") + result = q.format_for_llm() + assert result.args["query"] == "what colour?" + + +def test_query_format_for_llm_content_is_wrapped_object(): + obj = _SimpleComponent("text") + q = Query(obj, "q") + result = q.format_for_llm() + assert result.args["content"] is obj + + +def test_query_parse_returns_value(): + obj = _SimpleComponent() + q = Query(obj, "q") + mot = ModelOutputThunk(value="answer") + assert q._parse(mot) == "answer" + + +def test_query_parse_none_returns_empty(): + obj = _SimpleComponent() + q = Query(obj, "q") + mot = ModelOutputThunk(value=None) + assert q._parse(mot) == "" + + +# --- Transform --- + + +def test_transform_parts_returns_wrapped_object(): + obj = _SimpleComponent("doc text") + t = Transform(obj, "translate to French") + parts = t.parts() + assert len(parts) == 1 + assert parts[0] is obj + + +def test_transform_format_for_llm_returns_template_repr(): + obj = _SimpleComponent("text") + t = Transform(obj, "rewrite formally") + result = t.format_for_llm() + assert isinstance(result, TemplateRepresentation) + + +def test_transform_format_for_llm_transformation_field(): + obj = _SimpleComponent("text") + t = Transform(obj, "make it shorter") + result = t.format_for_llm() + assert result.args["transformation"] == "make it shorter" + + +def test_transform_format_for_llm_content_is_wrapped_object(): + obj = _SimpleComponent("text") + t = Transform(obj, "x") + result = t.format_for_llm() + assert result.args["content"] is obj + + +def test_transform_parse_returns_value(): + obj = _SimpleComponent() + t = Transform(obj, "x") + mot = ModelOutputThunk(value="result") + assert t._parse(mot) == "result" + + +# --- MObject --- + + +def test_mobject_parts_empty(): + obj = _SimpleComponent() + assert obj.parts() == [] + + +def test_mobject_get_query_object(): + obj = _SimpleComponent("text") + q = obj.get_query_object("what is this?") + assert isinstance(q, Query) + assert q._query == "what is this?" + assert q._obj is obj + + +def test_mobject_get_transform_object(): + obj = _SimpleComponent("text") + t = obj.get_transform_object("shorten it") + assert isinstance(t, Transform) + assert t._transformation == "shorten it" + assert t._obj is obj + + +def test_mobject_content_as_string(): + obj = _SimpleComponent("my content") + assert obj.content_as_string() == "my content" + + +def test_mobject_format_for_llm_returns_template_repr(): + obj = _SimpleComponent("text") + result = obj.format_for_llm() + # Uses the overridden format_for_llm returning str + assert result == "text" + + +def test_mobject_custom_query_type(): + class _CustomQuery(Query): + pass + + obj = MObject(query_type=_CustomQuery) + q = obj.get_query_object("q") + assert isinstance(q, _CustomQuery) + + +def test_mobject_custom_transform_type(): + class _CustomTransform(Transform): + pass + + obj = MObject(transform_type=_CustomTransform) + t = obj.get_transform_object("t") + assert isinstance(t, _CustomTransform) + + +def test_mobj_base_format_for_llm(): + """Test MObject.format_for_llm (not the overridden version) via base class directly.""" + + class _MObjectWithTools(MObject): + def my_tool(self) -> str: + """A custom tool.""" + return "result" + + def content_as_string(self) -> str: + return "content" + + def parts(self): + return [] + + def format_for_llm(self): + return MObject.format_for_llm(self) + + def _parse(self, computed): + return "" + + obj = _MObjectWithTools() + result = obj.format_for_llm() + assert isinstance(result, TemplateRepresentation) + assert result.args["content"] == "content" + + +def test_mobj_parse_returns_value(): + class _M(MObject): + def content_as_string(self): + return "" + + def parts(self): + return [] + + def _parse(self, computed): + return MObject._parse(self, computed) + + obj = _M() + mot = ModelOutputThunk(value="result") + assert obj._parse(mot) == "result" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/components/test_simple.py b/test/stdlib/components/test_simple.py new file mode 100644 index 000000000..c62fa9a4a --- /dev/null +++ b/test/stdlib/components/test_simple.py @@ -0,0 +1,135 @@ +"""Unit tests for SimpleComponent — kwargs rendering, type validation, JSON output.""" + +import json + +import pytest + +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.components.simple import SimpleComponent + +# --- constructor & type checking --- + + +def test_init_converts_strings_to_cblocks(): + sc = SimpleComponent(task="write a poem") + assert isinstance(sc._kwargs["task"], CBlock) + assert sc._kwargs["task"].value == "write a poem" + + +def test_init_accepts_cblock_directly(): + cb = CBlock("already a block") + sc = SimpleComponent(thing=cb) + assert sc._kwargs["thing"] is cb + + +def test_init_rejects_non_string_non_component(): + with pytest.raises(AssertionError): + SimpleComponent(bad=42) + + +def test_init_rejects_non_string_key(): + # We can't pass non-string keys via kwargs syntax; test _kwargs_type_check directly + sc = SimpleComponent(ok="fine") + with pytest.raises(AssertionError): + sc._kwargs_type_check({123: CBlock("v")}) + + +def test_init_multiple_kwargs(): + sc = SimpleComponent(task="summarise", context="some text") + assert len(sc._kwargs) == 2 + assert set(sc._kwargs.keys()) == {"task", "context"} + + +# --- parts() --- + + +def test_parts_returns_all_values(): + sc = SimpleComponent(a="one", b="two") + parts = sc.parts() + assert len(parts) == 2 + assert all(isinstance(p, CBlock) for p in parts) + + +def test_parts_empty(): + sc = SimpleComponent() + assert sc.parts() == [] + + +# --- make_simple_string --- + + +def test_make_simple_string_single(): + kwargs = {"task": CBlock("do something")} + result = SimpleComponent.make_simple_string(kwargs) + assert result == "<|task|>do something" + + +def test_make_simple_string_multiple(): + # Use ordered dict (Python 3.7+ guarantees insertion order) + kwargs = {"a": CBlock("first"), "b": CBlock("second")} + result = SimpleComponent.make_simple_string(kwargs) + assert "<|a|>first" in result + assert "<|b|>second" in result + assert "\n" in result + + +def test_make_simple_string_empty(): + assert SimpleComponent.make_simple_string({}) == "" + + +# --- make_json_string --- + + +def test_make_json_string_cblock(): + kwargs = {"key": CBlock("value")} + result = json.loads(SimpleComponent.make_json_string(kwargs)) + assert result == {"key": "value"} + + +def test_make_json_string_model_output_thunk(): + mot = ModelOutputThunk(value="output text") + kwargs = {"out": mot} + result = json.loads(SimpleComponent.make_json_string(kwargs)) + assert result == {"out": "output text"} + + +def test_make_json_string_nested_component(): + inner = SimpleComponent(x="nested") + kwargs = {"inner": inner} + result = json.loads(SimpleComponent.make_json_string(kwargs)) + assert "inner" in result + + +def test_make_json_string_empty(): + result = json.loads(SimpleComponent.make_json_string({})) + assert result == {} + + +# --- format_for_llm --- + + +def test_format_for_llm_returns_json_string(): + sc = SimpleComponent(topic="ocean", style="poetic") + formatted = sc.format_for_llm() + parsed = json.loads(formatted) + assert parsed["topic"] == "ocean" + assert parsed["style"] == "poetic" + + +# --- _parse --- + + +def test_parse_returns_value(): + sc = SimpleComponent(x="whatever") + mot = ModelOutputThunk(value="result") + assert sc._parse(mot) == "result" + + +def test_parse_none_returns_empty_string(): + sc = SimpleComponent(x="whatever") + mot = ModelOutputThunk(value=None) + assert sc._parse(mot) == "" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/sampling/test_majority_voting_unit.py b/test/stdlib/sampling/test_majority_voting_unit.py new file mode 100644 index 000000000..cf3429fe7 --- /dev/null +++ b/test/stdlib/sampling/test_majority_voting_unit.py @@ -0,0 +1,74 @@ +"""Unit tests for majority voting compare_strings methods — no backend required.""" + +import pytest + +from mellea.stdlib.sampling.majority_voting import ( + MajorityVotingStrategyForMath, + MBRDRougeLStrategy, +) + +# --- MajorityVotingStrategyForMath.compare_strings --- + + +@pytest.fixture +def math_strategy(): + return MajorityVotingStrategyForMath() + + +def test_math_compare_identical_boxed(math_strategy): + assert math_strategy.compare_strings(r"\boxed{2}", r"\boxed{2}") == 1.0 + + +def test_math_compare_identical_latex(math_strategy): + assert math_strategy.compare_strings(r"\boxed{4}", r"\boxed{4}") == 1.0 + + +def test_math_compare_unboxed_integers_return_zero(math_strategy): + # Plain integers without boxed notation are not extracted — returns 0.0 + assert math_strategy.compare_strings("2", "3") == 0.0 + + +def test_math_compare_different_boxed(math_strategy): + assert math_strategy.compare_strings(r"\boxed{2}", r"\boxed{3}") == 0.0 + + +def test_math_compare_returns_float(math_strategy): + result = math_strategy.compare_strings(r"\boxed{5}", r"\boxed{5}") + assert isinstance(result, float) + + +# --- MBRDRougeLStrategy.compare_strings --- + + +@pytest.fixture +def rouge_strategy(): + return MBRDRougeLStrategy() + + +def test_rougel_compare_identical(rouge_strategy): + score = rouge_strategy.compare_strings("hello world", "hello world") + assert score == pytest.approx(1.0) + + +def test_rougel_compare_completely_different(rouge_strategy): + score = rouge_strategy.compare_strings("hello world", "foo bar baz") + assert score < 0.5 + + +def test_rougel_compare_partial_overlap(rouge_strategy): + score = rouge_strategy.compare_strings("the quick brown fox", "the quick fox") + assert 0.0 < score < 1.0 + + +def test_rougel_compare_returns_float(rouge_strategy): + score = rouge_strategy.compare_strings("abc", "abc") + assert isinstance(score, float) + + +def test_rougel_score_in_range(rouge_strategy): + score = rouge_strategy.compare_strings("some text here", "some different text") + assert 0.0 <= score <= 1.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/sampling/test_sampling_base_unit.py b/test/stdlib/sampling/test_sampling_base_unit.py new file mode 100644 index 000000000..7c4161eec --- /dev/null +++ b/test/stdlib/sampling/test_sampling_base_unit.py @@ -0,0 +1,108 @@ +"""Unit tests for sampling/base.py static repair() logic — no backend required.""" + +import pytest + +from mellea.core import ( + ComputedModelOutputThunk, + ModelOutputThunk, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction, Message +from mellea.stdlib.context import ChatContext +from mellea.stdlib.sampling.base import RepairTemplateStrategy + +# --- BaseSamplingStrategy.repair --- + + +def _val(passed: bool, reason: str | None = None) -> ValidationResult: + return ValidationResult(result=passed, reason=reason) + + +def test_repair_instruction_builds_repair_string(): + ins = Instruction(description="Write a poem", requirements=["be concise"]) + req = Requirement(description="be concise") + old_ctx = ChatContext() + new_ctx = ChatContext() + + action, ctx = RepairTemplateStrategy.repair( + old_ctx=old_ctx, + new_ctx=new_ctx, + past_actions=[ins], + past_results=[ + ComputedModelOutputThunk(thunk=ModelOutputThunk(value="long text")) + ], + past_val=[[(req, _val(False, reason="Output was too long"))]], + ) + assert isinstance(action, Instruction) + assert action._repair_string is not None + assert "Output was too long" in action._repair_string + assert ctx is old_ctx + + +def test_repair_uses_req_description_when_no_reason(): + ins = Instruction(description="task") + req = Requirement(description="must be brief") + old_ctx = ChatContext() + + action, _ = RepairTemplateStrategy.repair( + old_ctx=old_ctx, + new_ctx=ChatContext(), + past_actions=[ins], + past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))], + past_val=[[(req, _val(False))]], + ) + assert "must be brief" in action._repair_string + + +def test_repair_non_instruction_returns_same_action(): + msg = Message("user", "hello") + old_ctx = ChatContext() + + action, ctx = RepairTemplateStrategy.repair( + old_ctx=old_ctx, + new_ctx=ChatContext(), + past_actions=[msg], + past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))], + past_val=[[]], + ) + assert action is msg + assert ctx is old_ctx + + +def test_repair_multiple_failures_all_listed(): + ins = Instruction(description="task") + r1 = Requirement(description="be short") + r2 = Requirement(description="be polite") + old_ctx = ChatContext() + + action, _ = RepairTemplateStrategy.repair( + old_ctx=old_ctx, + new_ctx=ChatContext(), + past_actions=[ins], + past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))], + past_val=[[(r1, _val(False, "too long")), (r2, _val(False, "rude tone"))]], + ) + assert "too long" in action._repair_string + assert "rude tone" in action._repair_string + + +def test_repair_passed_requirements_excluded(): + ins = Instruction(description="task") + r_pass = Requirement(description="format ok") + r_fail = Requirement(description="content wrong") + old_ctx = ChatContext() + + action, _ = RepairTemplateStrategy.repair( + old_ctx=old_ctx, + new_ctx=ChatContext(), + past_actions=[ins], + past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))], + past_val=[[(r_pass, _val(True)), (r_fail, _val(False, "incorrect"))]], + ) + assert "format ok" not in action._repair_string + assert "incorrect" in action._repair_string + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/sampling/test_sofai_unit.py b/test/stdlib/sampling/test_sofai_unit.py new file mode 100644 index 000000000..2b3f6f984 --- /dev/null +++ b/test/stdlib/sampling/test_sofai_unit.py @@ -0,0 +1,147 @@ +"""Unit tests for SOFAI sampling strategy pure static helpers — no backend required. + +Covers _extract_action_prompt, _parse_judgment, _extract_feedback, _select_best_attempt. +""" + +import pytest + +from mellea.core import Requirement, TemplateRepresentation, ValidationResult +from mellea.stdlib.components import Instruction, Message +from mellea.stdlib.sampling.sofai import SOFAISamplingStrategy + +# --- _parse_judgment --- + + +def test_parse_judgment_yes(): + assert SOFAISamplingStrategy._parse_judgment("Yes") is True + + +def test_parse_judgment_yes_with_explanation(): + assert SOFAISamplingStrategy._parse_judgment("Yes, the output is correct.") is True + + +def test_parse_judgment_no(): + assert SOFAISamplingStrategy._parse_judgment("No") is False + + +def test_parse_judgment_no_with_explanation(): + assert ( + SOFAISamplingStrategy._parse_judgment( + "No, it needs improvement.\nDetails here." + ) + is False + ) + + +def test_parse_judgment_yes_in_first_line(): + assert SOFAISamplingStrategy._parse_judgment("The answer is yes") is True + + +def test_parse_judgment_no_match_defaults_false(): + assert SOFAISamplingStrategy._parse_judgment("Maybe, hard to tell") is False + + +def test_parse_judgment_whitespace_stripped(): + assert SOFAISamplingStrategy._parse_judgment(" Yes ") is True + + +def test_parse_judgment_case_insensitive(): + assert SOFAISamplingStrategy._parse_judgment("YES") is True + + +# --- _extract_feedback --- + + +def test_extract_feedback_with_tags(): + text = "Some preamble. Fix the grammar. More text." + assert SOFAISamplingStrategy._extract_feedback(text) == "Fix the grammar." + + +def test_extract_feedback_no_tags(): + text = "Just plain feedback text." + assert SOFAISamplingStrategy._extract_feedback(text) == "Just plain feedback text." + + +def test_extract_feedback_multiline(): + text = "\nLine 1\nLine 2\n" + result = SOFAISamplingStrategy._extract_feedback(text) + assert "Line 1" in result + assert "Line 2" in result + + +def test_extract_feedback_case_insensitive_tags(): + text = "Fix it." + assert SOFAISamplingStrategy._extract_feedback(text) == "Fix it." + + +def test_extract_feedback_strips_whitespace(): + text = " some feedback " + assert SOFAISamplingStrategy._extract_feedback(text) == "some feedback" + + +# --- _extract_action_prompt --- + + +def test_extract_action_prompt_message(): + msg = Message("user", "What is 2+2?") + assert SOFAISamplingStrategy._extract_action_prompt(msg) == "What is 2+2?" + + +def test_extract_action_prompt_instruction(): + ins = Instruction(description="Summarise the text") + result = SOFAISamplingStrategy._extract_action_prompt(ins) + assert result == "Summarise the text" + + +def test_extract_action_prompt_format_for_llm_str(): + """Component whose format_for_llm returns a plain string.""" + from mellea.core import CBlock, Component, ModelOutputThunk + + class _StrComponent(Component[str]): + def parts(self): + return [] + + def format_for_llm(self) -> str: + return "plain text repr" + + def _parse(self, computed: ModelOutputThunk) -> str: + return "" + + result = SOFAISamplingStrategy._extract_action_prompt(_StrComponent()) + assert result == "plain text repr" + + +# --- _select_best_attempt --- + + +def _vr(passed: bool) -> ValidationResult: + return ValidationResult(result=passed) + + +def test_select_best_attempt_picks_most_passing(): + r = Requirement(description="r") + val = [ + [(r, _vr(True)), (r, _vr(False))], # 1 pass + [(r, _vr(True)), (r, _vr(True))], # 2 pass — best + [(r, _vr(False)), (r, _vr(False))], # 0 pass + ] + assert SOFAISamplingStrategy._select_best_attempt(val) == 1 + + +def test_select_best_attempt_tie_prefers_later(): + r = Requirement(description="r") + val = [ + [(r, _vr(True))], # 1 pass + [(r, _vr(True))], # 1 pass — tie, but later → preferred + ] + assert SOFAISamplingStrategy._select_best_attempt(val) == 1 + + +def test_select_best_attempt_single(): + r = Requirement(description="r") + val = [[(r, _vr(False))]] + assert SOFAISamplingStrategy._select_best_attempt(val) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/test_functional_unit.py b/test/stdlib/test_functional_unit.py new file mode 100644 index 000000000..7d4b215c9 --- /dev/null +++ b/test/stdlib/test_functional_unit.py @@ -0,0 +1,66 @@ +"""Unit tests for functional.py pure helpers — no backend, no LLM required. + +Covers _parse_and_clean_image_args image preprocessing. +""" + +import base64 +import io + +import pytest +from PIL import Image as PILImage + +from mellea.core import ImageBlock +from mellea.stdlib.functional import _parse_and_clean_image_args + + +def _make_image_block() -> ImageBlock: + """Return a valid ImageBlock backed by a 1x1 red PNG.""" + img = PILImage.new("RGB", (1, 1), color="red") + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + return ImageBlock(value=b64) + + +# --- _parse_and_clean_image_args --- + + +def test_none_returns_none(): + assert _parse_and_clean_image_args(None) is None + + +def test_empty_list_returns_none(): + assert _parse_and_clean_image_args([]) is None + + +def test_image_blocks_passed_through(): + ib = _make_image_block() + result = _parse_and_clean_image_args([ib]) + assert result == [ib] + + +def test_multiple_image_blocks_preserved(): + ib1 = _make_image_block() + ib2 = _make_image_block() + result = _parse_and_clean_image_args([ib1, ib2]) + assert result is not None + assert len(result) == 2 + assert result[0] is ib1 + assert result[1] is ib2 + + +def test_pil_images_converted_to_image_blocks(): + pil_img = PILImage.new("RGB", (1, 1), color="blue") + result = _parse_and_clean_image_args([pil_img]) + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], ImageBlock) + + +def test_non_list_raises(): + with pytest.raises(AssertionError, match="Images should be a list"): + _parse_and_clean_image_args("not_a_list") # type: ignore + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/stdlib/test_session_unit.py b/test/stdlib/test_session_unit.py new file mode 100644 index 000000000..466a85e9a --- /dev/null +++ b/test/stdlib/test_session_unit.py @@ -0,0 +1,66 @@ +"""Unit tests for session.py pure-logic — no Ollama server required. + +Covers backend_name_to_class factory resolution and get_session error path. +""" + +import pytest + +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.openai import OpenAIBackend +from mellea.stdlib.session import backend_name_to_class, get_session + +# --- backend_name_to_class --- + + +def test_ollama_resolves_to_ollama_backend(): + cls = backend_name_to_class("ollama") + assert cls is OllamaModelBackend + + +def test_openai_resolves_to_openai_backend(): + cls = backend_name_to_class("openai") + assert cls is OpenAIBackend + + +def test_unknown_name_returns_none(): + cls = backend_name_to_class("does_not_exist") + assert cls is None + + +def test_hf_resolves_or_raises_import_error(): + # Either resolves (if mellea[hf] is installed) or raises ImportError with helpful message + try: + cls = backend_name_to_class("hf") + assert cls is not None + except ImportError as e: + assert "mellea[hf]" in str(e) + + +def test_huggingface_alias_same_as_hf(): + # "hf" and "huggingface" should resolve to the same class + try: + cls_hf = backend_name_to_class("hf") + cls_hf_full = backend_name_to_class("huggingface") + assert cls_hf is cls_hf_full + except ImportError: + pass # OK if mellea[hf] is not installed + + +def test_litellm_resolves_or_raises_import_error(): + try: + cls = backend_name_to_class("litellm") + assert cls is not None + except ImportError as e: + assert "mellea[litellm]" in str(e) + + +# --- get_session --- + + +def test_get_session_raises_when_no_active_session(): + with pytest.raises(RuntimeError, match="No active session found"): + get_session() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/telemetry/test_backend_instrumentation.py b/test/telemetry/test_backend_instrumentation.py new file mode 100644 index 000000000..4163ccb8d --- /dev/null +++ b/test/telemetry/test_backend_instrumentation.py @@ -0,0 +1,209 @@ +"""Unit tests for backend_instrumentation helpers — model ID extraction, system name mapping, +context size introspection, and span attribute recording.""" + +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest + +from mellea.telemetry.backend_instrumentation import ( + get_context_size, + get_model_id_str, + get_system_name, + record_response_metadata, + record_token_usage, +) + +# --- get_model_id_str --- + + +@dataclass +class _BackendWithStrModelId: + model_id: str + + +@dataclass +class _HFModelId: + hf_model_name: str + + +@dataclass +class _BackendWithHFModelId: + model_id: _HFModelId + + +def test_get_model_id_str_plain_string(): + backend = _BackendWithStrModelId(model_id="granite-3-8b") + assert get_model_id_str(backend) == "granite-3-8b" + + +def test_get_model_id_str_hf_model_name(): + backend = _BackendWithHFModelId( + model_id=_HFModelId(hf_model_name="ibm-granite/granite-4.0-micro") + ) + assert get_model_id_str(backend) == "ibm-granite/granite-4.0-micro" + + +def test_get_model_id_str_no_model_id_returns_class_name(): + class UnknownBackend: + pass + + backend = UnknownBackend() + assert get_model_id_str(backend) == "UnknownBackend" + + +# --- get_system_name --- + + +def _fake_backend(class_name: str) -> object: + return type(class_name, (), {})() + + +def test_get_system_name_openai(): + assert get_system_name(_fake_backend("OpenAIBackend")) == "openai" + + +def test_get_system_name_ollama(): + assert get_system_name(_fake_backend("OllamaModelBackend")) == "ollama" + + +def test_get_system_name_huggingface(): + assert get_system_name(_fake_backend("LocalHFBackend")) == "huggingface" + + +def test_get_system_name_hf_shortname(): + assert get_system_name(_fake_backend("HFBackend")) == "huggingface" + + +def test_get_system_name_watsonx(): + assert get_system_name(_fake_backend("WatsonxBackend")) == "watsonx" + + +def test_get_system_name_litellm(): + assert get_system_name(_fake_backend("LiteLLMBackend")) == "litellm" + + +def test_get_system_name_unknown_returns_class_name(): + backend = _fake_backend("SomeCustomBackend") + assert get_system_name(backend) == "SomeCustomBackend" + + +# --- get_context_size --- + + +def test_get_context_size_with_len(): + ctx = [1, 2, 3] + assert get_context_size(ctx) == 3 + + +def test_get_context_size_empty_list(): + assert get_context_size([]) == 0 + + +def test_get_context_size_with_turns(): + ctx = type("Ctx", (), {"turns": [1, 2, 3, 4]})() + assert get_context_size(ctx) == 4 + + +def test_get_context_size_no_len_no_turns(): + class Opaque: + pass + + assert get_context_size(Opaque()) == 0 + + +def test_get_context_size_len_raises_returns_zero(): + class Broken: + def __len__(self): + raise RuntimeError("broken") + + assert get_context_size(Broken()) == 0 + + +# --- record_token_usage --- + + +def _mock_span(): + return MagicMock() + + +def test_record_token_usage_from_dict(): + span = _mock_span() + usage = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + record_token_usage(span, usage) + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.usage.input_tokens") == 10 + assert calls.get("gen_ai.usage.output_tokens") == 20 + assert calls.get("gen_ai.usage.total_tokens") == 30 + + +def test_record_token_usage_from_object(): + span = _mock_span() + usage = type( + "Usage", (), {"prompt_tokens": 5, "completion_tokens": 15, "total_tokens": 20} + )() + record_token_usage(span, usage) + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.usage.input_tokens") == 5 + + +def test_record_token_usage_none_span_no_op(): + # Should not raise + record_token_usage(None, {"prompt_tokens": 1}) + + +def test_record_token_usage_none_usage_no_op(): + span = _mock_span() + record_token_usage(span, None) + span.set_attribute.assert_not_called() + + +def test_record_token_usage_partial_fields(): + span = _mock_span() + usage = {"prompt_tokens": 7} + record_token_usage(span, usage) + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.usage.input_tokens") == 7 + assert "gen_ai.usage.output_tokens" not in calls + + +# --- record_response_metadata --- + + +def test_record_response_metadata_model_from_dict(): + span = _mock_span() + response = {"model": "granite-3-8b", "choices": [], "id": "resp-123"} + record_response_metadata(span, response) + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.response.model") == "granite-3-8b" + assert calls.get("gen_ai.response.id") == "resp-123" + + +def test_record_response_metadata_explicit_model_id_overrides(): + span = _mock_span() + response = {"model": "old-model"} + record_response_metadata(span, response, model_id="new-model") + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.response.model") == "new-model" + + +def test_record_response_metadata_finish_reason(): + span = _mock_span() + response = {"choices": [{"finish_reason": "stop"}]} + record_response_metadata(span, response) + calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list} + assert calls.get("gen_ai.response.finish_reasons") == ["stop"] + + +def test_record_response_metadata_none_span_no_op(): + record_response_metadata(None, {"model": "x"}) + + +def test_record_response_metadata_none_response_no_op(): + span = _mock_span() + record_response_metadata(span, None) + span.set_attribute.assert_not_called() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/telemetry/test_tracing_helpers.py b/test/telemetry/test_tracing_helpers.py new file mode 100644 index 000000000..a8cb4e06e --- /dev/null +++ b/test/telemetry/test_tracing_helpers.py @@ -0,0 +1,89 @@ +"""Unit tests for tracing helper functions — no OpenTelemetry installation required. + +_set_attribute_safe and end_backend_span operate on any object with a +set_attribute / end method, so these tests use MagicMock spans and run +unconditionally. test_set_span_error_records_exception calls into the real +OTel trace API and is skipped when opentelemetry is not installed. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from mellea.telemetry.tracing import ( + _set_attribute_safe, + end_backend_span, + set_span_error, +) + +# --- _set_attribute_safe type-conversion --- + + +def test_set_attribute_safe_none_value_no_op(): + span = MagicMock() + _set_attribute_safe(span, "key", None) + span.set_attribute.assert_not_called() + + +def test_set_attribute_safe_bool(): + span = MagicMock() + _set_attribute_safe(span, "flag", True) + span.set_attribute.assert_called_once_with("flag", True) + + +def test_set_attribute_safe_int(): + span = MagicMock() + _set_attribute_safe(span, "count", 42) + span.set_attribute.assert_called_once_with("count", 42) + + +def test_set_attribute_safe_str(): + span = MagicMock() + _set_attribute_safe(span, "name", "hello") + span.set_attribute.assert_called_once_with("name", "hello") + + +def test_set_attribute_safe_list_converted_to_string_list(): + span = MagicMock() + _set_attribute_safe(span, "items", [1, 2, 3]) + span.set_attribute.assert_called_once_with("items", ["1", "2", "3"]) + + +def test_set_attribute_safe_unsupported_type_stringified(): + span = MagicMock() + _set_attribute_safe(span, "obj", {"nested": "dict"}) + span.set_attribute.assert_called_once() + call_args = span.set_attribute.call_args + assert call_args.args[0] == "obj" + assert isinstance(call_args.args[1], str) + + +# --- set_span_error — requires opentelemetry for trace.Status --- + + +def test_set_span_error_records_exception(): + pytest.importorskip( + "opentelemetry", + reason="opentelemetry not installed — install mellea[telemetry]", + ) + span = MagicMock() + exc = ValueError("something went wrong") + + with patch("mellea.telemetry.tracing._OTEL_AVAILABLE", True): + set_span_error(span, exc) + + span.record_exception.assert_called_once_with(exc) + span.set_status.assert_called_once() + + +# --- end_backend_span --- + + +def test_end_backend_span_calls_end_on_span(): + span = MagicMock() + end_backend_span(span) + span.end.assert_called_once() + + +def test_end_backend_span_none_no_op(): + end_backend_span(None)