diff --git a/.agents/skills/mellea-logging/SKILL.md b/.agents/skills/mellea-logging/SKILL.md new file mode 100644 index 000000000..d479ab14c --- /dev/null +++ b/.agents/skills/mellea-logging/SKILL.md @@ -0,0 +1,141 @@ +--- +name: mellea-logging +description: > + Best-practices guide for adding or reviewing logging in the Mellea codebase. + Covers when to use log_context() vs a dedicated logger call, canonical field + names, reserved attribute constraints, async/thread safety, and what events + deserve dedicated log lines. + Use when: adding a new log call; reviewing a PR that touches MelleaLogger; + deciding where to inject context fields; debugging why a field is missing from + a log record; or ensuring consistency with the project logging conventions. +argument-hint: "[file-or-directory]" +compatibility: "Claude Code, IBM Bob" +metadata: + version: "2026-04-15" + capabilities: [read_file, grep, glob] +--- + +# Mellea Logging Best Practices + +All logging in Mellea flows through `MelleaLogger.get_logger()`, defined in +`mellea/core/utils.py`. This skill documents the conventions for adding and +reviewing log instrumentation. + +## Quick reference + +```python +from mellea.core import MelleaLogger, log_context, set_log_context, clear_log_context + +logger = MelleaLogger.get_logger() + +# Dedicated log call — for a discrete event +logger.info("SUCCESS") + +# Context injection — attach fields to every record in a scope +with log_context(request_id="req-abc", trace_id="t-1"): + logger.info("Starting generation") # includes request_id, trace_id + # ... all nested calls inherit these fields automatically +``` + +## When to add a dedicated log call + +Use `logger.info/warning/error(...)` for **discrete, named events**: + +| Event type | Level | Example | +|------------|-------|---------| +| Phase transition | INFO | `"SUCCESS"`, `"FAILED"`, `"Starting session"` | +| Loop progress | INFO | `"Running loop 2 of 3"` | +| Recoverable issue | WARNING | `"Warmup failed for model: ..."` | +| Unexpected failure | ERROR | exception tracebacks, hard failures | +| Verbose diagnostics | DEBUG | token counts, prompt previews | + +Do **not** add log calls for: + +- Values already captured in a `log_context` field (redundant noise) +- Internal helper functions where the calling function already logs the event +- State that is already reflected in telemetry spans + +## When to use log_context + +Use `log_context` (or `set_log_context`) to attach **identifiers and metadata +that should appear on every log record within a scope** — without threading +them through every call. + +Typical injection points: + +| Scope | Where to inject | Fields | +|-------|----------------|--------| +| Session lifetime | `MelleaSession.__enter__` | `session_id`, `backend`, `model_id` | +| Sampling loop | `BaseSamplingStrategy.sample()` | `strategy`, `loop_budget` | +| HTTP request handler | entry point of the handler | `request_id`, `trace_id` | +| Background task | top of the task coroutine | `task_id`, `job_name` | + +## Canonical field names + +Use these names consistently. Do not invent synonyms. + +| Field | Type | Description | +|-------|------|-------------| +| `session_id` | str (UUID) | Unique ID for a `MelleaSession` | +| `backend` | str | Backend class name, e.g. `"OllamaModelBackend"` | +| `model_id` | str | Model identifier string | +| `strategy` | str | Sampling strategy class name | +| `loop_budget` | int | Max generate/validate cycles for this sampling call | +| `request_id` | str | Caller-supplied request identifier | +| `trace_id` | str | Distributed trace ID (from OpenTelemetry or caller) | +| `span_id` | str | Span ID within a trace | +| `user_id` | str | End-user identifier (when applicable) | + +## Reserved attribute names — do not use as context fields + +The following names are standard `logging.LogRecord` attributes. Passing them +to `log_context()` or `set_log_context()` raises `ValueError`. See +`RESERVED_LOG_RECORD_ATTRS` in `mellea/core/utils.py` for the full set. + +`args`, `created`, `exc_info`, `exc_text`, `filename`, `funcName`, +`levelname`, `levelno`, `lineno`, `message`, `module`, `msecs`, `msg`, +`name`, `pathname`, `process`, `processName`, `relativeCreated`, +`stack_info`, `thread`, `threadName` + +## Prefer the context manager over set/clear + +```python +# Preferred — guaranteed cleanup even on exceptions +with log_context(trace_id="abc"): + do_work() + +# Acceptable only when lifetime equals __enter__/__exit__ +# (e.g. MelleaSession, where the CM already guarantees cleanup) +set_log_context(session_id=self.id) +# ... later in __exit__ ... +clear_log_context() +``` + +The context manager uses a `ContextVar` token to restore the previous state +on exit. This means **nesting works correctly** — inner calls can add fields +without clobbering the outer scope's values. + +## Async and thread safety + +`log_context` uses `contextvars.ContextVar`, which is safe for concurrent +asyncio tasks: + +- Each `asyncio.Task` gets its own copy of the context. +- Fields set in one task do not bleed into sibling tasks. + +**Plugin hooks**: Mellea hooks (`AUDIT`, `SEQUENTIAL`, `CONCURRENT`) are +`await`ed in the same asyncio task as the call site. `ContextVar` state IS +inherited — fields set around a `strategy.sample()` call will appear on +records emitted inside hook handlers automatically. + +## Checklist before committing + +1. New log calls use `MelleaLogger.get_logger()`, not `logging.getLogger(...)`. +2. Context fields use canonical names from the table above. +3. No reserved attribute names passed to `log_context`. +4. Scoped fields use `with log_context(...)`, not `set_log_context` (unless + managing an `__enter__`/`__exit__` pair). +5. Hook handlers that need context set it internally — they do not inherit the + caller's context. +6. New events that span multiple log records inject fields via context, not by + repeating them on every `logger.info(...)` call. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7deee4c6c..c03d30836 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -439,8 +439,8 @@ CICD=1 uv run pytest ```python # Enable debug logging -from mellea.core import FancyLogger -FancyLogger.get_logger().setLevel("DEBUG") +from mellea.core import MelleaLogger +MelleaLogger.get_logger().setLevel("DEBUG") # See exact prompt sent to LLM print(m.last_prompt()) diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py index 50a152648..2d14f0d9f 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py @@ -4,7 +4,7 @@ from mellea import MelleaSession from mellea.backends import ModelOption -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components import Message from .._prompt_modules import PromptModule, PromptModuleString @@ -12,7 +12,7 @@ from ._prompt import get_system_prompt, get_user_prompt from ._types import SubtaskPromptConstraintsItem -FancyLogger.get_logger().setLevel("DEBUG") +MelleaLogger.get_logger().setLevel("DEBUG") T = TypeVar("T") @@ -175,7 +175,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: ).strip() else: # Fallback to raw text if tags are missing - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "Expected tags missing from LLM response; falling back to raw response text. " "Downstream stages may receive unstructured content." ) @@ -202,7 +202,7 @@ def _default_parser(generated_str: str) -> list[SubtaskPromptConstraintsItem]: # If content exists but no list items were parsed, # treat the whole text as a single constraint. if subtask_constraint_assign_str and not subtask_constraint_assign: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "No list-style constraints detected; falling back to full text as a single constraint." ) subtask_constraint_assign = [subtask_constraint_assign_str] diff --git a/docs/AGENTS_TEMPLATE.md b/docs/AGENTS_TEMPLATE.md index 1233bbfaf..116315526 100644 --- a/docs/AGENTS_TEMPLATE.md +++ b/docs/AGENTS_TEMPLATE.md @@ -160,8 +160,8 @@ Session methods: `ainstruct`, `achat`, `aact`, `avalidate`, `aquery`, `atransfor #### 12. Debugging ```python -from mellea.core import FancyLogger -FancyLogger.get_logger().setLevel("DEBUG") +from mellea.core import MelleaLogger +MelleaLogger.get_logger().setLevel("DEBUG") ``` - `m.last_prompt()` — see exact prompt sent diff --git a/docs/docs/community/contributing-guide.md b/docs/docs/community/contributing-guide.md index bb583292c..3ef7fcb84 100644 --- a/docs/docs/community/contributing-guide.md +++ b/docs/docs/community/contributing-guide.md @@ -315,10 +315,10 @@ CICD=1 uv run pytest ### Debugging tips ```python -from mellea.core import FancyLogger +from mellea.core import MelleaLogger # Enable debug logging -FancyLogger.get_logger().setLevel("DEBUG") +MelleaLogger.get_logger().setLevel("DEBUG") # Inspect the exact prompt sent to the LLM print(m.last_prompt()) diff --git a/docs/docs/evaluation-and-observability/logging.md b/docs/docs/evaluation-and-observability/logging.md index 908c43928..253c08693 100644 --- a/docs/docs/evaluation-and-observability/logging.md +++ b/docs/docs/evaluation-and-observability/logging.md @@ -14,22 +14,23 @@ Both work simultaneously when enabled. ## Console logging -Mellea uses `FancyLogger`, a color-coded singleton logger built on Python's +Mellea uses `MelleaLogger`, a color-coded singleton logger built on Python's `logging` module. All internal Mellea modules obtain their logger via -`FancyLogger.get_logger()`. +`MelleaLogger.get_logger()`. ### Configuration | Variable | Description | Default | | -------- | ----------- | ------- | -| `DEBUG` | Set to any value to enable `DEBUG`-level output | unset (`INFO` level) | -| `FLOG` | Set to any value to forward logs to a local REST endpoint at `http://localhost:8000/api/receive` | unset | +| `MELLEA_LOG_LEVEL` | Log level name (e.g. `DEBUG`, `INFO`, `WARNING`) | `INFO` | +| `MELLEA_LOG_JSON` | Set to any truthy value (`1`, `true`, `yes`) to emit structured JSON instead of colour-coded output | unset | +| `MELLEA_FLOG` | Set to any value to forward logs to a local REST endpoint at `http://localhost:8000/api/receive` | unset | -By default, `FancyLogger` logs at `INFO` level with color-coded output to -stdout. Set the `DEBUG` environment variable to lower the level to `DEBUG`: +By default, `MelleaLogger` logs at `INFO` level with color-coded output to +stdout. Set `MELLEA_LOG_LEVEL` to change the level: ```bash -export DEBUG=1 +export MELLEA_LOG_LEVEL=DEBUG python your_script.py ``` @@ -50,6 +51,55 @@ Each message is formatted as: message ``` +## Sample output + +### Console format (default) + +Running `m.instruct(...)` inside a session produces lines like: + +```text +=== 11:11:25-INFO ====== +SUCCESS +``` + +### JSON format (`MELLEA_LOG_JSON=1`) + +With structured JSON output enabled, the same `SUCCESS` record looks like: + +```json +{ + "timestamp": "2026-04-08T11:11:25", + "level": "INFO", + "message": "SUCCESS", + "module": "base", + "function": "sample", + "line_number": 258, + "process_id": 73738, + "thread_id": 6179762176, + "session_id": "550e8400-e29b-41d4-a716-446655440000", + "backend": "OllamaModelBackend", + "model_id": "granite4:micro", + "strategy": "RejectionSamplingStrategy", + "loop_budget": 3 +} +``` + +The `session_id`, `backend`, `model_id`, `strategy`, and `loop_budget` fields +are injected automatically when the call runs inside a `with session:` block. +They appear on every log record within that scope. + +### Adding custom context fields + +Use `log_context` to attach your own fields for the duration of a block: + +```python +from mellea.core import log_context + +with log_context(request_id="req-abc", user_id="usr-42"): + result = m.instruct("Summarise this document") + # Every log record emitted here will include request_id and user_id +``` + ## OTLP log export When the `[telemetry]` extra is installed, Mellea can export logs to an OTLP @@ -74,11 +124,11 @@ export OTEL_SERVICE_NAME=my-mellea-app ### How it works -When `MELLEA_LOGS_OTLP=true`, `FancyLogger` adds an OpenTelemetry +When `MELLEA_LOGS_OTLP=true`, `MelleaLogger` adds an OpenTelemetry `LoggingHandler` alongside its existing handlers: - **Console handler** — continues to work normally (color-coded output) -- **REST handler** — continues to work normally (when `FLOG` is set) +- **REST handler** — continues to work normally (when `MELLEA_FLOG` is set) - **OTLP handler** — exports logs to the configured OTLP collector Logs are exported using OpenTelemetry's Logs API with batched processing diff --git a/docs/docs/evaluation-and-observability/telemetry.md b/docs/docs/evaluation-and-observability/telemetry.md index 12f669461..dda40dbc4 100644 --- a/docs/docs/evaluation-and-observability/telemetry.md +++ b/docs/docs/evaluation-and-observability/telemetry.md @@ -140,7 +140,7 @@ and troubleshooting. ## Logging -Mellea uses a color-coded console logger (`FancyLogger`) by default. When the +Mellea uses a color-coded console logger (`MelleaLogger`) by default. When the `[telemetry]` extra is installed and `MELLEA_LOGS_OTLP=true` is set, Mellea also exports logs to an OTLP collector alongside existing console output. diff --git a/docs/examples/agents/react/react_from_scratch/react.py b/docs/examples/agents/react/react_from_scratch/react.py index 351bce08b..b74d6a43b 100644 --- a/docs/examples/agents/react/react_from_scratch/react.py +++ b/docs/examples/agents/react/react_from_scratch/react.py @@ -12,10 +12,10 @@ import mellea import mellea.stdlib.components.chat -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.context import ChatContext -FancyLogger.get_logger().setLevel("ERROR") +MelleaLogger.get_logger().setLevel("ERROR") react_system_template: Template = Template( """Answer the user's question as best you can. diff --git a/docs/examples/mcp/mcp_example.py b/docs/examples/mcp/mcp_example.py index 3cd4d344f..49d747b92 100644 --- a/docs/examples/mcp/mcp_example.py +++ b/docs/examples/mcp/mcp_example.py @@ -16,7 +16,7 @@ from mellea import MelleaSession from mellea.backends import ModelOption, model_ids from mellea.backends.ollama import OllamaModelBackend -from mellea.core import FancyLogger, ModelOutputThunk, Requirement +from mellea.core import MelleaLogger, ModelOutputThunk, Requirement from mellea.stdlib.requirements import simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy diff --git a/docs/examples/mify/rich_table_execute_basic.py b/docs/examples/mify/rich_table_execute_basic.py index 3c4b9e665..04d2350ca 100644 --- a/docs/examples/mify/rich_table_execute_basic.py +++ b/docs/examples/mify/rich_table_execute_basic.py @@ -5,10 +5,10 @@ from mellea import start_session from mellea.backends import ModelOption, model_ids -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components.docs.richdocument import RichDocument, Table -FancyLogger.get_logger().setLevel("ERROR") +MelleaLogger.get_logger().setLevel("ERROR") """ Here we demonstrate the use of the (internally m-ified) class diff --git a/docs/examples/sofai/sofai_graph_coloring.py b/docs/examples/sofai/sofai_graph_coloring.py index 67ed7acdc..06a46189b 100644 --- a/docs/examples/sofai/sofai_graph_coloring.py +++ b/docs/examples/sofai/sofai_graph_coloring.py @@ -23,7 +23,7 @@ import mellea from mellea.backends.ollama import OllamaModelBackend -from mellea.core import FancyLogger +from mellea.core import MelleaLogger from mellea.stdlib.components import Message from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import ValidationResult, req @@ -230,5 +230,5 @@ def main(): if __name__ == "__main__": # Set logging level - FancyLogger.get_logger().setLevel(logging.INFO) + MelleaLogger.get_logger().setLevel(logging.INFO) main() diff --git a/docs/metrics/coverage-baseline.json b/docs/metrics/coverage-baseline.json index 6a983a331..823445d3d 100644 --- a/docs/metrics/coverage-baseline.json +++ b/docs/metrics/coverage-baseline.json @@ -49,7 +49,7 @@ "ComponentParseError", "Context", "ContextTurn", - "FancyLogger", + "MelleaLogger", "Formatter", "GenerateLog", "GenerateType", diff --git a/docs/metrics/coverage-current.json b/docs/metrics/coverage-current.json index 62b134beb..80d7493e9 100644 --- a/docs/metrics/coverage-current.json +++ b/docs/metrics/coverage-current.json @@ -64,7 +64,7 @@ "default_output_to_bool", "SamplingResult", "SamplingStrategy", - "FancyLogger" + "MelleaLogger" ], "mellea.core.formatter": [ "CBlock", @@ -85,7 +85,7 @@ "Component", "Context", "ModelOutputThunk", - "FancyLogger", + "MelleaLogger", "BaseModelSubclass" ], "mellea.core.sampling": [ @@ -126,7 +126,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -145,7 +145,7 @@ "format" ], "mellea.backends.tools": [ - "FancyLogger", + "MelleaLogger", "CBlock", "Component", "TemplateRepresentation", @@ -158,7 +158,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -221,7 +221,7 @@ "HF_SMOLLM3_3B_no_ollama" ], "mellea.backends.model_options": [ - "FancyLogger" + "MelleaLogger" ], "mellea.backends.openai": [ "ALoraRequirement", @@ -231,7 +231,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -301,7 +301,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelToolCall", "AbstractMelleaTool", "ChatFormatter", @@ -316,7 +316,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -735,7 +735,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -849,7 +849,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "GenerateType", "ModelOutputThunk", @@ -1051,7 +1051,7 @@ "ModelIdentifier", "CBlock", "Component", - "FancyLogger", + "MelleaLogger", "TemplateRepresentation", "ChatFormatter" ], @@ -1348,7 +1348,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1370,7 +1370,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1421,7 +1421,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1464,7 +1464,7 @@ "local_code_interpreter" ], "mellea.stdlib.tools.interpreter": [ - "FancyLogger", + "MelleaLogger", "logger" ], "mellea.stdlib.sampling": [ @@ -1481,7 +1481,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1499,7 +1499,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1559,7 +1559,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1575,7 +1575,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1624,7 +1624,7 @@ "BaseModelSubclass", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "S", @@ -1642,7 +1642,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1726,7 +1726,7 @@ "mellea.stdlib.requirements.requirement": [ "CBlock", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "Intrinsic" @@ -1737,7 +1737,7 @@ "StaticAnalysisEnvironment", "UnsafeEnvironment", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "logger" @@ -1752,7 +1752,7 @@ "BaseModelSubclass", "CBlock", "Context", - "FancyLogger", + "MelleaLogger", "Requirement", "ValidationResult", "Message", @@ -1799,7 +1799,7 @@ "Component", "ModelOutputThunk", "TemplateRepresentation", - "FancyLogger", + "MelleaLogger", "MELLEA_FINALIZER_TOOL" ], "mellea.stdlib.components.chat": [ @@ -1824,7 +1824,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "ModelOutputThunk", "Requirement", "SamplingStrategy", @@ -1842,7 +1842,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -1953,7 +1953,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -2003,7 +2003,7 @@ "BaseModelSubclass", "AbstractMelleaTool", "ModelOutputThunk", - "FancyLogger", + "MelleaLogger", "ToolMessage", "MELLEA_FINALIZER_TOOL", "ReactInitiator", @@ -2017,7 +2017,7 @@ "CBlock", "Component", "Context", - "FancyLogger", + "MelleaLogger", "GenerateLog", "ImageBlock", "ModelOutputThunk", @@ -2102,7 +2102,7 @@ ], "mellea.helpers.openai_compatible_helpers": [ "validate_tool_arguments", - "FancyLogger", + "MelleaLogger", "ModelToolCall", "AbstractMelleaTool", "Document", diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 55f85be7b..64d315328 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -34,9 +34,9 @@ CBlock, Component, Context, - FancyLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, Requirement, ) @@ -199,7 +199,7 @@ def __call__( ) err = ll_matcher.get_error() # type: ignore[attr-defined] if err: - FancyLogger.get_logger().warning("Error in LLMatcher: %s", err) + MelleaLogger.get_logger().warning("Error in LLMatcher: %s", err) llguidance.torch.fill_next_token_bitmask(ll_matcher, bitmask, 0) llguidance.torch.apply_token_bitmask_inplace( @@ -402,7 +402,7 @@ async def _generate_from_context( if alora_req_adapter is None: # Log a warning if using an AloraRequirement but no adapter fit. if reroute_to_alora and isinstance(action, ALoraRequirement): - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"attempted to use an AloraRequirement but backend {self} doesn't have the specified adapter added {adapter_name}; defaulting to regular generation" ) reroute_to_alora = False @@ -472,7 +472,7 @@ async def _generate_from_intrinsic( raise Exception("Does not yet support non-chat contexts.") if len(model_options.items()) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "passing in model options when generating with an adapter; some model options may be overwritten / ignored" ) @@ -648,11 +648,11 @@ def _make_merged_kv_cache( case CBlock() if c.cache: assert c.value is not None if c.value in self._cached_blocks: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"KV CACHE HIT for: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" # type: ignore ) else: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"HF backend is caching a CBlock with hashed contents: {hash(c.value)} ({c.value[:3]}..{c.value[-3:]})" ) tokens = self._tokenizer(c.value, return_tensors="pt") @@ -690,14 +690,14 @@ def _make_merged_kv_cache( prefix, suffix = parts # Add the prefix, if any, to str+tok+dc parts. if prefix != "": - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Doing a forward pass on uncached block which is prefix to a cached CBlock: {prefix[:3]}.{len(prefix)}.{prefix[-3:]}" ) str_parts.append(prefix) tok_parts.append(self._tokenizer(prefix, return_tensors="pt")) dc_parts.append(self._make_dc_cache(tok_parts[-1])) # Add the cached CBlock to str+tok+dc parts. - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Replacing a substring with previously computed/retrieved cache with hahs value {hash(key)} ({key[:3]}..{key[-3:]})" ) # str_parts.append(key) @@ -710,7 +710,7 @@ def _make_merged_kv_cache( current_suffix = suffix # "base" case: the final suffix. if current_suffix != "": - FancyLogger.get_logger().debug( # type: ignore + MelleaLogger.get_logger().debug( # type: ignore f"Doing a forward pass on final suffix, an uncached block: {current_suffix[:3]}.{len(current_suffix)}.{current_suffix[-3:]}" # type: ignore ) # type: ignore str_parts.append(current_suffix) @@ -753,7 +753,7 @@ async def _generate_from_context_with_kv_cache( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -765,7 +765,7 @@ async def _generate_from_context_with_kv_cache( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") seed = model_options.get(ModelOption.SEED, None) if seed is not None: @@ -904,7 +904,7 @@ async def _generate_from_context_standard( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -916,7 +916,7 @@ async def _generate_from_context_standard( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") seed = model_options.get(ModelOption.SEED, None) if seed is not None: @@ -1266,7 +1266,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The raw endpoint does not support tool calling at the moment." ) @@ -1274,7 +1274,7 @@ async def generate_from_raw( # TODO: Remove this when we are able to update the torch package. # Test this by ensuring all outputs from this call are populated when running on mps. # https://github.com/pytorch/pytorch/pull/157727 - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "utilizing device mps with a `generate_from_raw` request; you may see issues when submitting batches of prompts to a huggingface backend; ensure all ModelOutputThunks have non-empty values." ) @@ -1475,7 +1475,7 @@ def add_adapter(self, adapter: LocalHFAdapter): """ if adapter.backend is not None: if adapter.backend is self: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"attempted to add adapter {adapter.name} with type {adapter.adapter_type} to the same backend {adapter.backend}" ) return @@ -1485,7 +1485,7 @@ def add_adapter(self, adapter: LocalHFAdapter): ) if self._added_adapters.get(adapter.qualified_name) is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Client code attempted to add {adapter.name} with type {adapter.adapter_type} but {adapter.name} was already added to {self.__class__}. The backend is refusing to do this, because adapter loading is not idempotent." ) return None @@ -1551,7 +1551,7 @@ def unload_adapter(self, adapter_qualified_name: str): # Check if the backend knows about this adapter. adapter = self._loaded_adapters.get(adapter_qualified_name, None) if adapter is None: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"could not unload adapter {adapter_qualified_name} for backend {self}: adapter is not loaded" ) return diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index bf3224ad0..cfa91593c 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -18,9 +18,9 @@ CBlock, Component, Context, - FancyLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) @@ -261,12 +261,12 @@ def _make_backend_specific_and_remove( unknown_keys.append(key) if len(unknown_keys) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"litellm allows for unknown / non-openai input params; mellea won't validate the following params that may cause issues: {', '.join(unknown_keys)}" ) if len(unsupported_openai_params) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"litellm may drop the following openai keys that it doesn't seem to recognize as being supported by the current model/provider: {', '.join(unsupported_openai_params)}" "\nThere are sometimes false positives here." ) @@ -339,7 +339,7 @@ async def _generate_from_chat_context_standard( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - FancyLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -570,7 +570,7 @@ def _extract_tools( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -580,7 +580,7 @@ def _extract_tools( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") return tools @overload @@ -636,7 +636,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) extra_body = {} if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The official OpenAI completion api does not accept response format / structured decoding; " "it will be passed as an extra arg." ) @@ -644,7 +644,7 @@ async def generate_from_raw( # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling." ) @@ -653,7 +653,7 @@ async def generate_from_raw( model_specific_options = self._make_backend_specific_and_remove(model_opts) if self._has_potential_event_loop_errors(): - FancyLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call." ) @@ -670,7 +670,7 @@ async def generate_from_raw( date = datetime.datetime.now() responses = completion_response.choices if len(responses) != len(prompts): - FancyLogger().get_logger().error( + MelleaLogger.get_logger().error( "litellm appears to have sent your batch request as a single message; this typically happens with providers like ollama that don't support batching" ) @@ -722,7 +722,7 @@ def _extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/backends/model_options.py b/mellea/backends/model_options.py index 2350797ff..f71ddfb53 100644 --- a/mellea/backends/model_options.py +++ b/mellea/backends/model_options.py @@ -2,7 +2,7 @@ from typing import Any -from ..core import FancyLogger +from ..core import MelleaLogger class ModelOption: @@ -111,7 +111,7 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: "Encountered conflict(s) when replacing keys. Could not replace keys for:\n" + "\n".join(conflict_log) ) - FancyLogger.get_logger().warning(f"{text_line}") + MelleaLogger.get_logger().warning(f"{text_line}") return new_options @staticmethod diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index f682a49d2..a81add1f5 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -16,9 +16,9 @@ CBlock, Component, Context, - FancyLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) @@ -89,11 +89,11 @@ def __init__( if not self._check_ollama_server(): err = f"could not create OllamaModelBackend: ollama server not running at {base_url}" - FancyLogger.get_logger().error(err) + MelleaLogger.get_logger().error(err) raise Exception(err) if not self._pull_ollama_model(): err = f"could not create OllamaModelBackend: {self._get_ollama_model_id()} could not be pulled from ollama library" - FancyLogger.get_logger().error(err) + MelleaLogger.get_logger().error(err) raise Exception(err) # A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent. @@ -168,7 +168,7 @@ def _pull_ollama_model(self) -> bool: return True try: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Loading/Pulling model from Ollama: {self._get_ollama_model_id()}" ) stream = self._client.pull(self._get_ollama_model_id(), stream=True) @@ -376,7 +376,7 @@ async def generate_from_chat_context( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -386,7 +386,7 @@ async def generate_from_chat_context( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") # Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option. chat_response: Coroutine[ @@ -483,11 +483,11 @@ async def generate_from_raw( list[ModelOutputThunk]: A list of model output thunks, one per action. """ if len(actions) > 1: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "Ollama doesn't support batching; will attempt to process concurrently." ) if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." ) @@ -523,7 +523,7 @@ async def generate_from_raw( result = None error = None if isinstance(response, BaseException): - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"generate_from_raw: request {i} failed with " f"{type(response).__name__}: {response}" ) @@ -583,7 +583,7 @@ def _extract_model_tool_requests( for tool in chat_response.message.tool_calls: func = tools.get(tool.function.name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool.function.name}" ) continue # skip this function if we can't find it. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 90d6c3646..301ad45e2 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -22,9 +22,9 @@ CBlock, Component, Context, - FancyLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, Requirement, ) @@ -175,7 +175,7 @@ def __init__( ) if self._base_url is None and os.getenv("OPENAI_BASE_URL") is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "OPENAI_BASE_URL or base_url is not set.\n" "The openai SDK is going to assume that the base_url is `https://api.openai.com/v1`" ) @@ -481,7 +481,7 @@ async def _generate_from_chat_context_standard( }, } else: - FancyLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "Mellea assumes you are NOT using the OpenAI platform, and that other model providers have less strict requirements on support JSON schemas passed into `format=`. If you encounter a server-side error following this message, then you found an exception to this assumption. Please open an issue at github.com/generative_computing/mellea with this stack trace and your inference engine / model provider." ) extra_params["response_format"] = { @@ -497,7 +497,7 @@ async def _generate_from_chat_context_standard( tools: dict[str, AbstractMelleaTool] = dict() if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}" ) else: @@ -507,7 +507,7 @@ async def _generate_from_chat_context_standard( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") thinking = model_opts.get(ModelOption.THINKING, None) if type(thinking) is bool and thinking: @@ -796,7 +796,7 @@ async def generate_from_raw( extra_body = {} if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The official OpenAI completion api does not accept response format / structured decoding; " "it will be passed as an extra arg." ) @@ -808,7 +808,7 @@ async def generate_from_raw( else: extra_body["guided_json"] = format.model_json_schema() # type: ignore if tool_calls: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "The completion endpoint does not support tool calling at the moment." ) @@ -832,7 +832,7 @@ async def generate_from_raw( ) # type: ignore except openai.BadRequestError as e: if openai_ollama_batching_error in e.message: - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "If you are trying to call `OpenAIBackend._generate_from_raw while targeting an ollama server, " "your requests will fail since ollama doesn't support batching requests." ) diff --git a/mellea/backends/tools.py b/mellea/backends/tools.py index cd5c541c6..3dd056512 100644 --- a/mellea/backends/tools.py +++ b/mellea/backends/tools.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, ConfigDict, Field -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger from ..core import CBlock, Component, TemplateRepresentation from ..core.base import AbstractMelleaTool @@ -97,7 +97,7 @@ def parameter_remapper(*args, **kwargs): """Langchain tools expect their first argument to be 'tool_input'.""" if args: # This shouldn't happen. Our ModelToolCall.call_func actually passes in everything as kwargs. - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"ignoring unexpected args while calling langchain tool ({tool_name}): ({args})" ) @@ -158,7 +158,7 @@ def tool_call(*args, **kwargs): """Wrapper for smolagents tool forward method.""" if args: # This shouldn't happen. Our ModelToolCall.call_func passes everything as kwargs. - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"ignoring unexpected args while calling smolagents tool ({tool_name}): ({args})" ) return tool.forward(**kwargs) @@ -480,7 +480,7 @@ def validate_tool_arguments( """ from pydantic import ValidationError, create_model - from ..core import FancyLogger + from ..core import MelleaLogger # Extract JSON schema from tool tool_schema = tool.as_json_tool.get("function", {}) @@ -581,7 +581,7 @@ def validate_tool_arguments( ) if coerced_fields and coerce_types: - FancyLogger.get_logger().debug( + MelleaLogger.get_logger().debug( f"Tool '{tool_name}' arguments coerced: {', '.join(coerced_fields)}" ) @@ -601,11 +601,11 @@ def validate_tool_arguments( if strict: # Re-raise with enhanced message - FancyLogger.get_logger().error(error_msg) + MelleaLogger.get_logger().error(error_msg) raise else: # Log warning and return original args - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( error_msg + "\nReturning original arguments without validation." ) return dict(args) @@ -615,10 +615,10 @@ def validate_tool_arguments( error_msg = f"Unexpected error validating tool '{tool_name}' arguments: {e}" if strict: - FancyLogger.get_logger().error(error_msg) + MelleaLogger.get_logger().error(error_msg) raise else: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( error_msg + "\nReturning original arguments without validation." ) return dict(args) diff --git a/mellea/backends/utils.py b/mellea/backends/utils.py index 5044a0f2a..f27f55950 100644 --- a/mellea/backends/utils.py +++ b/mellea/backends/utils.py @@ -13,7 +13,7 @@ from collections.abc import Callable from typing import Any -from ..core import CBlock, Component, Context, FancyLogger, ModelToolCall +from ..core import CBlock, Component, Context, MelleaLogger, ModelToolCall from ..core.base import AbstractMelleaTool from ..formatters import ChatFormatter from ..stdlib.components import Message @@ -76,7 +76,7 @@ def to_chat( for msg in ctx_as_conversation: for v in msg.values(): if "CBlock" in v: - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}" ) @@ -104,7 +104,7 @@ def to_tool_calls( for tool_name, tool_args in parse_tools(decoded_result): func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 5d414f52e..9ffe2d540 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -21,9 +21,9 @@ CBlock, Component, Context, - FancyLogger, GenerateLog, GenerateType, + MelleaLogger, ModelOutputThunk, ModelToolCall, ) @@ -387,7 +387,7 @@ async def generate_from_chat_context( tools: dict[str, AbstractMelleaTool] = {} if tool_calls: if _format: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"tool calling is superseded by format; will not call tools for request: {action}" ) else: @@ -397,7 +397,7 @@ async def generate_from_chat_context( # Add the tools from the action for this generation last so that # they overwrite conflicting names. add_tools_from_context_actions(tools, [action]) - FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + MelleaLogger.get_logger().info(f"Tools for call: {tools.keys()}") formatted_tools = convert_tools_to_json(tools) @@ -676,7 +676,7 @@ async def generate_from_raw( await self.do_generate_walks(list(actions)) if format is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "WatsonxAI completion api does not accept response format, ignoring it for this request." ) @@ -743,7 +743,7 @@ def _extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/core/__init__.py b/mellea/core/__init__.py index 910fced57..f31303a80 100644 --- a/mellea/core/__init__.py +++ b/mellea/core/__init__.py @@ -30,7 +30,22 @@ from .formatter import Formatter from .requirement import Requirement, ValidationResult, default_output_to_bool from .sampling import SamplingResult, SamplingStrategy -from .utils import FancyLogger +from .utils import MelleaLogger, clear_log_context, log_context, set_log_context + + +def __getattr__(name: str) -> object: + if name == "FancyLogger": + import warnings + + warnings.warn( + "FancyLogger has been renamed to MelleaLogger and will be removed in a future release. " + "Update your imports to use mellea.core.MelleaLogger.", + DeprecationWarning, + stacklevel=2, + ) + return MelleaLogger + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "Backend", @@ -42,11 +57,11 @@ "ComputedModelOutputThunk", "Context", "ContextTurn", - "FancyLogger", "Formatter", "GenerateLog", "GenerateType", "ImageBlock", + "MelleaLogger", "ModelOutputThunk", "ModelToolCall", "Requirement", @@ -56,6 +71,9 @@ "TemplateRepresentation", "ValidationResult", "blockify", + "clear_log_context", "default_output_to_bool", "generate_walk", + "log_context", + "set_log_context", ] diff --git a/mellea/core/backend.py b/mellea/core/backend.py index 82f9fae5f..a1ea18f13 100644 --- a/mellea/core/backend.py +++ b/mellea/core/backend.py @@ -21,7 +21,7 @@ from ..plugins.manager import has_plugins, invoke_hook from ..plugins.types import HookType from .base import C, CBlock, Component, Context, ModelOutputThunk -from .utils import FancyLogger +from .utils import MelleaLogger # Necessary to define a type variable that has a default value. # This is because VSCode's pyright static type checker instantiates @@ -180,7 +180,7 @@ async def do_generate_walk( coroutines = [x.avalue() for x in _to_compute] # The following log message might get noisy. Feel free to remove if so. if len(_to_compute) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." ) await asyncio.gather(*coroutines) @@ -202,7 +202,7 @@ async def do_generate_walks( coroutines = [x.avalue() for x in _to_compute] # The following log message might get noisy. Feel free to remove if so. if len(_to_compute) > 0: - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( f"generate_from_chat_context awaited on {len(_to_compute)} uncomputed mots." ) await asyncio.gather(*coroutines) diff --git a/mellea/core/utils.py b/mellea/core/utils.py index 3823a83a6..256d9ace6 100644 --- a/mellea/core/utils.py +++ b/mellea/core/utils.py @@ -1,23 +1,171 @@ """Logging utilities for the mellea core library. -Provides ``FancyLogger``, a singleton logger with colour-coded console output and +Provides ``MelleaLogger``, a singleton logger with colour-coded console output and an optional REST handler (``RESTHandler``) that forwards log records to a local -``/api/receive`` endpoint when the ``FLOG`` environment variable is set. All -internal mellea modules obtain their logger via ``FancyLogger.get_logger()``. +``/api/receive`` endpoint when the ``MELLEA_FLOG`` environment variable is set. All +internal mellea modules obtain their logger via ``MelleaLogger.get_logger()``. + +Environment variables +--------------------- +``MELLEA_LOG_LEVEL`` + Minimum log level name (e.g. ``DEBUG``, ``INFO``, ``WARNING``). Defaults to + ``INFO``. +``MELLEA_LOG_JSON`` + Set to any truthy value (``1``, ``true``, ``yes``) to emit structured JSON on + the console instead of the colour-coded human-readable format. +``MELLEA_FLOG`` + When set, log records are forwarded to a local REST endpoint. """ +import contextlib +import contextvars import json import logging import os import sys +import threading +from collections.abc import Generator +from typing import Any import requests +# --------------------------------------------------------------------------- +# Per-task/coroutine context fields (safe for asyncio — each Task gets its own copy) +# --------------------------------------------------------------------------- +_log_context: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar( + "log_context_fields", default={} +) + +# Lock used to make MelleaLogger singleton initialisation thread-safe. +_logger_lock: threading.Lock = threading.Lock() + +# Standard LogRecord attribute names that must not be overwritten by callers. +RESERVED_LOG_RECORD_ATTRS: frozenset[str] = frozenset( + ( + "args", + "created", + "exc_info", + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", + ) +) + + +def set_log_context(**fields: Any) -> None: + """Inject extra fields into every log record emitted from this coroutine or thread. + + Call this at the start of a request or task to attach identifiers such as + ``trace_id`` or ``request_id`` without modifying individual log calls. + + .. note:: + Prefer :func:`log_context` as the primary API — it guarantees cleanup + (including restoring outer values on same-key nesting) even on + exceptions. + + Args: + **fields: Arbitrary key-value pairs to include in log records. + + Raises: + ValueError: If any key clashes with a standard ``logging.LogRecord`` + attribute (e.g. ``levelname``, ``module``, ``thread``). + """ + invalid = frozenset(fields) & RESERVED_LOG_RECORD_ATTRS + if invalid: + raise ValueError( + f"Context field names clash with LogRecord reserved attributes: " + f"{sorted(invalid)}. Choose different names." + ) + _log_context.set({**_log_context.get(), **fields}) + + +def clear_log_context() -> None: + """Remove all context fields set by :func:`set_log_context` for this coroutine/thread.""" + _log_context.set({}) + + +@contextlib.contextmanager +def log_context(**fields: Any) -> Generator[None, None, None]: + """Context manager that injects *fields* for the duration of the block. + + On exit — including on exceptions — the context is restored to its state + before the block via a ``ContextVar`` token. This is safe for both nested + usage and concurrent asyncio tasks: each ``asyncio.Task`` owns an isolated + copy of the context variable, so coroutines running on the same event-loop + thread cannot overwrite each other's fields. + + Example:: + + with log_context(trace_id="abc-123", request_id="req-1"): + logger.info("Handling request") # both IDs appear here + logger.info("After request") # IDs are gone + + Args: + **fields: Key-value pairs to inject. Same restrictions as + :func:`set_log_context` — reserved ``LogRecord`` attribute names + are rejected with ``ValueError``. + + Yields: + None. The manager is used only for its enter/exit side effects. + + Raises: + ValueError: If any key clashes with a reserved ``LogRecord`` attribute. + """ + invalid = frozenset(fields) & RESERVED_LOG_RECORD_ATTRS + if invalid: + raise ValueError( + f"Context field names clash with LogRecord reserved attributes: " + f"{sorted(invalid)}. Choose different names." + ) + token = _log_context.set({**_log_context.get(), **fields}) + try: + yield + finally: + _log_context.reset(token) + + +class ContextFilter(logging.Filter): + """Logging filter that injects async-safe ContextVar fields into every record. + + Fields registered via :func:`set_log_context` are copied onto the + ``logging.LogRecord`` before formatters see it, enabling trace/request IDs + to appear in structured output without touching call sites. + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Attach async-safe ContextVar fields to *record* and allow it through. + + Args: + record (logging.LogRecord): The log record being processed. + + Returns: + bool: Always ``True`` — the record is never suppressed. + """ + fields: dict[str, Any] = _log_context.get() + for key, value in fields.items(): + setattr(record, key, value) + return True + class RESTHandler(logging.Handler): """Logging handler that forwards records to a local REST endpoint. - Sends log records as JSON to ``/api/receive`` when the ``FLOG`` environment + Sends log records as JSON to ``/api/receive`` when the ``MELLEA_FLOG`` environment variable is set. Failures are silently suppressed to avoid disrupting the application. @@ -38,22 +186,25 @@ def __init__( self.headers = headers or {"Content-Type": "application/json"} def emit(self, record: logging.LogRecord) -> None: - """Forwards a log record to the REST endpoint when the ``FLOG`` environment variable is set. + """Forwards a log record to the REST endpoint when the ``MELLEA_FLOG`` environment variable is set. Silently suppresses any network or HTTP errors to avoid disrupting the application. Args: record (logging.LogRecord): The log record to forward. """ - if os.environ.get("FLOG"): - log_data = self.format(record) + if _check_flog_env(): + formatter = self.formatter + if isinstance(formatter, JsonFormatter): + log_dict = formatter.format_as_dict(record) + else: + log_dict = {"message": self.format(record)} try: response = requests.request( self.method, self.api_url, headers=self.headers, - # data=json.dumps([{"log": log_data}]), - data=json.dumps([log_data]), + data=json.dumps([log_dict]), ) response.raise_for_status() except requests.exceptions.RequestException as _: @@ -61,39 +212,148 @@ def emit(self, record: logging.LogRecord) -> None: class JsonFormatter(logging.Formatter): - """Logging formatter that serialises log records as structured JSON dicts. + """Logging formatter that serialises log records as structured JSON strings. - Includes timestamp, level, message, module, function name, line number, - process ID, thread ID, and (if present) exception information. + Produces a consistent JSON schema with a fixed set of core fields. + Additional fields can be injected at construction time (``extra_fields``) or + dynamically per-thread via :func:`set_log_context` / :class:`ContextFilter`. + + Args: + timestamp_format: ``strftime`` format for the ``timestamp`` field. + Defaults to ISO-8601 (``"%Y-%m-%dT%H:%M:%S"``). + include_fields: Whitelist of **core** field names to keep. When ``None`` + all core fields are included. Note: this filter applies only to the + fields listed in ``_DEFAULT_FIELDS``; ``extra_fields`` passed to the + constructor and dynamic context fields (set via + :func:`set_log_context`) are **always** included regardless of this + setting. + exclude_fields: Set of core field names to drop. Applied after + *include_fields*. + extra_fields: Static key-value pairs merged into every log record. + + Attributes: + _DEFAULT_FIELDS (tuple[str, ...]): Canonical ordered list of core field + names produced by this formatter. """ - def format(self, record: logging.LogRecord) -> dict: # type: ignore[override] - """Formats a log record as a JSON-serialisable dictionary. + _DEFAULT_FIELDS: tuple[str, ...] = ( + "timestamp", + "level", + "message", + "module", + "function", + "line_number", + "process_id", + "thread_id", + ) + + def __init__( + self, + timestamp_format: str = "%Y-%m-%dT%H:%M:%S", + include_fields: list[str] | None = None, + exclude_fields: list[str] | None = None, + extra_fields: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialises the formatter; passes remaining kwargs to ``logging.Formatter``.""" + super().__init__(datefmt=timestamp_format, **kwargs) + + if include_fields is not None: + unknown = set(include_fields) - set(self._DEFAULT_FIELDS) + if unknown: + raise ValueError( + f"include_fields contains unknown field names: {sorted(unknown)}. " + f"Valid fields: {list(self._DEFAULT_FIELDS)}" + ) + + self._include: frozenset[str] | None = ( + frozenset(include_fields) if include_fields is not None else None + ) + self._exclude: frozenset[str] = frozenset(exclude_fields or []) + self._extra: dict[str, Any] = dict(extra_fields or {}) - Includes timestamp, level, message, module, function name, line number, - process ID, thread ID, and exception info if present. + def format_as_dict(self, record: logging.LogRecord) -> dict[str, Any]: + """Return the log record as a dictionary (public API for external callers). + + Equivalent to :meth:`_build_log_dict` but part of the public interface so + handlers and other callers do not need to reach into private methods. Args: - record (logging.LogRecord): The log record to format. + record: The log record to convert. + + Returns: + A dictionary ready for JSON serialisation. + """ + return self._build_log_dict(record) + + def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + """Build a log record dictionary with core, extra, and context fields. + + Args: + record: The log record to convert. Returns: - dict: A dictionary containing timestamp, level, message, module, function, - line number, process/thread IDs, and optional exception info. + A dictionary ready for JSON serialisation. """ - log_record = { + # Build the full set of core fields first. + # A TypeError here means the caller used %-style format placeholders + # with the wrong number of arguments (e.g. logger.info("%s %s", one)). + # Catch it and substitute a safe error string so the record is still emitted. + try: + message = record.getMessage() + except TypeError as exc: + message = f" original msg={record.msg!r}" + + all_core: dict[str, Any] = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, - "message": record.getMessage(), + "message": message, "module": record.module, "function": record.funcName, "line_number": record.lineno, "process_id": record.process, "thread_id": record.thread, } + + # Apply include/exclude filtering + if self._include is not None: + log_record: dict[str, Any] = { + k: v for k, v in all_core.items() if k in self._include + } + else: + log_record = {k: v for k, v in all_core.items() if k not in self._exclude} + + # Exception info if record.exc_info: log_record["exception"] = self.formatException(record.exc_info) + + # Static extra fields (constructor-level) + log_record.update(self._extra) + + # Dynamic context fields — prefer record attributes (set by + # ContextFilter) but fall back to ContextVar storage so the + # formatter works standalone without a filter attached. + context_fields: dict[str, Any] = _log_context.get() + for key, value in context_fields.items(): + log_record[key] = getattr(record, key, value) + return log_record + def format(self, record: logging.LogRecord) -> str: + """Formats a log record as a JSON string. + + Core fields are filtered by *include_fields* / *exclude_fields*. + Static *extra_fields* and any per-task ContextVar fields (set via + :func:`set_log_context`) are merged in after the core fields. + + Args: + record (logging.LogRecord): The log record to format. + + Returns: + str: A JSON-serialised log record. + """ + return json.dumps(self._build_log_dict(record), default=str) + class CustomFormatter(logging.Formatter): """A nice custom formatter copied from [Sergey Pleshakov's post on StackOverflow](https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output). @@ -138,13 +398,13 @@ def format(self, record: logging.LogRecord) -> str: return formatter.format(record) -class FancyLogger: +class MelleaLogger: """Singleton logger with colour-coded console output and optional REST forwarding. - Obtain the shared logger instance via ``FancyLogger.get_logger()``. Log level - defaults to ``INFO`` but can be raised to ``DEBUG`` by setting the ``DEBUG`` - environment variable. When the ``FLOG`` environment variable is set, records are - also forwarded to a local ``/api/receive`` REST endpoint via ``RESTHandler``. + Obtain the shared logger instance via ``MelleaLogger.get_logger()``. Log level + defaults to ``INFO`` but can be overridden via ``MELLEA_LOG_LEVEL``. When the + ``MELLEA_FLOG`` environment variable is set, records are also forwarded to a + local ``/api/receive`` REST endpoint via ``RESTHandler``. Attributes: logger (logging.Logger | None): The shared ``logging.Logger`` instance; ``None`` until first call to ``get_logger()``. @@ -169,47 +429,90 @@ class FancyLogger: DEBUG = 10 NOTSET = 0 + @staticmethod + def _resolve_log_level() -> int: + """Resolves the effective log level from environment variables. + + Checks ``MELLEA_LOG_LEVEL`` and defaults to ``INFO``. + + Returns: + int: A :mod:`logging` level integer. + """ + level_name = os.environ.get("MELLEA_LOG_LEVEL", "").strip().upper() + if level_name: + numeric = getattr(logging, level_name, None) + if isinstance(numeric, int): + return numeric + return MelleaLogger.INFO + @staticmethod def get_logger() -> logging.Logger: - """Returns a FancyLogger.logger and sets level based upon env vars. + """Returns a MelleaLogger.logger and sets level based upon env vars. + + The logger is created once (singleton). Subsequent calls return the + cached instance. Initialisation is protected by a module-level lock so + concurrent callers at startup cannot create duplicate handlers. Returns: Configured logger with REST, stream, and optional OTLP handlers. """ - if FancyLogger.logger is None: - logger = logging.getLogger("fancy_logger") - # Only set default level if user hasn't already configured it - if logger.level == logging.NOTSET: - if os.environ.get("DEBUG"): - logger.setLevel(FancyLogger.DEBUG) - else: - logger.setLevel(FancyLogger.INFO) - - # Define REST API endpoint - api_url = "http://localhost:8000/api/receive" - - # Create REST handler - rest_handler = RESTHandler(api_url) - - # Create formatter and set it for the handler - # formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - rest_handler.setFormatter(JsonFormatter()) - - # Add handler to the logger - logger.addHandler(rest_handler) - - stream_handler = logging.StreamHandler(stream=sys.stdout) - # stream_handler.setLevel(logging.INFO) - stream_handler.setFormatter(CustomFormatter(datefmt="%H:%M:%S,%03d")) - logger.addHandler(stream_handler) - - # Add OTLP handler if enabled - from ..telemetry import get_otlp_log_handler - - otlp_handler = get_otlp_log_handler() - if otlp_handler: - otlp_handler.setFormatter(JsonFormatter()) - logger.addHandler(otlp_handler) - - FancyLogger.logger = logger - return FancyLogger.logger + if MelleaLogger.logger is None: + with _logger_lock: + # Second check inside the lock: another thread may have finished + # initialisation while we were waiting. + if MelleaLogger.logger is None: + logger = logging.getLogger("fancy_logger") + + # Attach the context filter so ContextFilter fields reach all handlers + logger.addFilter(ContextFilter()) + + # Only set default level if user hasn't already configured it + if logger.level == logging.NOTSET: + logger.setLevel(MelleaLogger._resolve_log_level()) + + # --- REST handler --- + api_url = "http://localhost:8000/api/receive" + rest_handler = RESTHandler(api_url) + rest_handler.setFormatter(JsonFormatter()) + logger.addHandler(rest_handler) + + # --- Console / stream handler --- + stream_handler = logging.StreamHandler(stream=sys.stdout) + use_json_console = os.environ.get( + "MELLEA_LOG_JSON", "" + ).strip().lower() in ("1", "true", "yes") + if use_json_console: + stream_handler.setFormatter(JsonFormatter()) + else: + stream_handler.setFormatter( + CustomFormatter(datefmt="%H:%M:%S,%03d") + ) + logger.addHandler(stream_handler) + + # --- Optional OTLP handler --- + from ..telemetry import get_otlp_log_handler + + otlp_handler = get_otlp_log_handler() + if otlp_handler: + otlp_handler.setFormatter(JsonFormatter()) + logger.addHandler(otlp_handler) + + MelleaLogger.logger = logger + return MelleaLogger.logger + + +def _check_flog_env() -> bool: + """Check MELLEA_FLOG, with a DeprecationWarning fallback for the old FLOG name.""" + if os.environ.get("MELLEA_FLOG"): + return True + if os.environ.get("FLOG"): + import warnings + + warnings.warn( + "The FLOG environment variable is deprecated and will be removed in a future release. " + "Use MELLEA_FLOG instead.", + DeprecationWarning, + stacklevel=2, + ) + return True + return False diff --git a/mellea/formatters/template_formatter.py b/mellea/formatters/template_formatter.py index 90c806297..ba967da64 100644 --- a/mellea/formatters/template_formatter.py +++ b/mellea/formatters/template_formatter.py @@ -19,7 +19,7 @@ from ..backends.cache import SimpleLRUCache from ..backends.model_ids import ModelIdentifier -from ..core import CBlock, Component, FancyLogger, TemplateRepresentation +from ..core import CBlock, Component, MelleaLogger, TemplateRepresentation from .chat_formatter import ChatFormatter @@ -104,7 +104,7 @@ def _stringify( stringified_template_args[key] = self._stringify(val) if representation.obj is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"template formatter encountered a TemplateRepresentation with no obj when stringifying {c.__class__}; setting obj to {c}" ) representation.obj = c @@ -127,7 +127,7 @@ def _stringify( return stringified_list case _: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"formatter encountered an unexpected type in _stringify; using str() on {c.__class__}" ) return str(c) @@ -161,7 +161,7 @@ def _load_template(self, repr: TemplateRepresentation) -> jinja2.Template: return jinja2.Environment().from_string(repr.template) # type: ignore if repr.template_order is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"using template formatter for {repr.obj.__class__.__name__} but no template order was provided, defaulting to class name" ) repr.template_order = ["*"] diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 4858ea366..b8d5506a5 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -5,7 +5,7 @@ from typing import Any from ..backends.tools import validate_tool_arguments -from ..core import FancyLogger, ModelToolCall +from ..core import MelleaLogger, ModelToolCall from ..core.base import AbstractMelleaTool from ..stdlib.components import Document, Message @@ -33,7 +33,7 @@ def extract_model_tool_requests( func = tools.get(tool_name) if func is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"model attempted to call a non-existing function: {tool_name}" ) continue # skip this function if we can't find it. diff --git a/mellea/stdlib/components/genstub.py b/mellea/stdlib/components/genstub.py index 66e11e893..05ca11088 100644 --- a/mellea/stdlib/components/genstub.py +++ b/mellea/stdlib/components/genstub.py @@ -17,7 +17,7 @@ CBlock, Component, Context, - FancyLogger, + MelleaLogger, ModelOutputThunk, Requirement, SamplingStrategy, @@ -616,7 +616,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: # No retries if precondition validation fails. if not all(bool(val_result) for val_result in val_results): - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "generative stub arguments did not satisfy precondition requirements" ) raise PreconditionException( @@ -625,7 +625,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: ) elif len(stub_copy.precondition_requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "calling a generative stub with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" ) @@ -755,7 +755,7 @@ async def __async_call__() -> tuple[R, Context] | R: # No retries if precondition validation fails. if not all(bool(val_result) for val_result in val_results): - FancyLogger.get_logger().error( + MelleaLogger.get_logger().error( "generative stub arguments did not satisfy precondition requirements" ) raise PreconditionException( @@ -764,7 +764,7 @@ async def __async_call__() -> tuple[R, Context] | R: ) elif len(stub_copy.precondition_requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "calling a generative stub with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" ) diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index bee9eaf4a..8f08fe8a0 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -20,7 +20,7 @@ ModelOutputThunk, TemplateRepresentation, ) -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger MELLEA_FINALIZER_TOOL = "final_answer" """Used in the react loop to symbolize the loop is done.""" @@ -66,7 +66,7 @@ def format_for_llm(self) -> TemplateRepresentation: tools = {tool.name: tool for tool in self.tools} if tools.get(MELLEA_FINALIZER_TOOL, None) is not None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"overriding user tool '{MELLEA_FINALIZER_TOOL}' in react call; this tool name is required for internal use" ) diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index f924ad7bb..117af4866 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -11,7 +11,7 @@ from mellea.backends.model_options import ModelOption from mellea.core.backend import Backend, BaseModelSubclass from mellea.core.base import AbstractMelleaTool, ComputedModelOutputThunk -from mellea.core.utils import FancyLogger +from mellea.core.utils import MelleaLogger from mellea.stdlib import functional as mfuncs # from mellea.stdlib.components.docs.document import Document @@ -79,7 +79,7 @@ async def react( turn_num = 0 while (turn_num < loop_budget) or (loop_budget == -1): turn_num += 1 - FancyLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") + MelleaLogger.get_logger().info(f"## ReACT TURN NUMBER {turn_num}") step, next_context = await mfuncs.aact( action=ReactThought(), diff --git a/mellea/stdlib/functional.py b/mellea/stdlib/functional.py index 335546e71..f4bdb2371 100644 --- a/mellea/stdlib/functional.py +++ b/mellea/stdlib/functional.py @@ -17,9 +17,9 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, GenerateLog, ImageBlock, + MelleaLogger, ModelOutputThunk, ModelToolCall, Requirement, @@ -441,7 +441,7 @@ def transform( if chosen_tool is None: chosen_tool = tools[0] - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" # type: ignore ) @@ -449,12 +449,12 @@ def transform( if chosen_tool: # Tell the user the function they should've called if no generated values were added. if len(chosen_tool._tool.args.keys()) == 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" ) new_ctx.add(chosen_tool) - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "added a tool message from transform to the context" ) return chosen_tool._tool_output, new_ctx @@ -575,7 +575,7 @@ async def aact( tool_calls=tool_calls, ) as span: if not silence_context_type_warning and not isinstance(context, SimpleContext): - FancyLogger().get_logger().warning( + MelleaLogger.get_logger().warning( "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." "\nSee the async section of the docs: https://docs.mellea.ai/how-to/use-async-and-streaming" ) @@ -618,7 +618,7 @@ async def aact( if strategy is None: # Only use the strategy if one is provided. Add a warning if requirements were passed in though. if requirements is not None and len(requirements) > 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( "Calling the function with NO strategy BUT requirements. No requirement is being checked!" ) @@ -1201,7 +1201,7 @@ async def atransform( if chosen_tool is None: chosen_tool = tools[0] - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" # type: ignore ) @@ -1209,12 +1209,12 @@ async def atransform( if chosen_tool: # Tell the user the function they should've called if no generated values were added. if len(chosen_tool._tool.args.keys()) == 0: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" ) new_ctx.add(chosen_tool) - FancyLogger.get_logger().info( + MelleaLogger.get_logger().info( "added a tool message from transform to the context" ) return chosen_tool._tool_output, new_ctx diff --git a/mellea/stdlib/requirements/python_reqs.py b/mellea/stdlib/requirements/python_reqs.py index ad2d6a74b..3152acb71 100644 --- a/mellea/stdlib/requirements/python_reqs.py +++ b/mellea/stdlib/requirements/python_reqs.py @@ -9,9 +9,9 @@ UnsafeEnvironment, ) -from ...core import Context, FancyLogger, Requirement, ValidationResult +from ...core import Context, MelleaLogger, Requirement, ValidationResult -logger = FancyLogger.get_logger() +logger = MelleaLogger.get_logger() # region code extraction diff --git a/mellea/stdlib/requirements/requirement.py b/mellea/stdlib/requirements/requirement.py index 3ec5e8164..098151771 100644 --- a/mellea/stdlib/requirements/requirement.py +++ b/mellea/stdlib/requirements/requirement.py @@ -4,7 +4,7 @@ from collections.abc import Callable from typing import Any, overload -from ...core import CBlock, Context, FancyLogger, Requirement, ValidationResult +from ...core import CBlock, Context, MelleaLogger, Requirement, ValidationResult from ..components.intrinsic import Intrinsic @@ -37,7 +37,7 @@ def requirement_check_to_bool(x: CBlock | str) -> bool: likelihood = req_dict.get("requirement_likelihood", None) if likelihood is None: - FancyLogger.get_logger().warning( + MelleaLogger.get_logger().warning( f"could not get value from alora requirement output; looking for `requirement_likelihood` in {req_dict}" ) return False @@ -173,7 +173,7 @@ def simple_validate( def validate(ctx: Context) -> ValidationResult: o = ctx.last_output() if o is None or o.value is None: - FancyLogger.get_logger().warn( + MelleaLogger.get_logger().warn( "Last output of context was None. That might be a problem. We return validation as False to be able to continue..." ) return ValidationResult( diff --git a/mellea/stdlib/requirements/safety/guardian.py b/mellea/stdlib/requirements/safety/guardian.py index 26cd632af..e90e86157 100644 --- a/mellea/stdlib/requirements/safety/guardian.py +++ b/mellea/stdlib/requirements/safety/guardian.py @@ -9,7 +9,7 @@ BaseModelSubclass, CBlock, Context, - FancyLogger, + MelleaLogger, Requirement, ValidationResult, ) @@ -197,7 +197,7 @@ def __init__( except Exception: pass - self._logger = FancyLogger.get_logger() + self._logger = MelleaLogger.get_logger() def get_effective_risk(self) -> str: """Return the effective risk criteria to use for validation. diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 9dca02087..388417c3b 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -25,12 +25,13 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, SamplingStrategy, ValidationResult, + log_context, ) from ...plugins.manager import has_plugins, invoke_hook from ...plugins.types import HookType @@ -143,236 +144,243 @@ async def sample( """ validation_ctx = validation_ctx if validation_ctx is not None else context - flog = FancyLogger.get_logger() - - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO - - reqs = [] - # global requirements supersede local requirements (global requirements can be defined by user) - # Todo: re-evaluate if this makes sense - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - reqs = list(set(reqs)) - - loop_count = 0 - - # --- sampling_loop_start hook --- - effective_loop_budget = self.loop_budget - if has_plugins(HookType.SAMPLING_LOOP_START): - from ...plugins.hooks.sampling import SamplingLoopStartPayload - - start_payload = SamplingLoopStartPayload( - strategy_name=type(self).__name__, - action=action, - context=context, - requirements=reqs, - loop_budget=self.loop_budget, - ) - _, start_payload = await invoke_hook( - HookType.SAMPLING_LOOP_START, start_payload, backend=backend - ) - effective_loop_budget = start_payload.loop_budget + flog = MelleaLogger.get_logger() - loop_budget_range_iterator = ( - tqdm.tqdm(range(effective_loop_budget)) # type: ignore - if show_progress - else range(effective_loop_budget) # type: ignore - ) + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] - next_action = deepcopy(action) - next_context = context - for _ in loop_budget_range_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # run a generation pass - result, result_ctx = await backend.generate_from_context( - next_action, - ctx=next_context, - format=format, - model_options=model_options, - tool_calls=tool_calls, + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = ( + show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO ) - await result.avalue() - result = ComputedModelOutputThunk(result) - - # Sampling strategies may use different components from the original - # action. This might cause discrepancies in the expected parsed_repr - # type / value. Explicitly overwrite that here. - # TODO: See if there's a more elegant way for this so that each sampling - # strategy doesn't have to re-implement it. - result.parsed_repr = action.parse(result) - - # validation pass - val_scores_co = mfuncs.avalidate( - reqs=reqs, - context=result_ctx, - backend=backend, - output=result, - format=None, - model_options=model_options, - # tool_calls=tool_calls # Don't support using tool calls in validation strategies. - ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) - sample_contexts.append(result_ctx) - - all_validations_passed = all(bool(s[1]) for s in constraint_scores) - - # --- sampling_iteration hook --- - if has_plugins(HookType.SAMPLING_ITERATION): - from ...plugins.hooks.sampling import SamplingIterationPayload - - iter_payload = SamplingIterationPayload( - iteration=loop_count, - action=next_action, - result=result, - validation_results=constraint_scores, - all_validations_passed=all_validations_passed, - valid_count=sum(1 for s in constraint_scores if bool(s[1])), - total_count=len(constraint_scores), + + reqs = [] + # global requirements supersede local requirements (global requirements can be defined by user) + # Todo: re-evaluate if this makes sense + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) + + loop_count = 0 + + # --- sampling_loop_start hook --- + effective_loop_budget = self.loop_budget + if has_plugins(HookType.SAMPLING_LOOP_START): + from ...plugins.hooks.sampling import SamplingLoopStartPayload + + start_payload = SamplingLoopStartPayload( + strategy_name=type(self).__name__, + action=action, + context=context, + requirements=reqs, + loop_budget=self.loop_budget, ) - await invoke_hook( - HookType.SAMPLING_ITERATION, iter_payload, backend=backend + _, start_payload = await invoke_hook( + HookType.SAMPLING_LOOP_START, start_payload, backend=backend ) + effective_loop_budget = start_payload.loop_budget - # if all vals are true -- break and return success - if all_validations_passed: - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True + loop_budget_range_iterator = ( + tqdm.tqdm(range(effective_loop_budget)) # type: ignore + if show_progress + else range(effective_loop_budget) # type: ignore + ) - # --- sampling_loop_end hook (success) --- - if has_plugins(HookType.SAMPLING_LOOP_END): - from ...plugins.hooks.sampling import SamplingLoopEndPayload + next_action = deepcopy(action) + next_context = context + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # run a generation pass + result, result_ctx = await backend.generate_from_context( + next_action, + ctx=next_context, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + await result.avalue() + result = ComputedModelOutputThunk(result) + + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + # TODO: See if there's a more elegant way for this so that each sampling + # strategy doesn't have to re-implement it. + result.parsed_repr = action.parse(result) + + # validation pass + val_scores_co = mfuncs.avalidate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=None, + model_options=model_options, + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + ) + val_scores = await val_scores_co + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + all_validations_passed = all(bool(s[1]) for s in constraint_scores) + + # --- sampling_iteration hook --- + if has_plugins(HookType.SAMPLING_ITERATION): + from ...plugins.hooks.sampling import SamplingIterationPayload + + iter_payload = SamplingIterationPayload( + iteration=loop_count, + action=next_action, + result=result, + validation_results=constraint_scores, + all_validations_passed=all_validations_passed, + valid_count=sum(1 for s in constraint_scores if bool(s[1])), + total_count=len(constraint_scores), + ) + await invoke_hook( + HookType.SAMPLING_ITERATION, iter_payload, backend=backend + ) - end_payload = SamplingLoopEndPayload( + # if all vals are true -- break and return success + if all_validations_passed: + flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + + # --- sampling_loop_end hook (success) --- + if has_plugins(HookType.SAMPLING_LOOP_END): + from ...plugins.hooks.sampling import SamplingLoopEndPayload + + end_payload = SamplingLoopEndPayload( + success=True, + iterations_used=loop_count, + final_result=result, + final_action=next_action, + final_context=result_ctx, + all_results=sampled_results, + all_validations=sampled_scores, + ) + await invoke_hook( + HookType.SAMPLING_LOOP_END, end_payload, backend=backend + ) + + # SUCCESS !!!! + return SamplingResult( + result_index=len(sampled_results) - 1, success=True, - iterations_used=loop_count, - final_result=result, - final_action=next_action, - final_context=result_ctx, - all_results=sampled_results, - all_validations=sampled_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, ) - await invoke_hook( - HookType.SAMPLING_LOOP_END, end_payload, backend=backend + + else: + # log partial success and continue + failed = [s for s in constraint_scores if not bool(s[1])] + count_failed = len(failed) + failed_reqs = [ + r[0].description + if r[0].description is not None + else "[no description]" + for r in failed + ] + stringify_failed = "\n\t - " + "\n\t - ".join(failed_reqs) + flog.info( + f"FAILED. Valid: {len(constraint_scores) - count_failed}/{len(constraint_scores)}. Failed: {stringify_failed}" ) - # SUCCESS !!!! - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, + # If we did not pass all constraints, update the instruction and try again. + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, ) - else: - # log partial success and continue - failed = [s for s in constraint_scores if not bool(s[1])] - count_failed = len(failed) - failed_reqs = [ - r[0].description - if r[0].description is not None - else "[no description]" - for r in failed - ] - stringify_failed = "\n\t - " + "\n\t - ".join(failed_reqs) - flog.info( - f"FAILED. Valid: {len(constraint_scores) - count_failed}/{len(constraint_scores)}. Failed: {stringify_failed}" - ) + # --- sampling_repair hook --- + if has_plugins(HookType.SAMPLING_REPAIR): + from ...plugins.hooks.sampling import SamplingRepairPayload + + repair_payload = SamplingRepairPayload( + repair_type=getattr( + self, "_get_repair_type", lambda: "unknown" + )(), + failed_action=sampled_actions[-1], + failed_result=sampled_results[-1], + failed_validations=sampled_scores[-1], + repair_action=next_action, + repair_context=next_context, + repair_iteration=loop_count, + ) + await invoke_hook( + HookType.SAMPLING_REPAIR, repair_payload, backend=backend + ) - # If we did not pass all constraints, update the instruction and try again. - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - # --- sampling_repair hook --- - if has_plugins(HookType.SAMPLING_REPAIR): - from ...plugins.hooks.sampling import SamplingRepairPayload - - repair_payload = SamplingRepairPayload( - repair_type=getattr(self, "_get_repair_type", lambda: "unknown")(), - failed_action=sampled_actions[-1], - failed_result=sampled_results[-1], - failed_validations=sampled_scores[-1], - repair_action=next_action, - repair_context=next_context, - repair_iteration=loop_count, - ) - await invoke_hook( - HookType.SAMPLING_REPAIR, repair_payload, backend=backend - ) + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." + ) - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) - - assert ( - sampled_results[best_failed_index]._generate_log is not None - ) # Cannot be None after generation. - sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + # --- sampling_loop_end hook (failure) --- + if has_plugins(HookType.SAMPLING_LOOP_END): + from ...plugins.hooks.sampling import SamplingLoopEndPayload - # --- sampling_loop_end hook (failure) --- - if has_plugins(HookType.SAMPLING_LOOP_END): - from ...plugins.hooks.sampling import SamplingLoopEndPayload + _final_ctx = ( + sample_contexts[best_failed_index] if sample_contexts else context + ) + end_payload = SamplingLoopEndPayload( + success=False, + iterations_used=loop_count, + final_result=sampled_results[best_failed_index], + final_action=sampled_actions[best_failed_index], + final_context=_final_ctx, + failure_reason=f"Budget exhausted after {loop_count} iterations", + all_results=sampled_results, + all_validations=sampled_scores, + ) + await invoke_hook( + HookType.SAMPLING_LOOP_END, end_payload, backend=backend + ) - _final_ctx = ( - sample_contexts[best_failed_index] if sample_contexts else context - ) - end_payload = SamplingLoopEndPayload( + return SamplingResult( + result_index=best_failed_index, success=False, - iterations_used=loop_count, - final_result=sampled_results[best_failed_index], - final_action=sampled_actions[best_failed_index], - final_context=_final_ctx, - failure_reason=f"Budget exhausted after {loop_count} iterations", - all_results=sampled_results, - all_validations=sampled_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + sample_contexts=sample_contexts, ) - await invoke_hook(HookType.SAMPLING_LOOP_END, end_payload, backend=backend) - - return SamplingResult( - result_index=best_failed_index, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) class RejectionSamplingStrategy(BaseSamplingStrategy): diff --git a/mellea/stdlib/sampling/budget_forcing.py b/mellea/stdlib/sampling/budget_forcing.py index 25faccdca..5da1d0fae 100644 --- a/mellea/stdlib/sampling/budget_forcing.py +++ b/mellea/stdlib/sampling/budget_forcing.py @@ -11,11 +11,12 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, ValidationResult, + log_context, ) from ...stdlib import functional as mfuncs from .base import RejectionSamplingStrategy @@ -125,148 +126,151 @@ async def sample( """ validation_ctx = validation_ctx if validation_ctx is not None else context - flog = FancyLogger.get_logger() - - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] - - # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress - # flag to determine whether we should show the pbar. - show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO - - reqs = [] - # global requirements supersede local requirements (global requirements can be defined by user) - # Todo: re-evaluate if this makes sense - if self.requirements is not None: - reqs += self.requirements - elif requirements is not None: - reqs += requirements - reqs = list(set(reqs)) - - loop_count = 0 - loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) # type: ignore - if show_progress - else range(self.loop_budget) # type: ignore - ) - - next_action = deepcopy(action) - next_context = context - for _ in loop_budget_range_iterator: # type: ignore - loop_count += 1 - if not show_progress: - flog.info(f"Running loop {loop_count} of {self.loop_budget}") - - # TODO - # tool_calls is not supported for budget forcing - assert tool_calls is False, ( - "tool_calls is not supported with budget forcing" - ) - # TODO - assert isinstance(backend, OllamaModelBackend), ( - "Only ollama backend supported with budget forcing" - ) - # run a generation pass with budget forcing - result = await think_budget_forcing( - backend, - next_action, - ctx=context, - format=format, - tool_calls=tool_calls, - think_max_tokens=self.think_max_tokens, - answer_max_tokens=self.answer_max_tokens, - start_think_token=self.start_think_token, - end_think_token=self.end_think_token, - think_more_suffix=self.think_more_suffix, - answer_suffix=self.answer_suffix, - model_options=model_options, + flog = MelleaLogger.get_logger() + + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] + + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress + # flag to determine whether we should show the pbar. + show_progress = ( + show_progress and flog.getEffectiveLevel() <= MelleaLogger.INFO ) - result_ctx = next_context - await result.avalue() - result = ComputedModelOutputThunk(result) - - # Sampling strategies may use different components from the original - # action. This might cause discrepancies in the expected parsed_repr - # type / value. Explicitly overwrite that here. - result.parsed_repr = action.parse(result) - - # validation pass - val_scores_co = mfuncs.avalidate( - reqs=reqs, - context=result_ctx, - backend=backend, - output=result, - format=format, - model_options=model_options, - # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + + reqs = [] + # global requirements supersede local requirements (global requirements can be defined by user) + # Todo: re-evaluate if this makes sense + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) + + loop_count = 0 + loop_budget_range_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore ) - val_scores = await val_scores_co - - # match up reqs with scores - constraint_scores = list(zip(reqs, val_scores)) - - # collect all data - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) - sample_contexts.append(result_ctx) - - # if all vals are true -- break and return success - if all(bool(s[1]) for s in constraint_scores): - flog.info("SUCCESS") - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - result._generate_log.is_final_result = True - - # SUCCESS !!!! - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, + + next_action = deepcopy(action) + next_context = context + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # TODO + # tool_calls is not supported for budget forcing + assert tool_calls is False, ( + "tool_calls is not supported with budget forcing" + ) + # TODO + assert isinstance(backend, OllamaModelBackend), ( + "Only ollama backend supported with budget forcing" + ) + # run a generation pass with budget forcing + result = await think_budget_forcing( + backend, + next_action, + ctx=context, + format=format, + tool_calls=tool_calls, + think_max_tokens=self.think_max_tokens, + answer_max_tokens=self.answer_max_tokens, + start_think_token=self.start_think_token, + end_think_token=self.end_think_token, + think_more_suffix=self.think_more_suffix, + answer_suffix=self.answer_suffix, + model_options=model_options, + ) + result_ctx = next_context + await result.avalue() + result = ComputedModelOutputThunk(result) + + # Sampling strategies may use different components from the original + # action. This might cause discrepancies in the expected parsed_repr + # type / value. Explicitly overwrite that here. + result.parsed_repr = action.parse(result) + + # validation pass + val_scores_co = mfuncs.avalidate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=format, + model_options=model_options, + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. + ) + val_scores = await val_scores_co + + # match up reqs with scores + constraint_scores = list(zip(reqs, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + # if all vals are true -- break and return success + if all(bool(s[1]) for s in constraint_scores): + flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + + # SUCCESS !!!! + return SamplingResult( + result_index=len(sampled_results) - 1, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, + ) + + else: + # log partial success and continue + count_valid = len([s for s in constraint_scores if bool(s[1])]) + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + + # If we did not pass all constraints, update the instruction and try again. + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, ) - else: - # log partial success and continue - count_valid = len([s for s in constraint_scores if bool(s[1])]) - flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") - - # If we did not pass all constraints, update the instruction and try again. - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - flog.info( - f"Invoking select_from_failure after {len(sampled_results)} failed attempts." - ) - - # if no valid result could be determined, find a last resort. - best_failed_index = self.select_from_failure( - sampled_actions, sampled_results, sampled_scores - ) - assert best_failed_index < len(sampled_results), ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." - ) - - assert ( - sampled_results[best_failed_index]._generate_log is not None - ) # Cannot be None after generation. - sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore - - return SamplingResult( - result_index=best_failed_index, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_actions=sampled_actions, - sample_contexts=sample_contexts, - ) + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to selected from failed_results." + ) + + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + + return SamplingResult( + result_index=best_failed_index, + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + sample_contexts=sample_contexts, + ) diff --git a/mellea/stdlib/sampling/sofai.py b/mellea/stdlib/sampling/sofai.py index ff1c8f41f..792de2de4 100644 --- a/mellea/stdlib/sampling/sofai.py +++ b/mellea/stdlib/sampling/sofai.py @@ -20,13 +20,14 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, + MelleaLogger, Requirement, S, SamplingResult, SamplingStrategy, TemplateRepresentation, ValidationResult, + log_context, ) from ...stdlib import functional as mfuncs from ..components import Message @@ -431,7 +432,7 @@ def _prepare_s2_context( Returns: Tuple of (action_for_s2, context_for_s2). """ - flog = FancyLogger.get_logger() + flog = MelleaLogger.get_logger() if s2_mode == "fresh_start": # Clean slate: same prompt as S1 @@ -605,47 +606,144 @@ async def sample( "SOFAI requires ChatContext for conversation management." ) - flog = FancyLogger.get_logger() - reqs: list[Requirement] = list(requirements) if requirements else [] + flog = MelleaLogger.get_logger() - # State tracking for all attempts - sampled_results: list[ComputedModelOutputThunk] = [] - sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - sampled_actions: list[Component] = [] - sample_contexts: list[Context] = [] + with log_context(strategy=type(self).__name__, loop_budget=self.loop_budget): + reqs: list[Requirement] = list(requirements) if requirements else [] - # --------------------------------------------------------------------- - # PHASE 1: S1 Solver Loop - # --------------------------------------------------------------------- - flog.info( - f"SOFAI: Starting S1 Solver ({getattr(self.s1_solver_backend, 'model_id', 'unknown')}) " - f"loop (budget={self.loop_budget})" - ) + # State tracking for all attempts + sampled_results: list[ComputedModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] - previous_failed_set: set[tuple[str | None, str | None, float | None]] = set() - loop_count = 0 - next_action = deepcopy(action) - next_context: Context = context + # --------------------------------------------------------------------- + # PHASE 1: S1 Solver Loop + # --------------------------------------------------------------------- + flog.info( + f"SOFAI: Starting S1 Solver ({getattr(self.s1_solver_backend, 'model_id', 'unknown')}) " + f"loop (budget={self.loop_budget})" + ) - show_progress = flog.getEffectiveLevel() <= FancyLogger.INFO - loop_iterator = ( - tqdm.tqdm(range(self.loop_budget), desc="S1 Solver") - if show_progress - else range(self.loop_budget) - ) + previous_failed_set: set[tuple[str | None, str | None, float | None]] = ( + set() + ) + loop_count = 0 + next_action = deepcopy(action) + next_context: Context = context + + show_progress = flog.getEffectiveLevel() <= MelleaLogger.INFO + loop_iterator = ( + tqdm.tqdm(range(self.loop_budget), desc="S1 Solver") + if show_progress + else range(self.loop_budget) + ) + + # Exit conditions: success returns immediately; no-improvement breaks + # early to S2 escalation; loop budget exhaustion flows to S2 escalation. + for _ in loop_iterator: + loop_count += 1 + if not show_progress: + flog.info( + f"SOFAI S1: Running loop {loop_count} of {self.loop_budget}" + ) + + # Generate and validate + ( + result, + result_ctx, + constraint_scores, + ) = await self._generate_and_validate( + solver_backend=self.s1_solver_backend, + action=next_action, + ctx=next_context, + reqs=reqs, + session_backend=backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) - # Exit conditions: success returns immediately; no-improvement breaks - # early to S2 escalation; loop budget exhaustion flows to S2 escalation. - for _ in loop_iterator: - loop_count += 1 - if not show_progress: - flog.info(f"SOFAI S1: Running loop {loop_count} of {self.loop_budget}") + # Store attempt + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + # Check for success + if all(bool(score[1]) for score in constraint_scores): + flog.info(f"SOFAI S1: SUCCESS on attempt {loop_count}") + assert result._generate_log is not None + result._generate_log.is_final_result = True + + # Exit with success + return SamplingResult( + result_index=len(sampled_results) - 1, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, + ) + + # Log partial progress + count_valid = sum(1 for s in constraint_scores if bool(s[1])) + flog.info( + f"SOFAI S1: FAILED attempt {loop_count}. " + f"Valid: {count_valid}/{len(constraint_scores)}" + ) - # Generate and validate + # Check for no improvement (early exit to S2) + current_failed_set = { + (req.description, val.reason, val.score) + for req, val in constraint_scores + if not val.as_bool() + } + if loop_count > 1 and current_failed_set == previous_failed_set: + flog.warning( + f"SOFAI S1: No improvement detected between attempt " + f"{loop_count - 1} and {loop_count}. Escalating to S2 Solver." + ) + # Exit with no improvement + break + previous_failed_set = current_failed_set + + # Prepare repair for next iteration + if loop_count < self.loop_budget: + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, + ) + # Exit due to loop budget exhaustion or no improvement + + # --------------------------------------------------------------------- + # PHASE 2: S2 Solver Escalation + # --------------------------------------------------------------------- + flog.info( + f"SOFAI: S1 Solver completed {loop_count} attempts. " + f"Escalating to S2 Solver ({getattr(self.s2_solver_backend, 'model_id', 'unknown')})." + ) + + # Prepare S2 context based on mode + s2_action, s2_context = self._prepare_s2_context( + s2_mode=self.s2_solver_mode, + original_action=action, + original_context=context, + last_result_ctx=result_ctx, + last_action=next_action, + sampled_results=sampled_results, + sampled_scores=sampled_scores, + loop_count=loop_count, + ) + + # Generate and validate with S2 result, result_ctx, constraint_scores = await self._generate_and_validate( - solver_backend=self.s1_solver_backend, - action=next_action, - ctx=next_context, + solver_backend=self.s2_solver_backend, + action=s2_action, + ctx=s2_context, reqs=reqs, session_backend=backend, format=format, @@ -653,19 +751,18 @@ async def sample( tool_calls=tool_calls, ) - # Store attempt + # Store S2 attempt sampled_results.append(result) sampled_scores.append(constraint_scores) - sampled_actions.append(next_action) + sampled_actions.append(s2_action) sample_contexts.append(result_ctx) - # Check for success - if all(bool(score[1]) for score in constraint_scores): - flog.info(f"SOFAI S1: SUCCESS on attempt {loop_count}") - assert result._generate_log is not None - result._generate_log.is_final_result = True + # Check S2 success + assert result._generate_log is not None + result._generate_log.is_final_result = True - # Exit with success + if all(bool(score[1]) for score in constraint_scores): + flog.info("SOFAI S2: SUCCESS") return SamplingResult( result_index=len(sampled_results) - 1, success=True, @@ -674,103 +771,17 @@ async def sample( sample_contexts=sample_contexts, sample_actions=sampled_actions, ) - - # Log partial progress - count_valid = sum(1 for s in constraint_scores if bool(s[1])) - flog.info( - f"SOFAI S1: FAILED attempt {loop_count}. " - f"Valid: {count_valid}/{len(constraint_scores)}" - ) - - # Check for no improvement (early exit to S2) - current_failed_set = { - (req.description, val.reason, val.score) - for req, val in constraint_scores - if not val.as_bool() - } - if loop_count > 1 and current_failed_set == previous_failed_set: + else: + count_valid = sum(1 for s in constraint_scores if bool(s[1])) flog.warning( - f"SOFAI S1: No improvement detected between attempt " - f"{loop_count - 1} and {loop_count}. Escalating to S2 Solver." + f"SOFAI S2: FAILED. Valid: {count_valid}/{len(constraint_scores)}. " + f"Returning S2 Solver's attempt as final result." ) - # Exit with no improvement - break - previous_failed_set = current_failed_set - - # Prepare repair for next iteration - if loop_count < self.loop_budget: - next_action, next_context = self.repair( - next_context, - result_ctx, - sampled_actions, - sampled_results, - sampled_scores, + return SamplingResult( + result_index=len(sampled_results) - 1, + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, ) - # Exit due to loop budget exhaustion or no improvement - - # --------------------------------------------------------------------- - # PHASE 2: S2 Solver Escalation - # --------------------------------------------------------------------- - flog.info( - f"SOFAI: S1 Solver completed {loop_count} attempts. " - f"Escalating to S2 Solver ({getattr(self.s2_solver_backend, 'model_id', 'unknown')})." - ) - - # Prepare S2 context based on mode - s2_action, s2_context = self._prepare_s2_context( - s2_mode=self.s2_solver_mode, - original_action=action, - original_context=context, - last_result_ctx=result_ctx, - last_action=next_action, - sampled_results=sampled_results, - sampled_scores=sampled_scores, - loop_count=loop_count, - ) - - # Generate and validate with S2 - result, result_ctx, constraint_scores = await self._generate_and_validate( - solver_backend=self.s2_solver_backend, - action=s2_action, - ctx=s2_context, - reqs=reqs, - session_backend=backend, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - # Store S2 attempt - sampled_results.append(result) - sampled_scores.append(constraint_scores) - sampled_actions.append(s2_action) - sample_contexts.append(result_ctx) - - # Check S2 success - assert result._generate_log is not None - result._generate_log.is_final_result = True - - if all(bool(score[1]) for score in constraint_scores): - flog.info("SOFAI S2: SUCCESS") - return SamplingResult( - result_index=len(sampled_results) - 1, - success=True, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, - ) - else: - count_valid = sum(1 for s in constraint_scores if bool(s[1])) - flog.warning( - f"SOFAI S2: FAILED. Valid: {count_valid}/{len(constraint_scores)}. " - f"Returning S2 Solver's attempt as final result." - ) - return SamplingResult( - result_index=len(sampled_results) - 1, - success=False, - sample_generations=sampled_results, - sample_validations=sampled_scores, - sample_contexts=sample_contexts, - sample_actions=sampled_actions, - ) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index d18cfbdcf..1e3e3a672 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -25,9 +25,9 @@ Component, ComputedModelOutputThunk, Context, - FancyLogger, GenerateLog, ImageBlock, + MelleaLogger, ModelOutputThunk, Requirement, S, @@ -35,6 +35,7 @@ SamplingStrategy, ValidationResult, ) +from ..core.utils import _log_context from ..helpers import _run_async_in_thread from ..plugins.manager import has_plugins, invoke_hook from ..plugins.types import HookType @@ -191,7 +192,7 @@ def start_session( session.cleanup() ``` """ - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Get model_id string for logging and tracing if isinstance(model_id, ModelIdentifier): @@ -307,8 +308,9 @@ def __init__(self, backend: Backend, ctx: Context | None = None): self.id = str(uuid.uuid4()) self.backend = backend self.ctx: Context = ctx if ctx is not None else SimpleContext() - self._session_logger = FancyLogger.get_logger() + self._session_logger = MelleaLogger.get_logger() self._context_token = None + self._log_context_token = None self._session_span = None def __enter__(self): @@ -320,11 +322,22 @@ def __enter__(self): context_type=self.ctx.__class__.__name__, ).__enter__() self._context_token = _context_session.set(self) + self._log_context_token = _log_context.set( + { + **_log_context.get(), + "session_id": self.id, + "backend": self.backend.__class__.__name__, + "model_id": str(getattr(self.backend, "model_id", "unknown")), + } + ) return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit context manager and cleanup session.""" self.cleanup() + if self._log_context_token is not None: + _log_context.reset(self._log_context_token) + self._log_context_token = None if self._context_token is not None: _context_session.reset(self._context_token) self._context_token = None diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py index 6ee9c17bf..4c537cc32 100644 --- a/mellea/stdlib/tools/interpreter.py +++ b/mellea/stdlib/tools/interpreter.py @@ -19,9 +19,9 @@ from pathlib import Path from typing import Any -from ...core import FancyLogger +from ...core import MelleaLogger -logger = FancyLogger.get_logger() +logger = MelleaLogger.get_logger() @dataclass diff --git a/test/conftest.py b/test/conftest.py index 48b6c7d48..a904a06ba 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -6,7 +6,7 @@ import pytest import requests -from mellea.core import FancyLogger +from mellea.core import MelleaLogger # Try to import optional dependencies for system detection try: @@ -302,7 +302,8 @@ def cleanup_gpu_backend(backend, backend_name="unknown"): backend: The backend instance to clean up. backend_name: Name for logging. """ - logger = FancyLogger.get_logger() + + logger = MelleaLogger.get_logger() logger.info(f"Cleaning up {backend_name} backend GPU memory...") try: @@ -455,7 +456,7 @@ def pytest_collection_modifyitems(config, items): # Reorder tests by backend if requested if config.getoption("--group-by-backend", default=False): - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() logger.info("Grouping tests by backend (--group-by-backend enabled)") # Group items by backend @@ -517,7 +518,7 @@ def pytest_runtest_setup(item): prev_group = getattr(pytest_runtest_setup, "_last_backend_group", None) if prev_group is not None and current_group != prev_group: - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() logger.info( f"Backend transition: {prev_group} → {current_group}. " "Running GPU cleanup." @@ -527,7 +528,7 @@ def pytest_runtest_setup(item): # Warm up Ollama models when entering Ollama group if current_group == "ollama" and prev_group != "ollama": - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1") port = os.environ.get("OLLAMA_PORT", "11434") logger.info( @@ -551,7 +552,7 @@ def pytest_runtest_setup(item): # Evict Ollama models when leaving Ollama group if prev_group == "ollama" and current_group != "ollama": - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1") port = os.environ.get("OLLAMA_PORT", "11434") logger.info("Evicting ollama models from VRAM after ollama group...") @@ -613,7 +614,7 @@ def evict_ollama_models() -> None: Best-effort: errors are logged but never raised. """ - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Parse OLLAMA_HOST which may be "host", "host:port", or absent. host = os.environ.get("OLLAMA_HOST", "127.0.0.1") diff --git a/test/core/test_logger_plugin_hooks.py b/test/core/test_logger_plugin_hooks.py new file mode 100644 index 000000000..888252bcf --- /dev/null +++ b/test/core/test_logger_plugin_hooks.py @@ -0,0 +1,251 @@ +"""Integration tests verifying MelleaLogger works correctly inside plugin hook dispatch. + +Key properties verified: +- MelleaLogger.get_logger() is callable and usable from inside a hook handler. +- Log records emitted from inside a hook are captured. +- log_context fields set inside a hook appear on records from that hook. +- log_context fields set in the caller ARE visible inside AUDIT hook execution + (AUDIT hooks are awaited in the same asyncio task, so ContextVar state is inherited). +""" + +# pytest: integration + +from __future__ import annotations + +import logging +from typing import Any + +import pytest + +pytestmark = pytest.mark.integration + +pytest.importorskip("cpex.framework") + +import datetime + +# --------------------------------------------------------------------------- +# Minimal mock backend (avoids real LLM calls) +# --------------------------------------------------------------------------- +from unittest.mock import MagicMock + +from mellea.core.backend import Backend +from mellea.core.base import ( + CBlock, + Context, + GenerateLog, + GenerateType, + ModelOutputThunk, +) +from mellea.core.utils import ( + MelleaLogger, + clear_log_context, + log_context, + set_log_context, +) +from mellea.plugins import PluginMode, hook, register +from mellea.plugins.manager import shutdown_plugins +from mellea.stdlib.context import SimpleContext + + +class _MockBackend(Backend): + model_id = "mock-model" + + def __init__(self, *args, **kwargs): + pass + + async def _generate_from_context(self, action, ctx, **kwargs): + mot = MagicMock(spec=ModelOutputThunk) + glog = GenerateLog() + glog.prompt = "mocked prompt" + mot._generate_log = glog + mot.parsed_repr = None + mot._start = datetime.datetime.now() + + async def _avalue(): + return "mocked output" + + mot.avalue = _avalue + mot.value = "mocked output" + return mot, SimpleContext() + + async def generate_from_raw(self, actions, ctx, **kwargs): + return [] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +async def reset_plugins(): + """Shut down and reset the plugin manager after every test.""" + yield + await shutdown_plugins() + + +@pytest.fixture(autouse=True) +def reset_log_context(): + """Ensure log context is clean before and after each test.""" + clear_log_context() + yield + clear_log_context() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMelleaLoggerInHooks: + async def test_mellea_logger_callable_from_hook(self, caplog) -> None: + """MelleaLogger.get_logger() is usable inside a hook handler without error.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + fired: list[bool] = [] + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def log_hook(payload: Any, ctx: Any) -> None: + logger = MelleaLogger.get_logger() + logger.info("hook fired from MelleaLogger") + fired.append(True) + + register(log_hook) + + with caplog.at_level(logging.INFO, logger="fancy_logger"): + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert fired, "Hook did not fire" + assert any("hook fired from MelleaLogger" in r.message for r in caplog.records) + + async def test_log_context_set_inside_hook_appears_on_hook_records( + self, caplog + ) -> None: + """Context fields set inside a hook appear on records emitted in that hook.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + hook_records: list[logging.LogRecord] = [] + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def context_hook(payload: Any, ctx: Any) -> None: + with log_context(hook_trace_id="hook-abc"): + logger = MelleaLogger.get_logger() + # Emit via a plain handler so we can capture the LogRecord + record = logger.makeRecord( + name="fancy_logger", + level=logging.INFO, + fn="test", + lno=0, + msg="inside hook", + args=(), + exc_info=None, + ) + # Apply the context filter manually (as the logger would) + from mellea.core.utils import ContextFilter + + ContextFilter().filter(record) + hook_records.append(record) + + register(context_hook) + + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert hook_records, "Hook did not produce records" + assert getattr(hook_records[0], "hook_trace_id", None) == "hook-abc" + + async def test_log_context_is_visible_inside_hook(self) -> None: + """ContextVar state set in the caller IS visible inside hook execution. + + AUDIT hooks are awaited in the same asyncio task as the caller, so they + inherit the caller's ContextVar copy. This is the documented behaviour: + log_context fields set around a strategy.sample() call will appear on + records emitted inside hook handlers too. + """ + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + hook_records: list[logging.LogRecord] = [] + + set_log_context(outer_field="visible-in-hook") + + @hook("sampling_loop_start", mode=PluginMode.AUDIT) + async def visibility_hook(payload: Any, ctx: Any) -> None: + from mellea.core.utils import ContextFilter + + logger = MelleaLogger.get_logger() + # Create a log record to test context visibility + record = logger.makeRecord( + name="fancy_logger", + level=logging.INFO, + fn="test", + lno=0, + msg="testing context visibility", + args=(), + exc_info=None, + ) + # Apply the context filter to populate context fields + ContextFilter().filter(record) + hook_records.append(record) + + register(visibility_hook) + + await RejectionSamplingStrategy(loop_budget=1).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + assert hook_records, "Hook did not fire" + assert getattr(hook_records[0], "outer_field", None) == "visible-in-hook", ( + "log_context fields should be visible inside AUDIT hooks (same asyncio task)" + ) + + async def test_sampling_log_context_fields_present_on_success_record( + self, caplog + ) -> None: + """strategy and loop_budget context fields appear on the SUCCESS log record.""" + from mellea.stdlib.components import Instruction + from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + with caplog.at_level(logging.INFO, logger="fancy_logger"): + await RejectionSamplingStrategy(loop_budget=2).sample( + Instruction("test"), + context=SimpleContext(), + backend=_MockBackend(), + requirements=[], + format=None, + model_options=None, + tool_calls=False, + show_progress=False, + ) + + success_records = [r for r in caplog.records if r.getMessage() == "SUCCESS"] + assert success_records, "No SUCCESS record found" + record = success_records[0] + assert getattr(record, "strategy", None) == "RejectionSamplingStrategy" + assert getattr(record, "loop_budget", None) == 2 diff --git a/test/core/test_utils_logging.py b/test/core/test_utils_logging.py new file mode 100644 index 000000000..c4f5929d8 --- /dev/null +++ b/test/core/test_utils_logging.py @@ -0,0 +1,509 @@ +"""Unit tests for MelleaLogger, JsonFormatter, and ContextFilter enhancements.""" + +# pytest: unit + +import asyncio +import json +import logging +import threading +from typing import Any + +import pytest + +pytestmark = pytest.mark.unit + +from mellea.core.utils import ( + RESERVED_LOG_RECORD_ATTRS, + ContextFilter, + JsonFormatter, + MelleaLogger, + clear_log_context, + log_context, + set_log_context, +) + + +def _make_record(msg: str = "hello", level: int = logging.INFO) -> logging.LogRecord: + record = logging.LogRecord( + name="test", + level=level, + pathname="test_utils_logging.py", + lineno=1, + msg=msg, + args=(), + exc_info=None, + ) + return record + + +class TestJsonFormatterCoreSchema: + def test_returns_valid_json_string(self) -> None: + fmt = JsonFormatter() + output = fmt.format(_make_record("hi")) + parsed = json.loads(output) + assert isinstance(parsed, dict) + + def test_all_default_fields_present(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record("hi"))) + for field in JsonFormatter._DEFAULT_FIELDS: + assert field in parsed, f"missing field: {field}" + + def test_message_content(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record("test message"))) + assert parsed["message"] == "test message" + + def test_level_name(self) -> None: + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record(level=logging.WARNING))) + assert parsed["level"] == "WARNING" + + def test_exception_field_added_when_exc_info(self) -> None: + fmt = JsonFormatter() + try: + raise ValueError("boom") + except ValueError: + import sys + + record = _make_record("oops") + record.exc_info = sys.exc_info() + parsed = json.loads(fmt.format(record)) + assert "exception" in parsed + assert "ValueError" in parsed["exception"] + + +class TestJsonFormatterFieldConfig: + def test_include_fields_limits_output(self) -> None: + fmt = JsonFormatter(include_fields=["timestamp", "level", "message"]) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == {"timestamp", "level", "message"} + + def test_exclude_fields_removes_keys(self) -> None: + fmt = JsonFormatter(exclude_fields=["process_id", "thread_id"]) + parsed = json.loads(fmt.format(_make_record())) + assert "process_id" not in parsed + assert "thread_id" not in parsed + assert "level" in parsed # other fields still present + + def test_extra_fields_merged(self) -> None: + fmt = JsonFormatter(extra_fields={"service": "mellea", "env": "test"}) + parsed = json.loads(fmt.format(_make_record())) + assert parsed["service"] == "mellea" + assert parsed["env"] == "test" + + def test_extra_fields_override_core(self) -> None: + # static extras come *after* core fields — they win on collision + fmt = JsonFormatter(extra_fields={"level": "OVERRIDDEN"}) + parsed = json.loads(fmt.format(_make_record(level=logging.DEBUG))) + assert parsed["level"] == "OVERRIDDEN" + + def test_timestamp_format_respected(self) -> None: + fmt = JsonFormatter(timestamp_format="%Y") + parsed = json.loads(fmt.format(_make_record())) + # Should be just a 4-digit year + assert len(parsed["timestamp"]) == 4 + assert parsed["timestamp"].isdigit() + + +class TestJsonFormatterContextInjection: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_context_fields_appear_in_output(self) -> None: + set_log_context(trace_id="abc-123") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed.get("trace_id") == "abc-123" + + def test_multiple_context_fields(self) -> None: + set_log_context(trace_id="t1", request_id="r1", user="alice") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "t1" + assert parsed["request_id"] == "r1" + assert parsed["user"] == "alice" + + def test_clear_context_removes_fields(self) -> None: + set_log_context(trace_id="gone") + clear_log_context() + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_context_is_thread_local(self) -> None: + """Fields set in one thread must not bleed into another.""" + results: dict[str, Any] = {} + barrier = threading.Barrier(2) + + def worker_a() -> None: + set_log_context(trace_id="thread-a") + barrier.wait() # both threads read context at the same time + fmt = JsonFormatter() + results["a"] = json.loads(fmt.format(_make_record())) + clear_log_context() + + def worker_b() -> None: + # does NOT call set_log_context + barrier.wait() # both threads read context at the same time + fmt = JsonFormatter() + results["b"] = json.loads(fmt.format(_make_record())) + + ta = threading.Thread(target=worker_a) + tb = threading.Thread(target=worker_b) + ta.start() + tb.start() + ta.join() + tb.join() + + assert results["a"].get("trace_id") == "thread-a" + assert "trace_id" not in results["b"] + + +class TestContextFilter: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_filter_always_returns_true(self) -> None: + f = ContextFilter() + assert f.filter(_make_record()) is True + + def test_filter_attaches_context_fields_to_record(self) -> None: + set_log_context(span_id="span-999") + f = ContextFilter() + record = _make_record() + f.filter(record) + assert getattr(record, "span_id", None) == "span-999" + + def test_filter_noop_when_no_context(self) -> None: + f = ContextFilter() + record = _make_record() + f.filter(record) + assert not hasattr(record, "trace_id") + + +@pytest.mark.unit +class TestMelleaLoggerLogLevel: + def _reset(self) -> None: + MelleaLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def teardown_method(self) -> None: + self._reset() + + def test_default_level_is_info(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("MELLEA_LOG_LEVEL", raising=False) + assert MelleaLogger._resolve_log_level() == logging.INFO + + def test_mellea_log_level_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "DEBUG") + assert MelleaLogger._resolve_log_level() == logging.DEBUG + + def test_mellea_log_level_warning(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "WARNING") + assert MelleaLogger._resolve_log_level() == logging.WARNING + + def test_invalid_level_falls_back_to_info( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_LEVEL", "BOGUS") + assert MelleaLogger._resolve_log_level() == logging.INFO + + +@pytest.mark.unit +class TestMelleaLoggerJsonConsole: + def _reset(self) -> None: + MelleaLogger.logger = None + logger = logging.getLogger("fancy_logger") + logger.handlers.clear() + logger.setLevel(logging.NOTSET) + + def setup_method(self) -> None: + self._reset() + + def teardown_method(self) -> None: + self._reset() + + def _get_stream_handler(self) -> logging.StreamHandler: # type: ignore[type-arg] + logger = MelleaLogger.get_logger() + handlers = [h for h in logger.handlers if isinstance(h, logging.StreamHandler)] + # RESTHandler is a subclass of Handler but not StreamHandler, so this + # correctly picks the console handler. + return handlers[0] + + def test_default_uses_custom_formatter( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from mellea.core.utils import CustomFormatter + + monkeypatch.delenv("MELLEA_LOG_JSON", raising=False) + handler = self._get_stream_handler() + assert isinstance(handler.formatter, CustomFormatter) + + def test_json_console_enabled_with_true( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "true") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_enabled_with_1(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "1") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_enabled_with_yes( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("MELLEA_LOG_JSON", "yes") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, JsonFormatter) + + def test_json_console_disabled_with_false( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from mellea.core.utils import CustomFormatter + + monkeypatch.setenv("MELLEA_LOG_JSON", "false") + handler = self._get_stream_handler() + assert isinstance(handler.formatter, CustomFormatter) + + +@pytest.mark.unit +class TestMelleaLoggerContextFilterWired: + def setup_method(self) -> None: + MelleaLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + clear_log_context() + + def teardown_method(self) -> None: + MelleaLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + clear_log_context() + + def test_context_filter_present(self) -> None: + logger = MelleaLogger.get_logger() + assert any(isinstance(f, ContextFilter) for f in logger.filters) + + +class TestLogContext: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_fields_present_inside_block(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="ctx-1"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "ctx-1" + + def test_fields_removed_after_block(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="ctx-2"): + pass + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_cleanup_on_exception(self) -> None: + fmt = JsonFormatter() + with pytest.raises(RuntimeError): + with log_context(trace_id="ctx-err"): + raise RuntimeError("boom") + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_nested_contexts_preserve_outer(self) -> None: + fmt = JsonFormatter() + with log_context(outer="yes"): + with log_context(inner="yes"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["outer"] == "yes" + assert parsed["inner"] == "yes" + # inner should be gone, outer still present + parsed = json.loads(fmt.format(_make_record())) + assert parsed["outer"] == "yes" + assert "inner" not in parsed + # both gone + parsed = json.loads(fmt.format(_make_record())) + assert "outer" not in parsed + + def test_nested_same_key_restores_outer(self) -> None: + fmt = JsonFormatter() + with log_context(trace_id="outer"): + with log_context(trace_id="inner"): + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "inner" + parsed = json.loads(fmt.format(_make_record())) + assert parsed["trace_id"] == "outer" + parsed = json.loads(fmt.format(_make_record())) + assert "trace_id" not in parsed + + def test_rejects_reserved_attribute(self) -> None: + with pytest.raises(ValueError, match="reserved"): + with log_context(levelname="BAD"): + pass + + +class TestLogContextAsyncIsolation: + """Verify that concurrent asyncio tasks cannot contaminate each other's context.""" + + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_concurrent_tasks_isolated(self) -> None: + """Fields set inside one asyncio.Task must not bleed into a sibling task.""" + fmt = JsonFormatter() + results: dict[str, Any] = {} + + async def task_a() -> None: + with log_context(trace_id="task-a"): + # Yield so task_b can run and attempt to overwrite the context + await asyncio.sleep(0) + results["a"] = json.loads(fmt.format(_make_record())) + + async def task_b() -> None: + with log_context(trace_id="task-b"): + await asyncio.sleep(0) + results["b"] = json.loads(fmt.format(_make_record())) + + async def run() -> None: + await asyncio.gather( + asyncio.create_task(task_a()), asyncio.create_task(task_b()) + ) + + asyncio.run(run()) + + assert results["a"].get("trace_id") == "task-a" + assert results["b"].get("trace_id") == "task-b" + + def test_task_context_does_not_leak_after_completion(self) -> None: + """A task's context fields must not persist into the caller after the task ends.""" + fmt = JsonFormatter() + + async def child() -> None: + set_log_context(trace_id="child-task") + + async def run() -> dict[str, object]: + await asyncio.create_task(child()) + # The caller's context should be unaffected + return json.loads(fmt.format(_make_record())) + + parsed = asyncio.run(run()) + assert "trace_id" not in parsed + + +class TestReservedAttributeValidation: + def setup_method(self) -> None: + clear_log_context() + + def teardown_method(self) -> None: + clear_log_context() + + def test_set_log_context_rejects_reserved_key(self) -> None: + with pytest.raises(ValueError, match="reserved"): + set_log_context(module="bad") + + def test_set_log_context_rejects_multiple_reserved(self) -> None: + with pytest.raises(ValueError, match="reserved"): + set_log_context(thread="x", process="y") + + def test_set_log_context_accepts_non_reserved(self) -> None: + set_log_context(custom_field="fine") + fmt = JsonFormatter() + parsed = json.loads(fmt.format(_make_record())) + assert parsed["custom_field"] == "fine" + + def test_reserved_set_is_non_empty(self) -> None: + assert len(RESERVED_LOG_RECORD_ATTRS) > 10 + + +@pytest.mark.unit +class TestGetLoggerThreadSafety: + def setup_method(self) -> None: + MelleaLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def teardown_method(self) -> None: + MelleaLogger.logger = None + logging.getLogger("fancy_logger").handlers.clear() + logging.getLogger("fancy_logger").setLevel(logging.NOTSET) + + def test_concurrent_get_logger_returns_same_instance(self) -> None: + """Multiple threads calling get_logger() must all get the same object.""" + results: list[logging.Logger] = [] + barrier = threading.Barrier(4) + + def worker() -> None: + barrier.wait() + results.append(MelleaLogger.get_logger()) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(results) == 4 + assert all(r is results[0] for r in results) + + +class TestJsonFormatterFormatSignature: + def test_format_returns_str(self) -> None: + """JsonFormatter.format returns str — no type: ignore needed.""" + fmt = JsonFormatter() + result = fmt.format(_make_record("check")) + assert isinstance(result, str) + json.loads(result) # also valid JSON + + +class TestFormatAsDict: + def test_format_as_dict_returns_same_as_internal(self) -> None: + fmt = JsonFormatter() + record = _make_record("hi") + assert fmt.format_as_dict(record) == fmt._build_log_dict(record) + + def test_format_as_dict_malformed_args_does_not_raise(self) -> None: + fmt = JsonFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="x.py", + lineno=1, + msg="val: %s %s", + args=("only_one",), + exc_info=None, + ) + result = fmt.format_as_dict(record) + assert "message" in result + assert "format error" in result["message"] + + +class TestIncludeFieldsValidation: + def test_valid_include_fields_accepted(self) -> None: + fmt = JsonFormatter(include_fields=["timestamp", "level", "message"]) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == {"timestamp", "level", "message"} + + def test_unknown_include_field_raises(self) -> None: + with pytest.raises(ValueError, match="unknown field names"): + JsonFormatter(include_fields=["timestamp", "bogus_field"]) + + def test_all_default_fields_accepted(self) -> None: + fmt = JsonFormatter(include_fields=list(JsonFormatter._DEFAULT_FIELDS)) + parsed = json.loads(fmt.format(_make_record())) + assert set(parsed.keys()) == set(JsonFormatter._DEFAULT_FIELDS) diff --git a/test/telemetry/test_logging.py b/test/telemetry/test_logging.py index 595def8c0..b8cddbce8 100644 --- a/test/telemetry/test_logging.py +++ b/test/telemetry/test_logging.py @@ -27,14 +27,14 @@ def _reset_logging_modules(): import mellea.core.utils import mellea.telemetry.logging - from mellea.core.utils import FancyLogger + from mellea.core.utils import MelleaLogger # Clear any existing handlers from previous tests fancy_logger = logging.getLogger("fancy_logger") fancy_logger.handlers.clear() - # Reset FancyLogger singleton - FancyLogger.logger = None + # Reset MelleaLogger singleton + MelleaLogger.logger = None # Force reload of logging module and core.utils to pick up env vars importlib.reload(mellea.telemetry.logging) @@ -160,28 +160,28 @@ def test_get_otlp_log_handler_can_be_added_to_logger(enable_otlp_logging): logger.removeHandler(handler) -# FancyLogger Integration Tests +# MelleaLogger Integration Tests def test_fancy_logger_includes_otlp_handler_when_enabled(enable_otlp_logging): - """Test that FancyLogger includes OTLP handler when enabled.""" - from mellea.core.utils import FancyLogger + """Test that MelleaLogger includes OTLP handler when enabled.""" + from mellea.core.utils import MelleaLogger - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Check that logger has handlers assert len(logger.handlers) > 0 # Check if any handler is a LoggingHandler (OTLP) has_otlp_handler = any(isinstance(h, LoggingHandler) for h in logger.handlers) # type: ignore - assert has_otlp_handler, "FancyLogger should have OTLP handler when enabled" + assert has_otlp_handler, "MelleaLogger should have OTLP handler when enabled" def test_fancy_logger_works_without_otlp(clean_logging_env): - """Test that FancyLogger works normally when OTLP is disabled.""" - from mellea.core.utils import FancyLogger + """Test that MelleaLogger works normally when OTLP is disabled.""" + from mellea.core.utils import MelleaLogger - logger = FancyLogger.get_logger() + logger = MelleaLogger.get_logger() # Should still have REST and console handlers assert len(logger.handlers) >= 2 @@ -189,7 +189,7 @@ def test_fancy_logger_works_without_otlp(clean_logging_env): # Should not have OTLP handler has_otlp_handler = any(isinstance(h, LoggingHandler) for h in logger.handlers) # type: ignore assert not has_otlp_handler, ( - "FancyLogger should not have OTLP handler when disabled" + "MelleaLogger should not have OTLP handler when disabled" ) # Verify logger can log messages (backward compatibility)