From b2df3956a4e8026db28d15052f4ee19ab906fb4b Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 8 Apr 2026 09:47:12 -0400 Subject: [PATCH 01/11] improve fancylogger implementation --- mellea/core/__init__.py | 5 +- mellea/core/utils.py | 361 ++++++++++++++++++++++---- test/core/test_utils_logging.py | 444 ++++++++++++++++++++++++++++++++ 3 files changed, 760 insertions(+), 50 deletions(-) create mode 100644 test/core/test_utils_logging.py diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index 910fced57..34facdd7e 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -30,7 +30,7 @@ from .formatter import Formatter from .requirement import Requirement, ValidationResult, default_output_to_bool from .sampling import SamplingResult, SamplingStrategy -from .utils import FancyLogger +from .utils import FancyLogger, clear_log_context, log_context, set_log_context __all__ = [ "Backend", @@ -56,6 +56,9 @@ "TemplateRepresentation", "ValidationResult", "blockify", + "clear_log_context", "default_output_to_bool", "generate_walk", + "log_context", + "set_log_context", ] diff --git a/mellea/core/utils.py b/mellea/core/utils.py index 3823a83a6..e517dd632 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -4,15 +4,163 @@ an optional REST handler (``RESTHandler``) that forwards log records to a local ``/api/receive`` endpoint when the ``FLOG`` environment variable is set. All internal mellea modules obtain their logger via ``FancyLogger.get_logger()``. + +Environment variables +--------------------- +``MELLEA_LOG_LEVEL`` + Minimum log level name (e.g. ``DEBUG``, ``INFO``, ``WARNING``). Defaults to + ``INFO``. The legacy ``DEBUG`` variable is still honoured as a fallback. +``MELLEA_LOG_JSON`` + Set to any truthy value (``1``, ``true``, ``yes``) to emit structured JSON on + the console instead of the colour-coded human-readable format. +``FLOG`` + When set, log records are forwarded to a local REST endpoint. +``DEBUG`` + Legacy flag — equivalent to ``MELLEA_LOG_LEVEL=DEBUG``. """ +import contextlib import json import logging import os import sys +import threading +from collections.abc import Generator +from typing import Any import requests +# --------------------------------------------------------------------------- +# Thread-local storage for per-request context fields +# --------------------------------------------------------------------------- +_context_local: threading.local = threading.local() + +# Lock used to make FancyLogger singleton initialisation thread-safe. +_logger_lock: threading.Lock = threading.Lock() + +# Standard LogRecord attribute names that must not be overwritten by callers. +_RESERVED_LOG_RECORD_ATTRS: frozenset[str] = frozenset( + ( + "args", + "created", + "exc_info", + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", + ) +) + + +def set_log_context(**fields: Any) -> None: + """Inject extra fields into every log record emitted from this thread. + + Call this at the start of a request or task to attach identifiers such as + ``trace_id`` or ``request_id`` without modifying individual log calls. + + .. note:: + Prefer :func:`log_context` as the primary API — it guarantees cleanup + (including restoring outer values on same-key nesting) even on + exceptions. + + Args: + **fields: Arbitrary key-value pairs to include in log records. + + Raises: + ValueError: If any key clashes with a standard ``logging.LogRecord`` + attribute (e.g. ``levelname``, ``module``, ``thread``). + """ + invalid = frozenset(fields) & _RESERVED_LOG_RECORD_ATTRS + if invalid: + raise ValueError( + f"Context field names clash with LogRecord reserved attributes: " + f"{sorted(invalid)}. Choose different names." + ) + existing: dict[str, Any] = getattr(_context_local, "fields", {}) + existing.update(fields) + _context_local.fields = existing + + +def clear_log_context() -> None: + """Remove all thread-local log context fields set by :func:`set_log_context`.""" + _context_local.fields = {} + + +@contextlib.contextmanager +def log_context(**fields: Any) -> Generator[None, None, None]: + """Context manager that injects *fields* for the duration of the block. + + On exit — including on exceptions — removes only the keys that *this* + invocation set, leaving any enclosing context intact. This makes nested + and thread-pool usage safe without requiring manual cleanup. + + Example:: + + with log_context(trace_id="abc-123", request_id="req-1"): + logger.info("Handling request") # both IDs appear here + logger.info("After request") # IDs are gone + + Args: + **fields: Key-value pairs to inject. Same restrictions as + :func:`set_log_context` — reserved ``LogRecord`` attribute names + are rejected with ``ValueError``. + + Raises: + ValueError: If any key clashes with a reserved ``LogRecord`` attribute. + """ + existing_before: dict[str, Any] = getattr(_context_local, "fields", {}) + previous: dict[str, Any] = { + k: existing_before[k] for k in fields if k in existing_before + } + set_log_context(**fields) + try: + yield + finally: + existing: dict[str, Any] = getattr(_context_local, "fields", {}) + for key in fields: + if key in previous: + existing[key] = previous[key] + else: + existing.pop(key, None) + _context_local.fields = existing + + +class ContextFilter(logging.Filter): + """Logging filter that injects thread-local context fields into every record. + + Fields registered via :func:`set_log_context` are copied onto the + ``logging.LogRecord`` before formatters see it, enabling trace/request IDs + to appear in structured output without touching call sites. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Attach thread-local context fields to *record* and allow it through. + + Args: + record (logging.LogRecord): The log record being processed. + + Returns: + bool: Always ``True`` — the record is never suppressed. + """ + fields: dict[str, Any] = getattr(_context_local, "fields", {}) + for key, value in fields.items(): + setattr(record, key, value) + return True + class RESTHandler(logging.Handler): """Logging handler that forwards records to a local REST endpoint. @@ -46,14 +194,17 @@ def emit(self, record: logging.LogRecord) -> None: record (logging.LogRecord): The log record to forward. """ if os.environ.get("FLOG"): - log_data = self.format(record) + formatter = self.formatter + if isinstance(formatter, JsonFormatter): + log_dict = formatter._build_log_dict(record) + else: + log_dict = {"message": self.format(record)} try: response = requests.request( self.method, self.api_url, headers=self.headers, - # data=json.dumps([{"log": log_data}]), - data=json.dumps([log_data]), + data=json.dumps([log_dict]), ) response.raise_for_status() except requests.exceptions.RequestException as _: @@ -61,26 +212,73 @@ def emit(self, record: logging.LogRecord) -> None: class JsonFormatter(logging.Formatter): - """Logging formatter that serialises log records as structured JSON dicts. + """Logging formatter that serialises log records as structured JSON strings. + + Produces a consistent JSON schema with a fixed set of core fields. + Additional fields can be injected at construction time (``extra_fields``) or + dynamically per-thread via :func:`set_log_context` / :class:`ContextFilter`. + + Args: + timestamp_format: ``strftime`` format for the ``timestamp`` field. + Defaults to ISO-8601 (``"%Y-%m-%dT%H:%M:%S"``). + include_fields: Whitelist of core field names to keep. When ``None`` + all core fields are included. + exclude_fields: Set of core field names to drop. Applied after + *include_fields*. + extra_fields: Static key-value pairs merged into every log record. - Includes timestamp, level, message, module, function name, line number, - process ID, thread ID, and (if present) exception information. + Attributes: + _DEFAULT_FIELDS (tuple[str, ...]): Canonical ordered list of core field + names produced by this formatter. """ - def format(self, record: logging.LogRecord) -> dict: # type: ignore[override] - """Formats a log record as a JSON-serialisable dictionary. + _DEFAULT_FIELDS: tuple[str, ...] = ( + "timestamp", + "level", + "message", + "module", + "function", + "line_number", + "process_id", + "thread_id", + ) + + def __init__( + self, + timestamp_format: str = "%Y-%m-%dT%H:%M:%S", + include_fields: list[str] | None = None, + exclude_fields: list[str] | None = None, + extra_fields: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialises the formatter; passes remaining kwargs to ``logging.Formatter``.""" + super().__init__(datefmt=timestamp_format, **kwargs) + + if include_fields is not None: + unknown = set(include_fields) - set(self._DEFAULT_FIELDS) + if unknown: + raise ValueError( + f"include_fields contains unknown field names: {sorted(unknown)}. " + f"Valid fields: {list(self._DEFAULT_FIELDS)}" + ) + + self._include: frozenset[str] | None = ( + frozenset(include_fields) if include_fields is not None else None + ) + self._exclude: frozenset[str] = frozenset(exclude_fields or []) + self._extra: dict[str, Any] = dict(extra_fields or {}) - Includes timestamp, level, message, module, function name, line number, - process ID, thread ID, and exception info if present. + def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + """Build a log record dictionary with core, extra, and context fields. Args: - record (logging.LogRecord): The log record to format. + record: The log record to convert. Returns: - dict: A dictionary containing timestamp, level, message, module, function, - line number, process/thread IDs, and optional exception info. + A dictionary ready for JSON serialisation. """ - log_record = { + # Build the full set of core fields first + all_core: dict[str, Any] = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, "message": record.getMessage(), @@ -90,10 +288,46 @@ def format(self, record: logging.LogRecord) -> dict: # type: ignore[override] "process_id": record.process, "thread_id": record.thread, } + + # Apply include/exclude filtering + if self._include is not None: + log_record: dict[str, Any] = { + k: v for k, v in all_core.items() if k in self._include + } + else: + log_record = {k: v for k, v in all_core.items() if k not in self._exclude} + + # Exception info if record.exc_info: log_record["exception"] = self.formatException(record.exc_info) + + # Static extra fields (constructor-level) + log_record.update(self._extra) + + # Dynamic context fields — prefer record attributes (set by + # ContextFilter) but fall back to thread-local storage so the + # formatter works standalone without a filter attached. + context_fields: dict[str, Any] = getattr(_context_local, "fields", {}) + for key, value in context_fields.items(): + log_record[key] = getattr(record, key, value) + return log_record + def format(self, record: logging.LogRecord) -> str: + """Formats a log record as a JSON string. + + Core fields are filtered by *include_fields* / *exclude_fields*. + Static *extra_fields* and any thread-local context fields (set via + :func:`set_log_context`) are merged in after the core fields. + + Args: + record (logging.LogRecord): The log record to format. + + Returns: + str: A JSON-serialised log record. + """ + return json.dumps(self._build_log_dict(record), default=str) + class CustomFormatter(logging.Formatter): """A nice custom formatter copied from [Sergey Pleshakov's post on StackOverflow](https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output). @@ -169,47 +403,76 @@ class FancyLogger: DEBUG = 10 NOTSET = 0 + @staticmethod + def _resolve_log_level() -> int: + """Resolves the effective log level from environment variables. + + Checks ``MELLEA_LOG_LEVEL`` first, then falls back to the legacy + ``DEBUG`` flag, then defaults to ``INFO``. + + Returns: + int: A :mod:`logging` level integer. + """ + level_name = os.environ.get("MELLEA_LOG_LEVEL", "").strip().upper() + if level_name: + numeric = getattr(logging, level_name, None) + if isinstance(numeric, int): + return numeric + if os.environ.get("DEBUG"): + return FancyLogger.DEBUG + return FancyLogger.INFO + @staticmethod def get_logger() -> logging.Logger: """Returns a FancyLogger.logger and sets level based upon env vars. + The logger is created once (singleton). Subsequent calls return the + cached instance. Initialisation is protected by a module-level lock so + concurrent callers at startup cannot create duplicate handlers. + Returns: Configured logger with REST, stream, and optional OTLP handlers. """ if FancyLogger.logger is None: - logger = logging.getLogger("fancy_logger") - # Only set default level if user hasn't already configured it - if logger.level == logging.NOTSET: - if os.environ.get("DEBUG"): - logger.setLevel(FancyLogger.DEBUG) - else: - logger.setLevel(FancyLogger.INFO) - - # Define REST API endpoint - api_url = "http://localhost:8000/api/receive" - - # Create REST handler - rest_handler = RESTHandler(api_url) - - # Create formatter and set it for the handler - # formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - rest_handler.setFormatter(JsonFormatter()) - - # Add handler to the logger - logger.addHandler(rest_handler) - - stream_handler = logging.StreamHandler(stream=sys.stdout) - # stream_handler.setLevel(logging.INFO) - stream_handler.setFormatter(CustomFormatter(datefmt="%H:%M:%S,%03d")) - logger.addHandler(stream_handler) - - # Add OTLP handler if enabled - from ..telemetry import get_otlp_log_handler - - otlp_handler = get_otlp_log_handler() - if otlp_handler: - otlp_handler.setFormatter(JsonFormatter()) - logger.addHandler(otlp_handler) - - FancyLogger.logger = logger + with _logger_lock: + # Second check inside the lock: another thread may have finished + # initialisation while we were waiting. + if FancyLogger.logger is None: + logger = logging.getLogger("fancy_logger") + + # Attach the context filter so ContextFilter fields reach all handlers + logger.addFilter(ContextFilter()) + + # Only set default level if user hasn't already configured it + if logger.level == logging.NOTSET: + logger.setLevel(FancyLogger._resolve_log_level()) + + # --- REST handler --- + api_url = "http://localhost:8000/api/receive" + rest_handler = RESTHandler(api_url) + rest_handler.setFormatter(JsonFormatter()) + logger.addHandler(rest_handler) + + # --- Console / stream handler --- + stream_handler = logging.StreamHandler(stream=sys.stdout) + use_json_console = os.environ.get( + "MELLEA_LOG_JSON", "" + ).strip().lower() in ("1", "true", "yes") + if use_json_console: + stream_handler.setFormatter(JsonFormatter()) + else: + stream_handler.setFormatter( + CustomFormatter(datefmt="%H:%M:%S,%03d") + ) + logger.addHandler(stream_handler) + + # --- Optional OTLP handler --- + from ..telemetry import get_otlp_log_handler + + otlp_handler = get_otlp_log_handler() + if otlp_handler: + otlp_handler.setFormatter(JsonFormatter()) + logger.addHandler(otlp_handler) + + FancyLogger.logger = logger return FancyLogger.logger diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py new file mode 100644 index 000000000..241283c6a --- /dev/null +++ b/test/core/test_utils_logging.py @@ -0,0 +1,444 @@ +"""Unit tests for FancyLogger, JsonFormatter, and ContextFilter enhancements.""" + +import json +import logging +import threading +from typing import Any + +import pytest + +from mellea.core.utils import ( + _RESERVED_LOG_RECORD_ATTRS, + ContextFilter, + FancyLogger, + JsonFormatter, + clear_log_context, + log_context, + set_log_context, +) + + +def _make_record(msg: str = "hello", level: int = logging.INFO) -> logging.LogRecord: + record = logging.LogRecord( + name="test", + level=level, + pathname="test_utils_logging.py", + lineno=1, + msg=msg, + args=(), + exc_info=None, + ) + return record + + +class TestJsonFormatterCoreSchema: + def test_returns_valid_json_string(self) -> None: + fmt = JsonFormatter() + output = fmt.format(_make_record("hi")) + parsed = json.loads(output) + assert isinstance(parsed, dict) + + def test_all_default_fields_present(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record("hi"))) + for field in JsonFormatter._DEFAULT_FIELDS: + assert field in parsed, f"missing field: {field}" + + def test_message_content(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record("test message"))) + assert parsed["message"] == "test message" + + def test_level_name(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record(level=logging.WARNING))) + assert parsed["level"] == "WARNING" + + def test_exception_field_added_when_exc_info(self) -> None: + fmt = JsonFormatter() + try: + raise ValueError("boom") + except ValueError: + import sys + + record = _make_record("oops") + record.exc_info = sys.exc_info() + parsed = json.loads(fmt.format(record)) + assert "exception" in parsed + assert "ValueError" in parsed["exception"] + + +class TestJsonFormatterFieldConfig: + def test_include_fields_limits_output(self) -> None: + fmt = JsonFormatter(include_fields=["timestamp", "level", "message"]) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == {"timestamp", "level", "message"} + + def test_exclude_fields_removes_keys(self) -> None: + fmt = JsonFormatter(exclude_fields=["process_id", "thread_id"]) + parsed = json.loads(fmt.format(_make_record())) + assert "process_id" not in parsed + assert "thread_id" not in parsed + assert "level" in parsed # other fields still present + + def test_extra_fields_merged(self) -> None: + fmt = JsonFormatter(extra_fields={"service": "mellea", "env": "test"}) + parsed = json.loads(fmt.format(_make_record())) + assert parsed["service"] == "mellea" + assert parsed["env"] == "test" + + def test_extra_fields_override_core(self) -> None: + # static extras come *after* core fields — they win on collision + fmt = JsonFormatter(extra_fields={"level": "OVERRIDDEN"}) + parsed = json.loads(fmt.format(_make_record(level=logging.DEBUG))) + assert parsed["level"] == "OVERRIDDEN" + + def test_timestamp_format_respected(self) -> None: + fmt = JsonFormatter(timestamp_format="%Y") + parsed = json.loads(fmt.format(_make_record())) + # Should be just a 4-digit year + assert len(parsed["timestamp"]) == 4 + assert parsed["timestamp"].isdigit() + + +class TestJsonFormatterContextInjection: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_context_fields_appear_in_output(self) -> None: + set_log_context(trace_id="abc-123") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed.get("trace_id") == "abc-123" + + def test_multiple_context_fields(self) -> None: + set_log_context(trace_id="t1", request_id="r1", user="alice") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "t1" + assert parsed["request_id"] == "r1" + assert parsed["user"] == "alice" + + def test_clear_context_removes_fields(self) -> None: + set_log_context(trace_id="gone") + clear_log_context() + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_context_is_thread_local(self) -> None: + """Fields set in one thread must not bleed into another.""" + results: dict[str, Any] = {} + + def worker_a() -> None: + set_log_context(trace_id="thread-a") + import time + + time.sleep(0.05) + fmt = JsonFormatter() + results["a"] = json.loads(fmt.format(_make_record())) + clear_log_context() + + def worker_b() -> None: + # does NOT call set_log_context + import time + + time.sleep(0.02) + fmt = JsonFormatter() + results["b"] = json.loads(fmt.format(_make_record())) + + ta = threading.Thread(target=worker_a) + tb = threading.Thread(target=worker_b) + ta.start() + tb.start() + ta.join() + tb.join() + + assert results["a"].get("trace_id") == "thread-a" + assert "trace_id" not in results["b"] + + +class TestContextFilter: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_filter_always_returns_true(self) -> None: + f = ContextFilter() + assert f.filter(_make_record()) is True + + def test_filter_attaches_context_fields_to_record(self) -> None: + set_log_context(span_id="span-999") + f = ContextFilter() + record = _make_record() + f.filter(record) + assert getattr(record, "span_id", None) == "span-999" + + def test_filter_noop_when_no_context(self) -> None: + f = ContextFilter() + record = _make_record() + f.filter(record) + assert not hasattr(record, "trace_id") + + +class TestFancyLoggerLogLevel: + def _reset(self) -> None: + FancyLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def teardown_method(self) -> None: + self._reset() + + def test_default_level_is_info(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MELLEA_LOG_LEVEL", raising=False) + monkeypatch.delenv("DEBUG", raising=False) + assert FancyLogger._resolve_log_level() == logging.INFO + + def test_mellea_log_level_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "DEBUG") + assert FancyLogger._resolve_log_level() == logging.DEBUG + + def test_mellea_log_level_warning(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "WARNING") + assert FancyLogger._resolve_log_level() == logging.WARNING + + def test_legacy_debug_flag(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MELLEA_LOG_LEVEL", raising=False) + monkeypatch.setenv("DEBUG", "1") + assert FancyLogger._resolve_log_level() == logging.DEBUG + + def test_mellea_log_level_takes_precedence_over_debug( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "WARNING") + monkeypatch.setenv("DEBUG", "1") + assert FancyLogger._resolve_log_level() == logging.WARNING + + def test_invalid_level_falls_back_to_info( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "BOGUS") + monkeypatch.delenv("DEBUG", raising=False) + assert FancyLogger._resolve_log_level() == logging.INFO + + +class TestFancyLoggerJsonConsole: + def _reset(self) -> None: + FancyLogger.logger = None + logger = logging.getLogger("fancy_logger") + logger.handlers.clear() + logger.setLevel(logging.NOTSET) + + def setup_method(self) -> None: + self._reset() + + def teardown_method(self) -> None: + self._reset() + + def _get_stream_handler(self) -> logging.StreamHandler: # type: ignore[type-arg] + logger = FancyLogger.get_logger() + handlers = [h for h in logger.handlers if isinstance(h, logging.StreamHandler)] + # RESTHandler is a subclass of Handler but not StreamHandler, so this + # correctly picks the console handler. + return handlers[0] + + def test_default_uses_custom_formatter( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from mellea.core.utils import CustomFormatter + + monkeypatch.delenv("MELLEA_LOG_JSON", raising=False) + handler = self._get_stream_handler() + assert isinstance(handler.formatter, CustomFormatter) + + def test_json_console_enabled_with_true( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "true") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_enabled_with_1(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "1") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_enabled_with_yes( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "yes") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_disabled_with_false( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from mellea.core.utils import CustomFormatter + + monkeypatch.setenv("MELLEA_LOG_JSON", "false") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, CustomFormatter) + + +class TestFancyLoggerContextFilterWired: + def setup_method(self) -> None: + FancyLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + clear_log_context() + + def teardown_method(self) -> None: + FancyLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + clear_log_context() + + def test_context_filter_present(self) -> None: + logger = FancyLogger.get_logger() + assert any(isinstance(f, ContextFilter) for f in logger.filters) + + +class TestLogContext: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_fields_present_inside_block(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="ctx-1"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "ctx-1" + + def test_fields_removed_after_block(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="ctx-2"): + pass + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_cleanup_on_exception(self) -> None: + fmt = JsonFormatter() + with pytest.raises(RuntimeError): + with log_context(trace_id="ctx-err"): + raise RuntimeError("boom") + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_nested_contexts_preserve_outer(self) -> None: + fmt = JsonFormatter() + with log_context(outer="yes"): + with log_context(inner="yes"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["outer"] == "yes" + assert parsed["inner"] == "yes" + # inner should be gone, outer still present + parsed = json.loads(fmt.format(_make_record())) + assert parsed["outer"] == "yes" + assert "inner" not in parsed + # both gone + parsed = json.loads(fmt.format(_make_record())) + assert "outer" not in parsed + + def test_nested_same_key_restores_outer(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="outer"): + with log_context(trace_id="inner"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "inner" + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "outer" + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_rejects_reserved_attribute(self) -> None: + with pytest.raises(ValueError, match="reserved"): + with log_context(levelname="BAD"): + pass + + +class TestReservedAttributeValidation: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_set_log_context_rejects_reserved_key(self) -> None: + with pytest.raises(ValueError, match="reserved"): + set_log_context(module="bad") + + def test_set_log_context_rejects_multiple_reserved(self) -> None: + with pytest.raises(ValueError, match="reserved"): + set_log_context(thread="x", process="y") + + def test_set_log_context_accepts_non_reserved(self) -> None: + set_log_context(custom_field="fine") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed["custom_field"] == "fine" + + def test_reserved_set_is_non_empty(self) -> None: + assert len(_RESERVED_LOG_RECORD_ATTRS) > 10 + + +class TestGetLoggerThreadSafety: + def setup_method(self) -> None: + FancyLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def teardown_method(self) -> None: + FancyLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def test_concurrent_get_logger_returns_same_instance(self) -> None: + """Multiple threads calling get_logger() must all get the same object.""" + results: list[logging.Logger] = [] + barrier = threading.Barrier(4) + + def worker() -> None: + barrier.wait() + results.append(FancyLogger.get_logger()) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results) == 4 + assert all(r is results[0] for r in results) + + +class TestJsonFormatterFormatSignature: + def test_format_returns_str(self) -> None: + """JsonFormatter.format returns str — no type: ignore needed.""" + fmt = JsonFormatter() + result = fmt.format(_make_record("check")) + assert isinstance(result, str) + json.loads(result) # also valid JSON + + +class TestIncludeFieldsValidation: + def test_valid_include_fields_accepted(self) -> None: + fmt = JsonFormatter(include_fields=["timestamp", "level", "message"]) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == {"timestamp", "level", "message"} + + def test_unknown_include_field_raises(self) -> None: + with pytest.raises(ValueError, match="unknown field names"): + JsonFormatter(include_fields=["timestamp", "bogus_field"]) + + def test_all_default_fields_accepted(self) -> None: + fmt = JsonFormatter(include_fields=list(JsonFormatter._DEFAULT_FIELDS)) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == set(JsonFormatter._DEFAULT_FIELDS) From 4ffaf91100140012c307d128cb383667464f5b84 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 8 Apr 2026 09:52:38 -0400 Subject: [PATCH 02/11] update docstring --- mellea/core/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mellea/core/utils.py b/mellea/core/utils.py index e517dd632..d652aa40a 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -119,6 +119,9 @@ def log_context(**fields: Any) -> Generator[None, None, None]: :func:`set_log_context` — reserved ``LogRecord`` attribute names are rejected with ``ValueError``. + Yields: + None. The manager is used only for its enter/exit side effects. + Raises: ValueError: If any key clashes with a reserved ``LogRecord`` attribute. """ From e54b637ed266a7e52af820ddb265c071c4b75a72 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 01:44:26 -0400 Subject: [PATCH 03/11] address feedback from jake --- CONTRIBUTING.md | 4 +- docs/AGENTS_TEMPLATE.md | 4 +- docs/docs/community/contributing-guide.md | 4 +- .../evaluation-and-observability/logging.md | 19 +++--- .../evaluation-and-observability/telemetry.md | 2 +- .../agents/react/react_from_scratch/react.py | 4 +- docs/examples/mcp/mcp_example.py | 2 +- .../examples/mify/rich_table_execute_basic.py | 4 +- docs/examples/sofai/sofai_graph_coloring.py | 4 +- docs/metrics/coverage-baseline.json | 2 +- docs/metrics/coverage-current.json | 64 +++++++++---------- mellea/backends/huggingface.py | 36 +++++------ mellea/backends/litellm.py | 22 +++---- mellea/backends/model_options.py | 4 +- mellea/backends/ollama.py | 20 +++--- mellea/backends/openai.py | 16 ++--- mellea/backends/tools.py | 18 +++--- mellea/backends/utils.py | 6 +- mellea/backends/watsonx.py | 10 +-- mellea/core/__init__.py | 4 +- mellea/core/backend.py | 6 +- mellea/core/utils.py | 49 +++++++------- mellea/formatters/template_formatter.py | 8 +-- mellea/helpers/openai_compatible_helpers.py | 4 +- mellea/stdlib/components/react.py | 4 +- mellea/stdlib/frameworks/react.py | 4 +- mellea/stdlib/functional.py | 18 +++--- mellea/stdlib/requirements/python_reqs.py | 4 +- mellea/stdlib/requirements/requirement.py | 6 +- mellea/stdlib/requirements/safety/guardian.py | 4 +- mellea/stdlib/sampling/base.py | 6 +- mellea/stdlib/sampling/budget_forcing.py | 6 +- mellea/stdlib/sampling/sofai.py | 8 +-- mellea/stdlib/session.py | 6 +- mellea/stdlib/tools/interpreter.py | 4 +- test/conftest.py | 14 ++-- test/core/test_utils_logging.py | 50 ++++++--------- test/telemetry/test_logging.py | 24 +++---- 38 files changed, 229 insertions(+), 245 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7deee4c6c..c03d30836 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -439,8 +439,8 @@ CICD=1 uv run pytest ```python # Enable debug logging -from mellea.core import FancyLogger -FancyLogger.get_logger().setLevel("DEBUG") +from mellea.core import MelleaLogger +MelleaLogger.get_logger().setLevel("DEBUG") # See exact prompt sent to LLM print(m.last_prompt()) diff --git a/docs/AGENTS_TEMPLATE.md b/docs/AGENTS_TEMPLATE.md index 1233bbfaf..116315526 100644 --- a/docs/AGENTS_TEMPLATE.md +++ b/docs/AGENTS_TEMPLATE.md @@ -160,8 +160,8 @@ Session methods: `ainstruct`, `achat`, `aact`, `avalidate`, `aquery`, `atransfor #### 12. Debugging ```python -from mellea.core import FancyLogger -FancyLogger.get_logger().setLevel("DEBUG") +from mellea.core import MelleaLogger +MelleaLogger.get_logger().setLevel("DEBUG") ``` - `m.last_prompt()` — see exact prompt sent diff --git a/docs/docs/community/contributing-guide.md b/docs/docs/community/contributing-guide.md index bb583292c..3ef7fcb84 100644 --- a/docs/docs/community/contributing-guide.md +++ b/docs/docs/community/contributing-guide.md @@ -315,10 +315,10 @@ CICD=1 uv run pytest ### Debugging tips ```python -from mellea.core import FancyLogger +from mellea.core import MelleaLogger # Enable debug logging -FancyLogger.get_logger().setLevel("DEBUG") +MelleaLogger.get_logger().setLevel("DEBUG") # Inspect the exact prompt sent to the LLM print(m.last_prompt()) diff --git a/docs/docs/evaluation-and-observability/logging.md b/docs/docs/evaluation-and-observability/logging.md index 908c43928..854aec956 100644 --- a/docs/docs/evaluation-and-observability/logging.md +++ b/docs/docs/evaluation-and-observability/logging.md @@ -14,22 +14,23 @@ Both work simultaneously when enabled. ## Console logging -Mellea uses `FancyLogger`, a color-coded singleton logger built on Python's +Mellea uses `MelleaLogger`, a color-coded singleton logger built on Python's `logging` module. All internal Mellea modules obtain their logger via -`FancyLogger.get_logger()`. +`MelleaLogger.get_logger()`. ### Configuration | Variable | Description | Default | | -------- | ----------- | ------- | -| `DEBUG` | Set to any value to enable `DEBUG`-level output | unset (`INFO` level) | -| `FLOG` | Set to any value to forward logs to a local REST endpoint at `http://localhost:8000/api/receive` | unset | +| `MELLEA_LOG_LEVEL` | Log level name (e.g. `DEBUG`, `INFO`, `WARNING`) | `INFO` | +| `MELLEA_LOG_JSON` | Set to any truthy value (`1`, `true`, `yes`) to emit structured JSON instead of colour-coded output | unset | +| `MELLEA_FLOG` | Set to any value to forward logs to a local REST endpoint at `http://localhost:8000/api/receive` | unset | -By default, `FancyLogger` logs at `INFO` level with color-coded output to -stdout. Set the `DEBUG` environment variable to lower the level to `DEBUG`: +By default, `MelleaLogger` logs at `INFO` level with color-coded output to +stdout. Set `MELLEA_LOG_LEVEL` to change the level: ```bash -export DEBUG=1 +export MELLEA_LOG_LEVEL=DEBUG python your_script.py ``` @@ -74,11 +75,11 @@ export OTEL_SERVICE_NAME=my-mellea-app ### How it works -When `MELLEA_LOGS_OTLP=true`, `FancyLogger` adds an OpenTelemetry +When `MELLEA_LOGS_OTLP=true`, `MelleaLogger` adds an OpenTelemetry `LoggingHandler` alongside its existing handlers: - **Console handler** — continues to work normally (color-coded output) -- **REST handler** — continues to work normally (when `FLOG` is set) +- **REST handler** — continues to work normally (when `MELLEA_FLOG` is set) - **OTLP handler** — exports logs to the configured OTLP collector Logs are exported using OpenTelemetry's Logs API with batched processing diff --git a/docs/docs/evaluation-and-observability/telemetry.md b/docs/docs/evaluation-and-observability/telemetry.md index 12f669461..dda40dbc4 100644 --- a/docs/docs/evaluation-and-observability/telemetry.md +++ b/docs/docs/evaluation-and-observability/telemetry.md @@ -140,7 +140,7 @@ and troubleshooting. ## Logging -Mellea uses a color-coded console logger (`FancyLogger`) by default. When the +Mellea uses a color-coded console logger (`MelleaLogger`) by default. When the `[telemetry]` extra is installed and `MELLEA_LOGS_OTLP=true` is set, Mellea also exports logs to an OTLP collector alongside existing console output. diff --git a/docs/examples/agents/react/react_from_scratch/react.py b/docs/examples/agents/react/react_from_scratch/react.py index 351bce08b..b74d6a43b 100644 --- a/docs/examples/agents/react/react_from_scratch/react.py +++ b/docs/examples/agents/react/react_from_scratch/react.py @@ -12,10 +12,10 @@ import mellea import mellea.stdlib.components.chat -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.context import ChatContext -FancyLogger.get_logger().setLevel("ERROR") +MelleaLogger.get_logger().setLevel("ERROR") react_system_template: Template = Template( """Answer the user's question as best you can. diff --git a/docs/examples/mcp/mcp_example.py b/docs/examples/mcp/mcp_example.py index 3cd4d344f..49d747b92 100644 --- a/docs/examples/mcp/mcp_example.py +++ b/docs/examples/mcp/mcp_example.py @@ -16,7 +16,7 @@ from mellea import MelleaSession from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.core import FancyLogger, ModelOutputThunk, Requirement +from mellea.core import MelleaLogger, ModelOutputThunk, Requirement from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy diff --git a/docs/examples/mify/rich_table_execute_basic.py b/docs/examples/mify/rich_table_execute_basic.py index 3c4b9e665..04d2350ca 100644 --- a/docs/examples/mify/rich_table_execute_basic.py +++ b/docs/examples/mify/rich_table_execute_basic.py @@ -5,10 +5,10 @@ from mellea import start_session from mellea.backends import ModelOption, model_ids -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components.docs.richdocument import RichDocument, Table -FancyLogger.get_logger().setLevel("ERROR") +MelleaLogger.get_logger().setLevel("ERROR") """ Here we demonstrate the use of the (internally m-ified) class diff --git a/docs/examples/sofai/sofai_graph_coloring.py b/docs/examples/sofai/sofai_graph_coloring.py index 67ed7acdc..06a46189b 100644 --- a/docs/examples/sofai/sofai_graph_coloring.py +++ b/docs/examples/sofai/sofai_graph_coloring.py @@ -23,7 +23,7 @@ import mellea from mellea.backends.ollama import OllamaModelBackend -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components import Message from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import ValidationResult, req @@ -230,5 +230,5 @@ def main(): if __name__ == "__main__": # Set logging level - FancyLogger.get_logger().setLevel(logging.INFO) + MelleaLogger.get_logger().setLevel(logging.INFO) main() diff --git a/docs/metrics/coverage-baseline.json b/docs/metrics/coverage-baseline.json index 6a983a331..823445d3d 100644 --- a/docs/metrics/coverage-baseline.json +++ b/docs/metrics/coverage-baseline.json @@ -49,7 +49,7 @@ "ComponentParseError", "Context", "ContextTurn", - "FancyLogger", + "MelleaLogger", "Formatter", "GenerateLog", "GenerateType", diff --git a/docs/metrics/coverage-current.json b/docs/metrics/coverage-current.json index 62b134beb..80d7493e9 100644 --- a/docs/metrics/coverage-current.json +++ b/docs/metrics/coverage-current.json @@ -64,7 +64,7 @@ "default_output_to_bool", "SamplingResult", "SamplingStrategy", - "FancyLogger" + "MelleaLogger" ], "mellea.core.formatter": [ "CBlock", @@ -85,7 +85,7 @@ "Component", "Context", "ModelOutputThunk", - "FancyLogger", + "MelleaLogger", "BaseModelSubclass" ], "mellea.core.sampling": [ @@ -126,7 +126,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -145,7 +145,7 @@ "format" ], "mellea.backends.tools": [ - "FancyLogger", + "MelleaLogger", "CBlock", "Component", "TemplateRepresentation", @@ -158,7 +158,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -221,7 +221,7 @@ "HF_SMOLLM3_3B_no_ollama" ], "mellea.backends.model_options": [ - "FancyLogger" + "MelleaLogger" ], "mellea.backends.openai": [ "ALoraRequirement", @@ -231,7 +231,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -301,7 +301,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelToolCall", "AbstractMelleaTool", "ChatFormatter", @@ -316,7 +316,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -735,7 +735,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -849,7 +849,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -1051,7 +1051,7 @@ "ModelIdentifier", "CBlock", "Component", - "FancyLogger", + "MelleaLogger", "TemplateRepresentation", "ChatFormatter" ], @@ -1348,7 +1348,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1370,7 +1370,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1421,7 +1421,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1464,7 +1464,7 @@ "local_code_interpreter" ], "mellea.stdlib.tools.interpreter": [ - "FancyLogger", + "MelleaLogger", "logger" ], "mellea.stdlib.sampling": [ @@ -1481,7 +1481,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1499,7 +1499,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1559,7 +1559,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1575,7 +1575,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1624,7 +1624,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1642,7 +1642,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1726,7 +1726,7 @@ "mellea.stdlib.requirements.requirement": [ "CBlock", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "Intrinsic" @@ -1737,7 +1737,7 @@ "StaticAnalysisEnvironment", "UnsafeEnvironment", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "logger" @@ -1752,7 +1752,7 @@ "BaseModelSubclass", "CBlock", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "Message", @@ -1799,7 +1799,7 @@ "Component", "ModelOutputThunk", "TemplateRepresentation", - "FancyLogger", + "MelleaLogger", "MELLEA_FINALIZER_TOOL" ], "mellea.stdlib.components.chat": [ @@ -1824,7 +1824,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "SamplingStrategy", @@ -1842,7 +1842,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1953,7 +1953,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -2003,7 +2003,7 @@ "BaseModelSubclass", "AbstractMelleaTool", "ModelOutputThunk", - "FancyLogger", + "MelleaLogger", "ToolMessage", "MELLEA_FINALIZER_TOOL", "ReactInitiator", @@ -2017,7 +2017,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -2102,7 +2102,7 @@ ], "mellea.helpers.openai_compatible_helpers": [ "validate_tool_arguments", - "FancyLogger", + "MelleaLogger", "ModelToolCall", "AbstractMelleaTool", "Document", diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 55f85be7b..af0ba88be 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -34,7 +34,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, GenerateLog, GenerateType, ModelOutputThunk, @@ -199,7 +199,7 @@ def __call__( ) err = ll_matcher.get_error() # type: ignore[attr-defined] if err: - FancyLogger.get_logger().warning("Error in LLMatcher: %s", err) + MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err) llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0) llguidance.torch.apply_token_bitmask_inplace( @@ -402,7 +402,7 @@ async def _generate_from_context( if alora_req_adapter is None: # Log a warning if using an AloraRequirement but no adapter fit. if reroute_to_alora and isinstance(action, ALoraRequirement): - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation" ) reroute_to_alora = False @@ -472,7 +472,7 @@ async def _generate_from_intrinsic( raise Exception("Does not yet support non-chat contexts.") if len(model_options.items()) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "passing in model options when generating with an adapter; some model options may be overwritten / ignored" ) @@ -648,11 +648,11 @@ def _make_merged_kv_cache( case CBlock() if c.cache: assert c.value is not None if c.value in self._cached_blocks: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"KV CACHE HIT for: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" # type: ignore ) else: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"HF backend is caching a CBlock with hashed contents: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" ) tokens = self._tokenizer(c.value, return_tensors="pt") @@ -690,14 +690,14 @@ def _make_merged_kv_cache( prefix, suffix = parts # Add the prefix, if any, to str+tok+dc parts. if prefix != "": - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Doing a forward pass on uncached block which is prefix to a cached CBlock: {prefix[:3]}.{len(prefix)}.{prefix[-3:]}" ) str_parts.append(prefix) tok_parts.append(self._tokenizer(prefix, return_tensors="pt")) dc_parts.append(self._make_dc_cache(tok_parts[-1])) # Add the cached CBlock to str+tok+dc parts. - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Replacing a substring with previously computed/retrieved cache with hahs value {hash(key)} ({key[:3]}..{key[-3:]})" ) # str_parts.append(key) @@ -710,7 +710,7 @@ def _make_merged_kv_cache( current_suffix = suffix # "base" case: the final suffix. if current_suffix != "": - FancyLogger.get_logger().debug( # type: ignore + MelleaLogger.get_logger().debug( # type: ignore f"Doing a forward pass on final suffix, an uncached block: {current_suffix[:3]}.{len(current_suffix)}.{current_suffix[-3:]}" # type: ignore ) # type: ignore str_parts.append(current_suffix) @@ -753,7 +753,7 @@ async def _generate_from_context_with_kv_cache( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -765,7 +765,7 @@ async def _generate_from_context_with_kv_cache( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") seed = model_options.get(ModelOption.SEED, None) if seed is not None: @@ -904,7 +904,7 @@ async def _generate_from_context_standard( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -916,7 +916,7 @@ async def _generate_from_context_standard( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") seed = model_options.get(ModelOption.SEED, None) if seed is not None: @@ -1266,7 +1266,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The raw endpoint does not support tool calling at the moment." ) @@ -1274,7 +1274,7 @@ async def generate_from_raw( # TODO: Remove this when we are able to update the torch package. # Test this by ensuring all outputs from this call are populated when running on mps. # https://github.com/pytorch/pytorch/pull/157727 - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "utilizing device mps with a `generate_from_raw` request; you may see issues when submitting batches of prompts to a huggingface backend; ensure all ModelOutputThunks have non-empty values." ) @@ -1475,7 +1475,7 @@ def add_adapter(self, adapter: LocalHFAdapter): """ if adapter.backend is not None: if adapter.backend is self: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"attempted to add adapter {adapter.name} with type {adapter.adapter_type} to the same backend {adapter.backend}" ) return @@ -1485,7 +1485,7 @@ def add_adapter(self, adapter: LocalHFAdapter): ) if self._added_adapters.get(adapter.qualified_name) is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Client code attempted to add {adapter.name} with type {adapter.adapter_type} but {adapter.name} was already added to {self.__class__}. The backend is refusing to do this, because adapter loading is not idempotent." ) return None @@ -1551,7 +1551,7 @@ def unload_adapter(self, adapter_qualified_name: str): # Check if the backend knows about this adapter. adapter = self._loaded_adapters.get(adapter_qualified_name, None) if adapter is None: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"could not unload adapter {adapter_qualified_name} for backend {self}: adapter is not loaded" ) return diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index bf3224ad0..79610aa5e 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -18,7 +18,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, GenerateLog, GenerateType, ModelOutputThunk, @@ -261,12 +261,12 @@ def _make_backend_specific_and_remove( unknown_keys.append(key) if len(unknown_keys) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}" ) if len(unsupported_openai_params) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"litellm may drop the following openai keys that it doesn't seem to recognize as being supported by the current model/provider: {', '.join(unsupported_openai_params)}" "\nThere are sometimes false positives here." ) @@ -339,7 +339,7 @@ async def _generate_from_chat_context_standard( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - FancyLogger().get_logger().warning( + MelleaLogger().get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -570,7 +570,7 @@ def _extract_tools( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -580,7 +580,7 @@ def _extract_tools( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") return tools @overload @@ -636,7 +636,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) extra_body = {} if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The official OpenAI completion api does not accept response format / structured decoding; " "it will be passed as an extra arg." ) @@ -644,7 +644,7 @@ async def generate_from_raw( # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling." ) @@ -653,7 +653,7 @@ async def generate_from_raw( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - FancyLogger().get_logger().warning( + MelleaLogger().get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -670,7 +670,7 @@ async def generate_from_raw( date = datetime.datetime.now() responses = completion_response.choices if len(responses) != len(prompts): - FancyLogger().get_logger().error( + MelleaLogger().get_logger().error( "litellm appears to have sent your batch request as a single message; this typically happens with providers like ollama that don't support batching" ) @@ -722,7 +722,7 @@ def _extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index 2350797ff..f71ddfb53 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -2,7 +2,7 @@ from typing import Any -from ..core import FancyLogger +from ..core import MelleaLogger class ModelOption: @@ -111,7 +111,7 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: "Encountered conflict(s) when replacing keys. Could not replace keys for:\n" + "\n".join(conflict_log) ) - FancyLogger.get_logger().warning(f"{text_line}") + MelleaLogger.get_logger().warning(f"{text_line}") return new_options @staticmethod diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index f682a49d2..621bc9e21 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -16,7 +16,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, GenerateLog, GenerateType, ModelOutputThunk, @@ -89,11 +89,11 @@ def __init__( if not self._check_ollama_server(): err = f"could not create OllamaModelBackend: ollama server not running at {base_url}" - FancyLogger.get_logger().error(err) + MelleaLogger.get_logger().error(err) raise Exception(err) if not self._pull_ollama_model(): err = f"could not create OllamaModelBackend: {self._get_ollama_model_id()} could not be pulled from ollama library" - FancyLogger.get_logger().error(err) + MelleaLogger.get_logger().error(err) raise Exception(err) # A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent. @@ -168,7 +168,7 @@ def _pull_ollama_model(self) -> bool: return True try: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Loading/Pulling model from Ollama: {self._get_ollama_model_id()}" ) stream = self._client.pull(self._get_ollama_model_id(), stream=True) @@ -376,7 +376,7 @@ async def generate_from_chat_context( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -386,7 +386,7 @@ async def generate_from_chat_context( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") # Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option. chat_response: Coroutine[ @@ -483,11 +483,11 @@ async def generate_from_raw( list[ModelOutputThunk]: A list of model output thunks, one per action. """ if len(actions) > 1: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "Ollama doesn't support batching; will attempt to process concurrently." ) if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." ) @@ -523,7 +523,7 @@ async def generate_from_raw( result = None error = None if isinstance(response, BaseException): - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"generate_from_raw: request {i} failed with " f"{type(response).__name__}: {response}" ) @@ -583,7 +583,7 @@ def _extract_model_tool_requests( for tool in chat_response.message.tool_calls: func = tools.get(tool.function.name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool.function.name}" ) continue # skip this function if we can't find it. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 90d6c3646..7b9af23a3 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -22,7 +22,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, GenerateLog, GenerateType, ModelOutputThunk, @@ -175,7 +175,7 @@ def __init__( ) if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "OPENAI_BASE_URL or base_url is not set.\n" "The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`" ) @@ -481,7 +481,7 @@ async def _generate_from_chat_context_standard( }, } else: - FancyLogger().get_logger().warning( + MelleaLogger().get_logger().warning( "Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider." ) extra_params["response_format"] = { @@ -497,7 +497,7 @@ async def _generate_from_chat_context_standard( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -507,7 +507,7 @@ async def _generate_from_chat_context_standard( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") thinking = model_opts.get(ModelOption.THINKING, None) if type(thinking) is bool and thinking: @@ -796,7 +796,7 @@ async def generate_from_raw( extra_body = {} if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The official OpenAI completion api does not accept response format / structured decoding; " "it will be passed as an extra arg." ) @@ -808,7 +808,7 @@ async def generate_from_raw( else: extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." ) @@ -832,7 +832,7 @@ async def generate_from_raw( ) # type: ignore except openai.BadRequestError as e: if openai_ollama_batching_error in e.message: - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, " "your requests will fail since ollama doesn't support batching requests." ) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index cd5c541c6..3dd056512 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, ConfigDict, Field -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger from ..core import CBlock, Component, TemplateRepresentation from ..core.base import AbstractMelleaTool @@ -97,7 +97,7 @@ def parameter_remapper(*args, **kwargs): """Langchain tools expect their first argument to be 'tool_input'.""" if args: # This shouldn't happen. Our ModelToolCall.call_func actually passes in everything as kwargs. - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"ignoring unexpected args while calling langchain tool ({tool_name}): ({args})" ) @@ -158,7 +158,7 @@ def tool_call(*args, **kwargs): """Wrapper for smolagents tool forward method.""" if args: # This shouldn't happen. Our ModelToolCall.call_func passes everything as kwargs. - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"ignoring unexpected args while calling smolagents tool ({tool_name}): ({args})" ) return tool.forward(**kwargs) @@ -480,7 +480,7 @@ def validate_tool_arguments( """ from pydantic import ValidationError, create_model - from ..core import FancyLogger + from ..core import MelleaLogger # Extract JSON schema from tool tool_schema = tool.as_json_tool.get("function", {}) @@ -581,7 +581,7 @@ def validate_tool_arguments( ) if coerced_fields and coerce_types: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}" ) @@ -601,11 +601,11 @@ def validate_tool_arguments( if strict: # Re-raise with enhanced message - FancyLogger.get_logger().error(error_msg) + MelleaLogger.get_logger().error(error_msg) raise else: # Log warning and return original args - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( error_msg + "\nReturning original arguments without validation." ) return dict(args) @@ -615,10 +615,10 @@ def validate_tool_arguments( error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}" if strict: - FancyLogger.get_logger().error(error_msg) + MelleaLogger.get_logger().error(error_msg) raise else: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( error_msg + "\nReturning original arguments without validation." ) return dict(args) diff --git a/mellea/backends/utils.py b/mellea/backends/utils.py index 5044a0f2a..f27f55950 100644 --- a/mellea/backends/utils.py +++ b/mellea/backends/utils.py @@ -13,7 +13,7 @@ from collections.abc import Callable from typing import Any -from ..core import CBlock, Component, Context, FancyLogger, ModelToolCall +from ..core import CBlock, Component, Context, MelleaLogger, ModelToolCall from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter from ..stdlib.components import Message @@ -76,7 +76,7 @@ def to_chat( for msg in ctx_as_conversation: for v in msg.values(): if "CBlock" in v: - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}" ) @@ -104,7 +104,7 @@ def to_tool_calls( for tool_name, tool_args in parse_tools(decoded_result): func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 5d414f52e..7dd781563 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -21,7 +21,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, GenerateLog, GenerateType, ModelOutputThunk, @@ -387,7 +387,7 @@ async def generate_from_chat_context( tools: dict[str, AbstractMelleaTool] = {} if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"tool calling is superseded by format; will not call tools for request: {action}" ) else: @@ -397,7 +397,7 @@ async def generate_from_chat_context( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") formatted_tools = convert_tools_to_json(tools) @@ -676,7 +676,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "WatsonxAI completion api does not accept response format, ignoring it for this request." ) @@ -743,7 +743,7 @@ def _extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index 34facdd7e..ec275296d 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -30,7 +30,7 @@ from .formatter import Formatter from .requirement import Requirement, ValidationResult, default_output_to_bool from .sampling import SamplingResult, SamplingStrategy -from .utils import FancyLogger, clear_log_context, log_context, set_log_context +from .utils import MelleaLogger, clear_log_context, log_context, set_log_context __all__ = [ "Backend", @@ -42,7 +42,7 @@ "ComputedModelOutputThunk", "Context", "ContextTurn", - "FancyLogger", + "MelleaLogger", "Formatter", "GenerateLog", "GenerateType", diff --git a/mellea/core/backend.py b/mellea/core/backend.py index 82f9fae5f..a1ea18f13 100644 --- a/mellea/core/backend.py +++ b/mellea/core/backend.py @@ -21,7 +21,7 @@ from ..plugins.manager import has_plugins, invoke_hook from ..plugins.types import HookType from .base import C, CBlock, Component, Context, ModelOutputThunk -from .utils import FancyLogger +from .utils import MelleaLogger # Necessary to define a type variable that has a default value. # This is because VSCode's pyright static type checker instantiates @@ -180,7 +180,7 @@ async def do_generate_walk( coroutines = [x.avalue() for x in _to_compute] # The following log message might get noisy. Feel free to remove if so. if len(_to_compute) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." ) await asyncio.gather(*coroutines) @@ -202,7 +202,7 @@ async def do_generate_walks( coroutines = [x.avalue() for x in _to_compute] # The following log message might get noisy. Feel free to remove if so. if len(_to_compute) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." ) await asyncio.gather(*coroutines) diff --git a/mellea/core/utils.py b/mellea/core/utils.py index d652aa40a..ef73800b1 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -1,22 +1,20 @@ """Logging utilities for the mellea core library. -Provides ``FancyLogger``, a singleton logger with colour-coded console output and +Provides ``MelleaLogger``, a singleton logger with colour-coded console output and an optional REST handler (``RESTHandler``) that forwards log records to a local -``/api/receive`` endpoint when the ``FLOG`` environment variable is set. All -internal mellea modules obtain their logger via ``FancyLogger.get_logger()``. +``/api/receive`` endpoint when the ``MELLEA_FLOG`` environment variable is set. All +internal mellea modules obtain their logger via ``MelleaLogger.get_logger()``. Environment variables --------------------- ``MELLEA_LOG_LEVEL`` Minimum log level name (e.g. ``DEBUG``, ``INFO``, ``WARNING``). Defaults to - ``INFO``. The legacy ``DEBUG`` variable is still honoured as a fallback. + ``INFO``. ``MELLEA_LOG_JSON`` Set to any truthy value (``1``, ``true``, ``yes``) to emit structured JSON on the console instead of the colour-coded human-readable format. -``FLOG`` +``MELLEA_FLOG`` When set, log records are forwarded to a local REST endpoint. -``DEBUG`` - Legacy flag — equivalent to ``MELLEA_LOG_LEVEL=DEBUG``. """ import contextlib @@ -35,7 +33,7 @@ # --------------------------------------------------------------------------- _context_local: threading.local = threading.local() -# Lock used to make FancyLogger singleton initialisation thread-safe. +# Lock used to make MelleaLogger singleton initialisation thread-safe. _logger_lock: threading.Lock = threading.Lock() # Standard LogRecord attribute names that must not be overwritten by callers. @@ -168,7 +166,7 @@ def filter(self, record: logging.LogRecord) -> bool: class RESTHandler(logging.Handler): """Logging handler that forwards records to a local REST endpoint. - Sends log records as JSON to ``/api/receive`` when the ``FLOG`` environment + Sends log records as JSON to ``/api/receive`` when the ``MELLEA_FLOG`` environment variable is set. Failures are silently suppressed to avoid disrupting the application. @@ -189,14 +187,14 @@ def __init__( self.headers = headers or {"Content-Type": "application/json"} def emit(self, record: logging.LogRecord) -> None: - """Forwards a log record to the REST endpoint when the ``FLOG`` environment variable is set. + """Forwards a log record to the REST endpoint when the ``MELLEA_FLOG`` environment variable is set. Silently suppresses any network or HTTP errors to avoid disrupting the application. Args: record (logging.LogRecord): The log record to forward. """ - if os.environ.get("FLOG"): + if os.environ.get("MELLEA_FLOG"): formatter = self.formatter if isinstance(formatter, JsonFormatter): log_dict = formatter._build_log_dict(record) @@ -375,13 +373,13 @@ def format(self, record: logging.LogRecord) -> str: return formatter.format(record) -class FancyLogger: +class MelleaLogger: """Singleton logger with colour-coded console output and optional REST forwarding. - Obtain the shared logger instance via ``FancyLogger.get_logger()``. Log level - defaults to ``INFO`` but can be raised to ``DEBUG`` by setting the ``DEBUG`` - environment variable. When the ``FLOG`` environment variable is set, records are - also forwarded to a local ``/api/receive`` REST endpoint via ``RESTHandler``. + Obtain the shared logger instance via ``MelleaLogger.get_logger()``. Log level + defaults to ``INFO`` but can be overridden via ``MELLEA_LOG_LEVEL``. When the + ``MELLEA_FLOG`` environment variable is set, records are also forwarded to a + local ``/api/receive`` REST endpoint via ``RESTHandler``. Attributes: logger (logging.Logger | None): The shared ``logging.Logger`` instance; ``None`` until first call to ``get_logger()``. @@ -410,8 +408,7 @@ class FancyLogger: def _resolve_log_level() -> int: """Resolves the effective log level from environment variables. - Checks ``MELLEA_LOG_LEVEL`` first, then falls back to the legacy - ``DEBUG`` flag, then defaults to ``INFO``. + Checks ``MELLEA_LOG_LEVEL`` and defaults to ``INFO``. Returns: int: A :mod:`logging` level integer. @@ -421,13 +418,11 @@ def _resolve_log_level() -> int: numeric = getattr(logging, level_name, None) if isinstance(numeric, int): return numeric - if os.environ.get("DEBUG"): - return FancyLogger.DEBUG - return FancyLogger.INFO + return MelleaLogger.INFO @staticmethod def get_logger() -> logging.Logger: - """Returns a FancyLogger.logger and sets level based upon env vars. + """Returns a MelleaLogger.logger and sets level based upon env vars. The logger is created once (singleton). Subsequent calls return the cached instance. Initialisation is protected by a module-level lock so @@ -436,11 +431,11 @@ def get_logger() -> logging.Logger: Returns: Configured logger with REST, stream, and optional OTLP handlers. """ - if FancyLogger.logger is None: + if MelleaLogger.logger is None: with _logger_lock: # Second check inside the lock: another thread may have finished # initialisation while we were waiting. - if FancyLogger.logger is None: + if MelleaLogger.logger is None: logger = logging.getLogger("fancy_logger") # Attach the context filter so ContextFilter fields reach all handlers @@ -448,7 +443,7 @@ def get_logger() -> logging.Logger: # Only set default level if user hasn't already configured it if logger.level == logging.NOTSET: - logger.setLevel(FancyLogger._resolve_log_level()) + logger.setLevel(MelleaLogger._resolve_log_level()) # --- REST handler --- api_url = "http://localhost:8000/api/receive" @@ -477,5 +472,5 @@ def get_logger() -> logging.Logger: otlp_handler.setFormatter(JsonFormatter()) logger.addHandler(otlp_handler) - FancyLogger.logger = logger - return FancyLogger.logger + MelleaLogger.logger = logger + return MelleaLogger.logger diff --git a/mellea/formatters/template_formatter.py b/mellea/formatters/template_formatter.py index 90c806297..ba967da64 100644 --- a/mellea/formatters/template_formatter.py +++ b/mellea/formatters/template_formatter.py @@ -19,7 +19,7 @@ from ..backends.cache import SimpleLRUCache from ..backends.model_ids import ModelIdentifier -from ..core import CBlock, Component, FancyLogger, TemplateRepresentation +from ..core import CBlock, Component, MelleaLogger, TemplateRepresentation from .chat_formatter import ChatFormatter @@ -104,7 +104,7 @@ def _stringify( stringified_template_args[key] = self._stringify(val) if representation.obj is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"template formatter encountered a TemplateRepresentation with no obj when stringifying {c.__class__}; setting obj to {c}" ) representation.obj = c @@ -127,7 +127,7 @@ def _stringify( return stringified_list case _: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"formatter encountered an unexpected type in _stringify; using str() on {c.__class__}" ) return str(c) @@ -161,7 +161,7 @@ def _load_template(self, repr: TemplateRepresentation) -> jinja2.Template: return jinja2.Environment().from_string(repr.template) # type: ignore if repr.template_order is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"using template formatter for {repr.obj.__class__.__name__} but no template order was provided, defaulting to class name" ) repr.template_order = ["*"] diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 4858ea366..b8d5506a5 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -5,7 +5,7 @@ from typing import Any from ..backends.tools import validate_tool_arguments -from ..core import FancyLogger, ModelToolCall +from ..core import MelleaLogger, ModelToolCall from ..core.base import AbstractMelleaTool from ..stdlib.components import Document, Message @@ -33,7 +33,7 @@ def extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index bee9eaf4a..8f08fe8a0 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -20,7 +20,7 @@ ModelOutputThunk, TemplateRepresentation, ) -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger MELLEA_FINALIZER_TOOL = "final_answer" """Used in the react loop to symbolize the loop is done.""" @@ -66,7 +66,7 @@ def format_for_llm(self) -> TemplateRepresentation: tools = {tool.name: tool for tool in self.tools} if tools.get(MELLEA_FINALIZER_TOOL, None) is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"overriding user tool '{MELLEA_FINALIZER_TOOL}' in react call; this tool name is required for internal use" ) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index f924ad7bb..117af4866 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -11,7 +11,7 @@ from mellea.backends.model_options import ModelOption from mellea.core.backend import Backend, BaseModelSubclass from mellea.core.base import AbstractMelleaTool, ComputedModelOutputThunk -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document @@ -79,7 +79,7 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 - FancyLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") + MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( action=ReactThought(), diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 335546e71..a3d361714 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -17,7 +17,7 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, GenerateLog, ImageBlock, ModelOutputThunk, @@ -441,7 +441,7 @@ def transform( if chosen_tool is None: chosen_tool = tools[0] - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" # type: ignore ) @@ -449,12 +449,12 @@ def transform( if chosen_tool: # Tell the user the function they should've called if no generated values were added. if len(chosen_tool._tool.args.keys()) == 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" ) new_ctx.add(chosen_tool) - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "added a tool message from transform to the context" ) return chosen_tool._tool_output, new_ctx @@ -575,7 +575,7 @@ async def aact( tool_calls=tool_calls, ) as span: if not silence_context_type_warning and not isinstance(context, SimpleContext): - FancyLogger().get_logger().warning( + MelleaLogger().get_logger().warning( "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." "\nSee the async section of the docs: https://docs.mellea.ai/how-to/use-async-and-streaming" ) @@ -618,7 +618,7 @@ async def aact( if strategy is None: # Only use the strategy if one is provided. Add a warning if requirements were passed in though. if requirements is not None and len(requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "Calling the function with NO strategy BUT requirements. No requirement is being checked!" ) @@ -1201,7 +1201,7 @@ async def atransform( if chosen_tool is None: chosen_tool = tools[0] - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" # type: ignore ) @@ -1209,12 +1209,12 @@ async def atransform( if chosen_tool: # Tell the user the function they should've called if no generated values were added. if len(chosen_tool._tool.args.keys()) == 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" ) new_ctx.add(chosen_tool) - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "added a tool message from transform to the context" ) return chosen_tool._tool_output, new_ctx diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index ad2d6a74b..3152acb71 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -9,9 +9,9 @@ UnsafeEnvironment, ) -from ...core import Context, FancyLogger, Requirement, ValidationResult +from ...core import Context, MelleaLogger, Requirement, ValidationResult -logger = FancyLogger.get_logger() +logger = MelleaLogger.get_logger() # region code extraction diff --git a/mellea/stdlib/requirements/requirement.py b/mellea/stdlib/requirements/requirement.py index 3ec5e8164..098151771 100644 --- a/mellea/stdlib/requirements/requirement.py +++ b/mellea/stdlib/requirements/requirement.py @@ -4,7 +4,7 @@ from collections.abc import Callable from typing import Any, overload -from ...core import CBlock, Context, FancyLogger, Requirement, ValidationResult +from ...core import CBlock, Context, MelleaLogger, Requirement, ValidationResult from ..components.intrinsic import Intrinsic @@ -37,7 +37,7 @@ def requirement_check_to_bool(x: CBlock | str) -> bool: likelihood = req_dict.get("requirement_likelihood", None) if likelihood is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"could not get value from alora requirement output; looking for `requirement_likelihood` in {req_dict}" ) return False @@ -173,7 +173,7 @@ def simple_validate( def validate(ctx: Context) -> ValidationResult: o = ctx.last_output() if o is None or o.value is None: - FancyLogger.get_logger().warn( + MelleaLogger.get_logger().warn( "Last output of context was None. That might be a problem. We return validation as False to be able to continue..." ) return ValidationResult( diff --git a/mellea/stdlib/requirements/safety/guardian.py b/mellea/stdlib/requirements/safety/guardian.py index 26cd632af..e90e86157 100644 --- a/mellea/stdlib/requirements/safety/guardian.py +++ b/mellea/stdlib/requirements/safety/guardian.py @@ -9,7 +9,7 @@ BaseModelSubclass, CBlock, Context, - FancyLogger, + MelleaLogger, Requirement, ValidationResult, ) @@ -197,7 +197,7 @@ def __init__( except Exception: pass - self._logger = FancyLogger.get_logger() + self._logger = MelleaLogger.get_logger() def get_effective_risk(self) -> str: """Return the effective risk criteria to use for validation. diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 9dca02087..65c77cd89 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -25,7 +25,7 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, @@ -143,7 +143,7 @@ async def sample( """ validation_ctx = validation_ctx if validation_ctx is not None else context - flog = FancyLogger.get_logger() + flog = MelleaLogger.get_logger() sampled_results: list[ComputedModelOutputThunk] = [] sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] @@ -152,7 +152,7 @@ async def sample( # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO + show_progress = show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO reqs = [] # global requirements supersede local requirements (global requirements can be defined by user) diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 25faccdca..78e5a5efe 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -11,7 +11,7 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, @@ -125,7 +125,7 @@ async def sample( """ validation_ctx = validation_ctx if validation_ctx is not None else context - flog = FancyLogger.get_logger() + flog = MelleaLogger.get_logger() sampled_results: list[ComputedModelOutputThunk] = [] sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] @@ -134,7 +134,7 @@ async def sample( # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO + show_progress = show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO reqs = [] # global requirements supersede local requirements (global requirements can be defined by user) diff --git a/mellea/stdlib/sampling/sofai.py b/mellea/stdlib/sampling/sofai.py index ff1c8f41f..e50eecb94 100644 --- a/mellea/stdlib/sampling/sofai.py +++ b/mellea/stdlib/sampling/sofai.py @@ -20,7 +20,7 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, @@ -431,7 +431,7 @@ def _prepare_s2_context( Returns: Tuple of (action_for_s2, context_for_s2). """ - flog = FancyLogger.get_logger() + flog = MelleaLogger.get_logger() if s2_mode == "fresh_start": # Clean slate: same prompt as S1 @@ -605,7 +605,7 @@ async def sample( "SOFAI requires ChatContext for conversation management." ) - flog = FancyLogger.get_logger() + flog = MelleaLogger.get_logger() reqs: list[Requirement] = list(requirements) if requirements else [] # State tracking for all attempts @@ -627,7 +627,7 @@ async def sample( next_action = deepcopy(action) next_context: Context = context - show_progress = flog.getEffectiveLevel() <= FancyLogger.INFO + show_progress = flog.getEffectiveLevel() <= MelleaLogger.INFO loop_iterator = ( tqdm.tqdm(range(self.loop_budget), desc="S1 Solver") if show_progress diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index d18cfbdcf..7b2cf26bc 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -25,7 +25,7 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, GenerateLog, ImageBlock, ModelOutputThunk, @@ -191,7 +191,7 @@ def start_session( session.cleanup() ``` """ - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Get model_id string for logging and tracing if isinstance(model_id, ModelIdentifier): @@ -307,7 +307,7 @@ def __init__(self, backend: Backend, ctx: Context | None = None): self.id = str(uuid.uuid4()) self.backend = backend self.ctx: Context = ctx if ctx is not None else SimpleContext() - self._session_logger = FancyLogger.get_logger() + self._session_logger = MelleaLogger.get_logger() self._context_token = None self._session_span = None diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 6ee9c17bf..4c537cc32 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -19,9 +19,9 @@ from pathlib import Path from typing import Any -from ...core import FancyLogger +from ...core import MelleaLogger -logger = FancyLogger.get_logger() +logger = MelleaLogger.get_logger() @dataclass diff --git a/test/conftest.py b/test/conftest.py index 48b6c7d48..22415fcbd 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -6,7 +6,7 @@ import pytest import requests -from mellea.core import FancyLogger +from mellea.core import MelleaLogger # Try to import optional dependencies for system detection try: @@ -302,7 +302,9 @@ def cleanup_gpu_backend(backend, backend_name="unknown"): backend: The backend instance to clean up. backend_name: Name for logging. """ - logger = FancyLogger.get_logger() + import gc + + logger = MelleaLogger.get_logger() logger.info(f"Cleaning up {backend_name} backend GPU memory...") try: @@ -455,7 +457,7 @@ def pytest_collection_modifyitems(config, items): # Reorder tests by backend if requested if config.getoption("--group-by-backend", default=False): - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() logger.info("Grouping tests by backend (--group-by-backend enabled)") # Group items by backend @@ -517,7 +519,7 @@ def pytest_runtest_setup(item): prev_group = getattr(pytest_runtest_setup, "_last_backend_group", None) if prev_group is not None and current_group != prev_group: - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() logger.info( f"Backend transition: {prev_group} → {current_group}. " "Running GPU cleanup." @@ -527,7 +529,7 @@ def pytest_runtest_setup(item): # Warm up Ollama models when entering Ollama group if current_group == "ollama" and prev_group != "ollama": - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1") port = os.environ.get("OLLAMA_PORT", "11434") logger.info( @@ -551,7 +553,7 @@ def pytest_runtest_setup(item): # Evict Ollama models when leaving Ollama group if prev_group == "ollama" and current_group != "ollama": - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1") port = os.environ.get("OLLAMA_PORT", "11434") logger.info("Evicting ollama models from VRAM after ollama group...") diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py index 241283c6a..75333f866 100644 --- a/test/core/test_utils_logging.py +++ b/test/core/test_utils_logging.py @@ -1,4 +1,4 @@ -"""Unit tests for FancyLogger, JsonFormatter, and ContextFilter enhancements.""" +"""Unit tests for MelleaLogger, JsonFormatter, and ContextFilter enhancements.""" import json import logging @@ -10,7 +10,7 @@ from mellea.core.utils import ( _RESERVED_LOG_RECORD_ATTRS, ContextFilter, - FancyLogger, + MelleaLogger, JsonFormatter, clear_log_context, log_context, @@ -186,9 +186,9 @@ def test_filter_noop_when_no_context(self) -> None: assert not hasattr(record, "trace_id") -class TestFancyLoggerLogLevel: +class TestMelleaLoggerLogLevel: def _reset(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logging.getLogger("fancy_logger").handlers.clear() logging.getLogger("fancy_logger").setLevel(logging.NOTSET) @@ -197,40 +197,26 @@ def teardown_method(self) -> None: def test_default_level_is_info(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("MELLEA_LOG_LEVEL", raising=False) - monkeypatch.delenv("DEBUG", raising=False) - assert FancyLogger._resolve_log_level() == logging.INFO + assert MelleaLogger._resolve_log_level() == logging.INFO def test_mellea_log_level_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("MELLEA_LOG_LEVEL", "DEBUG") - assert FancyLogger._resolve_log_level() == logging.DEBUG + assert MelleaLogger._resolve_log_level() == logging.DEBUG def test_mellea_log_level_warning(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("MELLEA_LOG_LEVEL", "WARNING") - assert FancyLogger._resolve_log_level() == logging.WARNING - - def test_legacy_debug_flag(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("MELLEA_LOG_LEVEL", raising=False) - monkeypatch.setenv("DEBUG", "1") - assert FancyLogger._resolve_log_level() == logging.DEBUG - - def test_mellea_log_level_takes_precedence_over_debug( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: - monkeypatch.setenv("MELLEA_LOG_LEVEL", "WARNING") - monkeypatch.setenv("DEBUG", "1") - assert FancyLogger._resolve_log_level() == logging.WARNING + assert MelleaLogger._resolve_log_level() == logging.WARNING def test_invalid_level_falls_back_to_info( self, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setenv("MELLEA_LOG_LEVEL", "BOGUS") - monkeypatch.delenv("DEBUG", raising=False) - assert FancyLogger._resolve_log_level() == logging.INFO + assert MelleaLogger._resolve_log_level() == logging.INFO -class TestFancyLoggerJsonConsole: +class TestMelleaLoggerJsonConsole: def _reset(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logger = logging.getLogger("fancy_logger") logger.handlers.clear() logger.setLevel(logging.NOTSET) @@ -242,7 +228,7 @@ def teardown_method(self) -> None: self._reset() def _get_stream_handler(self) -> logging.StreamHandler: # type: ignore[type-arg] - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() handlers = [h for h in logger.handlers if isinstance(h, logging.StreamHandler)] # RESTHandler is a subclass of Handler but not StreamHandler, so this # correctly picks the console handler. @@ -286,21 +272,21 @@ def test_json_console_disabled_with_false( assert isinstance(handler.formatter, CustomFormatter) -class TestFancyLoggerContextFilterWired: +class TestMelleaLoggerContextFilterWired: def setup_method(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logging.getLogger("fancy_logger").handlers.clear() logging.getLogger("fancy_logger").setLevel(logging.NOTSET) clear_log_context() def teardown_method(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logging.getLogger("fancy_logger").handlers.clear() logging.getLogger("fancy_logger").setLevel(logging.NOTSET) clear_log_context() def test_context_filter_present(self) -> None: - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() assert any(isinstance(f, ContextFilter) for f in logger.filters) @@ -391,12 +377,12 @@ def test_reserved_set_is_non_empty(self) -> None: class TestGetLoggerThreadSafety: def setup_method(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logging.getLogger("fancy_logger").handlers.clear() logging.getLogger("fancy_logger").setLevel(logging.NOTSET) def teardown_method(self) -> None: - FancyLogger.logger = None + MelleaLogger.logger = None logging.getLogger("fancy_logger").handlers.clear() logging.getLogger("fancy_logger").setLevel(logging.NOTSET) @@ -407,7 +393,7 @@ def test_concurrent_get_logger_returns_same_instance(self) -> None: def worker() -> None: barrier.wait() - results.append(FancyLogger.get_logger()) + results.append(MelleaLogger.get_logger()) threads = [threading.Thread(target=worker) for _ in range(4)] for t in threads: diff --git a/test/telemetry/test_logging.py b/test/telemetry/test_logging.py index 595def8c0..b8cddbce8 100644 --- a/test/telemetry/test_logging.py +++ b/test/telemetry/test_logging.py @@ -27,14 +27,14 @@ def _reset_logging_modules(): import mellea.core.utils import mellea.telemetry.logging - from mellea.core.utils import FancyLogger + from mellea.core.utils import MelleaLogger # Clear any existing handlers from previous tests fancy_logger = logging.getLogger("fancy_logger") fancy_logger.handlers.clear() - # Reset FancyLogger singleton - FancyLogger.logger = None + # Reset MelleaLogger singleton + MelleaLogger.logger = None # Force reload of logging module and core.utils to pick up env vars importlib.reload(mellea.telemetry.logging) @@ -160,28 +160,28 @@ def test_get_otlp_log_handler_can_be_added_to_logger(enable_otlp_logging): logger.removeHandler(handler) -# FancyLogger Integration Tests +# MelleaLogger Integration Tests def test_fancy_logger_includes_otlp_handler_when_enabled(enable_otlp_logging): - """Test that FancyLogger includes OTLP handler when enabled.""" - from mellea.core.utils import FancyLogger + """Test that MelleaLogger includes OTLP handler when enabled.""" + from mellea.core.utils import MelleaLogger - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Check that logger has handlers assert len(logger.handlers) > 0 # Check if any handler is a LoggingHandler (OTLP) has_otlp_handler = any(isinstance(h, LoggingHandler) for h in logger.handlers) # type: ignore - assert has_otlp_handler, "FancyLogger should have OTLP handler when enabled" + assert has_otlp_handler, "MelleaLogger should have OTLP handler when enabled" def test_fancy_logger_works_without_otlp(clean_logging_env): - """Test that FancyLogger works normally when OTLP is disabled.""" - from mellea.core.utils import FancyLogger + """Test that MelleaLogger works normally when OTLP is disabled.""" + from mellea.core.utils import MelleaLogger - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Should still have REST and console handlers assert len(logger.handlers) >= 2 @@ -189,7 +189,7 @@ def test_fancy_logger_works_without_otlp(clean_logging_env): # Should not have OTLP handler has_otlp_handler = any(isinstance(h, LoggingHandler) for h in logger.handlers) # type: ignore assert not has_otlp_handler, ( - "FancyLogger should not have OTLP handler when disabled" + "MelleaLogger should not have OTLP handler when disabled" ) # Verify logger can log messages (backward compatibility) From 2ca30b46b1aea098315ff7d5e7d283bc31b07c6b Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 01:51:01 -0400 Subject: [PATCH 04/11] use contextvars --- mellea/core/utils.py | 51 ++++++++++++++++----------------- test/core/test_utils_logging.py | 50 ++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/mellea/core/utils.py b/mellea/core/utils.py index ef73800b1..89bc1a41f 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -18,6 +18,7 @@ """ import contextlib +import contextvars import json import logging import os @@ -29,9 +30,11 @@ import requests # --------------------------------------------------------------------------- -# Thread-local storage for per-request context fields +# Per-task/coroutine context fields (safe for asyncio — each Task gets its own copy) # --------------------------------------------------------------------------- -_context_local: threading.local = threading.local() +_log_context: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( + "log_context_fields", default={} +) # Lock used to make MelleaLogger singleton initialisation thread-safe. _logger_lock: threading.Lock = threading.Lock() @@ -65,7 +68,7 @@ def set_log_context(**fields: Any) -> None: - """Inject extra fields into every log record emitted from this thread. + """Inject extra fields into every log record emitted from this coroutine or thread. Call this at the start of a request or task to attach identifiers such as ``trace_id`` or ``request_id`` without modifying individual log calls. @@ -88,23 +91,23 @@ def set_log_context(**fields: Any) -> None: f"Context field names clash with LogRecord reserved attributes: " f"{sorted(invalid)}. Choose different names." ) - existing: dict[str, Any] = getattr(_context_local, "fields", {}) - existing.update(fields) - _context_local.fields = existing + _log_context.set({**_log_context.get(), **fields}) def clear_log_context() -> None: - """Remove all thread-local log context fields set by :func:`set_log_context`.""" - _context_local.fields = {} + """Remove all context fields set by :func:`set_log_context` for this coroutine/thread.""" + _log_context.set({}) @contextlib.contextmanager def log_context(**fields: Any) -> Generator[None, None, None]: """Context manager that injects *fields* for the duration of the block. - On exit — including on exceptions — removes only the keys that *this* - invocation set, leaving any enclosing context intact. This makes nested - and thread-pool usage safe without requiring manual cleanup. + On exit — including on exceptions — the context is restored to its state + before the block via a ``ContextVar`` token. This is safe for both nested + usage and concurrent asyncio tasks: each ``asyncio.Task`` owns an isolated + copy of the context variable, so coroutines running on the same event-loop + thread cannot overwrite each other's fields. Example:: @@ -123,21 +126,17 @@ def log_context(**fields: Any) -> Generator[None, None, None]: Raises: ValueError: If any key clashes with a reserved ``LogRecord`` attribute. """ - existing_before: dict[str, Any] = getattr(_context_local, "fields", {}) - previous: dict[str, Any] = { - k: existing_before[k] for k in fields if k in existing_before - } - set_log_context(**fields) + invalid = frozenset(fields) & _RESERVED_LOG_RECORD_ATTRS + if invalid: + raise ValueError( + f"Context field names clash with LogRecord reserved attributes: " + f"{sorted(invalid)}. Choose different names." + ) + token = _log_context.set({**_log_context.get(), **fields}) try: yield finally: - existing: dict[str, Any] = getattr(_context_local, "fields", {}) - for key in fields: - if key in previous: - existing[key] = previous[key] - else: - existing.pop(key, None) - _context_local.fields = existing + _log_context.reset(token) class ContextFilter(logging.Filter): @@ -157,7 +156,7 @@ def filter(self, record: logging.LogRecord) -> bool: Returns: bool: Always ``True`` — the record is never suppressed. """ - fields: dict[str, Any] = getattr(_context_local, "fields", {}) + fields: dict[str, Any] = _log_context.get() for key, value in fields.items(): setattr(record, key, value) return True @@ -306,9 +305,9 @@ def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: log_record.update(self._extra) # Dynamic context fields — prefer record attributes (set by - # ContextFilter) but fall back to thread-local storage so the + # ContextFilter) but fall back to ContextVar storage so the # formatter works standalone without a filter attached. - context_fields: dict[str, Any] = getattr(_context_local, "fields", {}) + context_fields: dict[str, Any] = _log_context.get() for key, value in context_fields.items(): log_record[key] = getattr(record, key, value) diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py index 75333f866..945c797e7 100644 --- a/test/core/test_utils_logging.py +++ b/test/core/test_utils_logging.py @@ -1,5 +1,6 @@ """Unit tests for MelleaLogger, JsonFormatter, and ContextFilter enhancements.""" +import asyncio import json import logging import threading @@ -350,6 +351,55 @@ def test_rejects_reserved_attribute(self) -> None: pass +class TestLogContextAsyncIsolation: + """Verify that concurrent asyncio tasks cannot contaminate each other's context.""" + + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_concurrent_tasks_isolated(self) -> None: + """Fields set inside one asyncio.Task must not bleed into a sibling task.""" + fmt = JsonFormatter() + results: dict[str, Any] = {} + + async def task_a() -> None: + with log_context(trace_id="task-a"): + # Yield so task_b can run and attempt to overwrite the context + await asyncio.sleep(0) + results["a"] = json.loads(fmt.format(_make_record())) + + async def task_b() -> None: + with log_context(trace_id="task-b"): + await asyncio.sleep(0) + results["b"] = json.loads(fmt.format(_make_record())) + + async def run() -> None: + await asyncio.gather(asyncio.create_task(task_a()), asyncio.create_task(task_b())) + + asyncio.run(run()) + + assert results["a"].get("trace_id") == "task-a" + assert results["b"].get("trace_id") == "task-b" + + def test_task_context_does_not_leak_after_completion(self) -> None: + """A task's context fields must not persist into the caller after the task ends.""" + fmt = JsonFormatter() + + async def child() -> None: + set_log_context(trace_id="child-task") + + async def run() -> None: + await asyncio.create_task(child()) + # The caller's context should be unaffected + return json.loads(fmt.format(_make_record())) + + parsed = asyncio.run(run()) + assert "trace_id" not in parsed + + class TestReservedAttributeValidation: def setup_method(self) -> None: clear_log_context() From e6fbaacc088c47fe8b39448cce6a9eec3e7bdd88 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 01:57:30 -0400 Subject: [PATCH 05/11] address feedback from alex --- mellea/core/utils.py | 42 ++++++++++++++++++++++++++------- test/core/test_utils_logging.py | 34 ++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/mellea/core/utils.py b/mellea/core/utils.py index 89bc1a41f..bafa06b01 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -40,7 +40,7 @@ _logger_lock: threading.Lock = threading.Lock() # Standard LogRecord attribute names that must not be overwritten by callers. -_RESERVED_LOG_RECORD_ATTRS: frozenset[str] = frozenset( +RESERVED_LOG_RECORD_ATTRS: frozenset[str] = frozenset( ( "args", "created", @@ -85,7 +85,7 @@ def set_log_context(**fields: Any) -> None: ValueError: If any key clashes with a standard ``logging.LogRecord`` attribute (e.g. ``levelname``, ``module``, ``thread``). """ - invalid = frozenset(fields) & _RESERVED_LOG_RECORD_ATTRS + invalid = frozenset(fields) & RESERVED_LOG_RECORD_ATTRS if invalid: raise ValueError( f"Context field names clash with LogRecord reserved attributes: " @@ -126,7 +126,7 @@ def log_context(**fields: Any) -> Generator[None, None, None]: Raises: ValueError: If any key clashes with a reserved ``LogRecord`` attribute. """ - invalid = frozenset(fields) & _RESERVED_LOG_RECORD_ATTRS + invalid = frozenset(fields) & RESERVED_LOG_RECORD_ATTRS if invalid: raise ValueError( f"Context field names clash with LogRecord reserved attributes: " @@ -196,7 +196,7 @@ def emit(self, record: logging.LogRecord) -> None: if os.environ.get("MELLEA_FLOG"): formatter = self.formatter if isinstance(formatter, JsonFormatter): - log_dict = formatter._build_log_dict(record) + log_dict = formatter.format_as_dict(record) else: log_dict = {"message": self.format(record)} try: @@ -221,8 +221,12 @@ class JsonFormatter(logging.Formatter): Args: timestamp_format: ``strftime`` format for the ``timestamp`` field. Defaults to ISO-8601 (``"%Y-%m-%dT%H:%M:%S"``). - include_fields: Whitelist of core field names to keep. When ``None`` - all core fields are included. + include_fields: Whitelist of **core** field names to keep. When ``None`` + all core fields are included. Note: this filter applies only to the + fields listed in ``_DEFAULT_FIELDS``; ``extra_fields`` passed to the + constructor and dynamic context fields (set via + :func:`set_log_context`) are **always** included regardless of this + setting. exclude_fields: Set of core field names to drop. Applied after *include_fields*. extra_fields: Static key-value pairs merged into every log record. @@ -268,6 +272,20 @@ def __init__( self._exclude: frozenset[str] = frozenset(exclude_fields or []) self._extra: dict[str, Any] = dict(extra_fields or {}) + def format_as_dict(self, record: logging.LogRecord) -> dict[str, Any]: + """Return the log record as a dictionary (public API for external callers). + + Equivalent to :meth:`_build_log_dict` but part of the public interface so + handlers and other callers do not need to reach into private methods. + + Args: + record: The log record to convert. + + Returns: + A dictionary ready for JSON serialisation. + """ + return self._build_log_dict(record) + def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: """Build a log record dictionary with core, extra, and context fields. @@ -277,11 +295,19 @@ def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: Returns: A dictionary ready for JSON serialisation. """ - # Build the full set of core fields first + # Build the full set of core fields first. + # A TypeError here means the caller used %-style format placeholders + # with the wrong number of arguments (e.g. logger.info("%s %s", one)). + # Catch it and substitute a safe error string so the record is still emitted. + try: + message = record.getMessage() + except TypeError as exc: + message = f" original msg={record.msg!r}" + all_core: dict[str, Any] = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, - "message": record.getMessage(), + "message": message, "module": record.module, "function": record.funcName, "line_number": record.lineno, diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py index 945c797e7..738819c95 100644 --- a/test/core/test_utils_logging.py +++ b/test/core/test_utils_logging.py @@ -1,5 +1,7 @@ """Unit tests for MelleaLogger, JsonFormatter, and ContextFilter enhancements.""" +# pytest: unit + import asyncio import json import logging @@ -8,8 +10,10 @@ import pytest +pytestmark = pytest.mark.unit + from mellea.core.utils import ( - _RESERVED_LOG_RECORD_ATTRS, + RESERVED_LOG_RECORD_ATTRS, ContextFilter, MelleaLogger, JsonFormatter, @@ -187,6 +191,7 @@ def test_filter_noop_when_no_context(self) -> None: assert not hasattr(record, "trace_id") +@pytest.mark.unit class TestMelleaLoggerLogLevel: def _reset(self) -> None: MelleaLogger.logger = None @@ -215,6 +220,7 @@ def test_invalid_level_falls_back_to_info( assert MelleaLogger._resolve_log_level() == logging.INFO +@pytest.mark.unit class TestMelleaLoggerJsonConsole: def _reset(self) -> None: MelleaLogger.logger = None @@ -273,6 +279,7 @@ def test_json_console_disabled_with_false( assert isinstance(handler.formatter, CustomFormatter) +@pytest.mark.unit class TestMelleaLoggerContextFilterWired: def setup_method(self) -> None: MelleaLogger.logger = None @@ -422,9 +429,10 @@ def test_set_log_context_accepts_non_reserved(self) -> None: assert parsed["custom_field"] == "fine" def test_reserved_set_is_non_empty(self) -> None: - assert len(_RESERVED_LOG_RECORD_ATTRS) > 10 + assert len(RESERVED_LOG_RECORD_ATTRS) > 10 +@pytest.mark.unit class TestGetLoggerThreadSafety: def setup_method(self) -> None: MelleaLogger.logger = None @@ -464,6 +472,28 @@ def test_format_returns_str(self) -> None: json.loads(result) # also valid JSON +class TestFormatAsDict: + def test_format_as_dict_returns_same_as_internal(self) -> None: + fmt = JsonFormatter() + record = _make_record("hi") + assert fmt.format_as_dict(record) == fmt._build_log_dict(record) + + def test_format_as_dict_malformed_args_does_not_raise(self) -> None: + fmt = JsonFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="x.py", + lineno=1, + msg="val: %s %s", + args=("only_one",), + exc_info=None, + ) + result = fmt.format_as_dict(record) + assert "message" in result + assert "format error" in result["message"] + + class TestIncludeFieldsValidation: def test_valid_include_fields_accepted(self) -> None: fmt = JsonFormatter(include_fields=["timestamp", "level", "message"]) From 208ecd0463663402c8c16e68c0f54bb4f6bb6d89 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 02:30:24 -0400 Subject: [PATCH 06/11] additional changes based on feedback --- .agents/skills/mellea-logging/SKILL.md | 141 ++++++ .../evaluation-and-observability/logging.md | 49 ++ mellea/stdlib/sampling/base.py | 420 +++++++++--------- mellea/stdlib/session.py | 10 +- test/core/test_logger_plugin_hooks.py | 239 ++++++++++ test/core/test_utils_logging.py | 8 +- 6 files changed, 657 insertions(+), 210 deletions(-) create mode 100644 .agents/skills/mellea-logging/SKILL.md create mode 100644 test/core/test_logger_plugin_hooks.py diff --git a/.agents/skills/mellea-logging/SKILL.md b/.agents/skills/mellea-logging/SKILL.md new file mode 100644 index 000000000..d479ab14c --- /dev/null +++ b/.agents/skills/mellea-logging/SKILL.md @@ -0,0 +1,141 @@ +--- +name: mellea-logging +description: > + Best-practices guide for adding or reviewing logging in the Mellea codebase. + Covers when to use log_context() vs a dedicated logger call, canonical field + names, reserved attribute constraints, async/thread safety, and what events + deserve dedicated log lines. + Use when: adding a new log call; reviewing a PR that touches MelleaLogger; + deciding where to inject context fields; debugging why a field is missing from + a log record; or ensuring consistency with the project logging conventions. +argument-hint: "[file-or-directory]" +compatibility: "Claude Code, IBM Bob" +metadata: + version: "2026-04-15" + capabilities: [read_file, grep, glob] +--- + +# Mellea Logging Best Practices + +All logging in Mellea flows through `MelleaLogger.get_logger()`, defined in +`mellea/core/utils.py`. This skill documents the conventions for adding and +reviewing log instrumentation. + +## Quick reference + +```python +from mellea.core import MelleaLogger, log_context, set_log_context, clear_log_context + +logger = MelleaLogger.get_logger() + +# Dedicated log call — for a discrete event +logger.info("SUCCESS") + +# Context injection — attach fields to every record in a scope +with log_context(request_id="req-abc", trace_id="t-1"): + logger.info("Starting generation") # includes request_id, trace_id + # ... all nested calls inherit these fields automatically +``` + +## When to add a dedicated log call + +Use `logger.info/warning/error(...)` for **discrete, named events**: + +| Event type | Level | Example | +|------------|-------|---------| +| Phase transition | INFO | `"SUCCESS"`, `"FAILED"`, `"Starting session"` | +| Loop progress | INFO | `"Running loop 2 of 3"` | +| Recoverable issue | WARNING | `"Warmup failed for model: ..."` | +| Unexpected failure | ERROR | exception tracebacks, hard failures | +| Verbose diagnostics | DEBUG | token counts, prompt previews | + +Do **not** add log calls for: + +- Values already captured in a `log_context` field (redundant noise) +- Internal helper functions where the calling function already logs the event +- State that is already reflected in telemetry spans + +## When to use log_context + +Use `log_context` (or `set_log_context`) to attach **identifiers and metadata +that should appear on every log record within a scope** — without threading +them through every call. + +Typical injection points: + +| Scope | Where to inject | Fields | +|-------|----------------|--------| +| Session lifetime | `MelleaSession.__enter__` | `session_id`, `backend`, `model_id` | +| Sampling loop | `BaseSamplingStrategy.sample()` | `strategy`, `loop_budget` | +| HTTP request handler | entry point of the handler | `request_id`, `trace_id` | +| Background task | top of the task coroutine | `task_id`, `job_name` | + +## Canonical field names + +Use these names consistently. Do not invent synonyms. + +| Field | Type | Description | +|-------|------|-------------| +| `session_id` | str (UUID) | Unique ID for a `MelleaSession` | +| `backend` | str | Backend class name, e.g. `"OllamaModelBackend"` | +| `model_id` | str | Model identifier string | +| `strategy` | str | Sampling strategy class name | +| `loop_budget` | int | Max generate/validate cycles for this sampling call | +| `request_id` | str | Caller-supplied request identifier | +| `trace_id` | str | Distributed trace ID (from OpenTelemetry or caller) | +| `span_id` | str | Span ID within a trace | +| `user_id` | str | End-user identifier (when applicable) | + +## Reserved attribute names — do not use as context fields + +The following names are standard `logging.LogRecord` attributes. Passing them +to `log_context()` or `set_log_context()` raises `ValueError`. See +`RESERVED_LOG_RECORD_ATTRS` in `mellea/core/utils.py` for the full set. + +`args`, `created`, `exc_info`, `exc_text`, `filename`, `funcName`, +`levelname`, `levelno`, `lineno`, `message`, `module`, `msecs`, `msg`, +`name`, `pathname`, `process`, `processName`, `relativeCreated`, +`stack_info`, `thread`, `threadName` + +## Prefer the context manager over set/clear + +```python +# Preferred — guaranteed cleanup even on exceptions +with log_context(trace_id="abc"): + do_work() + +# Acceptable only when lifetime equals __enter__/__exit__ +# (e.g. MelleaSession, where the CM already guarantees cleanup) +set_log_context(session_id=self.id) +# ... later in __exit__ ... +clear_log_context() +``` + +The context manager uses a `ContextVar` token to restore the previous state +on exit. This means **nesting works correctly** — inner calls can add fields +without clobbering the outer scope's values. + +## Async and thread safety + +`log_context` uses `contextvars.ContextVar`, which is safe for concurrent +asyncio tasks: + +- Each `asyncio.Task` gets its own copy of the context. +- Fields set in one task do not bleed into sibling tasks. + +**Plugin hooks**: Mellea hooks (`AUDIT`, `SEQUENTIAL`, `CONCURRENT`) are +`await`ed in the same asyncio task as the call site. `ContextVar` state IS +inherited — fields set around a `strategy.sample()` call will appear on +records emitted inside hook handlers automatically. + +## Checklist before committing + +1. New log calls use `MelleaLogger.get_logger()`, not `logging.getLogger(...)`. +2. Context fields use canonical names from the table above. +3. No reserved attribute names passed to `log_context`. +4. Scoped fields use `with log_context(...)`, not `set_log_context` (unless + managing an `__enter__`/`__exit__` pair). +5. Hook handlers that need context set it internally — they do not inherit the + caller's context. +6. New events that span multiple log records inject fields via context, not by + repeating them on every `logger.info(...)` call. diff --git a/docs/docs/evaluation-and-observability/logging.md b/docs/docs/evaluation-and-observability/logging.md index 854aec956..672593a9f 100644 --- a/docs/docs/evaluation-and-observability/logging.md +++ b/docs/docs/evaluation-and-observability/logging.md @@ -51,6 +51,55 @@ Each message is formatted as: message ``` +## Sample output + +### Console format (default) + +Running `m.instruct(...)` inside a session produces lines like: + +```text +=== 11:11:25-INFO ====== +SUCCESS +``` + +### JSON format (`MELLEA_LOG_JSON=1`) + +With structured JSON output enabled, the same `SUCCESS` record looks like: + +```json +{ + "timestamp": "2026-04-08T11:11:25", + "level": "INFO", + "message": "SUCCESS", + "module": "base", + "function": "sample", + "line_number": 258, + "process_id": 73738, + "thread_id": 6179762176, + "session_id": "550e8400-e29b-41d4-a716-446655440000", + "backend": "OllamaModelBackend", + "model_id": "granite4:micro", + "strategy": "RejectionSamplingStrategy", + "loop_budget": 3 +} +``` + +The `session_id`, `backend`, `model_id`, `strategy`, and `loop_budget` fields +are injected automatically when the call runs inside a `with session:` block. +They appear on every log record within that scope — not just `SUCCESS`. + +### Adding custom context fields + +Use `log_context` to attach your own fields for the duration of a block: + +```python +from mellea.core import log_context + +with log_context(request_id="req-abc", user_id="usr-42"): + result = m.instruct("Summarise this document") + # Every log record emitted here will include request_id and user_id +``` + ## OTLP log export When the `[telemetry]` extra is installed, Mellea can export logs to an OTLP diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 65c77cd89..388417c3b 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -31,6 +31,7 @@ SamplingResult, SamplingStrategy, ValidationResult, + log_context, ) from ...plugins.manager import has_plugins, invoke_hook from ...plugins.types import HookType @@ -145,234 +146,241 @@ async def sample( flog = MelleaLogger.get_logger() - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO - - reqs = [] - # global requirements supersede local requirements (global requirements can be defined by user) - # Todo: re-evaluate if this makes sense - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - reqs = list(set(reqs)) - - loop_count = 0 - - # --- sampling_loop_start hook --- - effective_loop_budget = self.loop_budget - if has_plugins(HookType.SAMPLING_LOOP_START): - from ...plugins.hooks.sampling import SamplingLoopStartPayload - - start_payload = SamplingLoopStartPayload( - strategy_name=type(self).__name__, - action=action, - context=context, - requirements=reqs, - loop_budget=self.loop_budget, - ) - _, start_payload = await invoke_hook( - HookType.SAMPLING_LOOP_START, start_payload, backend=backend - ) - effective_loop_budget = start_payload.loop_budget + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] - loop_budget_range_iterator = ( - tqdm.tqdm(range(effective_loop_budget)) # type: ignore - if show_progress - else range(effective_loop_budget) # type: ignore - ) - - next_action = deepcopy(action) - next_context = context - for _ in loop_budget_range_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # run a generation pass - result, result_ctx = await backend.generate_from_context( - next_action, - ctx=next_context, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - await result.avalue() - result = ComputedModelOutputThunk(result) - - # Sampling strategies may use different components from the original - # action. This might cause discrepancies in the expected parsed_repr - # type / value. Explicitly overwrite that here. - # TODO: See if there's a more elegant way for this so that each sampling - # strategy doesn't have to re-implement it. - result.parsed_repr = action.parse(result) - - # validation pass - val_scores_co = mfuncs.avalidate( - reqs=reqs, - context=result_ctx, - backend=backend, - output=result, - format=None, - model_options=model_options, - # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = ( + show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) - sample_contexts.append(result_ctx) - - all_validations_passed = all(bool(s[1]) for s in constraint_scores) - - # --- sampling_iteration hook --- - if has_plugins(HookType.SAMPLING_ITERATION): - from ...plugins.hooks.sampling import SamplingIterationPayload - - iter_payload = SamplingIterationPayload( - iteration=loop_count, - action=next_action, - result=result, - validation_results=constraint_scores, - all_validations_passed=all_validations_passed, - valid_count=sum(1 for s in constraint_scores if bool(s[1])), - total_count=len(constraint_scores), + + reqs = [] + # global requirements supersede local requirements (global requirements can be defined by user) + # Todo: re-evaluate if this makes sense + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) + + loop_count = 0 + + # --- sampling_loop_start hook --- + effective_loop_budget = self.loop_budget + if has_plugins(HookType.SAMPLING_LOOP_START): + from ...plugins.hooks.sampling import SamplingLoopStartPayload + + start_payload = SamplingLoopStartPayload( + strategy_name=type(self).__name__, + action=action, + context=context, + requirements=reqs, + loop_budget=self.loop_budget, ) - await invoke_hook( - HookType.SAMPLING_ITERATION, iter_payload, backend=backend + _, start_payload = await invoke_hook( + HookType.SAMPLING_LOOP_START, start_payload, backend=backend ) + effective_loop_budget = start_payload.loop_budget - # if all vals are true -- break and return success - if all_validations_passed: - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True + loop_budget_range_iterator = ( + tqdm.tqdm(range(effective_loop_budget)) # type: ignore + if show_progress + else range(effective_loop_budget) # type: ignore + ) - # --- sampling_loop_end hook (success) --- - if has_plugins(HookType.SAMPLING_LOOP_END): - from ...plugins.hooks.sampling import SamplingLoopEndPayload + next_action = deepcopy(action) + next_context = context + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # run a generation pass + result, result_ctx = await backend.generate_from_context( + next_action, + ctx=next_context, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + await result.avalue() + result = ComputedModelOutputThunk(result) + + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + # TODO: See if there's a more elegant way for this so that each sampling + # strategy doesn't have to re-implement it. + result.parsed_repr = action.parse(result) + + # validation pass + val_scores_co = mfuncs.avalidate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=None, + model_options=model_options, + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + ) + val_scores = await val_scores_co + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + all_validations_passed = all(bool(s[1]) for s in constraint_scores) + + # --- sampling_iteration hook --- + if has_plugins(HookType.SAMPLING_ITERATION): + from ...plugins.hooks.sampling import SamplingIterationPayload + + iter_payload = SamplingIterationPayload( + iteration=loop_count, + action=next_action, + result=result, + validation_results=constraint_scores, + all_validations_passed=all_validations_passed, + valid_count=sum(1 for s in constraint_scores if bool(s[1])), + total_count=len(constraint_scores), + ) + await invoke_hook( + HookType.SAMPLING_ITERATION, iter_payload, backend=backend + ) - end_payload = SamplingLoopEndPayload( + # if all vals are true -- break and return success + if all_validations_passed: + flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + + # --- sampling_loop_end hook (success) --- + if has_plugins(HookType.SAMPLING_LOOP_END): + from ...plugins.hooks.sampling import SamplingLoopEndPayload + + end_payload = SamplingLoopEndPayload( + success=True, + iterations_used=loop_count, + final_result=result, + final_action=next_action, + final_context=result_ctx, + all_results=sampled_results, + all_validations=sampled_scores, + ) + await invoke_hook( + HookType.SAMPLING_LOOP_END, end_payload, backend=backend + ) + + # SUCCESS !!!! + return SamplingResult( + result_index=len(sampled_results) - 1, success=True, - iterations_used=loop_count, - final_result=result, - final_action=next_action, - final_context=result_ctx, - all_results=sampled_results, - all_validations=sampled_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, ) - await invoke_hook( - HookType.SAMPLING_LOOP_END, end_payload, backend=backend + + else: + # log partial success and continue + failed = [s for s in constraint_scores if not bool(s[1])] + count_failed = len(failed) + failed_reqs = [ + r[0].description + if r[0].description is not None + else "[no description]" + for r in failed + ] + stringify_failed = "\n\t - " + "\n\t - ".join(failed_reqs) + flog.info( + f"FAILED. Valid: {len(constraint_scores) - count_failed}/{len(constraint_scores)}. Failed: {stringify_failed}" ) - # SUCCESS !!!! - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, + # If we did not pass all constraints, update the instruction and try again. + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, ) - else: - # log partial success and continue - failed = [s for s in constraint_scores if not bool(s[1])] - count_failed = len(failed) - failed_reqs = [ - r[0].description - if r[0].description is not None - else "[no description]" - for r in failed - ] - stringify_failed = "\n\t - " + "\n\t - ".join(failed_reqs) - flog.info( - f"FAILED. Valid: {len(constraint_scores) - count_failed}/{len(constraint_scores)}. Failed: {stringify_failed}" - ) + # --- sampling_repair hook --- + if has_plugins(HookType.SAMPLING_REPAIR): + from ...plugins.hooks.sampling import SamplingRepairPayload + + repair_payload = SamplingRepairPayload( + repair_type=getattr( + self, "_get_repair_type", lambda: "unknown" + )(), + failed_action=sampled_actions[-1], + failed_result=sampled_results[-1], + failed_validations=sampled_scores[-1], + repair_action=next_action, + repair_context=next_context, + repair_iteration=loop_count, + ) + await invoke_hook( + HookType.SAMPLING_REPAIR, repair_payload, backend=backend + ) - # If we did not pass all constraints, update the instruction and try again. - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - # --- sampling_repair hook --- - if has_plugins(HookType.SAMPLING_REPAIR): - from ...plugins.hooks.sampling import SamplingRepairPayload - - repair_payload = SamplingRepairPayload( - repair_type=getattr(self, "_get_repair_type", lambda: "unknown")(), - failed_action=sampled_actions[-1], - failed_result=sampled_results[-1], - failed_validations=sampled_scores[-1], - repair_action=next_action, - repair_context=next_context, - repair_iteration=loop_count, - ) - await invoke_hook( - HookType.SAMPLING_REPAIR, repair_payload, backend=backend - ) - - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." + ) - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore - assert ( - sampled_results[best_failed_index]._generate_log is not None - ) # Cannot be None after generation. - sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + # --- sampling_loop_end hook (failure) --- + if has_plugins(HookType.SAMPLING_LOOP_END): + from ...plugins.hooks.sampling import SamplingLoopEndPayload - # --- sampling_loop_end hook (failure) --- - if has_plugins(HookType.SAMPLING_LOOP_END): - from ...plugins.hooks.sampling import SamplingLoopEndPayload + _final_ctx = ( + sample_contexts[best_failed_index] if sample_contexts else context + ) + end_payload = SamplingLoopEndPayload( + success=False, + iterations_used=loop_count, + final_result=sampled_results[best_failed_index], + final_action=sampled_actions[best_failed_index], + final_context=_final_ctx, + failure_reason=f"Budget exhausted after {loop_count} iterations", + all_results=sampled_results, + all_validations=sampled_scores, + ) + await invoke_hook( + HookType.SAMPLING_LOOP_END, end_payload, backend=backend + ) - _final_ctx = ( - sample_contexts[best_failed_index] if sample_contexts else context - ) - end_payload = SamplingLoopEndPayload( + return SamplingResult( + result_index=best_failed_index, success=False, - iterations_used=loop_count, - final_result=sampled_results[best_failed_index], - final_action=sampled_actions[best_failed_index], - final_context=_final_ctx, - failure_reason=f"Budget exhausted after {loop_count} iterations", - all_results=sampled_results, - all_validations=sampled_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + sample_contexts=sample_contexts, ) - await invoke_hook(HookType.SAMPLING_LOOP_END, end_payload, backend=backend) - - return SamplingResult( - result_index=best_failed_index, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) class RejectionSamplingStrategy(BaseSamplingStrategy): diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 7b2cf26bc..0cace1d63 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -25,15 +25,17 @@ Component, ComputedModelOutputThunk, Context, - MelleaLogger, GenerateLog, ImageBlock, + MelleaLogger, ModelOutputThunk, Requirement, S, SamplingResult, SamplingStrategy, ValidationResult, + clear_log_context, + set_log_context, ) from ..helpers import _run_async_in_thread from ..plugins.manager import has_plugins, invoke_hook @@ -320,11 +322,17 @@ def __enter__(self): context_type=self.ctx.__class__.__name__, ).__enter__() self._context_token = _context_session.set(self) + set_log_context( + session_id=self.id, + backend=self.backend.__class__.__name__, + model_id=str(getattr(self.backend, "model_id", "unknown")), + ) return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit context manager and cleanup session.""" self.cleanup() + clear_log_context() if self._context_token is not None: _context_session.reset(self._context_token) self._context_token = None diff --git a/test/core/test_logger_plugin_hooks.py b/test/core/test_logger_plugin_hooks.py new file mode 100644 index 000000000..77757f3c4 --- /dev/null +++ b/test/core/test_logger_plugin_hooks.py @@ -0,0 +1,239 @@ +"""Integration tests verifying MelleaLogger works correctly inside plugin hook dispatch. + +Key properties verified: +- MelleaLogger.get_logger() is callable and usable from inside a hook handler. +- Log records emitted from inside a hook are captured. +- log_context fields set inside a hook appear on records from that hook. +- log_context fields set in the test body do NOT bleed into hook execution + (hooks run in a different thread, so ContextVar state is not inherited). +""" + +# pytest: integration + +from __future__ import annotations + +import logging +from typing import Any + +import pytest + +pytestmark = pytest.mark.integration + +pytest.importorskip("cpex.framework") + +import datetime + +# --------------------------------------------------------------------------- +# Minimal mock backend (avoids real LLM calls) +# --------------------------------------------------------------------------- +from unittest.mock import MagicMock + +from mellea.core.backend import Backend +from mellea.core.base import ( + CBlock, + Context, + GenerateLog, + GenerateType, + ModelOutputThunk, +) +from mellea.core.utils import ( + MelleaLogger, + clear_log_context, + log_context, + set_log_context, +) +from mellea.plugins import PluginMode, hook, register +from mellea.plugins.manager import shutdown_plugins +from mellea.stdlib.context import SimpleContext + + +class _MockBackend(Backend): + model_id = "mock-model" + + def __init__(self, *args, **kwargs): + pass + + async def _generate_from_context(self, action, ctx, **kwargs): + mot = MagicMock(spec=ModelOutputThunk) + glog = GenerateLog() + glog.prompt = "mocked prompt" + mot._generate_log = glog + mot.parsed_repr = None + mot._start = datetime.datetime.now() + + async def _avalue(): + return "mocked output" + + mot.avalue = _avalue + mot.value = "mocked output" + return mot, SimpleContext() + + async def generate_from_raw(self, actions, ctx, **kwargs): + return [] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +async def reset_plugins(): + """Shut down and reset the plugin manager after every test.""" + yield + await shutdown_plugins() + + +@pytest.fixture(autouse=True) +def reset_log_context(): + """Ensure log context is clean before and after each test.""" + clear_log_context() + yield + clear_log_context() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMelleaLoggerInHooks: + async def test_mellea_logger_callable_from_hook(self, caplog) -> None: + """MelleaLogger.get_logger() is usable inside a hook handler without error.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + fired: list[bool] = [] + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def log_hook(payload: Any, ctx: Any) -> None: + logger = MelleaLogger.get_logger() + logger.info("hook fired from MelleaLogger") + fired.append(True) + + register(log_hook) + + with caplog.at_level(logging.INFO, logger="fancy_logger"): + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert fired, "Hook did not fire" + assert any("hook fired from MelleaLogger" in r.message for r in caplog.records) + + async def test_log_context_set_inside_hook_appears_on_hook_records( + self, caplog + ) -> None: + """Context fields set inside a hook appear on records emitted in that hook.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + hook_records: list[logging.LogRecord] = [] + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def context_hook(payload: Any, ctx: Any) -> None: + with log_context(hook_trace_id="hook-abc"): + logger = MelleaLogger.get_logger() + # Emit via a plain handler so we can capture the LogRecord + record = logger.makeRecord( + name="fancy_logger", + level=logging.INFO, + fn="test", + lno=0, + msg="inside hook", + args=(), + exc_info=None, + ) + # Apply the context filter manually (as the logger would) + from mellea.core.utils import ContextFilter + + ContextFilter().filter(record) + hook_records.append(record) + + register(context_hook) + + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert hook_records, "Hook did not produce records" + assert getattr(hook_records[0], "hook_trace_id", None) == "hook-abc" + + async def test_log_context_is_visible_inside_hook(self) -> None: + """ContextVar state set in the caller IS visible inside hook execution. + + AUDIT hooks are awaited in the same asyncio task as the caller, so they + inherit the caller's ContextVar copy. This is the documented behaviour: + log_context fields set around a strategy.sample() call will appear on + records emitted inside hook handlers too. + """ + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + hook_saw_context: list[bool] = [] + + set_log_context(outer_field="visible-in-hook") + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def visibility_hook(payload: Any, ctx: Any) -> None: + from mellea.core.utils import _log_context + + fields = _log_context.get() + hook_saw_context.append("outer_field" in fields) + + register(visibility_hook) + + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert hook_saw_context, "Hook did not fire" + assert hook_saw_context[0], ( + "log_context fields should be visible inside AUDIT hooks (same asyncio task)" + ) + + async def test_sampling_log_context_fields_present_on_success_record( + self, caplog + ) -> None: + """strategy and loop_budget context fields appear on the SUCCESS log record.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + with caplog.at_level(logging.INFO, logger="fancy_logger"): + await RejectionSamplingStrategy(loop_budget=2).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + success_records = [r for r in caplog.records if r.getMessage() == "SUCCESS"] + assert success_records, "No SUCCESS record found" + record = success_records[0] + assert getattr(record, "strategy", None) == "RejectionSamplingStrategy" + assert getattr(record, "loop_budget", None) == 2 diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py index 738819c95..b25f70be3 100644 --- a/test/core/test_utils_logging.py +++ b/test/core/test_utils_logging.py @@ -15,8 +15,8 @@ from mellea.core.utils import ( RESERVED_LOG_RECORD_ATTRS, ContextFilter, - MelleaLogger, JsonFormatter, + MelleaLogger, clear_log_context, log_context, set_log_context, @@ -384,7 +384,9 @@ async def task_b() -> None: results["b"] = json.loads(fmt.format(_make_record())) async def run() -> None: - await asyncio.gather(asyncio.create_task(task_a()), asyncio.create_task(task_b())) + await asyncio.gather( + asyncio.create_task(task_a()), asyncio.create_task(task_b()) + ) asyncio.run(run()) @@ -398,7 +400,7 @@ def test_task_context_does_not_leak_after_completion(self) -> None: async def child() -> None: set_log_context(trace_id="child-task") - async def run() -> None: + async def run() -> dict[str, object]: await asyncio.create_task(child()) # The caller's context should be unaffected return json.loads(fmt.format(_make_record())) From e81f211b30b78ef433f8721fbe41190f7c1402a1 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 02:54:11 -0400 Subject: [PATCH 07/11] additional fancylogger removal --- mellea/stdlib/components/genstub.py | 10 +++++----- test/conftest.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mellea/stdlib/components/genstub.py b/mellea/stdlib/components/genstub.py index 66e11e893..05ca11088 100644 --- a/mellea/stdlib/components/genstub.py +++ b/mellea/stdlib/components/genstub.py @@ -17,7 +17,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, ModelOutputThunk, Requirement, SamplingStrategy, @@ -616,7 +616,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: # No retries if precondition validation fails. if not all(bool(val_result) for val_result in val_results): - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "generative stub arguments did not satisfy precondition requirements" ) raise PreconditionException( @@ -625,7 +625,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: ) elif len(stub_copy.precondition_requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "calling a generative stub with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" ) @@ -755,7 +755,7 @@ async def __async_call__() -> tuple[R, Context] | R: # No retries if precondition validation fails. if not all(bool(val_result) for val_result in val_results): - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "generative stub arguments did not satisfy precondition requirements" ) raise PreconditionException( @@ -764,7 +764,7 @@ async def __async_call__() -> tuple[R, Context] | R: ) elif len(stub_copy.precondition_requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "calling a generative stub with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" ) diff --git a/test/conftest.py b/test/conftest.py index 22415fcbd..4859b0121 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -615,7 +615,7 @@ def evict_ollama_models() -> None: Best-effort: errors are logged but never raised. """ - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Parse OLLAMA_HOST which may be "host", "host:port", or absent. host = os.environ.get("OLLAMA_HOST", "127.0.0.1") From af4178e9f8fc7e963760de55dec6008a726bd56d Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Wed, 15 Apr 2026 03:50:57 -0400 Subject: [PATCH 08/11] additional mellealogger updates --- .../_subtask_constraint_assign.py | 8 ++++---- mellea/backends/huggingface.py | 2 +- mellea/backends/litellm.py | 2 +- mellea/backends/ollama.py | 2 +- mellea/backends/openai.py | 2 +- mellea/backends/watsonx.py | 2 +- mellea/core/__init__.py | 2 +- mellea/stdlib/functional.py | 2 +- 8 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py index 50a152648..2d14f0d9f 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py @@ -4,7 +4,7 @@ from mellea import MelleaSession from mellea.backends import ModelOption -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString @@ -12,7 +12,7 @@ from ._prompt import get_system_prompt, get_user_prompt from ._types import SubtaskPromptConstraintsItem -FancyLogger.get_logger().setLevel("DEBUG") +MelleaLogger.get_logger().setLevel("DEBUG") T = TypeVar("T") @@ -175,7 +175,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: ).strip() else: # Fallback to raw text if tags are missing - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "Expected tags missing from LLM response; falling back to raw response text. " "Downstream stages may receive unstructured content." ) @@ -202,7 +202,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: # If content exists but no list items were parsed, # treat the whole text as a single constraint. if subtask_constraint_assign_str and not subtask_constraint_assign: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "No list-style constraints detected; falling back to full text as a single constraint." ) subtask_constraint_assign = [subtask_constraint_assign_str] diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index af0ba88be..64d315328 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -34,9 +34,9 @@ CBlock, Component, Context, - MelleaLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, Requirement, ) diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 79610aa5e..c9fdc1b59 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -18,9 +18,9 @@ CBlock, Component, Context, - MelleaLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 621bc9e21..a81add1f5 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -16,9 +16,9 @@ CBlock, Component, Context, - MelleaLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 7b9af23a3..022c07da6 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -22,9 +22,9 @@ CBlock, Component, Context, - MelleaLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, Requirement, ) diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 7dd781563..9ffe2d540 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -21,9 +21,9 @@ CBlock, Component, Context, - MelleaLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index ec275296d..7ebc6e1d4 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -42,11 +42,11 @@ "ComputedModelOutputThunk", "Context", "ContextTurn", - "MelleaLogger", "Formatter", "GenerateLog", "GenerateType", "ImageBlock", + "MelleaLogger", "ModelOutputThunk", "ModelToolCall", "Requirement", diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index a3d361714..710af6203 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -17,9 +17,9 @@ Component, ComputedModelOutputThunk, Context, - MelleaLogger, GenerateLog, ImageBlock, + MelleaLogger, ModelOutputThunk, ModelToolCall, Requirement, From 1f885b80f12a9c68f700fa2fc24f35d84a7c0fb2 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Thu, 16 Apr 2026 08:38:03 -0400 Subject: [PATCH 09/11] additional review feedback --- mellea/backends/litellm.py | 6 +- mellea/backends/openai.py | 2 +- mellea/core/__init__.py | 15 ++ mellea/core/utils.py | 25 +- mellea/stdlib/functional.py | 2 +- mellea/stdlib/sampling/budget_forcing.py | 278 +++++++++++----------- mellea/stdlib/sampling/sofai.py | 287 ++++++++++++----------- mellea/stdlib/session.py | 19 +- test/conftest.py | 1 - test/core/test_logger_plugin_hooks.py | 4 +- test/core/test_utils_logging.py | 9 +- 11 files changed, 348 insertions(+), 300 deletions(-) diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index c9fdc1b59..cfa91593c 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -339,7 +339,7 @@ async def _generate_from_chat_context_standard( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - MelleaLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -653,7 +653,7 @@ async def generate_from_raw( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - MelleaLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -670,7 +670,7 @@ async def generate_from_raw( date = datetime.datetime.now() responses = completion_response.choices if len(responses) != len(prompts): - MelleaLogger().get_logger().error( + MelleaLogger.get_logger().error( "litellm appears to have sent your batch request as a single message; this typically happens with providers like ollama that don't support batching" ) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 022c07da6..301ad45e2 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -481,7 +481,7 @@ async def _generate_from_chat_context_standard( }, } else: - MelleaLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider." ) extra_params["response_format"] = { diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index 7ebc6e1d4..f31303a80 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -32,6 +32,21 @@ from .sampling import SamplingResult, SamplingStrategy from .utils import MelleaLogger, clear_log_context, log_context, set_log_context + +def __getattr__(name: str) -> object: + if name == "FancyLogger": + import warnings + + warnings.warn( + "FancyLogger has been renamed to MelleaLogger and will be removed in a future release. " + "Update your imports to use mellea.core.MelleaLogger.", + DeprecationWarning, + stacklevel=2, + ) + return MelleaLogger + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "Backend", "BaseModelSubclass", diff --git a/mellea/core/utils.py b/mellea/core/utils.py index bafa06b01..256d9ace6 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -140,7 +140,7 @@ def log_context(**fields: Any) -> Generator[None, None, None]: class ContextFilter(logging.Filter): - """Logging filter that injects thread-local context fields into every record. + """Logging filter that injects async-safe ContextVar fields into every record. Fields registered via :func:`set_log_context` are copied onto the ``logging.LogRecord`` before formatters see it, enabling trace/request IDs @@ -148,7 +148,7 @@ class ContextFilter(logging.Filter): """ def filter(self, record: logging.LogRecord) -> bool: - """Attach thread-local context fields to *record* and allow it through. + """Attach async-safe ContextVar fields to *record* and allow it through. Args: record (logging.LogRecord): The log record being processed. @@ -193,7 +193,7 @@ def emit(self, record: logging.LogRecord) -> None: Args: record (logging.LogRecord): The log record to forward. """ - if os.environ.get("MELLEA_FLOG"): + if _check_flog_env(): formatter = self.formatter if isinstance(formatter, JsonFormatter): log_dict = formatter.format_as_dict(record) @@ -343,7 +343,7 @@ def format(self, record: logging.LogRecord) -> str: """Formats a log record as a JSON string. Core fields are filtered by *include_fields* / *exclude_fields*. - Static *extra_fields* and any thread-local context fields (set via + Static *extra_fields* and any per-task ContextVar fields (set via :func:`set_log_context`) are merged in after the core fields. Args: @@ -499,3 +499,20 @@ def get_logger() -> logging.Logger: MelleaLogger.logger = logger return MelleaLogger.logger + + +def _check_flog_env() -> bool: + """Check MELLEA_FLOG, with a DeprecationWarning fallback for the old FLOG name.""" + if os.environ.get("MELLEA_FLOG"): + return True + if os.environ.get("FLOG"): + import warnings + + warnings.warn( + "The FLOG environment variable is deprecated and will be removed in a future release. " + "Use MELLEA_FLOG instead.", + DeprecationWarning, + stacklevel=2, + ) + return True + return False diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 710af6203..f4bdb2371 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -575,7 +575,7 @@ async def aact( tool_calls=tool_calls, ) as span: if not silence_context_type_warning and not isinstance(context, SimpleContext): - MelleaLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." "\nSee the async section of the docs: https://docs.mellea.ai/how-to/use-async-and-streaming" ) diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 78e5a5efe..5da1d0fae 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -16,6 +16,7 @@ S, SamplingResult, ValidationResult, + log_context, ) from ...stdlib import functional as mfuncs from .base import RejectionSamplingStrategy @@ -127,146 +128,149 @@ async def sample( flog = MelleaLogger.get_logger() - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO - - reqs = [] - # global requirements supersede local requirements (global requirements can be defined by user) - # Todo: re-evaluate if this makes sense - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - reqs = list(set(reqs)) - - loop_count = 0 - loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) # type: ignore - if show_progress - else range(self.loop_budget) # type: ignore - ) - - next_action = deepcopy(action) - next_context = context - for _ in loop_budget_range_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # TODO - # tool_calls is not supported for budget forcing - assert tool_calls is False, ( - "tool_calls is not supported with budget forcing" - ) - # TODO - assert isinstance(backend, OllamaModelBackend), ( - "Only ollama backend supported with budget forcing" - ) - # run a generation pass with budget forcing - result = await think_budget_forcing( - backend, - next_action, - ctx=context, - format=format, - tool_calls=tool_calls, - think_max_tokens=self.think_max_tokens, - answer_max_tokens=self.answer_max_tokens, - start_think_token=self.start_think_token, - end_think_token=self.end_think_token, - think_more_suffix=self.think_more_suffix, - answer_suffix=self.answer_suffix, - model_options=model_options, + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] + + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = ( + show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO ) - result_ctx = next_context - await result.avalue() - result = ComputedModelOutputThunk(result) - - # Sampling strategies may use different components from the original - # action. This might cause discrepancies in the expected parsed_repr - # type / value. Explicitly overwrite that here. - result.parsed_repr = action.parse(result) - - # validation pass - val_scores_co = mfuncs.avalidate( - reqs=reqs, - context=result_ctx, - backend=backend, - output=result, - format=format, - model_options=model_options, - # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + + reqs = [] + # global requirements supersede local requirements (global requirements can be defined by user) + # Todo: re-evaluate if this makes sense + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) + + loop_count = 0 + loop_budget_range_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) - sample_contexts.append(result_ctx) - - # if all vals are true -- break and return success - if all(bool(s[1]) for s in constraint_scores): - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True - - # SUCCESS !!!! - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, + + next_action = deepcopy(action) + next_context = context + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # TODO + # tool_calls is not supported for budget forcing + assert tool_calls is False, ( + "tool_calls is not supported with budget forcing" + ) + # TODO + assert isinstance(backend, OllamaModelBackend), ( + "Only ollama backend supported with budget forcing" + ) + # run a generation pass with budget forcing + result = await think_budget_forcing( + backend, + next_action, + ctx=context, + format=format, + tool_calls=tool_calls, + think_max_tokens=self.think_max_tokens, + answer_max_tokens=self.answer_max_tokens, + start_think_token=self.start_think_token, + end_think_token=self.end_think_token, + think_more_suffix=self.think_more_suffix, + answer_suffix=self.answer_suffix, + model_options=model_options, ) + result_ctx = next_context + await result.avalue() + result = ComputedModelOutputThunk(result) + + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + result.parsed_repr = action.parse(result) + + # validation pass + val_scores_co = mfuncs.avalidate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=format, + model_options=model_options, + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + ) + val_scores = await val_scores_co + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + # if all vals are true -- break and return success + if all(bool(s[1]) for s in constraint_scores): + flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + + # SUCCESS !!!! + return SamplingResult( + result_index=len(sampled_results) - 1, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, + ) + + else: + # log partial success and continue + count_valid = len([s for s in constraint_scores if bool(s[1])]) + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + + # If we did not pass all constraints, update the instruction and try again. + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, + ) + + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." + ) - else: - # log partial success and continue - count_valid = len([s for s in constraint_scores if bool(s[1])]) - flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") - - # If we did not pass all constraints, update the instruction and try again. - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) - - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) - - assert ( - sampled_results[best_failed_index]._generate_log is not None - ) # Cannot be None after generation. - sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore - - return SamplingResult( - result_index=best_failed_index, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + + return SamplingResult( + result_index=best_failed_index, + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + sample_contexts=sample_contexts, + ) diff --git a/mellea/stdlib/sampling/sofai.py b/mellea/stdlib/sampling/sofai.py index e50eecb94..792de2de4 100644 --- a/mellea/stdlib/sampling/sofai.py +++ b/mellea/stdlib/sampling/sofai.py @@ -27,6 +27,7 @@ SamplingStrategy, TemplateRepresentation, ValidationResult, + log_context, ) from ...stdlib import functional as mfuncs from ..components import Message @@ -606,46 +607,143 @@ async def sample( ) flog = MelleaLogger.get_logger() - reqs: list[Requirement] = list(requirements) if requirements else [] - # State tracking for all attempts - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + reqs: list[Requirement] = list(requirements) if requirements else [] - # --------------------------------------------------------------------- - # PHASE 1: S1 Solver Loop - # --------------------------------------------------------------------- - flog.info( - f"SOFAI: Starting S1 Solver ({getattr(self.s1_solver_backend, 'model_id', 'unknown')}) " - f"loop (budget={self.loop_budget})" - ) + # State tracking for all attempts + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] - previous_failed_set: set[tuple[str | None, str | None, float | None]] = set() - loop_count = 0 - next_action = deepcopy(action) - next_context: Context = context + # --------------------------------------------------------------------- + # PHASE 1: S1 Solver Loop + # --------------------------------------------------------------------- + flog.info( + f"SOFAI: Starting S1 Solver ({getattr(self.s1_solver_backend, 'model_id', 'unknown')}) " + f"loop (budget={self.loop_budget})" + ) - show_progress = flog.getEffectiveLevel() <= MelleaLogger.INFO - loop_iterator = ( - tqdm.tqdm(range(self.loop_budget), desc="S1 Solver") - if show_progress - else range(self.loop_budget) - ) + previous_failed_set: set[tuple[str | None, str | None, float | None]] = ( + set() + ) + loop_count = 0 + next_action = deepcopy(action) + next_context: Context = context + + show_progress = flog.getEffectiveLevel() <= MelleaLogger.INFO + loop_iterator = ( + tqdm.tqdm(range(self.loop_budget), desc="S1 Solver") + if show_progress + else range(self.loop_budget) + ) + + # Exit conditions: success returns immediately; no-improvement breaks + # early to S2 escalation; loop budget exhaustion flows to S2 escalation. + for _ in loop_iterator: + loop_count += 1 + if not show_progress: + flog.info( + f"SOFAI S1: Running loop {loop_count} of {self.loop_budget}" + ) + + # Generate and validate + ( + result, + result_ctx, + constraint_scores, + ) = await self._generate_and_validate( + solver_backend=self.s1_solver_backend, + action=next_action, + ctx=next_context, + reqs=reqs, + session_backend=backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) - # Exit conditions: success returns immediately; no-improvement breaks - # early to S2 escalation; loop budget exhaustion flows to S2 escalation. - for _ in loop_iterator: - loop_count += 1 - if not show_progress: - flog.info(f"SOFAI S1: Running loop {loop_count} of {self.loop_budget}") + # Store attempt + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + # Check for success + if all(bool(score[1]) for score in constraint_scores): + flog.info(f"SOFAI S1: SUCCESS on attempt {loop_count}") + assert result._generate_log is not None + result._generate_log.is_final_result = True + + # Exit with success + return SamplingResult( + result_index=len(sampled_results) - 1, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, + ) + + # Log partial progress + count_valid = sum(1 for s in constraint_scores if bool(s[1])) + flog.info( + f"SOFAI S1: FAILED attempt {loop_count}. " + f"Valid: {count_valid}/{len(constraint_scores)}" + ) - # Generate and validate + # Check for no improvement (early exit to S2) + current_failed_set = { + (req.description, val.reason, val.score) + for req, val in constraint_scores + if not val.as_bool() + } + if loop_count > 1 and current_failed_set == previous_failed_set: + flog.warning( + f"SOFAI S1: No improvement detected between attempt " + f"{loop_count - 1} and {loop_count}. Escalating to S2 Solver." + ) + # Exit with no improvement + break + previous_failed_set = current_failed_set + + # Prepare repair for next iteration + if loop_count < self.loop_budget: + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, + ) + # Exit due to loop budget exhaustion or no improvement + + # --------------------------------------------------------------------- + # PHASE 2: S2 Solver Escalation + # --------------------------------------------------------------------- + flog.info( + f"SOFAI: S1 Solver completed {loop_count} attempts. " + f"Escalating to S2 Solver ({getattr(self.s2_solver_backend, 'model_id', 'unknown')})." + ) + + # Prepare S2 context based on mode + s2_action, s2_context = self._prepare_s2_context( + s2_mode=self.s2_solver_mode, + original_action=action, + original_context=context, + last_result_ctx=result_ctx, + last_action=next_action, + sampled_results=sampled_results, + sampled_scores=sampled_scores, + loop_count=loop_count, + ) + + # Generate and validate with S2 result, result_ctx, constraint_scores = await self._generate_and_validate( - solver_backend=self.s1_solver_backend, - action=next_action, - ctx=next_context, + solver_backend=self.s2_solver_backend, + action=s2_action, + ctx=s2_context, reqs=reqs, session_backend=backend, format=format, @@ -653,19 +751,18 @@ async def sample( tool_calls=tool_calls, ) - # Store attempt + # Store S2 attempt sampled_results.append(result) sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) + sampled_actions.append(s2_action) sample_contexts.append(result_ctx) - # Check for success - if all(bool(score[1]) for score in constraint_scores): - flog.info(f"SOFAI S1: SUCCESS on attempt {loop_count}") - assert result._generate_log is not None - result._generate_log.is_final_result = True + # Check S2 success + assert result._generate_log is not None + result._generate_log.is_final_result = True - # Exit with success + if all(bool(score[1]) for score in constraint_scores): + flog.info("SOFAI S2: SUCCESS") return SamplingResult( result_index=len(sampled_results) - 1, success=True, @@ -674,103 +771,17 @@ async def sample( sample_contexts=sample_contexts, sample_actions=sampled_actions, ) - - # Log partial progress - count_valid = sum(1 for s in constraint_scores if bool(s[1])) - flog.info( - f"SOFAI S1: FAILED attempt {loop_count}. " - f"Valid: {count_valid}/{len(constraint_scores)}" - ) - - # Check for no improvement (early exit to S2) - current_failed_set = { - (req.description, val.reason, val.score) - for req, val in constraint_scores - if not val.as_bool() - } - if loop_count > 1 and current_failed_set == previous_failed_set: + else: + count_valid = sum(1 for s in constraint_scores if bool(s[1])) flog.warning( - f"SOFAI S1: No improvement detected between attempt " - f"{loop_count - 1} and {loop_count}. Escalating to S2 Solver." + f"SOFAI S2: FAILED. Valid: {count_valid}/{len(constraint_scores)}. " + f"Returning S2 Solver's attempt as final result." ) - # Exit with no improvement - break - previous_failed_set = current_failed_set - - # Prepare repair for next iteration - if loop_count < self.loop_budget: - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + return SamplingResult( + result_index=len(sampled_results) - 1, + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, ) - # Exit due to loop budget exhaustion or no improvement - - # --------------------------------------------------------------------- - # PHASE 2: S2 Solver Escalation - # --------------------------------------------------------------------- - flog.info( - f"SOFAI: S1 Solver completed {loop_count} attempts. " - f"Escalating to S2 Solver ({getattr(self.s2_solver_backend, 'model_id', 'unknown')})." - ) - - # Prepare S2 context based on mode - s2_action, s2_context = self._prepare_s2_context( - s2_mode=self.s2_solver_mode, - original_action=action, - original_context=context, - last_result_ctx=result_ctx, - last_action=next_action, - sampled_results=sampled_results, - sampled_scores=sampled_scores, - loop_count=loop_count, - ) - - # Generate and validate with S2 - result, result_ctx, constraint_scores = await self._generate_and_validate( - solver_backend=self.s2_solver_backend, - action=s2_action, - ctx=s2_context, - reqs=reqs, - session_backend=backend, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - # Store S2 attempt - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(s2_action) - sample_contexts.append(result_ctx) - - # Check S2 success - assert result._generate_log is not None - result._generate_log.is_final_result = True - - if all(bool(score[1]) for score in constraint_scores): - flog.info("SOFAI S2: SUCCESS") - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, - ) - else: - count_valid = sum(1 for s in constraint_scores if bool(s[1])) - flog.warning( - f"SOFAI S2: FAILED. Valid: {count_valid}/{len(constraint_scores)}. " - f"Returning S2 Solver's attempt as final result." - ) - return SamplingResult( - result_index=len(sampled_results) - 1, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, - ) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 0cace1d63..1e3e3a672 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -34,9 +34,8 @@ SamplingResult, SamplingStrategy, ValidationResult, - clear_log_context, - set_log_context, ) +from ..core.utils import _log_context from ..helpers import _run_async_in_thread from ..plugins.manager import has_plugins, invoke_hook from ..plugins.types import HookType @@ -311,6 +310,7 @@ def __init__(self, backend: Backend, ctx: Context | None = None): self.ctx: Context = ctx if ctx is not None else SimpleContext() self._session_logger = MelleaLogger.get_logger() self._context_token = None + self._log_context_token = None self._session_span = None def __enter__(self): @@ -322,17 +322,22 @@ def __enter__(self): context_type=self.ctx.__class__.__name__, ).__enter__() self._context_token = _context_session.set(self) - set_log_context( - session_id=self.id, - backend=self.backend.__class__.__name__, - model_id=str(getattr(self.backend, "model_id", "unknown")), + self._log_context_token = _log_context.set( + { + **_log_context.get(), + "session_id": self.id, + "backend": self.backend.__class__.__name__, + "model_id": str(getattr(self.backend, "model_id", "unknown")), + } ) return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit context manager and cleanup session.""" self.cleanup() - clear_log_context() + if self._log_context_token is not None: + _log_context.reset(self._log_context_token) + self._log_context_token = None if self._context_token is not None: _context_session.reset(self._context_token) self._context_token = None diff --git a/test/conftest.py b/test/conftest.py index 4859b0121..a904a06ba 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -302,7 +302,6 @@ def cleanup_gpu_backend(backend, backend_name="unknown"): backend: The backend instance to clean up. backend_name: Name for logging. """ - import gc logger = MelleaLogger.get_logger() logger.info(f"Cleaning up {backend_name} backend GPU memory...") diff --git a/test/core/test_logger_plugin_hooks.py b/test/core/test_logger_plugin_hooks.py index 77757f3c4..ef461aa38 100644 --- a/test/core/test_logger_plugin_hooks.py +++ b/test/core/test_logger_plugin_hooks.py @@ -4,8 +4,8 @@ - MelleaLogger.get_logger() is callable and usable from inside a hook handler. - Log records emitted from inside a hook are captured. - log_context fields set inside a hook appear on records from that hook. -- log_context fields set in the test body do NOT bleed into hook execution - (hooks run in a different thread, so ContextVar state is not inherited). +- log_context fields set in the caller ARE visible inside AUDIT hook execution + (AUDIT hooks are awaited in the same asyncio task, so ContextVar state is inherited). """ # pytest: integration diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py index b25f70be3..c4f5929d8 100644 --- a/test/core/test_utils_logging.py +++ b/test/core/test_utils_logging.py @@ -137,21 +137,18 @@ def test_clear_context_removes_fields(self) -> None: def test_context_is_thread_local(self) -> None: """Fields set in one thread must not bleed into another.""" results: dict[str, Any] = {} + barrier = threading.Barrier(2) def worker_a() -> None: set_log_context(trace_id="thread-a") - import time - - time.sleep(0.05) + barrier.wait() # both threads read context at the same time fmt = JsonFormatter() results["a"] = json.loads(fmt.format(_make_record())) clear_log_context() def worker_b() -> None: # does NOT call set_log_context - import time - - time.sleep(0.02) + barrier.wait() # both threads read context at the same time fmt = JsonFormatter() results["b"] = json.loads(fmt.format(_make_record())) From 8f49206beb114e67215ef42f12f34f54db984f8b Mon Sep 17 00:00:00 2001 From: Angelo Danducci Date: Thu, 16 Apr 2026 13:40:10 -0400 Subject: [PATCH 10/11] Update docs/docs/evaluation-and-observability/logging.md Co-authored-by: Alex Bozarth --- docs/docs/evaluation-and-observability/logging.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/evaluation-and-observability/logging.md b/docs/docs/evaluation-and-observability/logging.md index 672593a9f..253c08693 100644 --- a/docs/docs/evaluation-and-observability/logging.md +++ b/docs/docs/evaluation-and-observability/logging.md @@ -86,7 +86,7 @@ With structured JSON output enabled, the same `SUCCESS` record looks like: The `session_id`, `backend`, `model_id`, `strategy`, and `loop_budget` fields are injected automatically when the call runs inside a `with session:` block. -They appear on every log record within that scope — not just `SUCCESS`. +They appear on every log record within that scope. ### Adding custom context fields From 077cc46e7afad1802f0f38270b6f5f171fb60830 Mon Sep 17 00:00:00 2001 From: AngeloDanducci Date: Thu, 16 Apr 2026 13:51:36 -0400 Subject: [PATCH 11/11] fix missed logger plugin hookst update --- test/core/test_logger_plugin_hooks.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/test/core/test_logger_plugin_hooks.py b/test/core/test_logger_plugin_hooks.py index ef461aa38..888252bcf 100644 --- a/test/core/test_logger_plugin_hooks.py +++ b/test/core/test_logger_plugin_hooks.py @@ -184,16 +184,28 @@ async def test_log_context_is_visible_inside_hook(self) -> None: from mellea.stdlib.components import Instruction from mellea.stdlib.sampling.base import RejectionSamplingStrategy - hook_saw_context: list[bool] = [] + hook_records: list[logging.LogRecord] = [] set_log_context(outer_field="visible-in-hook") @hook("sampling_loop_start", mode=PluginMode.AUDIT) async def visibility_hook(payload: Any, ctx: Any) -> None: - from mellea.core.utils import _log_context + from mellea.core.utils import ContextFilter - fields = _log_context.get() - hook_saw_context.append("outer_field" in fields) + logger = MelleaLogger.get_logger() + # Create a log record to test context visibility + record = logger.makeRecord( + name="fancy_logger", + level=logging.INFO, + fn="test", + lno=0, + msg="testing context visibility", + args=(), + exc_info=None, + ) + # Apply the context filter to populate context fields + ContextFilter().filter(record) + hook_records.append(record) register(visibility_hook) @@ -208,8 +220,8 @@ async def visibility_hook(payload: Any, ctx: Any) -> None: show_progress=False, ) - assert hook_saw_context, "Hook did not fire" - assert hook_saw_context[0], ( + assert hook_records, "Hook did not fire" + assert getattr(hook_records[0], "outer_field", None) == "visible-in-hook", ( "log_context fields should be visible inside AUDIT hooks (same asyncio task)" )