diff --git a/test/backends/test_ollama_unit.py b/test/backends/test_ollama_unit.py
new file mode 100644
index 000000000..9ade604b7
--- /dev/null
+++ b/test/backends/test_ollama_unit.py
@@ -0,0 +1,151 @@
+"""Unit tests for Ollama backend pure-logic helpers — no Ollama server required.
+
+Covers _simplify_and_merge, _make_backend_specific_and_remove, and
+chat_response_delta_merge.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import ollama
+import pytest
+
+from mellea.backends import ModelOption
+from mellea.backends.ollama import OllamaModelBackend, chat_response_delta_merge
+from mellea.core import ModelOutputThunk
+
+
+def _make_backend(model_options: dict | None = None) -> OllamaModelBackend:
+ """Return an OllamaModelBackend with all network calls patched."""
+ with (
+ patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True),
+ patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True),
+ patch("mellea.backends.ollama.ollama.Client", return_value=MagicMock()),
+ patch("mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock()),
+ ):
+ return OllamaModelBackend(model_id="granite3.3:8b", model_options=model_options)
+
+
+@pytest.fixture
+def backend():
+ """Return an OllamaModelBackend with no pre-set model options."""
+ return _make_backend()
+
+
+# --- Map consistency ---
+
+
+def test_from_mellea_keys_are_subset_of_to_mellea_values(backend):
+ """Every key in from_mellea must appear as a value in to_mellea (maps agree)."""
+ to_values = set(backend.to_mellea_model_opts_map.values())
+ from_keys = set(backend.from_mellea_model_opts_map.keys())
+ assert from_keys <= to_values, (
+ f"from_mellea has keys absent from to_mellea values: {from_keys - to_values}"
+ )
+
+
+# --- _simplify_and_merge ---
+
+
+def test_simplify_and_merge_none_returns_empty_dict(backend):
+ result = backend._simplify_and_merge(None)
+ assert result == {}
+
+
+def test_simplify_and_merge_all_to_mellea_entries(backend):
+ """Every to_mellea entry remaps to its ModelOption via _simplify_and_merge."""
+ for backend_key, mellea_key in backend.to_mellea_model_opts_map.items():
+ result = backend._simplify_and_merge({backend_key: 42})
+ assert mellea_key in result, f"{backend_key!r} did not produce {mellea_key!r}"
+ assert result[mellea_key] == 42
+
+
+def test_simplify_and_merge_remaps_num_predict(backend):
+ """Hardcoded anchor: the most critical mapping for generation length."""
+ result = backend._simplify_and_merge({"num_predict": 128})
+ assert ModelOption.MAX_NEW_TOKENS in result
+ assert result[ModelOption.MAX_NEW_TOKENS] == 128
+
+
+def test_simplify_and_merge_per_call_overrides_backend():
+ # Backend sets num_predict=128; per-call value of 256 must win.
+ b = _make_backend(model_options={"num_predict": 128})
+ result = b._simplify_and_merge({"num_predict": 256})
+ assert result[ModelOption.MAX_NEW_TOKENS] == 256
+
+
+# --- _make_backend_specific_and_remove ---
+
+
+def test_make_backend_specific_all_from_mellea_entries(backend):
+ """Every from_mellea entry remaps to its backend key via _make_backend_specific_and_remove."""
+ for mellea_key, backend_key in backend.from_mellea_model_opts_map.items():
+ result = backend._make_backend_specific_and_remove({mellea_key: 42})
+ assert backend_key in result, f"{mellea_key!r} did not produce {backend_key!r}"
+ assert result[backend_key] == 42
+
+
+def test_make_backend_specific_remaps_max_new_tokens(backend):
+ """Hardcoded anchor: the most critical mapping for generation length."""
+ opts = {ModelOption.MAX_NEW_TOKENS: 64}
+ result = backend._make_backend_specific_and_remove(opts)
+ assert "num_predict" in result
+ assert result["num_predict"] == 64
+
+
+def test_make_backend_specific_removes_sentinel_keys(backend):
+ opts = {ModelOption.MAX_NEW_TOKENS: 32, ModelOption.SYSTEM_PROMPT: "sys"}
+ result = backend._make_backend_specific_and_remove(opts)
+ # Sentinel keys not in from_mellea_model_opts_map should be removed
+ assert ModelOption.SYSTEM_PROMPT not in result
+
+
+# --- chat_response_delta_merge ---
+
+
+def _make_delta(
+ content: str,
+ role: str = "assistant",
+ done: bool = False,
+ thinking: str | None = None,
+) -> ollama.ChatResponse:
+ msg = ollama.Message(role=role, content=content, thinking=thinking)
+ return ollama.ChatResponse(model="test", created_at=None, message=msg, done=done)
+
+
+def test_delta_merge_first_sets_chat_response():
+ mot = ModelOutputThunk(value=None)
+ delta = _make_delta("Hello")
+ chat_response_delta_merge(mot, delta)
+ assert mot._meta["chat_response"] is delta
+
+
+def test_delta_merge_second_appends_content():
+ mot = ModelOutputThunk(value=None)
+ chat_response_delta_merge(mot, _make_delta("Hello"))
+ chat_response_delta_merge(mot, _make_delta(" world"))
+ assert mot._meta["chat_response"].message.content == "Hello world"
+
+
+def test_delta_merge_done_propagated():
+ mot = ModelOutputThunk(value=None)
+ chat_response_delta_merge(mot, _make_delta("partial", done=False))
+ chat_response_delta_merge(mot, _make_delta("", done=True))
+ assert mot._meta["chat_response"].done is True
+
+
+def test_delta_merge_role_set_from_first_delta():
+ mot = ModelOutputThunk(value=None)
+ chat_response_delta_merge(mot, _make_delta("hi", role="assistant"))
+ chat_response_delta_merge(mot, _make_delta(" there", role=""))
+ assert mot._meta["chat_response"].message.role == "assistant"
+
+
+def test_delta_merge_thinking_concatenated():
+ mot = ModelOutputThunk(value=None)
+ chat_response_delta_merge(mot, _make_delta("reply", thinking="step 1"))
+ chat_response_delta_merge(mot, _make_delta("", thinking=" step 2"))
+ assert mot._meta["chat_response"].message.thinking == "step 1 step 2"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/backends/test_openai_unit.py b/test/backends/test_openai_unit.py
new file mode 100644
index 000000000..09524df8c
--- /dev/null
+++ b/test/backends/test_openai_unit.py
@@ -0,0 +1,172 @@
+"""Unit tests for OpenAI backend pure-logic helpers — no API calls required.
+
+Covers filter_openai_client_kwargs, filter_chat_completions_kwargs,
+_simplify_and_merge, and _make_backend_specific_and_remove.
+"""
+
+import pytest
+
+from mellea.backends import ModelOption
+from mellea.backends.openai import OpenAIBackend
+
+
+def _make_backend(model_options: dict | None = None) -> OpenAIBackend:
+ """Return an OpenAIBackend with a fake API key."""
+ return OpenAIBackend(
+ model_id="gpt-4o",
+ api_key="fake-key",
+ base_url="http://localhost:9999/v1",
+ model_options=model_options,
+ )
+
+
+@pytest.fixture
+def backend():
+ """Return an OpenAIBackend with no pre-set model options."""
+ return _make_backend()
+
+
+# --- filter_openai_client_kwargs ---
+
+
+def test_filter_openai_client_kwargs_removes_unknown():
+ result = OpenAIBackend.filter_openai_client_kwargs(
+ api_key="sk-test", unknown_param="x"
+ )
+ assert "api_key" in result
+ assert "unknown_param" not in result
+
+
+def test_filter_openai_client_kwargs_known_params():
+ result = OpenAIBackend.filter_openai_client_kwargs(
+ api_key="sk-test", base_url="http://localhost", timeout=30
+ )
+ assert "api_key" in result
+ assert "base_url" in result
+
+
+def test_filter_openai_client_kwargs_empty():
+ result = OpenAIBackend.filter_openai_client_kwargs()
+ assert result == {}
+
+
+# --- filter_chat_completions_kwargs ---
+
+
+def test_filter_chat_completions_keeps_valid_params(backend):
+ result = backend.filter_chat_completions_kwargs(
+ {"model": "gpt-4o", "temperature": 0.7, "unknown_option": True}
+ )
+ assert "model" in result
+ assert "temperature" in result
+ assert "unknown_option" not in result
+
+
+def test_filter_chat_completions_empty(backend):
+ result = backend.filter_chat_completions_kwargs({})
+ assert result == {}
+
+
+def test_filter_chat_completions_max_tokens(backend):
+ result = backend.filter_chat_completions_kwargs({"max_completion_tokens": 100})
+ assert "max_completion_tokens" in result
+
+
+# --- Map consistency ---
+
+
+@pytest.mark.parametrize("context", ["chats", "completions"])
+def test_from_mellea_keys_are_subset_of_to_mellea_values(backend, context):
+ """Every key in from_mellea must appear as a value in to_mellea (maps agree)."""
+ to_map = getattr(backend, f"to_mellea_model_opts_map_{context}")
+ from_map = getattr(backend, f"from_mellea_model_opts_map_{context}")
+ to_values = set(to_map.values())
+ from_keys = set(from_map.keys())
+ assert from_keys <= to_values, (
+ f"from_mellea_{context} has keys absent from to_mellea values: {from_keys - to_values}"
+ )
+
+
+# --- _simplify_and_merge ---
+
+
+def test_simplify_and_merge_none_returns_empty_dict(backend):
+ result = backend._simplify_and_merge(None, is_chat_context=True)
+ assert result == {}
+
+
+@pytest.mark.parametrize("context", ["chats", "completions"])
+def test_simplify_and_merge_all_to_mellea_entries(backend, context):
+ """Every to_mellea entry remaps to its ModelOption via _simplify_and_merge."""
+ is_chat = context == "chats"
+ to_map = getattr(backend, f"to_mellea_model_opts_map_{context}")
+ for backend_key, mellea_key in to_map.items():
+ result = backend._simplify_and_merge({backend_key: 42}, is_chat_context=is_chat)
+ assert mellea_key in result, f"{backend_key!r} did not produce {mellea_key!r}"
+ assert result[mellea_key] == 42
+
+
+def test_simplify_and_merge_remaps_max_completion_tokens(backend):
+ """Hardcoded anchor: the critical chat API mapping for generation length."""
+ result = backend._simplify_and_merge(
+ {"max_completion_tokens": 256}, is_chat_context=True
+ )
+ assert ModelOption.MAX_NEW_TOKENS in result
+ assert result[ModelOption.MAX_NEW_TOKENS] == 256
+
+
+def test_simplify_and_merge_completions_remaps_max_tokens(backend):
+ """Hardcoded anchor: completions API uses a different key for the same sentinel."""
+ result = backend._simplify_and_merge({"max_tokens": 100}, is_chat_context=False)
+ assert ModelOption.MAX_NEW_TOKENS in result
+ assert result[ModelOption.MAX_NEW_TOKENS] == 100
+
+
+def test_simplify_and_merge_per_call_overrides_backend():
+ # Backend sets max_completion_tokens=128; per-call value of 512 must win.
+ b = _make_backend(model_options={"max_completion_tokens": 128})
+ result = b._simplify_and_merge({"max_completion_tokens": 512}, is_chat_context=True)
+ assert result[ModelOption.MAX_NEW_TOKENS] == 512
+
+
+# --- _make_backend_specific_and_remove ---
+
+
+@pytest.mark.parametrize("context", ["chats", "completions"])
+def test_make_backend_specific_all_from_mellea_entries(backend, context):
+ """Every from_mellea entry remaps to its backend key via _make_backend_specific_and_remove."""
+ is_chat = context == "chats"
+ from_map = getattr(backend, f"from_mellea_model_opts_map_{context}")
+ for mellea_key, backend_key in from_map.items():
+ result = backend._make_backend_specific_and_remove(
+ {mellea_key: 42}, is_chat_context=is_chat
+ )
+ assert backend_key in result, f"{mellea_key!r} did not produce {backend_key!r}"
+ assert result[backend_key] == 42
+
+
+def test_make_backend_specific_chat_remaps_max_new_tokens(backend):
+ """Hardcoded anchor: chat API maps MAX_NEW_TOKENS → max_completion_tokens."""
+ opts = {ModelOption.MAX_NEW_TOKENS: 200}
+ result = backend._make_backend_specific_and_remove(opts, is_chat_context=True)
+ assert "max_completion_tokens" in result
+ assert result["max_completion_tokens"] == 200
+
+
+def test_make_backend_specific_completions_remaps_max_new_tokens(backend):
+ """Hardcoded anchor: completions API maps MAX_NEW_TOKENS → max_tokens."""
+ opts = {ModelOption.MAX_NEW_TOKENS: 100}
+ result = backend._make_backend_specific_and_remove(opts, is_chat_context=False)
+ assert "max_tokens" in result
+ assert result["max_tokens"] == 100
+
+
+def test_make_backend_specific_unknown_mellea_keys_removed(backend):
+ opts = {ModelOption.TOOLS: ["tool1"], ModelOption.SYSTEM_PROMPT: "sys"}
+ result = backend._make_backend_specific_and_remove(opts, is_chat_context=True)
+ # SYSTEM_PROMPT has no from_mellea mapping — should be removed
+ assert ModelOption.SYSTEM_PROMPT not in result
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/backends/test_utils.py b/test/backends/test_utils.py
new file mode 100644
index 000000000..764c8c9c7
--- /dev/null
+++ b/test/backends/test_utils.py
@@ -0,0 +1,164 @@
+"""Unit tests for backends/utils.py — get_value accessor and to_tool_calls parser."""
+
+from dataclasses import dataclass
+
+import pytest
+
+from mellea.backends.tools import MelleaTool
+from mellea.backends.utils import get_value, to_tool_calls
+from mellea.core import ModelToolCall
+
+# --- get_value ---
+
+
+def test_get_value_dict_present():
+ assert get_value({"a": 1, "b": 2}, "a") == 1
+
+
+def test_get_value_dict_missing():
+ assert get_value({"a": 1}, "missing") is None
+
+
+def test_get_value_object_attribute():
+ obj = type("Obj", (), {"x": "hello"})()
+ assert get_value(obj, "x") == "hello"
+
+
+def test_get_value_object_missing_attribute():
+ obj = type("Obj", (), {})()
+ assert get_value(obj, "nonexistent") is None
+
+
+def test_get_value_dict_none_value():
+ # Explicitly stored None should come back as None (same as get())
+ assert get_value({"k": None}, "k") is None
+
+
+@dataclass
+class _DC:
+ score: float
+ label: str
+
+
+def test_get_value_dataclass():
+ dc = _DC(score=0.9, label="positive")
+ assert get_value(dc, "score") == 0.9
+ assert get_value(dc, "label") == "positive"
+
+
+# --- to_tool_calls ---
+
+
+def _make_tool_registry() -> dict:
+ def add(x: int, y: int) -> int:
+ """Add two integers."""
+ return x + y
+
+ def greet(name: str) -> str:
+ """Greet a person."""
+ return f"Hello, {name}!"
+
+ return {
+ "add": MelleaTool.from_callable(add),
+ "greet": MelleaTool.from_callable(greet),
+ }
+
+
+def _tool_call_json(name: str, args: dict) -> str:
+ import json
+
+ return json.dumps([{"name": name, "arguments": args}])
+
+
+def test_to_tool_calls_single_call():
+ registry = _make_tool_registry()
+ raw = _tool_call_json("add", {"x": 3, "y": 4})
+ result = to_tool_calls(registry, raw)
+ assert result is not None
+ assert "add" in result
+ mtc = result["add"]
+ assert isinstance(mtc, ModelToolCall)
+ assert mtc.name == "add"
+ assert mtc.args == {"x": 3, "y": 4}
+
+
+def test_to_tool_calls_returns_none_when_no_calls():
+ registry = _make_tool_registry()
+ result = to_tool_calls(registry, "no tool call here")
+ assert result is None
+
+
+def test_to_tool_calls_unknown_tool_skipped():
+ registry = _make_tool_registry()
+ raw = _tool_call_json("nonexistent_fn", {"arg": "val"})
+ # Unknown tool is skipped — result should be None (empty dict → None)
+ result = to_tool_calls(registry, raw)
+ assert result is None
+
+
+def test_to_tool_calls_empty_params_cleared():
+ """When the tool has no parameters, hallucinated args should be stripped."""
+
+ def noop() -> str:
+ """Does nothing."""
+ return "done"
+
+ registry = {"noop": MelleaTool.from_callable(noop)}
+ raw = _tool_call_json("noop", {"hallucinated": "arg"})
+ result = to_tool_calls(registry, raw)
+ assert result is not None
+ assert result["noop"].args == {}
+
+
+def test_to_tool_calls_string_arg_coerced_to_int():
+ """validate_tool_arguments coerces strings to int when strict=False."""
+ registry = _make_tool_registry()
+ raw = _tool_call_json("add", {"x": "5", "y": "10"})
+ result = to_tool_calls(registry, raw)
+ assert result is not None
+ assert result["add"].args["x"] == 5
+ assert result["add"].args["y"] == 10
+
+
+# --- to_chat ---
+
+
+def test_to_chat_basic_message():
+ from mellea.backends.utils import to_chat
+ from mellea.formatters.template_formatter import TemplateFormatter as ChatFormatter
+ from mellea.stdlib.components import Message
+ from mellea.stdlib.context import ChatContext
+
+ ctx = ChatContext()
+ ctx = ctx.add(Message("user", "hello"))
+ action = Message("user", "next question")
+ formatter = ChatFormatter(model_id="test")
+
+ result = to_chat(action, ctx, formatter, system_prompt=None)
+ assert isinstance(result, list)
+ assert len(result) == 2
+ assert result[0]["role"] == "user"
+ assert result[0]["content"] == "hello"
+ assert result[1]["role"] == "user"
+ assert result[1]["content"] == "next question"
+
+
+def test_to_chat_with_system_prompt():
+ from mellea.backends.utils import to_chat
+ from mellea.formatters.template_formatter import TemplateFormatter as ChatFormatter
+ from mellea.stdlib.components import Message
+ from mellea.stdlib.context import ChatContext
+
+ ctx = ChatContext()
+ ctx = ctx.add(Message("user", "hi"))
+ action = Message("user", "q")
+ formatter = ChatFormatter(model_id="test")
+
+ result = to_chat(action, ctx, formatter, system_prompt="You are helpful.")
+ assert result[0]["role"] == "system"
+ assert result[0]["content"] == "You are helpful."
+ assert len(result) == 3 # system + user context + user action
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/core/test_base.py b/test/core/test_base.py
index cada2f42e..ff89148aa 100644
--- a/test/core/test_base.py
+++ b/test/core/test_base.py
@@ -1,8 +1,11 @@
+import base64
+import io
from typing import Any
import pytest
+from PIL import Image as PILImage
-from mellea.core import CBlock, Component, ModelOutputThunk
+from mellea.core import CBlock, Component, ImageBlock, ModelOutputThunk
from mellea.stdlib.components import Message
@@ -66,5 +69,93 @@ def __init__(self, msg: Message) -> None:
assert result.parsed_repr.content == "result value"
+# --- CBlock edge cases ---
+
+
+def test_cblock_non_string_value_raises():
+ with pytest.raises(TypeError, match="should always be a string or None"):
+ CBlock(value=42) # type: ignore
+
+
+def test_cblock_none_value_allowed():
+ cb = CBlock(value=None)
+ assert str(cb) == ""
+
+
+def test_cblock_value_setter():
+ cb = CBlock(value="old")
+ cb.value = "new"
+ assert cb.value == "new"
+
+
+# --- ImageBlock.is_valid_base64_png ---
+
+
+def _make_png_b64() -> str:
+ img = PILImage.new("RGB", (1, 1), color="red")
+ buf = io.BytesIO()
+ img.save(buf, format="PNG")
+ return base64.b64encode(buf.getvalue()).decode()
+
+
+def test_image_block_valid_png():
+ b64 = _make_png_b64()
+ assert ImageBlock.is_valid_base64_png(b64) is True
+
+
+def test_image_block_invalid_base64_returns_false():
+ assert ImageBlock.is_valid_base64_png("not-base64!!!") is False
+
+
+def test_image_block_valid_base64_but_not_png():
+ # Base64-encoded JPEG magic bytes
+ jpg_magic = base64.b64encode(b"\xff\xd8\xff" + b"\x00" * 20).decode()
+ assert ImageBlock.is_valid_base64_png(jpg_magic) is False
+
+
+def test_image_block_data_uri_prefix_stripped():
+ b64 = _make_png_b64()
+ data_uri = f"data:image/png;base64,{b64}"
+ assert ImageBlock.is_valid_base64_png(data_uri) is True
+
+
+def test_image_block_invalid_value_raises():
+ with pytest.raises(AssertionError, match="Invalid base64"):
+ ImageBlock(value="not-a-png")
+
+
+# --- ModelOutputThunk._copy_from ---
+
+
+def test_mot_copy_from_copies_underlying_value():
+ a = ModelOutputThunk(value=None)
+ b = ModelOutputThunk(value="copied")
+ a._copy_from(b)
+ # _copy_from copies _underlying_value (not _computed), so check raw field
+ assert a._underlying_value == "copied"
+
+
+def test_mot_copy_from_copies_meta():
+ a = ModelOutputThunk(value=None)
+ b = ModelOutputThunk(value="x", meta={"key": "val"})
+ a._copy_from(b)
+ assert a._meta["key"] == "val"
+
+
+def test_mot_copy_from_copies_tool_calls():
+ a = ModelOutputThunk(value=None)
+ b = ModelOutputThunk(value="x", tool_calls={"fn": None})
+ a._copy_from(b)
+ assert a.tool_calls == {"fn": None}
+
+
+def test_mot_copy_from_copies_usage():
+ a = ModelOutputThunk(value=None)
+ b = ModelOutputThunk(value="x")
+ b.usage = {"prompt_tokens": 10}
+ a._copy_from(b)
+ assert a.usage == {"prompt_tokens": 10}
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/test/core/test_requirement_helpers.py b/test/core/test_requirement_helpers.py
new file mode 100644
index 000000000..9ae3f71b5
--- /dev/null
+++ b/test/core/test_requirement_helpers.py
@@ -0,0 +1,91 @@
+"""Unit tests for core/requirement.py pure helpers — ValidationResult, default_output_to_bool."""
+
+import pytest
+
+from mellea.core import CBlock, ModelOutputThunk
+from mellea.core.requirement import ValidationResult, default_output_to_bool
+
+# --- ValidationResult ---
+
+
+def test_validation_result_pass():
+ r = ValidationResult(result=True)
+ assert r.as_bool() is True
+ assert bool(r) is True
+
+
+def test_validation_result_fail():
+ r = ValidationResult(result=False)
+ assert r.as_bool() is False
+ assert bool(r) is False
+
+
+def test_validation_result_reason():
+ r = ValidationResult(result=True, reason="looks good")
+ assert r.reason == "looks good"
+
+
+def test_validation_result_score():
+ r = ValidationResult(result=True, score=0.95)
+ assert r.score == pytest.approx(0.95)
+
+
+def test_validation_result_thunk():
+ mot = ModelOutputThunk(value="x")
+ r = ValidationResult(result=True, thunk=mot)
+ assert r.thunk is mot
+
+
+def test_validation_result_context():
+ from mellea.stdlib.context import SimpleContext
+
+ ctx = SimpleContext()
+ r = ValidationResult(result=True, context=ctx)
+ assert r.context is ctx
+
+
+def test_validation_result_defaults_none():
+ r = ValidationResult(result=False)
+ assert r.reason is None
+ assert r.score is None
+ assert r.thunk is None
+ assert r.context is None
+
+
+# --- default_output_to_bool ---
+
+
+def test_yes_exact_passes():
+ assert default_output_to_bool(CBlock("yes")) is True
+
+
+def test_yes_uppercase_passes():
+ assert default_output_to_bool(CBlock("YES")) is True
+
+
+def test_y_passes():
+ assert default_output_to_bool(CBlock("y")) is True
+
+
+def test_yes_in_sentence():
+ assert default_output_to_bool(CBlock("Yes, it meets the requirement.")) is True
+
+
+def test_no_fails():
+ assert default_output_to_bool(CBlock("no")) is False
+
+
+def test_empty_string_fails():
+ assert default_output_to_bool(CBlock("")) is False
+
+
+def test_random_text_fails():
+ assert default_output_to_bool(CBlock("the output looks reasonable")) is False
+
+
+def test_plain_string_yes():
+ assert default_output_to_bool("YES") is True # type: ignore
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/formatters/granite/base/__init__.py b/test/formatters/granite/base/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/test/formatters/granite/base/test_base_util.py b/test/formatters/granite/base/test_base_util.py
new file mode 100644
index 000000000..0a78b3e37
--- /dev/null
+++ b/test/formatters/granite/base/test_base_util.py
@@ -0,0 +1,61 @@
+"""Unit tests for formatters/granite/base/util.py pure helpers."""
+
+import pytest
+
+from mellea.formatters.granite.base.util import find_substring_in_text
+
+# --- find_substring_in_text ---
+
+
+def test_find_single_match():
+ result = find_substring_in_text("hello", "say hello world")
+ assert len(result) == 1
+ assert result[0]["begin_idx"] == 4
+ assert result[0]["end_idx"] == 9
+
+
+def test_find_multiple_matches():
+ result = find_substring_in_text("ab", "ababab")
+ assert len(result) == 3
+ # Verify positions are non-overlapping
+ assert result[0]["begin_idx"] == 0
+ assert result[1]["begin_idx"] == 2
+ assert result[2]["begin_idx"] == 4
+
+
+def test_find_no_match_returns_empty():
+ result = find_substring_in_text("xyz", "hello world")
+ assert result == []
+
+
+def test_find_empty_text_returns_empty():
+ result = find_substring_in_text("hello", "")
+ assert result == []
+
+
+def test_find_at_start():
+ result = find_substring_in_text("the", "the quick fox")
+ assert result[0]["begin_idx"] == 0
+
+
+def test_find_at_end():
+ result = find_substring_in_text("fox", "the quick fox")
+ assert result[-1]["end_idx"] == len("the quick fox")
+
+
+def test_find_full_text_match():
+ result = find_substring_in_text("exact", "exact")
+ assert len(result) == 1
+ assert result[0]["begin_idx"] == 0
+ assert result[0]["end_idx"] == 5
+
+
+def test_find_special_regex_chars_escaped():
+ # Dots in the substring should be treated literally
+ result = find_substring_in_text("a.b", "a.b and axb")
+ assert len(result) == 1
+ assert result[0]["begin_idx"] == 0
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/formatters/granite/test_granite32_output.py b/test/formatters/granite/test_granite32_output.py
index a5d0aa89f..356588628 100644
--- a/test/formatters/granite/test_granite32_output.py
+++ b/test/formatters/granite/test_granite32_output.py
@@ -29,6 +29,7 @@
Granite3Controls,
Granite3Kwargs,
)
+from test.predicates import require_nltk_data
# ---------------------------------------------------------------------------
# _parse_citations_text
@@ -285,6 +286,7 @@ def test_invalid_tool_call_falls_through(self):
assert result.tool_calls == []
assert isinstance(result.content, str)
+ @require_nltk_data()
def test_citations_and_hallucinations_pipeline(self):
proc = Granite32OutputProcessor()
model_output = (
diff --git a/test/formatters/granite/test_granite33_output.py b/test/formatters/granite/test_granite33_output.py
index 8ce4e8827..a283d09ae 100644
--- a/test/formatters/granite/test_granite33_output.py
+++ b/test/formatters/granite/test_granite33_output.py
@@ -31,6 +31,7 @@
Granite3Controls,
Granite3Kwargs,
)
+from test.predicates import require_nltk_data
# ---------------------------------------------------------------------------
# _parse_citations_text
@@ -256,6 +257,7 @@ def test_tool_call_parsing(self):
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "search"
+ @require_nltk_data()
def test_raw_content_set_when_different(self):
proc = Granite33OutputProcessor()
model_output = (
diff --git a/test/predicates.py b/test/predicates.py
index 2c00f1865..26f34c507 100644
--- a/test/predicates.py
+++ b/test/predicates.py
@@ -156,6 +156,52 @@ def test_watsonx_generate(): ...
return pytest.mark.skipif(False, reason="")
+# ---------------------------------------------------------------------------
+# NLTK data
+# ---------------------------------------------------------------------------
+
+
+def _nltk_data_available() -> tuple[bool, str]:
+ """Check whether nltk is installed *and* punkt_tab data is downloaded.
+
+ Returns a (available, reason) tuple so the skip message is specific:
+ - nltk not installed → "nltk not installed — install mellea[formatters]"
+ - punkt_tab missing → "NLTK punkt_tab data not downloaded — run: python -m nltk.downloader punkt_tab"
+ - both ok → (True, "")
+ """
+ try:
+ import nltk
+ except ImportError:
+ return False, "nltk not installed — install mellea[formatters]"
+
+ try:
+ import nltk.data
+
+ nltk.data.find("tokenizers/punkt_tab")
+ except LookupError:
+ return (
+ False,
+ "NLTK punkt_tab data not downloaded — run: python -m nltk.downloader punkt_tab",
+ )
+
+ return True, ""
+
+
+def require_nltk_data():
+ """Skip unless nltk is installed and punkt_tab tokenizer data is available.
+
+ Distinguishes between the two failure modes so the skip reason is actionable::
+
+ @require_nltk_data()
+ def test_citation_spans(): ...
+
+ # Module-level (skips all tests in the file):
+ pytestmark = [require_nltk_data()]
+ """
+ available, reason = _nltk_data_available()
+ return pytest.mark.skipif(not available, reason=reason)
+
+
# ---------------------------------------------------------------------------
# Optional dependencies
# ---------------------------------------------------------------------------
diff --git a/test/stdlib/components/test_chat.py b/test/stdlib/components/test_chat.py
index 66ebb9fc2..2319aafff 100644
--- a/test/stdlib/components/test_chat.py
+++ b/test/stdlib/components/test_chat.py
@@ -1,7 +1,10 @@
import pytest
+from mellea.core import CBlock, ModelOutputThunk, TemplateRepresentation
from mellea.helpers import messages_to_docs
from mellea.stdlib.components import Document, Message
+from mellea.stdlib.components.chat import ToolMessage, as_chat_history
+from mellea.stdlib.context import ChatContext
def test_message_with_docs():
@@ -22,5 +25,231 @@ def test_message_with_docs():
assert tr.args["documents"]
+# --- Message init ---
+
+
+def test_message_basic_fields():
+ msg = Message("user", "hello")
+ assert msg.role == "user"
+ assert msg.content == "hello"
+ assert msg._images is None
+ assert msg._docs is None
+
+
+def test_message_content_block_created():
+ msg = Message("assistant", "response")
+ assert isinstance(msg._content_cblock, CBlock)
+ assert msg._content_cblock.value == "response"
+
+
+def test_message_repr():
+ msg = Message("user", "hi there")
+ r = repr(msg)
+ assert 'role="user"' in r
+ assert 'content="hi there"' in r
+
+
+# --- Message images property ---
+
+
+def test_message_images_none():
+ msg = Message("user", "text")
+ assert msg.images is None
+
+
+# --- Message parts() ---
+
+
+def test_message_parts_no_docs_no_images():
+ msg = Message("user", "text")
+ parts = msg.parts()
+ assert len(parts) == 1
+ assert parts[0] is msg._content_cblock
+
+
+def test_message_parts_with_docs():
+ doc = Document("text", "title")
+ msg = Message("user", "hi", documents=[doc])
+ parts = msg.parts()
+ assert doc in parts
+
+
+# --- Message format_for_llm ---
+
+
+def test_message_format_for_llm_structure():
+ msg = Message("user", "hello")
+ tr = msg.format_for_llm()
+ assert isinstance(tr, TemplateRepresentation)
+ assert tr.args["role"] == "user"
+ assert tr.args["content"] is msg._content_cblock
+ assert tr.args["images"] is None
+ assert tr.args["documents"] is None
+
+
+# --- Message._parse — no tool calls ---
+
+
+def test_parse_plain_value_no_meta():
+ msg = Message("user", "original")
+ mot = ModelOutputThunk(value="model response")
+ result = msg._parse(mot)
+ assert isinstance(result, Message)
+ assert result.role == "assistant"
+ assert result.content == "model response"
+
+
+def test_parse_ollama_chat_response():
+ msg = Message("user", "q")
+ mot = ModelOutputThunk(value="v")
+ fake_response = type(
+ "Resp",
+ (),
+ {
+ "message": type(
+ "Msg", (), {"role": "assistant", "content": "ollama answer"}
+ )()
+ },
+ )()
+ mot._meta["chat_response"] = fake_response
+ result = msg._parse(mot)
+ assert result.role == "assistant"
+ assert result.content == "ollama answer"
+
+
+def test_parse_openai_chat_response():
+ msg = Message("user", "q")
+ mot = ModelOutputThunk(value="v")
+ mot._meta["oai_chat_response"] = {
+ "choices": [{"message": {"role": "assistant", "content": "openai answer"}}]
+ }
+ result = msg._parse(mot)
+ assert result.role == "assistant"
+ assert result.content == "openai answer"
+
+
+# --- Message._parse — with tool calls ---
+
+
+def test_parse_tool_calls_ollama():
+ msg = Message("user", "q")
+ mot = ModelOutputThunk(value="v", tool_calls={"some_fn": None})
+ fake_calls = [{"name": "some_fn"}]
+ fake_response = type(
+ "Resp",
+ (),
+ {"message": type("Msg", (), {"role": "assistant", "tool_calls": fake_calls})()},
+ )()
+ mot._meta["chat_response"] = fake_response
+ result = msg._parse(mot)
+ assert result.role == "assistant"
+ assert "some_fn" in result.content
+
+
+def test_parse_tool_calls_openai():
+ msg = Message("user", "q")
+ mot = ModelOutputThunk(value="v", tool_calls={"fn": None})
+ mot._meta["oai_chat_response"] = {
+ "choices": [
+ {
+ "message": {
+ "role": "assistant",
+ "tool_calls": [{"function": {"name": "fn"}}],
+ }
+ }
+ ]
+ }
+ result = msg._parse(mot)
+ assert result.role == "assistant"
+
+
+def test_parse_tool_calls_fallback_uses_value():
+ """No chat_response or oai_chat_response — falls back to computed.value."""
+ msg = Message("user", "q")
+ mot = ModelOutputThunk(value="fn()", tool_calls={"fn": None})
+ result = msg._parse(mot)
+ assert result.role == "assistant"
+ assert result.content == "fn()"
+
+
+# --- ToolMessage ---
+
+
+def test_tool_message_fields():
+ from mellea.core import ModelToolCall
+
+ fake_tool = type("T", (), {"as_json_tool": {}})()
+ mtc = ModelToolCall("my_tool", fake_tool, {"x": 1})
+ tm = ToolMessage(
+ role="tool",
+ content='{"result": 42}',
+ tool_output=42,
+ name="my_tool",
+ args={"x": 1},
+ tool=mtc,
+ )
+ assert tm.role == "tool"
+ assert tm.name == "my_tool"
+ assert tm.arguments == {"x": 1}
+
+
+def test_tool_message_format_for_llm_includes_name():
+ from mellea.core import ModelToolCall
+
+ fake_tool = type("T", (), {"as_json_tool": {}})()
+ mtc = ModelToolCall("my_tool", fake_tool, {})
+ tm = ToolMessage(
+ role="tool",
+ content="output",
+ tool_output="output",
+ name="my_tool",
+ args={},
+ tool=mtc,
+ )
+ tr = tm.format_for_llm()
+ assert isinstance(tr, TemplateRepresentation)
+ assert tr.args["name"] == "my_tool"
+
+
+def test_tool_message_repr():
+ from mellea.core import ModelToolCall
+
+ fake_tool = type("T", (), {"as_json_tool": {}})()
+ mtc = ModelToolCall("fn", fake_tool, {})
+ tm = ToolMessage("tool", "out", "out", "fn", {}, mtc)
+ r = repr(tm)
+ assert 'name="fn"' in r
+
+
+# --- as_chat_history ---
+
+
+def test_as_chat_history_messages_only():
+ ctx = ChatContext()
+ ctx = ctx.add(Message("user", "hello"))
+ ctx = ctx.add(Message("assistant", "hi"))
+ history = as_chat_history(ctx)
+ assert len(history) == 2
+ assert history[0].role == "user"
+ assert history[1].role == "assistant"
+
+
+def test_as_chat_history_empty():
+ ctx = ChatContext()
+ history = as_chat_history(ctx)
+ assert history == []
+
+
+def test_as_chat_history_with_parsed_mot():
+ ctx = ChatContext()
+ ctx = ctx.add(Message("user", "hello"))
+ mot = ModelOutputThunk(value="reply")
+ mot.parsed_repr = Message("assistant", "reply")
+ ctx = ctx.add(mot)
+ history = as_chat_history(ctx)
+ assert len(history) == 2
+ assert history[1].content == "reply"
+
+
if __name__ == "__main__":
- pytest.main([__file__])
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/components/test_genstub_unit.py b/test/stdlib/components/test_genstub_unit.py
new file mode 100644
index 000000000..20d046814
--- /dev/null
+++ b/test/stdlib/components/test_genstub_unit.py
@@ -0,0 +1,326 @@
+"""Unit tests for genstub pure-logic helpers — no backend, no LLM required.
+
+Covers describe_function, get_argument, bind_function_arguments,
+create_response_format, GenerativeStub.format_for_llm, and @generative routing.
+"""
+
+from typing import Literal
+
+import pytest
+
+from mellea import generative
+from mellea.core import TemplateRepresentation, ValidationResult
+from mellea.stdlib.components.genstub import (
+ ArgPreconditionRequirement,
+ Arguments,
+ AsyncGenerativeStub,
+ Function,
+ PreconditionException,
+ SyncGenerativeStub,
+ bind_function_arguments,
+ create_response_format,
+ describe_function,
+ get_argument,
+)
+from mellea.stdlib.requirements.requirement import reqify
+
+# --- describe_function ---
+
+
+def test_describe_function_name():
+ def greet(name: str) -> str:
+ """Say hello."""
+ return f"Hello {name}"
+
+ result = describe_function(greet)
+ assert result["name"] == "greet"
+
+
+def test_describe_function_signature_includes_params():
+ def add(x: int, y: int) -> int:
+ return x + y
+
+ result = describe_function(add)
+ assert "x" in result["signature"]
+ assert "y" in result["signature"]
+
+
+def test_describe_function_docstring():
+ def noop() -> None:
+ """Does nothing."""
+
+ result = describe_function(noop)
+ assert result["docstring"] == "Does nothing."
+
+
+def test_describe_function_no_docstring():
+ def bare():
+ pass
+
+ result = describe_function(bare)
+ assert result["docstring"] is None
+
+
+# --- get_argument ---
+
+
+def test_get_argument_string_value_quoted():
+ def fn(name: str) -> None:
+ pass
+
+ arg = get_argument(fn, "name", "Alice")
+ assert arg._argument_dict["value"] == '"Alice"'
+ assert arg._argument_dict["name"] == "name"
+
+
+def test_get_argument_int_value_not_quoted():
+ def fn(count: int) -> None:
+ pass
+
+ arg = get_argument(fn, "count", 42)
+ assert arg._argument_dict["value"] == 42
+ assert "int" in str(arg._argument_dict["annotation"])
+
+
+def test_get_argument_no_annotation_falls_back_to_runtime_type():
+ # No annotation on kwargs — should fall back to type(val)
+ def fn(**kwargs) -> None:
+ pass
+
+ arg = get_argument(fn, "x", 3.14)
+ assert "float" in str(arg._argument_dict["annotation"])
+
+
+# --- bind_function_arguments ---
+
+
+def test_bind_function_arguments_basic():
+ def fn(x: int, y: int) -> int:
+ return x + y
+
+ result = bind_function_arguments(fn, x=1, y=2)
+ assert result == {"x": 1, "y": 2}
+
+
+def test_bind_function_arguments_with_defaults():
+ def fn(x: int, y: int = 10) -> int:
+ return x + y
+
+ result = bind_function_arguments(fn, x=5)
+ assert result == {"x": 5, "y": 10}
+
+
+def test_bind_function_arguments_missing_required_raises():
+ def fn(x: int, y: int) -> int:
+ return x + y
+
+ with pytest.raises(TypeError, match="missing required parameter"):
+ bind_function_arguments(fn, x=1)
+
+
+def test_bind_function_arguments_no_params():
+ def fn() -> str:
+ return "hi"
+
+ result = bind_function_arguments(fn)
+ assert result == {}
+
+
+# --- create_response_format ---
+
+
+def test_create_response_format_class_name_derived_from_func():
+ def get_sentiment() -> str: ...
+
+ model = create_response_format(get_sentiment)
+ assert "GetSentiment" in model.__name__
+
+
+def test_create_response_format_result_field_accessible():
+ def score_text() -> float: ...
+
+ model = create_response_format(score_text)
+ instance = model(result=0.9)
+ assert instance.result == 0.9
+
+
+def test_create_response_format_literal_type():
+ def classify() -> Literal["pos", "neg"]: ...
+
+ model = create_response_format(classify)
+ instance = model(result="pos")
+ assert instance.result == "pos"
+
+
+# --- GenerativeStub.format_for_llm ---
+
+
+def test_generative_stub_format_for_llm_returns_template_repr():
+ @generative
+ def summarise(text: str) -> str:
+ """Summarise the given text."""
+
+ result = summarise.format_for_llm()
+ assert isinstance(result, TemplateRepresentation)
+
+
+def test_generative_stub_format_for_llm_includes_function_name():
+ @generative
+ def my_function(x: int) -> int: ...
+
+ result = my_function.format_for_llm()
+ assert result.args["function"]["name"] == "my_function"
+
+
+def test_generative_stub_format_for_llm_includes_docstring():
+ @generative
+ def documented() -> str:
+ """This is the docstring."""
+
+ result = documented.format_for_llm()
+ assert result.args["function"]["docstring"] == "This is the docstring."
+
+
+def test_generative_stub_format_for_llm_no_args_until_called():
+ @generative
+ def fn() -> str: ...
+
+ result = fn.format_for_llm()
+ assert result.args["arguments"] is None
+
+
+# --- @generative decorator routing ---
+
+
+def test_generative_sync_function_returns_sync_stub():
+ @generative
+ def sync_fn() -> str: ...
+
+ assert isinstance(sync_fn, SyncGenerativeStub)
+
+
+def test_generative_async_function_returns_async_stub():
+ @generative
+ async def async_fn() -> str: ...
+
+ assert isinstance(async_fn, AsyncGenerativeStub)
+
+
+def test_generative_disallowed_param_name_raises():
+ with pytest.raises(ValueError, match="disallowed parameter names"):
+
+ @generative
+ def fn(backend: str) -> str: ...
+
+
+# --- Arguments (CBlock subclass rendering bound args) ---
+
+
+def test_arguments_renders_text():
+ def fn(name: str, count: int) -> None:
+ pass
+
+ args = [get_argument(fn, "name", "Alice"), get_argument(fn, "count", 3)]
+ block = Arguments(args)
+ assert "name" in block.value
+ assert "count" in block.value
+
+
+def test_arguments_stores_meta_by_name():
+ def fn(x: int) -> None:
+ pass
+
+ args = [get_argument(fn, "x", 5)]
+ block = Arguments(args)
+ assert "x" in block._meta
+
+
+def test_arguments_empty_list():
+ block = Arguments([])
+ assert block.value == ""
+
+
+# --- Function (wraps callable with metadata) ---
+
+
+def test_function_stores_callable():
+ def greet(name: str) -> str:
+ """Say hi."""
+ return f"hi {name}"
+
+ f = Function(greet)
+ assert f._func is greet
+ assert f._function_dict["name"] == "greet"
+ assert f._function_dict["docstring"] == "Say hi."
+
+
+# --- ArgPreconditionRequirement (requirement wrapper) ---
+
+
+def test_arg_precondition_delegates_description():
+ req = reqify("must be non-empty")
+ wrapper = ArgPreconditionRequirement(req)
+ assert wrapper.description == req.description
+
+
+def test_arg_precondition_copy():
+ from copy import copy
+
+ req = reqify("be valid")
+ wrapper = ArgPreconditionRequirement(req)
+ copied = copy(wrapper)
+ assert isinstance(copied, ArgPreconditionRequirement)
+ assert copied.req is req
+
+
+def test_arg_precondition_deepcopy():
+ from copy import deepcopy
+
+ req = reqify("be clean")
+ wrapper = ArgPreconditionRequirement(req)
+ cloned = deepcopy(wrapper)
+ assert isinstance(cloned, ArgPreconditionRequirement)
+ assert cloned.description == req.description
+
+
+# --- PreconditionException ---
+
+
+def test_precondition_exception_message():
+ vr = ValidationResult(result=False, reason="failed check")
+ exc = PreconditionException("precondition failed", [vr])
+ assert "precondition failed" in str(exc)
+ assert exc.validation == [vr]
+
+
+# --- GenerativeStub._parse ---
+
+
+def test_genstub_parse_json_to_result():
+ import json
+
+ from mellea.core import ModelOutputThunk
+
+ @generative
+ def classify(text: str) -> str: ...
+
+ mot = ModelOutputThunk(value=json.dumps({"result": "positive"}))
+ parsed = classify._parse(mot)
+ assert parsed == "positive"
+
+
+def test_genstub_parse_int_result():
+ import json
+
+ from mellea.core import ModelOutputThunk
+
+ @generative
+ def compute(x: int) -> int: ...
+
+ mot = ModelOutputThunk(value=json.dumps({"result": 42}))
+ parsed = compute._parse(mot)
+ assert parsed == 42
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/components/test_instruction.py b/test/stdlib/components/test_instruction.py
new file mode 100644
index 000000000..aef2b6984
--- /dev/null
+++ b/test/stdlib/components/test_instruction.py
@@ -0,0 +1,264 @@
+"""Unit tests for the Instruction component — init, jinja rendering, copy/repair, parts, format."""
+
+import pytest
+
+from mellea.core import CBlock, ModelOutputThunk, Requirement, TemplateRepresentation
+from mellea.stdlib.components.instruction import Instruction
+
+# --- basic init ---
+
+
+def test_init_minimal():
+ ins = Instruction(description="summarise the text")
+ assert ins._description is not None
+ assert str(ins._description) == "summarise the text"
+ assert ins._requirements == []
+ assert ins._icl_examples == []
+ assert ins._grounding_context == {}
+ assert ins._repair_string is None
+
+
+def test_init_no_args():
+ ins = Instruction()
+ assert ins._description is None
+ assert ins._requirements == []
+
+
+def test_init_converts_string_description_to_cblock():
+ ins = Instruction(description="hello")
+ assert isinstance(ins._description, CBlock)
+
+
+def test_init_accepts_cblock_description():
+ cb = CBlock("already a block")
+ ins = Instruction(description=cb)
+ assert ins._description is cb
+
+
+def test_init_string_requirements_converted():
+ ins = Instruction(requirements=["must be concise", "must be accurate"])
+ assert len(ins._requirements) == 2
+ for r in ins._requirements:
+ assert isinstance(r, Requirement)
+
+
+def test_init_requirement_objects_preserved():
+ r = Requirement(description="no profanity")
+ ins = Instruction(requirements=[r])
+ assert ins._requirements[0].description == "no profanity"
+
+
+def test_init_grounding_context_strings_blockified():
+ ins = Instruction(grounding_context={"doc1": "some content"})
+ assert isinstance(ins._grounding_context["doc1"], CBlock)
+
+
+def test_init_prefix_converted():
+ ins = Instruction(prefix="Answer:")
+ assert isinstance(ins._prefix, CBlock)
+
+
+def test_init_output_prefix_raises():
+ """output_prefix is currently unsupported; should raise AssertionError."""
+ with pytest.raises(
+ AssertionError, match="output_prefix is not currently supported"
+ ):
+ Instruction(user_variables={"x": "y"}, output_prefix="Result:")
+
+
+# --- apply_user_dict_from_jinja ---
+
+
+def test_jinja_simple_substitution():
+ result = Instruction.apply_user_dict_from_jinja(
+ {"name": "world"}, "Hello {{ name }}!"
+ )
+ assert result == "Hello world!"
+
+
+def test_jinja_multiple_variables():
+ result = Instruction.apply_user_dict_from_jinja(
+ {"a": "foo", "b": "bar"}, "{{ a }} and {{ b }}"
+ )
+ assert result == "foo and bar"
+
+
+def test_jinja_missing_variable_renders_empty():
+ result = Instruction.apply_user_dict_from_jinja({}, "Hello {{ name }}!")
+ assert result == "Hello !"
+
+
+def test_jinja_no_variables():
+ result = Instruction.apply_user_dict_from_jinja({}, "plain string")
+ assert result == "plain string"
+
+
+# --- user_variables applied to fields ---
+
+
+def test_user_variables_applied_to_description():
+ ins = Instruction(
+ description="Task: {{ task }}", user_variables={"task": "translate"}
+ )
+ assert str(ins._description) == "Task: translate"
+
+
+def test_user_variables_applied_to_prefix():
+ ins = Instruction(
+ prefix="{{ prefix_word }}:", user_variables={"prefix_word": "Answer"}
+ )
+ assert str(ins._prefix) == "Answer:"
+
+
+def test_user_variables_applied_to_requirements():
+ ins = Instruction(
+ requirements=["must be in {{ lang }}"], user_variables={"lang": "French"}
+ )
+ assert ins._requirements[0].description == "must be in French"
+
+
+def test_user_variables_applied_to_icl_examples():
+ ins = Instruction(icl_examples=["Example: {{ ex }}"], user_variables={"ex": "blue"})
+ assert str(ins._icl_examples[0]) == "Example: blue"
+
+
+def test_user_variables_applied_to_grounding_context():
+ ins = Instruction(
+ grounding_context={"doc": "See {{ ref }}"}, user_variables={"ref": "section 3"}
+ )
+ assert str(ins._grounding_context["doc"]) == "See section 3"
+
+
+def test_user_variables_description_must_be_string():
+ with pytest.raises(AssertionError, match="description must be a string"):
+ Instruction(description=CBlock("not a string"), user_variables={"x": "y"})
+
+
+def test_user_variables_requirement_object_description_rendered():
+ r = Requirement(description="must be in {{ lang }}")
+ ins = Instruction(requirements=[r], user_variables={"lang": "Spanish"})
+ assert ins._requirements[0].description == "must be in Spanish"
+
+
+# --- parts() ---
+
+
+def test_parts_includes_description():
+ ins = Instruction(description="do something")
+ parts = ins.parts()
+ assert ins._description in parts
+
+
+def test_parts_includes_requirements():
+ r = Requirement(description="be concise")
+ ins = Instruction(description="task", requirements=[r])
+ assert r in ins.parts()
+
+
+def test_parts_includes_grounding_context_values():
+ ins = Instruction(grounding_context={"doc": "content"})
+ parts = ins.parts()
+ assert ins._grounding_context["doc"] in parts
+
+
+def test_parts_empty_instruction():
+ ins = Instruction()
+ # No description, no requirements, no grounding context
+ assert ins.parts() == []
+
+
+def test_parts_includes_icl_examples():
+ ins = Instruction(icl_examples=["example 1"])
+ parts = ins.parts()
+ assert len(parts) == 1
+
+
+# --- format_for_llm ---
+
+
+def test_format_for_llm_returns_template_representation():
+ ins = Instruction(description="do something")
+ result = ins.format_for_llm()
+ assert isinstance(result, TemplateRepresentation)
+
+
+def test_format_for_llm_args_structure():
+ ins = Instruction(description="task", requirements=["req 1"], icl_examples=["ex 1"])
+ result = ins.format_for_llm()
+ assert "description" in result.args
+ assert "requirements" in result.args
+ assert "icl_examples" in result.args
+ assert "grounding_context" in result.args
+ assert "repair" in result.args
+
+
+def test_format_for_llm_check_only_req_excluded():
+ r = Requirement(description="internal check", check_only=True)
+ ins = Instruction(requirements=[r])
+ result = ins.format_for_llm()
+ assert r.description not in result.args["requirements"]
+
+
+def test_format_for_llm_repair_is_none_by_default():
+ ins = Instruction(description="task")
+ result = ins.format_for_llm()
+ assert result.args["repair"] is None
+
+
+# --- copy_and_repair ---
+
+
+def test_copy_and_repair_sets_repair_string():
+ ins = Instruction(description="task", requirements=["be brief"])
+ repaired = ins.copy_and_repair("requirement 'be brief' not met")
+ assert repaired._repair_string == "requirement 'be brief' not met"
+
+
+def test_copy_and_repair_does_not_mutate_original():
+ ins = Instruction(description="task")
+ _ = ins.copy_and_repair("failed")
+ assert ins._repair_string is None
+
+
+def test_copy_and_repair_deep_copy():
+ ins = Instruction(description="task", requirements=["be brief"])
+ repaired = ins.copy_and_repair("reason")
+ # Mutating the copy's requirements should not affect the original
+ repaired._requirements.append(Requirement(description="new"))
+ assert len(ins._requirements) == 1
+
+
+def test_copy_and_repair_format_includes_repair():
+ ins = Instruction(description="task")
+ repaired = ins.copy_and_repair("please fix this")
+ result = repaired.format_for_llm()
+ assert result.args["repair"] == "please fix this"
+
+
+# --- _parse ---
+
+
+def test_parse_returns_value():
+ ins = Instruction(description="x")
+ mot = ModelOutputThunk(value="answer")
+ assert ins._parse(mot) == "answer"
+
+
+def test_parse_none_returns_empty_string():
+ ins = Instruction(description="x")
+ mot = ModelOutputThunk(value=None)
+ assert ins._parse(mot) == ""
+
+
+# --- requirements property ---
+
+
+def test_requirements_property():
+ ins = Instruction(requirements=["be brief", "be accurate"])
+ reqs = ins.requirements
+ assert len(reqs) == 2
+ assert all(isinstance(r, Requirement) for r in reqs)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/components/test_mobject.py b/test/stdlib/components/test_mobject.py
new file mode 100644
index 000000000..05166a4be
--- /dev/null
+++ b/test/stdlib/components/test_mobject.py
@@ -0,0 +1,213 @@
+"""Unit tests for Query, Transform, and MObject — no docling, no backend required."""
+
+import pytest
+
+from mellea.core import ModelOutputThunk, TemplateRepresentation
+from mellea.stdlib.components.mobject import MObject, Query, Transform
+
+# --- helpers ---
+
+
+class _SimpleComponent(MObject):
+ """Minimal MObject subclass for testing."""
+
+ def __init__(self, content: str = "hello") -> None:
+ super().__init__()
+ self._content = content
+
+ def content_as_string(self) -> str:
+ return self._content
+
+ def format_for_llm(self) -> str:
+ return self._content
+
+ def parts(self):
+ return []
+
+ def _parse(self, computed):
+ return computed.value or ""
+
+
+# --- Query ---
+
+
+def test_query_parts_returns_wrapped_object():
+ obj = _SimpleComponent("doc text")
+ q = Query(obj, "what is this?")
+ parts = q.parts()
+ assert len(parts) == 1
+ assert parts[0] is obj
+
+
+def test_query_format_for_llm_returns_template_repr():
+ obj = _SimpleComponent("text")
+ q = Query(obj, "summarise")
+ result = q.format_for_llm()
+ assert isinstance(result, TemplateRepresentation)
+
+
+def test_query_format_for_llm_query_field():
+ obj = _SimpleComponent("text")
+ q = Query(obj, "what colour?")
+ result = q.format_for_llm()
+ assert result.args["query"] == "what colour?"
+
+
+def test_query_format_for_llm_content_is_wrapped_object():
+ obj = _SimpleComponent("text")
+ q = Query(obj, "q")
+ result = q.format_for_llm()
+ assert result.args["content"] is obj
+
+
+def test_query_parse_returns_value():
+ obj = _SimpleComponent()
+ q = Query(obj, "q")
+ mot = ModelOutputThunk(value="answer")
+ assert q._parse(mot) == "answer"
+
+
+def test_query_parse_none_returns_empty():
+ obj = _SimpleComponent()
+ q = Query(obj, "q")
+ mot = ModelOutputThunk(value=None)
+ assert q._parse(mot) == ""
+
+
+# --- Transform ---
+
+
+def test_transform_parts_returns_wrapped_object():
+ obj = _SimpleComponent("doc text")
+ t = Transform(obj, "translate to French")
+ parts = t.parts()
+ assert len(parts) == 1
+ assert parts[0] is obj
+
+
+def test_transform_format_for_llm_returns_template_repr():
+ obj = _SimpleComponent("text")
+ t = Transform(obj, "rewrite formally")
+ result = t.format_for_llm()
+ assert isinstance(result, TemplateRepresentation)
+
+
+def test_transform_format_for_llm_transformation_field():
+ obj = _SimpleComponent("text")
+ t = Transform(obj, "make it shorter")
+ result = t.format_for_llm()
+ assert result.args["transformation"] == "make it shorter"
+
+
+def test_transform_format_for_llm_content_is_wrapped_object():
+ obj = _SimpleComponent("text")
+ t = Transform(obj, "x")
+ result = t.format_for_llm()
+ assert result.args["content"] is obj
+
+
+def test_transform_parse_returns_value():
+ obj = _SimpleComponent()
+ t = Transform(obj, "x")
+ mot = ModelOutputThunk(value="result")
+ assert t._parse(mot) == "result"
+
+
+# --- MObject ---
+
+
+def test_mobject_parts_empty():
+ obj = _SimpleComponent()
+ assert obj.parts() == []
+
+
+def test_mobject_get_query_object():
+ obj = _SimpleComponent("text")
+ q = obj.get_query_object("what is this?")
+ assert isinstance(q, Query)
+ assert q._query == "what is this?"
+ assert q._obj is obj
+
+
+def test_mobject_get_transform_object():
+ obj = _SimpleComponent("text")
+ t = obj.get_transform_object("shorten it")
+ assert isinstance(t, Transform)
+ assert t._transformation == "shorten it"
+ assert t._obj is obj
+
+
+def test_mobject_content_as_string():
+ obj = _SimpleComponent("my content")
+ assert obj.content_as_string() == "my content"
+
+
+def test_mobject_format_for_llm_returns_template_repr():
+ obj = _SimpleComponent("text")
+ result = obj.format_for_llm()
+ # Uses the overridden format_for_llm returning str
+ assert result == "text"
+
+
+def test_mobject_custom_query_type():
+ class _CustomQuery(Query):
+ pass
+
+ obj = MObject(query_type=_CustomQuery)
+ q = obj.get_query_object("q")
+ assert isinstance(q, _CustomQuery)
+
+
+def test_mobject_custom_transform_type():
+ class _CustomTransform(Transform):
+ pass
+
+ obj = MObject(transform_type=_CustomTransform)
+ t = obj.get_transform_object("t")
+ assert isinstance(t, _CustomTransform)
+
+
+def test_mobj_base_format_for_llm():
+ """Test MObject.format_for_llm (not the overridden version) via base class directly."""
+
+ class _MObjectWithTools(MObject):
+ def my_tool(self) -> str:
+ """A custom tool."""
+ return "result"
+
+ def content_as_string(self) -> str:
+ return "content"
+
+ def parts(self):
+ return []
+
+ def format_for_llm(self):
+ return MObject.format_for_llm(self)
+
+ def _parse(self, computed):
+ return ""
+
+ obj = _MObjectWithTools()
+ result = obj.format_for_llm()
+ assert isinstance(result, TemplateRepresentation)
+ assert result.args["content"] == "content"
+
+
+def test_mobj_parse_returns_value():
+ class _M(MObject):
+ def content_as_string(self):
+ return ""
+
+ def parts(self):
+ return []
+
+ def _parse(self, computed):
+ return MObject._parse(self, computed)
+
+ obj = _M()
+ mot = ModelOutputThunk(value="result")
+ assert obj._parse(mot) == "result"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/components/test_simple.py b/test/stdlib/components/test_simple.py
new file mode 100644
index 000000000..c62fa9a4a
--- /dev/null
+++ b/test/stdlib/components/test_simple.py
@@ -0,0 +1,135 @@
+"""Unit tests for SimpleComponent — kwargs rendering, type validation, JSON output."""
+
+import json
+
+import pytest
+
+from mellea.core import CBlock, ModelOutputThunk
+from mellea.stdlib.components.simple import SimpleComponent
+
+# --- constructor & type checking ---
+
+
+def test_init_converts_strings_to_cblocks():
+ sc = SimpleComponent(task="write a poem")
+ assert isinstance(sc._kwargs["task"], CBlock)
+ assert sc._kwargs["task"].value == "write a poem"
+
+
+def test_init_accepts_cblock_directly():
+ cb = CBlock("already a block")
+ sc = SimpleComponent(thing=cb)
+ assert sc._kwargs["thing"] is cb
+
+
+def test_init_rejects_non_string_non_component():
+ with pytest.raises(AssertionError):
+ SimpleComponent(bad=42)
+
+
+def test_init_rejects_non_string_key():
+ # We can't pass non-string keys via kwargs syntax; test _kwargs_type_check directly
+ sc = SimpleComponent(ok="fine")
+ with pytest.raises(AssertionError):
+ sc._kwargs_type_check({123: CBlock("v")})
+
+
+def test_init_multiple_kwargs():
+ sc = SimpleComponent(task="summarise", context="some text")
+ assert len(sc._kwargs) == 2
+ assert set(sc._kwargs.keys()) == {"task", "context"}
+
+
+# --- parts() ---
+
+
+def test_parts_returns_all_values():
+ sc = SimpleComponent(a="one", b="two")
+ parts = sc.parts()
+ assert len(parts) == 2
+ assert all(isinstance(p, CBlock) for p in parts)
+
+
+def test_parts_empty():
+ sc = SimpleComponent()
+ assert sc.parts() == []
+
+
+# --- make_simple_string ---
+
+
+def test_make_simple_string_single():
+ kwargs = {"task": CBlock("do something")}
+ result = SimpleComponent.make_simple_string(kwargs)
+ assert result == "<|task|>do something|task|>"
+
+
+def test_make_simple_string_multiple():
+ # Use ordered dict (Python 3.7+ guarantees insertion order)
+ kwargs = {"a": CBlock("first"), "b": CBlock("second")}
+ result = SimpleComponent.make_simple_string(kwargs)
+ assert "<|a|>first|a|>" in result
+ assert "<|b|>second|b|>" in result
+ assert "\n" in result
+
+
+def test_make_simple_string_empty():
+ assert SimpleComponent.make_simple_string({}) == ""
+
+
+# --- make_json_string ---
+
+
+def test_make_json_string_cblock():
+ kwargs = {"key": CBlock("value")}
+ result = json.loads(SimpleComponent.make_json_string(kwargs))
+ assert result == {"key": "value"}
+
+
+def test_make_json_string_model_output_thunk():
+ mot = ModelOutputThunk(value="output text")
+ kwargs = {"out": mot}
+ result = json.loads(SimpleComponent.make_json_string(kwargs))
+ assert result == {"out": "output text"}
+
+
+def test_make_json_string_nested_component():
+ inner = SimpleComponent(x="nested")
+ kwargs = {"inner": inner}
+ result = json.loads(SimpleComponent.make_json_string(kwargs))
+ assert "inner" in result
+
+
+def test_make_json_string_empty():
+ result = json.loads(SimpleComponent.make_json_string({}))
+ assert result == {}
+
+
+# --- format_for_llm ---
+
+
+def test_format_for_llm_returns_json_string():
+ sc = SimpleComponent(topic="ocean", style="poetic")
+ formatted = sc.format_for_llm()
+ parsed = json.loads(formatted)
+ assert parsed["topic"] == "ocean"
+ assert parsed["style"] == "poetic"
+
+
+# --- _parse ---
+
+
+def test_parse_returns_value():
+ sc = SimpleComponent(x="whatever")
+ mot = ModelOutputThunk(value="result")
+ assert sc._parse(mot) == "result"
+
+
+def test_parse_none_returns_empty_string():
+ sc = SimpleComponent(x="whatever")
+ mot = ModelOutputThunk(value=None)
+ assert sc._parse(mot) == ""
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/sampling/test_majority_voting_unit.py b/test/stdlib/sampling/test_majority_voting_unit.py
new file mode 100644
index 000000000..cf3429fe7
--- /dev/null
+++ b/test/stdlib/sampling/test_majority_voting_unit.py
@@ -0,0 +1,74 @@
+"""Unit tests for majority voting compare_strings methods — no backend required."""
+
+import pytest
+
+from mellea.stdlib.sampling.majority_voting import (
+ MajorityVotingStrategyForMath,
+ MBRDRougeLStrategy,
+)
+
+# --- MajorityVotingStrategyForMath.compare_strings ---
+
+
+@pytest.fixture
+def math_strategy():
+ return MajorityVotingStrategyForMath()
+
+
+def test_math_compare_identical_boxed(math_strategy):
+ assert math_strategy.compare_strings(r"\boxed{2}", r"\boxed{2}") == 1.0
+
+
+def test_math_compare_identical_latex(math_strategy):
+ assert math_strategy.compare_strings(r"\boxed{4}", r"\boxed{4}") == 1.0
+
+
+def test_math_compare_unboxed_integers_return_zero(math_strategy):
+ # Plain integers without boxed notation are not extracted — returns 0.0
+ assert math_strategy.compare_strings("2", "3") == 0.0
+
+
+def test_math_compare_different_boxed(math_strategy):
+ assert math_strategy.compare_strings(r"\boxed{2}", r"\boxed{3}") == 0.0
+
+
+def test_math_compare_returns_float(math_strategy):
+ result = math_strategy.compare_strings(r"\boxed{5}", r"\boxed{5}")
+ assert isinstance(result, float)
+
+
+# --- MBRDRougeLStrategy.compare_strings ---
+
+
+@pytest.fixture
+def rouge_strategy():
+ return MBRDRougeLStrategy()
+
+
+def test_rougel_compare_identical(rouge_strategy):
+ score = rouge_strategy.compare_strings("hello world", "hello world")
+ assert score == pytest.approx(1.0)
+
+
+def test_rougel_compare_completely_different(rouge_strategy):
+ score = rouge_strategy.compare_strings("hello world", "foo bar baz")
+ assert score < 0.5
+
+
+def test_rougel_compare_partial_overlap(rouge_strategy):
+ score = rouge_strategy.compare_strings("the quick brown fox", "the quick fox")
+ assert 0.0 < score < 1.0
+
+
+def test_rougel_compare_returns_float(rouge_strategy):
+ score = rouge_strategy.compare_strings("abc", "abc")
+ assert isinstance(score, float)
+
+
+def test_rougel_score_in_range(rouge_strategy):
+ score = rouge_strategy.compare_strings("some text here", "some different text")
+ assert 0.0 <= score <= 1.0
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/sampling/test_sampling_base_unit.py b/test/stdlib/sampling/test_sampling_base_unit.py
new file mode 100644
index 000000000..7c4161eec
--- /dev/null
+++ b/test/stdlib/sampling/test_sampling_base_unit.py
@@ -0,0 +1,108 @@
+"""Unit tests for sampling/base.py static repair() logic — no backend required."""
+
+import pytest
+
+from mellea.core import (
+ ComputedModelOutputThunk,
+ ModelOutputThunk,
+ Requirement,
+ ValidationResult,
+)
+from mellea.stdlib.components import Instruction, Message
+from mellea.stdlib.context import ChatContext
+from mellea.stdlib.sampling.base import RepairTemplateStrategy
+
+# --- BaseSamplingStrategy.repair ---
+
+
+def _val(passed: bool, reason: str | None = None) -> ValidationResult:
+ return ValidationResult(result=passed, reason=reason)
+
+
+def test_repair_instruction_builds_repair_string():
+ ins = Instruction(description="Write a poem", requirements=["be concise"])
+ req = Requirement(description="be concise")
+ old_ctx = ChatContext()
+ new_ctx = ChatContext()
+
+ action, ctx = RepairTemplateStrategy.repair(
+ old_ctx=old_ctx,
+ new_ctx=new_ctx,
+ past_actions=[ins],
+ past_results=[
+ ComputedModelOutputThunk(thunk=ModelOutputThunk(value="long text"))
+ ],
+ past_val=[[(req, _val(False, reason="Output was too long"))]],
+ )
+ assert isinstance(action, Instruction)
+ assert action._repair_string is not None
+ assert "Output was too long" in action._repair_string
+ assert ctx is old_ctx
+
+
+def test_repair_uses_req_description_when_no_reason():
+ ins = Instruction(description="task")
+ req = Requirement(description="must be brief")
+ old_ctx = ChatContext()
+
+ action, _ = RepairTemplateStrategy.repair(
+ old_ctx=old_ctx,
+ new_ctx=ChatContext(),
+ past_actions=[ins],
+ past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))],
+ past_val=[[(req, _val(False))]],
+ )
+ assert "must be brief" in action._repair_string
+
+
+def test_repair_non_instruction_returns_same_action():
+ msg = Message("user", "hello")
+ old_ctx = ChatContext()
+
+ action, ctx = RepairTemplateStrategy.repair(
+ old_ctx=old_ctx,
+ new_ctx=ChatContext(),
+ past_actions=[msg],
+ past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))],
+ past_val=[[]],
+ )
+ assert action is msg
+ assert ctx is old_ctx
+
+
+def test_repair_multiple_failures_all_listed():
+ ins = Instruction(description="task")
+ r1 = Requirement(description="be short")
+ r2 = Requirement(description="be polite")
+ old_ctx = ChatContext()
+
+ action, _ = RepairTemplateStrategy.repair(
+ old_ctx=old_ctx,
+ new_ctx=ChatContext(),
+ past_actions=[ins],
+ past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))],
+ past_val=[[(r1, _val(False, "too long")), (r2, _val(False, "rude tone"))]],
+ )
+ assert "too long" in action._repair_string
+ assert "rude tone" in action._repair_string
+
+
+def test_repair_passed_requirements_excluded():
+ ins = Instruction(description="task")
+ r_pass = Requirement(description="format ok")
+ r_fail = Requirement(description="content wrong")
+ old_ctx = ChatContext()
+
+ action, _ = RepairTemplateStrategy.repair(
+ old_ctx=old_ctx,
+ new_ctx=ChatContext(),
+ past_actions=[ins],
+ past_results=[ComputedModelOutputThunk(thunk=ModelOutputThunk(value="x"))],
+ past_val=[[(r_pass, _val(True)), (r_fail, _val(False, "incorrect"))]],
+ )
+ assert "format ok" not in action._repair_string
+ assert "incorrect" in action._repair_string
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/sampling/test_sofai_unit.py b/test/stdlib/sampling/test_sofai_unit.py
new file mode 100644
index 000000000..2b3f6f984
--- /dev/null
+++ b/test/stdlib/sampling/test_sofai_unit.py
@@ -0,0 +1,147 @@
+"""Unit tests for SOFAI sampling strategy pure static helpers — no backend required.
+
+Covers _extract_action_prompt, _parse_judgment, _extract_feedback, _select_best_attempt.
+"""
+
+import pytest
+
+from mellea.core import Requirement, TemplateRepresentation, ValidationResult
+from mellea.stdlib.components import Instruction, Message
+from mellea.stdlib.sampling.sofai import SOFAISamplingStrategy
+
+# --- _parse_judgment ---
+
+
+def test_parse_judgment_yes():
+ assert SOFAISamplingStrategy._parse_judgment("Yes") is True
+
+
+def test_parse_judgment_yes_with_explanation():
+ assert SOFAISamplingStrategy._parse_judgment("Yes, the output is correct.") is True
+
+
+def test_parse_judgment_no():
+ assert SOFAISamplingStrategy._parse_judgment("No") is False
+
+
+def test_parse_judgment_no_with_explanation():
+ assert (
+ SOFAISamplingStrategy._parse_judgment(
+ "No, it needs improvement.\nDetails here."
+ )
+ is False
+ )
+
+
+def test_parse_judgment_yes_in_first_line():
+ assert SOFAISamplingStrategy._parse_judgment("The answer is yes") is True
+
+
+def test_parse_judgment_no_match_defaults_false():
+ assert SOFAISamplingStrategy._parse_judgment("Maybe, hard to tell") is False
+
+
+def test_parse_judgment_whitespace_stripped():
+ assert SOFAISamplingStrategy._parse_judgment(" Yes ") is True
+
+
+def test_parse_judgment_case_insensitive():
+ assert SOFAISamplingStrategy._parse_judgment("YES") is True
+
+
+# --- _extract_feedback ---
+
+
+def test_extract_feedback_with_tags():
+ text = "Some preamble. Fix the grammar. More text."
+ assert SOFAISamplingStrategy._extract_feedback(text) == "Fix the grammar."
+
+
+def test_extract_feedback_no_tags():
+ text = "Just plain feedback text."
+ assert SOFAISamplingStrategy._extract_feedback(text) == "Just plain feedback text."
+
+
+def test_extract_feedback_multiline():
+ text = "\nLine 1\nLine 2\n"
+ result = SOFAISamplingStrategy._extract_feedback(text)
+ assert "Line 1" in result
+ assert "Line 2" in result
+
+
+def test_extract_feedback_case_insensitive_tags():
+ text = "Fix it."
+ assert SOFAISamplingStrategy._extract_feedback(text) == "Fix it."
+
+
+def test_extract_feedback_strips_whitespace():
+ text = " some feedback "
+ assert SOFAISamplingStrategy._extract_feedback(text) == "some feedback"
+
+
+# --- _extract_action_prompt ---
+
+
+def test_extract_action_prompt_message():
+ msg = Message("user", "What is 2+2?")
+ assert SOFAISamplingStrategy._extract_action_prompt(msg) == "What is 2+2?"
+
+
+def test_extract_action_prompt_instruction():
+ ins = Instruction(description="Summarise the text")
+ result = SOFAISamplingStrategy._extract_action_prompt(ins)
+ assert result == "Summarise the text"
+
+
+def test_extract_action_prompt_format_for_llm_str():
+ """Component whose format_for_llm returns a plain string."""
+ from mellea.core import CBlock, Component, ModelOutputThunk
+
+ class _StrComponent(Component[str]):
+ def parts(self):
+ return []
+
+ def format_for_llm(self) -> str:
+ return "plain text repr"
+
+ def _parse(self, computed: ModelOutputThunk) -> str:
+ return ""
+
+ result = SOFAISamplingStrategy._extract_action_prompt(_StrComponent())
+ assert result == "plain text repr"
+
+
+# --- _select_best_attempt ---
+
+
+def _vr(passed: bool) -> ValidationResult:
+ return ValidationResult(result=passed)
+
+
+def test_select_best_attempt_picks_most_passing():
+ r = Requirement(description="r")
+ val = [
+ [(r, _vr(True)), (r, _vr(False))], # 1 pass
+ [(r, _vr(True)), (r, _vr(True))], # 2 pass — best
+ [(r, _vr(False)), (r, _vr(False))], # 0 pass
+ ]
+ assert SOFAISamplingStrategy._select_best_attempt(val) == 1
+
+
+def test_select_best_attempt_tie_prefers_later():
+ r = Requirement(description="r")
+ val = [
+ [(r, _vr(True))], # 1 pass
+ [(r, _vr(True))], # 1 pass — tie, but later → preferred
+ ]
+ assert SOFAISamplingStrategy._select_best_attempt(val) == 1
+
+
+def test_select_best_attempt_single():
+ r = Requirement(description="r")
+ val = [[(r, _vr(False))]]
+ assert SOFAISamplingStrategy._select_best_attempt(val) == 0
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/test_functional_unit.py b/test/stdlib/test_functional_unit.py
new file mode 100644
index 000000000..7d4b215c9
--- /dev/null
+++ b/test/stdlib/test_functional_unit.py
@@ -0,0 +1,66 @@
+"""Unit tests for functional.py pure helpers — no backend, no LLM required.
+
+Covers _parse_and_clean_image_args image preprocessing.
+"""
+
+import base64
+import io
+
+import pytest
+from PIL import Image as PILImage
+
+from mellea.core import ImageBlock
+from mellea.stdlib.functional import _parse_and_clean_image_args
+
+
+def _make_image_block() -> ImageBlock:
+ """Return a valid ImageBlock backed by a 1x1 red PNG."""
+ img = PILImage.new("RGB", (1, 1), color="red")
+ buf = io.BytesIO()
+ img.save(buf, format="PNG")
+ b64 = base64.b64encode(buf.getvalue()).decode()
+ return ImageBlock(value=b64)
+
+
+# --- _parse_and_clean_image_args ---
+
+
+def test_none_returns_none():
+ assert _parse_and_clean_image_args(None) is None
+
+
+def test_empty_list_returns_none():
+ assert _parse_and_clean_image_args([]) is None
+
+
+def test_image_blocks_passed_through():
+ ib = _make_image_block()
+ result = _parse_and_clean_image_args([ib])
+ assert result == [ib]
+
+
+def test_multiple_image_blocks_preserved():
+ ib1 = _make_image_block()
+ ib2 = _make_image_block()
+ result = _parse_and_clean_image_args([ib1, ib2])
+ assert result is not None
+ assert len(result) == 2
+ assert result[0] is ib1
+ assert result[1] is ib2
+
+
+def test_pil_images_converted_to_image_blocks():
+ pil_img = PILImage.new("RGB", (1, 1), color="blue")
+ result = _parse_and_clean_image_args([pil_img])
+ assert result is not None
+ assert len(result) == 1
+ assert isinstance(result[0], ImageBlock)
+
+
+def test_non_list_raises():
+ with pytest.raises(AssertionError, match="Images should be a list"):
+ _parse_and_clean_image_args("not_a_list") # type: ignore
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/stdlib/test_session_unit.py b/test/stdlib/test_session_unit.py
new file mode 100644
index 000000000..466a85e9a
--- /dev/null
+++ b/test/stdlib/test_session_unit.py
@@ -0,0 +1,66 @@
+"""Unit tests for session.py pure-logic — no Ollama server required.
+
+Covers backend_name_to_class factory resolution and get_session error path.
+"""
+
+import pytest
+
+from mellea.backends.ollama import OllamaModelBackend
+from mellea.backends.openai import OpenAIBackend
+from mellea.stdlib.session import backend_name_to_class, get_session
+
+# --- backend_name_to_class ---
+
+
+def test_ollama_resolves_to_ollama_backend():
+ cls = backend_name_to_class("ollama")
+ assert cls is OllamaModelBackend
+
+
+def test_openai_resolves_to_openai_backend():
+ cls = backend_name_to_class("openai")
+ assert cls is OpenAIBackend
+
+
+def test_unknown_name_returns_none():
+ cls = backend_name_to_class("does_not_exist")
+ assert cls is None
+
+
+def test_hf_resolves_or_raises_import_error():
+ # Either resolves (if mellea[hf] is installed) or raises ImportError with helpful message
+ try:
+ cls = backend_name_to_class("hf")
+ assert cls is not None
+ except ImportError as e:
+ assert "mellea[hf]" in str(e)
+
+
+def test_huggingface_alias_same_as_hf():
+ # "hf" and "huggingface" should resolve to the same class
+ try:
+ cls_hf = backend_name_to_class("hf")
+ cls_hf_full = backend_name_to_class("huggingface")
+ assert cls_hf is cls_hf_full
+ except ImportError:
+ pass # OK if mellea[hf] is not installed
+
+
+def test_litellm_resolves_or_raises_import_error():
+ try:
+ cls = backend_name_to_class("litellm")
+ assert cls is not None
+ except ImportError as e:
+ assert "mellea[litellm]" in str(e)
+
+
+# --- get_session ---
+
+
+def test_get_session_raises_when_no_active_session():
+ with pytest.raises(RuntimeError, match="No active session found"):
+ get_session()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/telemetry/test_backend_instrumentation.py b/test/telemetry/test_backend_instrumentation.py
new file mode 100644
index 000000000..4163ccb8d
--- /dev/null
+++ b/test/telemetry/test_backend_instrumentation.py
@@ -0,0 +1,209 @@
+"""Unit tests for backend_instrumentation helpers — model ID extraction, system name mapping,
+context size introspection, and span attribute recording."""
+
+from dataclasses import dataclass
+from unittest.mock import MagicMock
+
+import pytest
+
+from mellea.telemetry.backend_instrumentation import (
+ get_context_size,
+ get_model_id_str,
+ get_system_name,
+ record_response_metadata,
+ record_token_usage,
+)
+
+# --- get_model_id_str ---
+
+
+@dataclass
+class _BackendWithStrModelId:
+ model_id: str
+
+
+@dataclass
+class _HFModelId:
+ hf_model_name: str
+
+
+@dataclass
+class _BackendWithHFModelId:
+ model_id: _HFModelId
+
+
+def test_get_model_id_str_plain_string():
+ backend = _BackendWithStrModelId(model_id="granite-3-8b")
+ assert get_model_id_str(backend) == "granite-3-8b"
+
+
+def test_get_model_id_str_hf_model_name():
+ backend = _BackendWithHFModelId(
+ model_id=_HFModelId(hf_model_name="ibm-granite/granite-4.0-micro")
+ )
+ assert get_model_id_str(backend) == "ibm-granite/granite-4.0-micro"
+
+
+def test_get_model_id_str_no_model_id_returns_class_name():
+ class UnknownBackend:
+ pass
+
+ backend = UnknownBackend()
+ assert get_model_id_str(backend) == "UnknownBackend"
+
+
+# --- get_system_name ---
+
+
+def _fake_backend(class_name: str) -> object:
+ return type(class_name, (), {})()
+
+
+def test_get_system_name_openai():
+ assert get_system_name(_fake_backend("OpenAIBackend")) == "openai"
+
+
+def test_get_system_name_ollama():
+ assert get_system_name(_fake_backend("OllamaModelBackend")) == "ollama"
+
+
+def test_get_system_name_huggingface():
+ assert get_system_name(_fake_backend("LocalHFBackend")) == "huggingface"
+
+
+def test_get_system_name_hf_shortname():
+ assert get_system_name(_fake_backend("HFBackend")) == "huggingface"
+
+
+def test_get_system_name_watsonx():
+ assert get_system_name(_fake_backend("WatsonxBackend")) == "watsonx"
+
+
+def test_get_system_name_litellm():
+ assert get_system_name(_fake_backend("LiteLLMBackend")) == "litellm"
+
+
+def test_get_system_name_unknown_returns_class_name():
+ backend = _fake_backend("SomeCustomBackend")
+ assert get_system_name(backend) == "SomeCustomBackend"
+
+
+# --- get_context_size ---
+
+
+def test_get_context_size_with_len():
+ ctx = [1, 2, 3]
+ assert get_context_size(ctx) == 3
+
+
+def test_get_context_size_empty_list():
+ assert get_context_size([]) == 0
+
+
+def test_get_context_size_with_turns():
+ ctx = type("Ctx", (), {"turns": [1, 2, 3, 4]})()
+ assert get_context_size(ctx) == 4
+
+
+def test_get_context_size_no_len_no_turns():
+ class Opaque:
+ pass
+
+ assert get_context_size(Opaque()) == 0
+
+
+def test_get_context_size_len_raises_returns_zero():
+ class Broken:
+ def __len__(self):
+ raise RuntimeError("broken")
+
+ assert get_context_size(Broken()) == 0
+
+
+# --- record_token_usage ---
+
+
+def _mock_span():
+ return MagicMock()
+
+
+def test_record_token_usage_from_dict():
+ span = _mock_span()
+ usage = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
+ record_token_usage(span, usage)
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.usage.input_tokens") == 10
+ assert calls.get("gen_ai.usage.output_tokens") == 20
+ assert calls.get("gen_ai.usage.total_tokens") == 30
+
+
+def test_record_token_usage_from_object():
+ span = _mock_span()
+ usage = type(
+ "Usage", (), {"prompt_tokens": 5, "completion_tokens": 15, "total_tokens": 20}
+ )()
+ record_token_usage(span, usage)
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.usage.input_tokens") == 5
+
+
+def test_record_token_usage_none_span_no_op():
+ # Should not raise
+ record_token_usage(None, {"prompt_tokens": 1})
+
+
+def test_record_token_usage_none_usage_no_op():
+ span = _mock_span()
+ record_token_usage(span, None)
+ span.set_attribute.assert_not_called()
+
+
+def test_record_token_usage_partial_fields():
+ span = _mock_span()
+ usage = {"prompt_tokens": 7}
+ record_token_usage(span, usage)
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.usage.input_tokens") == 7
+ assert "gen_ai.usage.output_tokens" not in calls
+
+
+# --- record_response_metadata ---
+
+
+def test_record_response_metadata_model_from_dict():
+ span = _mock_span()
+ response = {"model": "granite-3-8b", "choices": [], "id": "resp-123"}
+ record_response_metadata(span, response)
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.response.model") == "granite-3-8b"
+ assert calls.get("gen_ai.response.id") == "resp-123"
+
+
+def test_record_response_metadata_explicit_model_id_overrides():
+ span = _mock_span()
+ response = {"model": "old-model"}
+ record_response_metadata(span, response, model_id="new-model")
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.response.model") == "new-model"
+
+
+def test_record_response_metadata_finish_reason():
+ span = _mock_span()
+ response = {"choices": [{"finish_reason": "stop"}]}
+ record_response_metadata(span, response)
+ calls = {call.args[0]: call.args[1] for call in span.set_attribute.call_args_list}
+ assert calls.get("gen_ai.response.finish_reasons") == ["stop"]
+
+
+def test_record_response_metadata_none_span_no_op():
+ record_response_metadata(None, {"model": "x"})
+
+
+def test_record_response_metadata_none_response_no_op():
+ span = _mock_span()
+ record_response_metadata(span, None)
+ span.set_attribute.assert_not_called()
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/test/telemetry/test_tracing_helpers.py b/test/telemetry/test_tracing_helpers.py
new file mode 100644
index 000000000..a8cb4e06e
--- /dev/null
+++ b/test/telemetry/test_tracing_helpers.py
@@ -0,0 +1,89 @@
+"""Unit tests for tracing helper functions — no OpenTelemetry installation required.
+
+_set_attribute_safe and end_backend_span operate on any object with a
+set_attribute / end method, so these tests use MagicMock spans and run
+unconditionally. test_set_span_error_records_exception calls into the real
+OTel trace API and is skipped when opentelemetry is not installed.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from mellea.telemetry.tracing import (
+ _set_attribute_safe,
+ end_backend_span,
+ set_span_error,
+)
+
+# --- _set_attribute_safe type-conversion ---
+
+
+def test_set_attribute_safe_none_value_no_op():
+ span = MagicMock()
+ _set_attribute_safe(span, "key", None)
+ span.set_attribute.assert_not_called()
+
+
+def test_set_attribute_safe_bool():
+ span = MagicMock()
+ _set_attribute_safe(span, "flag", True)
+ span.set_attribute.assert_called_once_with("flag", True)
+
+
+def test_set_attribute_safe_int():
+ span = MagicMock()
+ _set_attribute_safe(span, "count", 42)
+ span.set_attribute.assert_called_once_with("count", 42)
+
+
+def test_set_attribute_safe_str():
+ span = MagicMock()
+ _set_attribute_safe(span, "name", "hello")
+ span.set_attribute.assert_called_once_with("name", "hello")
+
+
+def test_set_attribute_safe_list_converted_to_string_list():
+ span = MagicMock()
+ _set_attribute_safe(span, "items", [1, 2, 3])
+ span.set_attribute.assert_called_once_with("items", ["1", "2", "3"])
+
+
+def test_set_attribute_safe_unsupported_type_stringified():
+ span = MagicMock()
+ _set_attribute_safe(span, "obj", {"nested": "dict"})
+ span.set_attribute.assert_called_once()
+ call_args = span.set_attribute.call_args
+ assert call_args.args[0] == "obj"
+ assert isinstance(call_args.args[1], str)
+
+
+# --- set_span_error — requires opentelemetry for trace.Status ---
+
+
+def test_set_span_error_records_exception():
+ pytest.importorskip(
+ "opentelemetry",
+ reason="opentelemetry not installed — install mellea[telemetry]",
+ )
+ span = MagicMock()
+ exc = ValueError("something went wrong")
+
+ with patch("mellea.telemetry.tracing._OTEL_AVAILABLE", True):
+ set_span_error(span, exc)
+
+ span.record_exception.assert_called_once_with(exc)
+ span.set_status.assert_called_once()
+
+
+# --- end_backend_span ---
+
+
+def test_end_backend_span_calls_end_on_span():
+ span = MagicMock()
+ end_backend_span(span)
+ span.end.assert_called_once()
+
+
+def test_end_backend_span_none_no_op():
+ end_backend_span(None)