diff --git a/docs/examples/context/README.md b/docs/examples/context/README.md index dde027bc5..e7b8b3752 100644 --- a/docs/examples/context/README.md +++ b/docs/examples/context/README.md @@ -1,13 +1,15 @@ # Context Examples -This directory contains examples demonstrating how to work with Mellea's context system, particularly when using sampling strategies and validation. +This directory contains examples demonstrating how to work with Mellea's context system: inspecting per-attempt contexts produced by sampling strategies, and shrinking contexts with the `Compactor` protocol. ## Files ### contexts_with_sampling.py + Shows how to retrieve and inspect context information when using sampling strategies and validation. **Key Features:** + - Using `RejectionSamplingStrategy` with requirements - Accessing `SamplingResult` objects to inspect generation attempts - Retrieving context for different generation attempts @@ -15,10 +17,34 @@ Shows how to retrieve and inspect context information when using sampling strate - Understanding the context tree structure **Usage:** -```bash + +``` python docs/examples/context/contexts_with_sampling.py ``` +### window_compactor.py + +`WindowCompactor` — opt-in by passing `compactor=` (or the `window_size=` sugar). Demonstrates system-prefix pinning, `pin_system_and_initial_user`, `pin_nothing` (pure last-N), and `size=0` to clear the body. + +### threshold_compactor.py + +`ThresholdCompactor` — gate an inner compactor on the conversation's cumulative token size. The reading is taken from the most recent `ModelOutputThunk`'s `total_tokens`, which for a chat backend equals `prompt_tokens` (full conversation history sent to the model) + `completion_tokens` (reply). The gate fires once the running conversation size crosses the threshold; once compaction shrinks the context, the next call produces a smaller reading and the gate closes again. + +### custom_compactor.py + +Implement the `Compactor` protocol with a plain class (no inheritance). Shows Pattern 1 (wired into `ChatContext`) and Pattern 2 (manual `compact()` call). + +### react_compaction.py + +Compose the ReACT loop with a sync `Compactor`. Two integration points: + +- **Per-add** — wire a `Compactor` onto the `ChatContext` so it runs every time `react()` appends a Message, ToolMessage, or thunk. +- **Per-turn** — pass `compactor=` to `react()`; it fires once per ReACT iteration after the tool observation. + +`LLMSummarizeCompactor` is also a sync `Compactor` — it hides the async backend call internally (worker thread when called from an already-running event loop) so callers don't have to think about sync vs async. + +Use `pin_react_initiator` (from `mellea.stdlib.components.react`) as the predicate so the goal and tool registration survive compaction. + ## Concepts Demonstrated - **Sampling Results**: Working with `SamplingResult` objects @@ -26,6 +52,8 @@ python docs/examples/context/contexts_with_sampling.py - **Multiple Attempts**: Examining different generation attempts - **Context Trees**: Understanding how contexts link together - **Validation Context**: Inspecting how requirements were evaluated +- **Compaction Protocol**: Sync `Compactor` for per-`add()` shrinking +- **Pin Predicates**: Auto-protect leading system messages or the user's initial prompt during compaction ## Key APIs @@ -48,8 +76,23 @@ gen_ctx.previous_node.node_data val_ctx.node_data ``` +```python +# Wire a compactor into a ChatContext (Pattern 1 — runs on every add()) +from mellea.stdlib.context import ChatContext, WindowCompactor, ThresholdCompactor + +ctx = ChatContext(compactor=WindowCompactor(size=5)) # default: pin_system +ctx = ChatContext(window_size=5) # sugar for the line above +ctx = ChatContext( + compactor=ThresholdCompactor(WindowCompactor(size=5), threshold=8000), +) + +# Manual compaction (Pattern 2) +ctx = WindowCompactor(size=0).compact(ctx) # drop body, keep pinned prefix +``` + ## Related Documentation -- See `mellea/stdlib/context.py` for context implementation +- See `mellea/stdlib/context/` for context and compactor implementations - See `mellea/stdlib/sampling/` for sampling strategies -- See `docs/dev/spans.md` for context architecture details \ No newline at end of file +- See `mellea/stdlib/frameworks/react.py` for the ReACT loop +- See `docs/dev/spans.md` for context architecture details diff --git a/docs/examples/context/custom_compactor.py b/docs/examples/context/custom_compactor.py new file mode 100644 index 000000000..663f21b4a --- /dev/null +++ b/docs/examples/context/custom_compactor.py @@ -0,0 +1,63 @@ +# pytest: unit +"""Implementing the Compactor protocol — anything with ``compact()`` works. + +The protocol is structurally typed: a class with a ``compact(ctx, *, +backend=None) -> ChatContext`` method is a valid Compactor. No +inheritance is required. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, Compactor +from mellea.stdlib.context.chat import _rebuild_chat_context + + +class TruncateOldest: + """Drop only the very first body component each call. + + Demonstrates the smallest possible Compactor implementation. Pattern + 1 (wired into ``ChatContext``) means each ``add()`` removes the + oldest item then appends — net result: the context never grows. + """ + + def compact(self, ctx, *, backend=None): + items = ctx.as_list() + if len(items) <= 1: + return ctx + return _rebuild_chat_context(items[1:], compactor=ctx._compactor) + + +def pattern_1_wired_into_context(): + """Pattern 1: compactor lives on the context, runs in ``add()``.""" + ctx = ChatContext(compactor=TruncateOldest()) + for i in range(4): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3'] (oldest dropped before each append) + + +def pattern_2_manual_call(): + """Pattern 2: caller invokes ``compact()`` directly between turns.""" + ctx = ChatContext(window_size=10_000) # permissive — no auto-compaction + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + truncated = TruncateOldest().compact(ctx) + return [m.content for m in truncated.as_list()] + + +def structural_typing_check(): + """The Compactor protocol is satisfied structurally, no inheritance.""" + c: Compactor = TruncateOldest() # mypy-checked Protocol assignment + return type(c).__name__ + + +if __name__ == "__main__": + for fn in [pattern_1_wired_into_context, pattern_2_manual_call]: + print(f"--- {fn.__name__} ---") + print(fn()) + print(f"structural typing: {structural_typing_check()} satisfies Compactor") + + +def test_custom_compactor_examples(): + assert pattern_1_wired_into_context() == ["msg 3"] + assert pattern_2_manual_call() == ["msg 1", "msg 2", "msg 3", "msg 4"] + assert structural_typing_check() == "TruncateOldest" diff --git a/docs/examples/context/react_compaction.py b/docs/examples/context/react_compaction.py new file mode 100644 index 000000000..06c6c76ec --- /dev/null +++ b/docs/examples/context/react_compaction.py @@ -0,0 +1,237 @@ +# pytest: unit +"""Compose the ReACT loop with a sync `Compactor`. + +Two integration points are available, and they're complementary: + +1. **Per-add** — the `ChatContext`'s own compactor runs every time the + ReACT loop appends a Message, ToolMessage, or thunk. This is fine + for cheap strategies like `WindowCompactor`. +2. **Per-turn** — pass `compactor=` to ``react(...)`` to invoke a + compactor once per ReACT iteration after the tool observation. Use + it for heavier strategies that should fire at turn boundaries + instead of on every component append. + +In both cases use ``pin_react_initiator`` (from +``mellea.stdlib.components.react``) so the goal and tool registration +survive compaction. + +This example exercises the wiring end-to-end against a fake backend so +no LLM is required. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from dataclasses import dataclass + +from mellea.backends.tools import MelleaTool +from mellea.core.backend import Backend, BaseModelSubclass +from mellea.core.base import ( + C, + CBlock, + Component, + Context, + GenerateLog, + ModelOutputThunk, + ModelToolCall, +) +from mellea.stdlib.components.react import ( + MELLEA_FINALIZER_TOOL, + ReactInitiator, + _mellea_finalize_tool, + pin_react_initiator, +) +from mellea.stdlib.context import ChatContext, WindowCompactor +from mellea.stdlib.frameworks.react import react + +# --------------------------------------------------------------------------- # +# Fake backend so the example runs without an LLM # +# --------------------------------------------------------------------------- # + + +@dataclass +class _ScriptedTurn: + value: str + tool_calls: dict[str, ModelToolCall] | None = None + + +class ScriptedBackend(Backend): + """Returns pre-scripted responses; no real model is called.""" + + def __init__(self, script: list[_ScriptedTurn]) -> None: + self._script = iter(script) + + async def _generate_from_context( + self, + action: Component[C] | CBlock, + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk[C], Context]: + turn = next(self._script) + mot: ModelOutputThunk = ModelOutputThunk( + value=turn.value, tool_calls=turn.tool_calls + ) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions: Sequence[Component[C] | CBlock], + ctx: Context, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +def _tool(name: str, return_value: str = "ok") -> MelleaTool: + def _fn() -> str: + return return_value + + return MelleaTool.from_callable(_fn, name=name) + + +def _tool_call(tool_name: str, tool: MelleaTool, thought: str) -> _ScriptedTurn: + tc = ModelToolCall(name=tool_name, func=tool, args={}) + return _ScriptedTurn(value=thought, tool_calls={tool_name: tc}) + + +def _final(answer: str) -> _ScriptedTurn: + finalizer = MelleaTool.from_callable(_mellea_finalize_tool, MELLEA_FINALIZER_TOOL) + tc = ModelToolCall( + name=MELLEA_FINALIZER_TOOL, func=finalizer, args={"answer": answer} + ) + return _ScriptedTurn(value="", tool_calls={MELLEA_FINALIZER_TOOL: tc}) + + +# --------------------------------------------------------------------------- # +# Pattern A — per-add compaction wired into the ChatContext # +# --------------------------------------------------------------------------- # + + +async def per_add_compaction(): + """A `WindowCompactor(pin_react_initiator)` on the ChatContext compacts + on every ``add()`` — Messages, ToolMessages, thunks. The ReactInitiator + stays pinned across the whole loop. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_react_initiator) + ) + result, ctx = await react( + goal="find info", context=ctx, backend=backend, tools=[search], loop_budget=10 + ) + return ( + result.value, + any(isinstance(c, ReactInitiator) for c in ctx.as_list()), + len(ctx.as_list()), + ) + + +# --------------------------------------------------------------------------- # +# Pattern B — per-turn compaction passed to react() # +# --------------------------------------------------------------------------- # + + +async def per_turn_compaction(): + """Pass ``compactor=`` to ``react`` for once-per-turn invocation. + + Use a permissive ``ChatContext`` (large window) so the per-add path is + effectively disabled — only the per-turn hook drives compaction. + """ + search = _tool("search") + backend = ScriptedBackend( + [ + _tool_call("search", search, "step 1"), + _tool_call("search", search, "step 2"), + _tool_call("search", search, "step 3"), + _final("done"), + ] + ) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +# --------------------------------------------------------------------------- # +# Pattern C — LLM-driven summarisation # +# --------------------------------------------------------------------------- # + + +async def llm_summarize_compaction(): + """Wire :class:`LLMSummarizeCompactor` into ``react()``. + + ``LLMSummarizeCompactor`` implements the sync :class:`Compactor` + protocol — its ``compact`` method internally orchestrates the async + backend call (running it on a worker thread when invoked from inside + an event loop). From ``react()``'s perspective it's just another + sync compactor. + + To keep the scripted backend simple, this example sets ``keep_n`` + large enough that summarisation never fires (no LLM call is needed). + Real usage would pair it with ``ThresholdCompactor`` so it only + activates once the conversation crosses a token budget. See + ``TestLLMSummarizeCompactor`` in ``test/stdlib/test_compactor.py`` for + unit tests that exercise the actual summary path. + """ + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _tool("search") + backend = ScriptedBackend([_tool_call("search", search, "step 1"), _final("done")]) + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + # keep_n=1000 → no summarisation triggers in this short script; + # the example just shows the async compactor is wired correctly. + compactor=LLMSummarizeCompactor( + default_backend=backend, keep_n=1000, pin_predicate=pin_react_initiator + ), + ) + return (result.value, any(isinstance(c, ReactInitiator) for c in ctx.as_list())) + + +if __name__ == "__main__": + print(f"per_add_compaction: {asyncio.run(per_add_compaction())}") + print(f"per_turn_compaction: {asyncio.run(per_turn_compaction())}") + print(f"llm_summarize_compact: {asyncio.run(llm_summarize_compaction())}") + + +def test_per_add_compaction(): + answer, has_initiator, _length = asyncio.run(per_add_compaction()) + assert answer == "done" + assert has_initiator + + +def test_per_turn_compaction(): + answer, has_initiator = asyncio.run(per_turn_compaction()) + assert answer == "done" + assert has_initiator + + +def test_llm_summarize_compaction(): + answer, has_initiator = asyncio.run(llm_summarize_compaction()) + assert answer == "done" + assert has_initiator diff --git a/docs/examples/context/threshold_compactor.py b/docs/examples/context/threshold_compactor.py new file mode 100644 index 000000000..120eba07c --- /dev/null +++ b/docs/examples/context/threshold_compactor.py @@ -0,0 +1,57 @@ +# pytest: unit +"""ThresholdCompactor — gate an inner Compactor on conversation size. + +Reads ``ModelOutputThunk.generation.usage`` from the most recent thunk +in the context. For a chat backend, ``total_tokens`` on that thunk is +``prompt_tokens`` (full conversation history sent to the model) plus +``completion_tokens`` (the reply), so it tracks *cumulative* context +size — not just one call's isolated tokens. The inner compactor fires +once that running size exceeds the configured threshold. +""" + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ChatContext, ThresholdCompactor, WindowCompactor + + +def _thunk(total_tokens: int) -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict (test helper).""" + mot = ModelOutputThunk(value="") + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +def below_threshold_passthrough(): + """Token usage is below threshold → inner compactor is NOT invoked.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(50)) # only 50 tokens — below 1000 + out = gated.compact(ctx) + return len(out.as_list()) # 6 (5 messages + thunk) — unchanged + + +def above_threshold_compacts(): + """Token usage exceeds threshold → inner compactor runs.""" + gated = ThresholdCompactor(WindowCompactor(size=2), threshold=1000) + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + ctx = ctx.add(_thunk(2000)) # 2000 tokens — over the gate + out = gated.compact(ctx) + return len(out.as_list()) # 2 — WindowCompactor(size=2) ran + + +if __name__ == "__main__": + print(f"below_threshold_passthrough: {below_threshold_passthrough()}") + print(f"above_threshold_compacts: {above_threshold_compacts()}") + + +def test_threshold_compactor_examples(): + assert below_threshold_passthrough() == 6 + assert above_threshold_compacts() == 2 diff --git a/docs/examples/context/window_compactor.py b/docs/examples/context/window_compactor.py new file mode 100644 index 000000000..320dd0e31 --- /dev/null +++ b/docs/examples/context/window_compactor.py @@ -0,0 +1,101 @@ +# pytest: unit +"""WindowCompactor — keep the last N body components. + +Demonstrates the default behaviour, the ``window_size=`` sugar on +``ChatContext``, and how the auto-pinned system prefix is preserved. +""" + +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + WindowCompactor, + pin_nothing, + pin_system_and_initial_user, +) + + +def basic_window(): + """``ChatContext()`` keeps the full history by default; opt in via + ``compactor=`` to start truncating. + """ + ctx = ChatContext(compactor=WindowCompactor(size=5)) + for i in range(8): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5', 'msg 6', 'msg 7'] + + +def window_size_sugar(): + """``window_size=`` is sugar for ``WindowCompactor(size=...)``.""" + ctx = ChatContext(window_size=3) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [m.content for m in ctx.as_list()] + # → ['msg 3', 'msg 4', 'msg 5'] + + +def system_prefix_pinned(): + """Default predicate ``pin_system`` keeps a leading system message.""" + ctx = ChatContext(window_size=3) + ctx = ctx.add(Message("system", "You are a helpful assistant.")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + # → [('system', '...'), ('user', 'msg 3'), ('user', 'msg 4'), ('user', 'msg 5')] + + +def pin_initial_user_too(): + """Use ``pin_system_and_initial_user`` to also keep the user's first turn.""" + ctx = ChatContext( + compactor=WindowCompactor(size=3, pin_predicate=pin_system_and_initial_user) + ) + ctx = ctx.add(Message("system", "You are helpful.")) + ctx = ctx.add(Message("user", "What is the capital of France?")) + for i in range(6): + ctx = ctx.add(Message("assistant", f"reply {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def pure_last_n(): + """``pin_nothing`` disables prefix pinning — the system message is dropped.""" + ctx = ChatContext(compactor=WindowCompactor(size=3, pin_predicate=pin_nothing)) + ctx = ctx.add(Message("system", "ignored after a few turns")) + for i in range(6): + ctx = ctx.add(Message("user", f"msg {i}")) + return [(m.role, m.content) for m in ctx.as_list()] + + +def clear_body_keep_prefix(): + """``size=0`` drops the body entirely while keeping the pinned prefix.""" + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message("system", "You are helpful.")) + for i in range(5): + ctx = ctx.add(Message("user", f"msg {i}")) + cleared = WindowCompactor(size=0).compact(ctx) + return [(m.role, m.content) for m in cleared.as_list()] + # → [('system', 'You are helpful.')] + + +if __name__ == "__main__": + for fn in [ + basic_window, + window_size_sugar, + system_prefix_pinned, + pin_initial_user_too, + pure_last_n, + clear_body_keep_prefix, + ]: + print(f"--- {fn.__name__} ---") + print(fn()) + + +def test_window_compactor_examples(): + """Smoke test all examples — invariants documented in each docstring.""" + assert basic_window() == ["msg 3", "msg 4", "msg 5", "msg 6", "msg 7"] + assert window_size_sugar() == ["msg 3", "msg 4", "msg 5"] + assert system_prefix_pinned()[0] == ("system", "You are a helpful assistant.") + pinned = pin_initial_user_too() + assert pinned[0] == ("system", "You are helpful.") + assert pinned[1] == ("user", "What is the capital of France?") + assert all(role == "user" for role, _ in pure_last_n()) + assert clear_body_keep_prefix() == [("system", "You are helpful.")] diff --git a/mellea/plugins/hooks/session.py b/mellea/plugins/hooks/session.py index 6933461ff..d77138294 100644 --- a/mellea/plugins/hooks/session.py +++ b/mellea/plugins/hooks/session.py @@ -57,7 +57,10 @@ class SessionCleanupPayload(MelleaBasePayload): Attributes: context: The `Context` at the time of cleanup (observe-only). - interaction_count: Number of items in the context at cleanup time. + interaction_count: Number of model-interaction turns committed during + the session (each ``self.ctx = ...`` assignment in ``MelleaSession`` + counts as one). Reset to 0 by ``MelleaSession.reset()``. Stable + under any context-compaction strategy. """ context: Any = None diff --git a/mellea/stdlib/components/react.py b/mellea/stdlib/components/react.py index b94e61ab6..6d7473688 100644 --- a/mellea/stdlib/components/react.py +++ b/mellea/stdlib/components/react.py @@ -32,6 +32,102 @@ def _mellea_finalize_tool(answer: str) -> str: return answer +def pin_react_initiator(components: list[Component | CBlock]) -> int: + """A ``PinPredicate`` that pins everything up to and including the first ``ReactInitiator``. + + Plug it into any compactor in :mod:`mellea.stdlib.context` that takes a + ``pin_predicate`` (e.g. :class:`WindowCompactor`, + :class:`ThresholdCompactor`'s inner compactor) so the react goal and + tool registration survive compaction: + + from mellea.stdlib.context import ChatContext, WindowCompactor + from mellea.stdlib.components.react import pin_react_initiator + + ctx = ChatContext( + compactor=WindowCompactor(size=5, pin_predicate=pin_react_initiator), + ) + result, _ = await react(goal=..., context=ctx, ...) + + Returns ``0`` when no ``ReactInitiator`` is found, so a context that + has not yet been seeded with a react goal compacts as if there were + no prefix. + """ + for i, c in enumerate(components): + if isinstance(c, ReactInitiator): + return i + 1 + return 0 + + +def react_summary_prompt( + goal: str | None = None, max_tokens_hint: int | None = None +) -> str: + """Build a research-flavoured summary prompt for :class:`LLMSummarizeCompactor`. + + Returns a template with a ``{conversation}`` placeholder that + :class:`LLMSummarizeCompactor` fills in at compaction time. Pass the + react goal via ``goal=`` to anchor the summarisation around the + objective; with ``goal=None`` the ``GOAL:`` line is omitted. + + Pass ``max_tokens_hint=N`` to inject a soft length-cap bullet + ("Be at most ~N tokens") into the summarizer's instructions. The hint + is a plan-time anchor for the model — combine it with a hard + ``max_tokens`` API arg on the summarizer's LLM call to enforce. + ``max_tokens_hint=None`` (default) or non-positive values omit the + bullet, so the prompt is byte-identical to the un-hinted form. + + Curly braces in ``goal`` are escaped so :meth:`str.format` (used by the + compactor) preserves them as literal characters. + + Example:: + + from mellea.stdlib.components.react import ( + pin_react_initiator, + react_summary_prompt, + ) + from mellea.stdlib.context import LLMSummarizeCompactor + + compactor = LLMSummarizeCompactor( + default_backend=my_backend, + keep_n=5, + pin_predicate=pin_react_initiator, + prompt_template=react_summary_prompt( + goal="find papers on X", + max_tokens_hint=2000, + ), + ) + """ + if goal is not None: + # Escape braces so .format() in the compactor keeps them literal. + safe_goal = goal.replace("{", "{{").replace("}", "}}") + goal_block = f"GOAL: {safe_goal}\n\n" + else: + goal_block = "" + if max_tokens_hint is not None and max_tokens_hint > 0: + # Rough heuristic: ~0.75 words per token for English research text. + words_estimate = int(max_tokens_hint * 0.75) + length_bullet = ( + f"- Be at most ~{max_tokens_hint} tokens (roughly " + f"{words_estimate} words). Prioritize density: drop redundant " + "or ancillary detail.\n" + ) + else: + length_bullet = "" + return ( + "You are summarizing research progress to maintain context " + "within token limits.\n\n" + f"{goal_block}" + "Provide a comprehensive summary of the research context below. " + "Your summary should:\n" + "- Preserve ALL specific facts, numbers, names, URLs, and search " + "queries found\n" + "- Note which tools were called and what results were obtained\n" + "- Highlight key findings and any dead ends encountered\n" + "- Be structured clearly so the research can continue seamlessly\n" + f"{length_bullet}" + "\nContext to summarize:\n{conversation}" + ) + + class ReactInitiator(Component[str]): """`ReactInitiator` is used at the start of the ReACT loop to prime the model. diff --git a/mellea/stdlib/context.py b/mellea/stdlib/context.py deleted file mode 100644 index b8a748fab..000000000 --- a/mellea/stdlib/context.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Concrete `Context` implementations for common conversation patterns. - -Provides `ChatContext`, which accumulates all turns in a sliding-window chat history -(configurable via `window_size`), and `SimpleContext`, in which each interaction -is treated as a stateless single-turn exchange (no prior history is passed to the -model). Import `ChatContext` for multi-turn conversations and `SimpleContext` when -you want each call to the model to be independent. -""" - -from __future__ import annotations - -# Leave unused `ContextTurn` import for import ergonomics. -from ..core import CBlock, Component, Context, ContextTurn - - -class ChatContext(Context): - """Initializes a chat context with unbounded window_size and is_chat=True by default. - - Args: - window_size (int | None): Maximum number of context turns to include when - calling `view_for_generation`. `None` (the default) means the full - history is always returned. - """ - - def __init__(self, *, window_size: int | None = None): - """Initialize ChatContext with an optional sliding-window size.""" - super().__init__() - self._window_size = window_size - - def add(self, c: Component | CBlock) -> ChatContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to append. - - Returns: - ChatContext: A new `ChatContext` with the added entry, preserving the - current `window_size` setting. - """ - new = ChatContext.from_previous(self, c) - new._window_size = self._window_size - return new - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return the context entries to pass to the model, respecting the configured window. - - Uses the `window_size` set during initialisation to limit how many past - turns are included. `None` is returned when the underlying history is - non-linear. - - Returns: - list[Component | CBlock] | None: Ordered list of context entries up to - `window_size` turns, or `None` if the history is non-linear. - """ - return self.as_list(self._window_size) - - -class SimpleContext(Context): - """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - - def add(self, c: Component | CBlock) -> SimpleContext: - """Add a new component or CBlock to the context and return the updated context. - - Args: - c (Component | CBlock): The component or content block to record. - - Returns: - SimpleContext: A new `SimpleContext` containing only the added entry; - prior history is not retained. - """ - return SimpleContext.from_previous(self, c) - - def view_for_generation(self) -> list[Component | CBlock] | None: - """Return an empty list, since `SimpleContext` does not pass history to the model. - - Each call to the model is treated as a stateless, independent exchange. - No prior turns are forwarded. - - Returns: - list[Component | CBlock] | None: Always an empty list. - """ - return [] diff --git a/mellea/stdlib/context/__init__.py b/mellea/stdlib/context/__init__.py new file mode 100644 index 000000000..03c4c7aaa --- /dev/null +++ b/mellea/stdlib/context/__init__.py @@ -0,0 +1,47 @@ +"""Concrete `Context` implementations and the `Compactor` protocol. + +Provides: + +- :class:`ChatContext` — accumulates all turns in a chat history (with an + optional sliding window). +- :class:`SimpleContext` — stateless, single-turn exchange (no prior history is + passed to the model). +- :class:`Compactor` — generic protocol for shrinking any `Context` subtype. + +The names :class:`Context`, :class:`ContextTurn`, :class:`CBlock`, and +:class:`Component` are re-exported from :mod:`mellea.core` for the convenience +of callers that import them via `mellea.stdlib.context`. +""" + +from mellea.core import CBlock, Component, Context, ContextTurn +from mellea.stdlib.context.chat import ChatContext +from mellea.stdlib.context.compactor import ( + Compactor, + InlineCompactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.simple import SimpleContext + +__all__ = [ + "CBlock", + "ChatContext", + "Compactor", + "Component", + "Context", + "ContextTurn", + "InlineCompactor", + "LLMSummarizeCompactor", + "PinPredicate", + "SimpleContext", + "ThresholdCompactor", + "WindowCompactor", + "pin_nothing", + "pin_system", + "pin_system_and_initial_user", +] diff --git a/mellea/stdlib/context/chat.py b/mellea/stdlib/context/chat.py new file mode 100644 index 000000000..5549ff070 --- /dev/null +++ b/mellea/stdlib/context/chat.py @@ -0,0 +1,122 @@ +"""Chat-style context with pluggable compaction.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from mellea.core import CBlock, Component, Context + +if TYPE_CHECKING: + from mellea.stdlib.context.compactor import InlineCompactor + + +class ChatContext(Context): + """Chat context that accumulates turns and optionally compacts on each `add`. + + By default the context performs **no compaction** — the full history is + retained. Compaction is opt-in: pass `compactor=` for a custom + strategy, or `window_size=` as sugar for `WindowCompactor(size=...)`. + + Note: + Compaction is now applied at `add()` time and persists in the linked + list, so `as_list()` and `view_for_generation()` both reflect the + post-compaction history. Earlier versions kept the full history in + `as_list()` and only windowed the model-facing view, so any caller + that used `len(ctx.as_list())` as a session-wide interaction count + will now silently undercount once the compactor fires. Track turn + counts out-of-band (e.g. on the session) if you need them. + + Args: + compactor (InlineCompactor | None): The compactor invoked on every + `add`. `None` (the default) means no compaction; full history + is kept. + window_size (int | None): Sugar that constructs a + :class:`WindowCompactor`. Mutually exclusive with `compactor`. + `None` (the default) means no windowing. + """ + + def __init__( + self, + *, + compactor: InlineCompactor | None = None, + window_size: int | None = None, + ) -> None: + """Initialize a ChatContext with an optional compactor.""" + if compactor is not None and window_size is not None: + raise ValueError( + "ChatContext: pass either `compactor` or `window_size`, not both." + ) + if compactor is not None: + from mellea.stdlib.context.compactor import InlineCompactor + + if not isinstance(compactor, InlineCompactor): + raise TypeError( + f"ChatContext requires an InlineCompactor; got " + f"{type(compactor).__name__}. Wrap it in ThresholdCompactor, " + "use via react(compactor=...), or call compact(ctx, ...) " + "manually instead." + ) + super().__init__() + if compactor is None and window_size is not None: + from mellea.stdlib.context.compactor import WindowCompactor + + self._compactor: InlineCompactor | None = WindowCompactor(size=window_size) + else: + self._compactor = compactor + + def add(self, c: Component | CBlock) -> ChatContext: + """Append `c` and run the compactor; return the resulting context. + + Args: + c (Component | CBlock): The component or content block to append. + + Returns: + ChatContext: A new `ChatContext` carrying the same compactor. + """ + new = ChatContext.from_previous(self, c) + new._compactor = self._compactor + if self._compactor is not None: + new = self._compactor.compact(new) + return new + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return the components to forward to the model. + + Compaction is now applied at `add` time (Pattern 1), so this just + returns the linear history. `None` is returned when the underlying + history is non-linear. + + Returns: + list[Component | CBlock] | None: Ordered list of context entries. + """ + return self.as_list() + + +def _rebuild_chat_context( + components: list[Component | CBlock], *, compactor: InlineCompactor | None = None +) -> ChatContext: + """Build a fresh `ChatContext` linked-list without triggering compaction. + + Used by `WindowCompactor` (and any future compactors that need to rebuild + a chat history). Manual node construction sidesteps `ChatContext.add` so + compactors don't recurse during their own work. + + Args: + components: Components to materialise as the new context, in order. + compactor: Compactor to attach to every node of the rebuilt context. + + Returns: + A new `ChatContext` whose linear history is exactly `components`. + """ + ctx: ChatContext = ChatContext.__new__(ChatContext) + Context.__init__(ctx) + ctx._compactor = compactor + for c in components: + new: ChatContext = ChatContext.__new__(ChatContext) + new._previous = ctx + new._data = c + new._is_root = False + new._is_chat_context = ctx._is_chat_context + new._compactor = compactor + ctx = new + return ctx diff --git a/mellea/stdlib/context/compactor.py b/mellea/stdlib/context/compactor.py new file mode 100644 index 000000000..cfc5ca590 --- /dev/null +++ b/mellea/stdlib/context/compactor.py @@ -0,0 +1,569 @@ +"""Generic `Compactor` protocol for shrinking a `Context`. + +A `Compactor` returns a fresh, compacted copy of a context. Implementations +must never mutate the input — by convention, every alteration must produce a +new `Context` instance (the base class enforces this via `from_previous`). + +Two usage patterns are supported: + +- **Pattern 1 (in `Context.add`):** A subclass of `Context` holds a + `Compactor` and applies it whenever a new component is appended. +- **Pattern 2 (manual):** The caller invokes `compactor.compact(ctx)` + directly between turns, e.g. when compaction is exposed to the model as a + tool. + +See `docs/examples/context/` for full usage examples. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar + +from mellea.core import CBlock, Component, Context, ModelOutputThunk +from mellea.core.backend import Backend + +if TYPE_CHECKING: + from mellea.stdlib.context.chat import ChatContext + +T = TypeVar("T", bound=Context) + + +# --------------------------------------------------------------------------- # +# Pin predicates # +# --------------------------------------------------------------------------- # + +PinPredicate: TypeAlias = Callable[[list[Component | CBlock]], int] +"""A function that returns the index after the pinned prefix. + +Given the full ordered list of context components, a `PinPredicate` +returns the integer index `idx` such that `components[:idx]` is the +pinned prefix that the compactor must preserve, and `components[idx:]` +is the body that compaction acts on. + +The shape subsumes both "contiguous role-based prefix" (e.g. +:func:`pin_system`) and "find the first marker component" styles. +""" + + +def pin_nothing(components: list[Component | CBlock]) -> int: + """A :class:`PinPredicate` that pins nothing — pure body, no protected prefix.""" + return 0 + + +def pin_system(components: list[Component | CBlock]) -> int: + """Pin contiguous leading `Message(role="system")` components. + + Stops at the first non-system component. A system message that appears + later in the conversation is *not* pinned. + """ + from mellea.stdlib.components.chat import Message + + i = 0 + while i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "system": + i += 1 + else: + break + return i + + +def pin_system_and_initial_user(components: list[Component | CBlock]) -> int: + """Pin leading system messages PLUS the first user message that follows. + + Useful when the initial user prompt encodes the goal of the conversation + and should survive compaction along with any system instructions. + """ + from mellea.stdlib.components.chat import Message + + i = pin_system(components) + if i < len(components): + c = components[i] + if isinstance(c, Message) and c.role == "user": + i += 1 + return i + + +def _last_usage_tokens(ctx: Context) -> int | None: + """Return cumulative token count of the conversation as of the most recent turn. + + Walks `ctx` back-to-front looking for a `ModelOutputThunk` whose + `generation.usage` dict has been populated by a backend's + `post_processing`. Returns `total_tokens` from that thunk — which, + for a chat backend, is `prompt_tokens` (size of the full conversation + sent to the model) plus `completion_tokens` (the model's reply). It + is therefore an estimate of the *current* conversation size, not just + one call's tokens in isolation. + + Falls back to `prompt_tokens + completion_tokens` when `total_tokens` + is missing. Returns `None` if no usable token count can be recovered + (typical before the first model call completes). + """ + for c in reversed(ctx.as_list()): + if isinstance(c, ModelOutputThunk) and c.generation.usage is not None: + usage = c.generation.usage + total = usage.get("total_tokens") + if total is None: + pt = usage.get("prompt_tokens") or 0 + ct = usage.get("completion_tokens") or 0 + total = pt + ct + return total if total and total > 0 else None + return None + + +class Compactor(Protocol): + """Protocol for objects that compact a `Context` into a smaller copy. + + A compactor receives a context and returns a new context that retains only + the data the strategy considers worth keeping. Implementations MUST NOT + mutate the input context; they must return a fresh instance and copy over + any data that should be preserved. + + The protocol is generic in `T` (a `Context` subtype) so concrete + compactors can narrow their input/output type — for example a chat-only + compactor declares `T = ChatContext`. + + The protocol is sync. Compactors that need to perform a backend call + (e.g. :class:`LLMSummarizeCompactor`) hide the async work behind the sync + method internally — see that class for the strategy used. + """ + + def compact(self, ctx: T, *, backend: Backend | None = None) -> T: + """Return a compacted copy of `ctx`. + + Args: + ctx: The context to compact. Must be left unchanged. + backend: Optional backend. Generic compactors that only filter + components can ignore it. + + Returns: + A new context of the same type as `ctx` containing only the + retained data. + """ + ... + + +class InlineCompactor: + """Marker base for compactors safe to attach directly to `ChatContext`. + + A compactor is "inline-safe" when its `compact()` does not call a backend + on every `add()`. `ChatContext.add()` invokes `compact()` without a + backend argument, so any compactor wired into `ChatContext(compactor=...)` + must either avoid backend calls (e.g. :class:`WindowCompactor`) or gate + them sparsely (e.g. :class:`ThresholdCompactor`). Compactors that would + invoke the backend on every `add()` (e.g. :class:`LLMSummarizeCompactor`) + must NOT inherit this marker — use them via `react(compactor=...)` or + by calling `compact(ctx, backend=...)` manually instead. + + The marker is purely nominal: opt in by inheriting, opt out by not. Pure + structural :class:`Compactor` Protocol satisfaction is not enough. + + Subclasses must override :meth:`compact`; the base implementation raises + :class:`NotImplementedError`. Carrying the method signature here lets + `InlineCompactor` be used as a static type (`ChatContext` parameters, + `_compactor` attribute) without losing the `Compactor` contract. + """ + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Subclasses must override this with their concrete strategy.""" + raise NotImplementedError("InlineCompactor subclasses must implement compact()") + + +class WindowCompactor(InlineCompactor): + """Retains the last `size` body components of a `ChatContext`. + + Uses `pin_predicate` to decide which leading components to preserve as + a protected prefix; the size limit is then applied to the body that + remains. The total context length after compaction is + `len(prefix) + min(size, body_len)`. `size` counts only body + components. + + When the body is already at or below `size`, `ctx` is returned + unchanged so the original linked-list and `previous_node` chain are + preserved. The result carries the same `Compactor` as the input so + subsequent `add()` calls keep compacting. + + Args: + size (int): Maximum number of most-recent body components to retain. + Pinned prefix components do NOT count against this budget. + `size=0` is a special case that drops the body entirely, + keeping only the pinned prefix. Negative values raise + :class:`ValueError`. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_system`, which pins contiguous + leading `Message(role="system")` components. Pass + :func:`pin_nothing` for pure last-N behaviour or any other + `PinPredicate` (e.g. :func:`pin_system_and_initial_user`). + """ + + def __init__(self, *, size: int, pin_predicate: PinPredicate = pin_system) -> None: + """Initialize with the desired body window size and a pin predicate.""" + if size < 0: + raise ValueError("WindowCompactor size must be non-negative") + self.size = size + self.pin_predicate = pin_predicate + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a copy of `ctx` truncated to the last `size` body components. + + Args: + ctx: The chat context to compact. + backend: Unused by this strategy; accepted for protocol compatibility. + + Returns: + A new `ChatContext` whose history is the pinned prefix plus the + last `size` body components, carrying `ctx`'s compactor. + Returns `ctx` itself if no truncation is required. + """ + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body_len = len(full) - pin_end + + if body_len <= self.size: + return ctx + + from mellea.stdlib.context.chat import _rebuild_chat_context + + keep_body = full[pin_end:][-self.size :] if self.size > 0 else [] + compacted = full[:pin_end] + keep_body + return _rebuild_chat_context(compacted, compactor=ctx._compactor) + + +class ThresholdCompactor(InlineCompactor): + """Wraps an inner `Compactor`, gating it on the conversation's token size. + + Despite the suffix, this class does not compact directly — it forwards + to `inner.compact` only when the conversation has grown larger than + `threshold` tokens; otherwise the input is returned unchanged. + + The token measurement is read off the most recent `ModelOutputThunk`'s + `generation.usage` (via :func:`_last_usage_tokens`). Because chat + backends report `prompt_tokens` as the size of the full history they + were given as input, `total_tokens = prompt_tokens + completion_tokens` + on the latest thunk effectively measures *the size of the conversation + after that turn*, not just one isolated call. So the gate fires once + cumulative context size crosses `threshold`. + + Caveats: + + - Components appended *after* the last thunk (e.g. a tool response in + the same turn) are not yet reflected in the reading — there is a + one-turn lag, negligible unless a single tool call adds a very large + payload. + - When the inner compactor shrinks the context, the *next* model call + will produce a smaller `prompt_tokens`, so the gate will close + again. The threshold is not a high-water mark. + - Returns the input unchanged if no thunk with usage is found yet + (typical before the first model call completes). + + Args: + inner (Compactor): The compactor to invoke once the threshold is + exceeded. + threshold (int): Trigger the inner compactor when the conversation's + measured token size (most recent thunk's `total_tokens`) + exceeds this value. `0` or negative disables the gate (the + inner is never invoked). + """ + + def __init__(self, inner: Compactor, *, threshold: int) -> None: + """Initialize with the inner compactor and token threshold.""" + self.inner = inner + self.threshold = threshold + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Forward to `inner.compact` only when `ctx` exceeds the threshold. + + Args: + ctx: The context to potentially compact. + backend: Forwarded to the inner compactor. + + Returns: + `inner.compact(ctx, backend=backend)` when the recovered token + count exceeds `self.threshold`, otherwise `ctx` unchanged. + """ + if self.threshold <= 0: + return ctx + tokens = _last_usage_tokens(ctx) + if tokens is None or tokens <= self.threshold: + return ctx + return self.inner.compact(ctx, backend=backend) + + +_DEFAULT_SUMMARY_PROMPT = ( + "You are summarizing a conversation to maintain context within token " + "limits.\n\n" + "Provide a concise summary that:\n" + "- Preserves specific facts, numbers, names, URLs, and key data\n" + "- Notes which tools were called and what results were obtained\n" + "- Highlights key decisions, findings, and unresolved issues\n" + "- Is structured clearly so the conversation can continue seamlessly\n\n" + "Conversation to summarize:\n{conversation}" +) + + +def _run_coro_blocking(coro): # type: ignore[no-untyped-def] + """Run an awaitable to completion regardless of the calling context. + + - Outside any event loop: `asyncio.run(coro)`. + - Inside a running event loop: spawn a worker thread that runs a fresh + event loop with `asyncio.run` and block until it returns. + + Used by sync compactors that need to call async backend code (e.g. + :class:`LLMSummarizeCompactor`). + + Warning: + When called from inside a running event loop (e.g. `react()`), the + second branch above blocks the calling thread — and therefore the + loop — for the full duration of the coroutine. **Nothing else on the + loop can make progress** while the worker runs: scheduled callbacks, + telemetry flushers, cancellation signals, other sessions sharing the + loop, periodic keepalives — all are stalled. Acceptable for a + strictly serial flow like ReACT (the next iteration cannot start + until compaction finishes anyway), but unsafe if the loop has + concurrent tasks that need to keep running. + + Backends that hold *per-loop* resources may behave unexpectedly. + :class:`httpx.AsyncClient`, for instance, is bound to the event + loop on which it was created; the coroutine here runs on a fresh + loop inside a worker thread, so any async resource captured in a + closure or stored on a backend instance from the outer loop cannot + be used directly. The typical symptom is `RuntimeError: This event + loop is already running` or a hung request. + + The long-term fix is an async variant on the :class:`Compactor` + protocol so callers can `await` natively instead of bridging + through a worker thread. Until then, only invoke compactors that + need a backend from contexts where this trade-off is acceptable + (typically: inside `react`, in a manual `compact()` call between + turns, or from a synchronous script). + """ + import asyncio + import concurrent.futures + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, coro).result() + + +class LLMSummarizeCompactor: + """Replace old body components with an LLM-generated summary, keep last `keep_n` verbatim. + + Implements the sync :class:`Compactor` protocol. The compactor's body + needs to call the (async) backend; that async work is hidden inside the + sync `compact` method via :func:`_run_coro_blocking`. The pinned + prefix (chosen by `pin_predicate`) is preserved unchanged; body + components older than the last `keep_n` are flattened into a single + `Message(role="user")` whose content is a structured summary; the + last `keep_n` body components are kept verbatim. + + Default `pin_predicate` is :func:`pin_nothing`, which means the entire + conversation participates in summarisation. For react workflows pass + :func:`mellea.stdlib.components.react.pin_react_initiator` so the goal + and tool registration survive untouched. + + Note: + This class does NOT inherit :class:`InlineCompactor`, so it cannot be + passed to `ChatContext(compactor=...)` directly — that would invoke + the backend on every `add()`. Use via `react(compactor=...)`, + wrap in :class:`ThresholdCompactor` (which gates by token usage), or + call `compact(ctx, backend=...)` manually. + + Note: + Summarisation is text-only and lossy for multimodal or heavy-tool + sessions. Image and document attachments on `Message` components + are noted by count only ("[N image(s) attached]") rather than + reproduced; `ModelOutputThunk` entries that carry only tool calls + (`value is None`) render the call name and arguments. If your + application depends on faithful preservation of attachments or + full tool-call payloads across compaction, prefer + :class:`WindowCompactor` (which keeps recent components verbatim) + or implement a domain-specific :class:`Compactor`. + + Args: + default_backend (Backend): Backend used by `compact()` when the + caller does not supply one. Required: `LLMSummarizeCompactor` + cannot do its job without a backend at compaction time. A + `backend=` kwarg passed to `compact()` overrides this default + for that call only. + keep_n (int): Number of recent body components to keep verbatim. + `0` summarises everything below the prefix. + pin_predicate (PinPredicate): Function that decides the prefix + boundary. Defaults to :func:`pin_nothing`. + prompt_template (str | None): Custom summary prompt. Must contain + the literal `{conversation}` placeholder, which is filled in + with a textual rendering of the body to summarise. Defaults to + a generic conversation-summary template. + model_options (dict | None): Forwarded to `mfuncs.aact` for the + summarisation call. Use this to set a real `max_tokens` budget + (most local backends default to 256-512, which silently truncates + long summaries) or any other backend-specific knob. Note: + :func:`react_summary_prompt`'s `max_tokens_hint` adds only a + soft prompt-side nudge; pair it with `model_options={"max_tokens": N}` + for hard enforcement. + """ + + def __init__( + self, + *, + default_backend: Backend, + keep_n: int = 5, + pin_predicate: PinPredicate = pin_nothing, + prompt_template: str | None = None, + model_options: dict | None = None, + ) -> None: + """Initialize with a default backend, recent-body window, pin predicate, prompt, and model options.""" + if keep_n < 0: + raise ValueError("LLMSummarizeCompactor keep_n must be non-negative") + template = ( + prompt_template if prompt_template is not None else _DEFAULT_SUMMARY_PROMPT + ) + if "{conversation}" not in template: + raise ValueError( + "LLMSummarizeCompactor prompt_template must contain '{conversation}'" + ) + self.default_backend = default_backend + self.keep_n = keep_n + self.pin_predicate = pin_predicate + self.prompt_template = template + self.model_options = model_options + + def compact( + self, ctx: ChatContext, *, backend: Backend | None = None + ) -> ChatContext: + """Return a context with the prefix, an LLM summary, and recent body components. + + Args: + ctx: The chat context to compact. + backend: Backend used to generate the summary. When `None` the + `default_backend` set at construction is used instead. + + Returns: + A new `ChatContext` containing the prefix, a single summary + `Message` produced by the backend, and the most-recent + `keep_n` body components verbatim. Returns `ctx` unchanged + when the body is already at or below `keep_n` in length, or + when the backend call fails (see Note). + + Note: + Compaction is best-effort: if the backend call raises (rate + limit, network error, timeout, etc.) the exception is caught, a + warning is logged, and `ctx` is returned unchanged. The next + `compact()` invocation will retry. `KeyboardInterrupt` and + other `BaseException`s propagate so users can still interrupt + a stuck loop. + """ + backend = backend or self.default_backend + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + body = full[pin_end:] + if len(body) <= self.keep_n: + return ctx + + try: + return _run_coro_blocking(self._async_compact(ctx, backend)) + except Exception as exc: + from mellea.core.utils import MelleaLogger + + MelleaLogger.get_logger().warning( + "LLMSummarizeCompactor: summarisation backend call failed " + "(%s: %s); returning context unchanged. The conversation will " + "keep growing until the next successful compaction.", + type(exc).__name__, + exc, + ) + return ctx + + async def _async_compact(self, ctx: ChatContext, backend: Backend) -> ChatContext: + """Async core — renders the body, calls the backend, rebuilds the context.""" + # Lazy imports to keep this module free of mellea.stdlib.components dependencies. + from mellea.stdlib import functional as mfuncs + from mellea.stdlib.components.chat import Message, ToolMessage + from mellea.stdlib.context.chat import _rebuild_chat_context + from mellea.stdlib.context.simple import SimpleContext + + full = ctx.as_list() + pin_end = self.pin_predicate(full) + prefix = full[:pin_end] + body = full[pin_end:] + + old = body[: -self.keep_n] if self.keep_n > 0 else body + recent = body[-self.keep_n :] if self.keep_n > 0 else [] + + # Render `old` to text the LLM can consume. This is intentionally a + # text-only rendering: image and document attachments on Messages are + # noted as markers (count only) rather than reproduced, and tool-call + # arguments are stringified. The summary is lossy for multimodal and + # heavy-tool sessions by design — see class docstring. + lines: list[str] = [] + for c in old: + if isinstance(c, ToolMessage): + lines.append(f"tool ({c.name}): {c.content}") + elif isinstance(c, Message): + attachments: list[str] = [] + imgs = getattr(c, "_images", None) + if imgs: + attachments.append(f"[{len(imgs)} image(s) attached]") + docs = getattr(c, "_docs", None) + if docs: + attachments.append(f"[{len(docs)} document(s) attached]") + attached = (" " + " ".join(attachments)) if attachments else "" + lines.append(f"{c.role}: {c.content}{attached}") + elif isinstance(c, ModelOutputThunk): + if c.value: + lines.append(f"assistant: {c.value}") + elif c.tool_calls: + rendered = ", ".join( + f"{name}({dict(tc.args)})" for name, tc in c.tool_calls.items() + ) + lines.append(f"assistant called tools: {rendered}") + # else: thunk with neither value nor tool_calls is skipped — + # nothing useful to summarise and a literal "" marker + # tends to show up verbatim in the resulting summary. + elif isinstance(c, CBlock): + lines.append(str(c)) + else: + # Catch-all for `Component` subclasses that aren't `Message`/ + # `ToolMessage`/`ModelOutputThunk` (e.g. `ReactInitiator`). + # Without special handling these would render as the default + # `<… object at 0x…>` repr and the summary would lose all + # information that the entry existed at all. Emit at minimum + # the type name plus a `content` attribute if present, so + # the summariser sees a marker. + content = getattr(c, "content", None) + if content is not None: + lines.append(f"<{type(c).__name__}: {content}>") + else: + lines.append(f"<{type(c).__name__}>") + + prompt = self.prompt_template.format(conversation="\n".join(lines)) + result, _ = await mfuncs.aact( + action=Message(role="user", content=prompt), + context=SimpleContext(), + backend=backend, + requirements=[], + strategy=None, + model_options=self.model_options, + await_result=True, + # Internal framework call: silence aact's context-type warning so + # it stays quiet if the context argument is later changed to a + # non-SimpleContext. Matches react.py's pattern. + silence_context_type_warning=True, + ) + + summary_message = Message( + role="user", content=f"[CONTEXT SUMMARY]\n{result.value or ''}" + ) + compacted = [*prefix, summary_message, *recent] + return _rebuild_chat_context(compacted, compactor=ctx._compactor) diff --git a/mellea/stdlib/context/simple.py b/mellea/stdlib/context/simple.py new file mode 100644 index 000000000..6726c5d28 --- /dev/null +++ b/mellea/stdlib/context/simple.py @@ -0,0 +1,32 @@ +"""Stateless single-turn context (no history is forwarded to the model).""" + +from __future__ import annotations + +from mellea.core import CBlock, Component, Context + + +class SimpleContext(Context): + """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" + + def add(self, c: Component | CBlock) -> SimpleContext: + """Add a new component or CBlock to the context and return the updated context. + + Args: + c (Component | CBlock): The component or content block to record. + + Returns: + SimpleContext: A new `SimpleContext` containing only the added entry; + prior history is not retained. + """ + return SimpleContext.from_previous(self, c) + + def view_for_generation(self) -> list[Component | CBlock] | None: + """Return an empty list, since `SimpleContext` does not pass history to the model. + + Each call to the model is treated as a stateless, independent exchange. + No prior turns are forwarded. + + Returns: + list[Component | CBlock] | None: Always an empty list. + """ + return [] diff --git a/mellea/stdlib/frameworks/react.py b/mellea/stdlib/frameworks/react.py index 9b523be58..1ff66b354 100644 --- a/mellea/stdlib/frameworks/react.py +++ b/mellea/stdlib/frameworks/react.py @@ -21,7 +21,7 @@ ReactInitiator, ReactThought, ) -from mellea.stdlib.context import ChatContext +from mellea.stdlib.context import ChatContext, Compactor async def react( @@ -36,6 +36,7 @@ async def react( model_options: dict | None = None, tools: list[AbstractMelleaTool] | None, loop_budget: int = 10, + compactor: Compactor | None = None, ) -> tuple[ComputedModelOutputThunk[str], ChatContext]: """Asynchronous ReACT pattern (Think -> Act -> Observe -> Repeat Until Done); attempts to accomplish the provided goal given the provided tools. @@ -47,6 +48,14 @@ async def react( model_options: additional model options, which will upsert into the model/backend's defaults. tools: the list of tools to use loop_budget: the number of steps allowed; use -1 for unlimited + compactor: optional sync ``Compactor`` invoked once per turn after the + tool observation. Use this for strategies that should fire at turn + boundaries rather than on every component append (per-add + compaction is configured on ``context`` itself). Compose with + :func:`mellea.stdlib.components.react.pin_react_initiator` to + preserve the goal across compactions. Compactors that need to + call the backend (e.g. ``LLMSummarizeCompactor``) hide the async + work behind their sync ``compact`` method internally. Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. @@ -129,4 +138,8 @@ async def react( step._underlying_value = str(tool_responses[0].content) return step, context + # Per-turn compaction hook (terminal turns skip this since `is_final` returned). + if compactor is not None: + context = compactor.compact(context, backend=backend) + raise RuntimeError(f"could not complete react loop in {loop_budget} iterations") diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 4563a7549..f6d23ac02 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -258,7 +258,7 @@ class MelleaSession: id (str): Unique session UUID assigned at construction. """ - ctx: Context + # ``ctx`` is exposed as a property below; backing field is ``_ctx``. def __init__(self, backend: Backend, ctx: Context | None = None): """Initialize MelleaSession with a backend and optional conversation context.""" @@ -266,13 +266,33 @@ def __init__(self, backend: Backend, ctx: Context | None = None): self.id = str(uuid.uuid4()) self.backend = backend - self.ctx: Context = ctx if ctx is not None else SimpleContext() + # Bypass the ctx setter so the initial assignment doesn't count as an + # interaction. + self._ctx: Context = ctx if ctx is not None else SimpleContext() + self._interaction_count: int = 0 self._session_logger = MelleaLogger.get_logger() self._context_token = None self._log_context_token = None self._session_span = None self._exit_stack: contextlib.ExitStack | None = None + @property + def ctx(self) -> Context: + """The session's current conversation context.""" + return self._ctx + + @ctx.setter + def ctx(self, value: Context) -> None: + """Replace the context and count this as one interaction. + + Every model-interaction code path in this class assigns to ``self.ctx`` + with the post-interaction context, so each setter call is exactly one + interaction. Lifecycle paths that swap the context wholesale (``reset``) + write to ``self._ctx`` directly to bypass this counter. + """ + self._ctx = value + self._interaction_count += 1 + def __enter__(self): """Enter context manager and set this session as the current global session.""" # Start a session span that will last for the entire context manager lifetime @@ -365,7 +385,9 @@ def reset(self): _run_async_in_thread( invoke_hook(HookType.SESSION_RESET, payload, backend=self.backend) ) - self.ctx = self.ctx.reset_to_new() + # Bypass the setter — a reset is a lifecycle event, not an interaction. + self._ctx = self._ctx.reset_to_new() + self._interaction_count = 0 def cleanup(self) -> None: """Clean up session resources and deregister session-scoped plugins.""" @@ -373,7 +395,7 @@ def cleanup(self) -> None: from ..plugins.hooks.session import SessionCleanupPayload payload = SessionCleanupPayload( - context=self.ctx, interaction_count=len(self.ctx.as_list()) + context=self.ctx, interaction_count=self._interaction_count ) _run_async_in_thread( invoke_hook(HookType.SESSION_CLEANUP, payload, backend=self.backend) diff --git a/test/stdlib/frameworks/test_react_framework.py b/test/stdlib/frameworks/test_react_framework.py index e121a91f5..449160ce2 100644 --- a/test/stdlib/frameworks/test_react_framework.py +++ b/test/stdlib/frameworks/test_react_framework.py @@ -231,5 +231,229 @@ async def test_react_rejects_non_chat_context(): await react(goal="g", context=Mock(), backend=Mock(), tools=None) +# --- compaction integration --- + + +def test_pin_react_initiator_finds_initiator(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [ + Message("system", "sys"), + ReactInitiator("solve x", []), + Message("user", "step 1"), + ] + # Pinned prefix = system + initiator = first two indices. + assert pin_react_initiator(components) == 2 + + +def test_pin_react_initiator_returns_zero_when_absent(): + from mellea.stdlib.components.chat import Message + from mellea.stdlib.components.react import pin_react_initiator + + components = [Message("user", "a"), Message("assistant", "b")] + assert pin_react_initiator(components) == 0 + + +def test_react_summary_prompt_default(): + """Without a goal the prompt has no GOAL: line and contains {conversation}.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt() + assert "{conversation}" in prompt + assert "GOAL:" not in prompt + assert "research progress" in prompt + assert "search queries" in prompt + assert "dead ends" in prompt + + +def test_react_summary_prompt_with_goal(): + """Goal is interpolated and the prompt still has the {conversation} placeholder.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="find papers on context compaction") + assert "GOAL: find papers on context compaction" in prompt + assert "{conversation}" in prompt + + +def test_react_summary_prompt_escapes_braces_in_goal(): + """Braces in the goal must survive str.format() in LLMSummarizeCompactor.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="solve {x: 1, y: 2}") + # After str.format(conversation=...), the goal should appear with literal braces. + rendered = prompt.format(conversation="") + assert "GOAL: solve {x: 1, y: 2}" in rendered + assert "" in rendered + + +def test_react_summary_prompt_works_with_llm_summarize_compactor(): + """The factory's output passes LLMSummarizeCompactor's template validation.""" + from mellea.stdlib.components.react import react_summary_prompt + from mellea.stdlib.context import LLMSummarizeCompactor + + # Should not raise on construction (template contains {conversation}). + # Backend value is unused in this validation-only test; any non-None object + # satisfies the required default_backend kwarg. + backend = object() + LLMSummarizeCompactor( + default_backend=backend, # type: ignore[arg-type] + prompt_template=react_summary_prompt(goal="g"), + ) + LLMSummarizeCompactor( + default_backend=backend, # type: ignore[arg-type] + prompt_template=react_summary_prompt(), + ) + LLMSummarizeCompactor( + default_backend=backend, # type: ignore[arg-type] + prompt_template=react_summary_prompt(goal="g", max_tokens_hint=2000), + ) + + +def test_react_summary_prompt_max_tokens_hint_omitted_by_default(): + """Without a hint, the prompt is byte-identical to the un-hinted form.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g") + prompt_explicit_none = react_summary_prompt(goal="g", max_tokens_hint=None) + assert prompt == prompt_explicit_none + assert "Be at most" not in prompt + assert "tokens (roughly" not in prompt + + +def test_react_summary_prompt_max_tokens_hint_injects_bullet(): + """Positive hint adds a bullet with token + word estimates.""" + from mellea.stdlib.components.react import react_summary_prompt + + prompt = react_summary_prompt(goal="g", max_tokens_hint=2000) + # The bullet sits after "structured clearly" and before "Context to summarize:". + assert "- Be at most ~2000 tokens (roughly 1500 words)" in prompt + assert "Prioritize density" in prompt + # Ordering: structured-clearly bullet comes before the length bullet, + # length bullet comes before the conversation marker. + sc_idx = prompt.index("structured clearly") + bullet_idx = prompt.index("Be at most ~2000") + conv_idx = prompt.index("Context to summarize:") + assert sc_idx < bullet_idx < conv_idx + + +def test_react_summary_prompt_max_tokens_hint_zero_or_negative_omits_bullet(): + """Non-positive hint values are treated as no hint.""" + from mellea.stdlib.components.react import react_summary_prompt + + base = react_summary_prompt() + assert react_summary_prompt(max_tokens_hint=0) == base + assert react_summary_prompt(max_tokens_hint=-1) == base + + +def test_react_summary_prompt_max_tokens_hint_word_estimate_scales(): + """Word estimate uses the ~0.75 words/token heuristic (int truncation).""" + from mellea.stdlib.components.react import react_summary_prompt + + # 1000 tokens → 750 words; 4000 → 3000. + assert "~1000 tokens (roughly 750 words)" in react_summary_prompt( + max_tokens_hint=1000 + ) + assert "~4000 tokens (roughly 3000 words)" in react_summary_prompt( + max_tokens_hint=4000 + ) + + +@pytest.mark.asyncio +async def test_react_invokes_per_turn_compactor(): + """The ``compactor=`` hook runs once per turn after the tool observation.""" + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _final_answer_call("done"), + ] + ) + + calls = [] + + class RecordingCompactor: + def compact(self, ctx, *, backend=None): + calls.append(len(ctx.as_list())) + return ctx # no-op compaction; we just observe + + result, _ctx = await react( + goal="find info", + context=ChatContext(), + backend=backend, + tools=[search], + loop_budget=10, + compactor=RecordingCompactor(), + ) + + # Two non-terminal turns each invoke the compactor; the final turn skips it. + assert result.value == "done" + assert len(calls) == 2 + # Per-turn context monotonically grows in this trace. + assert calls[0] < calls[1] + + +@pytest.mark.asyncio +async def test_react_runs_llm_summarize_compactor(): + """LLMSummarizeCompactor.compact is sync (hides async internally), so react() + just calls it like any other sync Compactor. + """ + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import LLMSummarizeCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [_tool_call_turn("search", search, "step 1"), _final_answer_call("done")] + ) + + # keep_n large → no actual summarisation fires; the test verifies that + # the sync compact() method is callable from inside the async react() + # loop without exception. + result, ctx = await react( + goal="find info", + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=LLMSummarizeCompactor( + default_backend=backend, keep_n=1000, pin_predicate=pin_react_initiator + ), + ) + assert result.value == "done" + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + + +@pytest.mark.asyncio +async def test_react_compactor_can_actually_compact(): + """A real WindowCompactor wired in via the per-turn hook truncates context.""" + from mellea.stdlib.components.react import pin_react_initiator + from mellea.stdlib.context import WindowCompactor + + search = _make_tool("search", "found it") + backend = ScriptedBackend( + [ + _tool_call_turn("search", search, "step 1"), + _tool_call_turn("search", search, "step 2"), + _tool_call_turn("search", search, "step 3"), + _final_answer_call("done"), + ] + ) + + result, ctx = await react( + goal="find info", + # Permissive per-add window so we isolate the per-turn compactor's effect. + context=ChatContext(window_size=10_000), + backend=backend, + tools=[search], + loop_budget=10, + compactor=WindowCompactor(size=2, pin_predicate=pin_react_initiator), + ) + + # The ReactInitiator must survive thanks to pin_react_initiator. + assert any(isinstance(c, ReactInitiator) for c in ctx.as_list()) + assert result.value == "done" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib/test_base_context.py b/test/stdlib/test_base_context.py index 2fccb11fd..b552fbf32 100644 --- a/test/stdlib/test_base_context.py +++ b/test/stdlib/test_base_context.py @@ -4,8 +4,8 @@ from mellea.stdlib.context import ChatContext, SimpleContext -def context_construction(cls: type[Context]): - tree0 = cls() +def context_construction(cls: type[Context], **kwargs): + tree0 = cls(**kwargs) tree1 = tree0.add(CBlock("abc")) assert tree1.previous_node == tree0 @@ -15,11 +15,13 @@ def context_construction(cls: type[Context]): def test_context_construction(): context_construction(SimpleContext) + # ChatContext defaults to compactor=None (no compaction), so the linked-list + # shape is identical to the pre-compaction behaviour. context_construction(ChatContext) -def large_context_construction(cls: type[Context]): - root = cls() +def large_context_construction(cls: type[Context], **kwargs): + root = cls(**kwargs) full_graph: Context = root for i in range(1000): @@ -31,7 +33,9 @@ def large_context_construction(cls: type[Context]): def test_large_context_construction(): large_context_construction(SimpleContext) - large_context_construction(ChatContext) + # ChatContext now applies real compaction at add() time; pass a window + # large enough that all 1000 components survive. + large_context_construction(ChatContext, window_size=2000) def test_render_view_for_simple_context(): @@ -48,7 +52,9 @@ def test_render_view_for_chat_context(): ctx = ChatContext(window_size=3) for i in range(5): ctx = ctx.add(CBlock(f"a {i}")) - assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" + # Compaction is now applied at add() time, so as_list and view_for_generation + # both reflect the sliding window of 3. + assert len(ctx.as_list()) == 3, "WindowCompactor(3) should keep 3 items" assert len(ctx.view_for_generation()) == 3, "Render size should be 3" # type: ignore diff --git a/test/stdlib/test_compactor.py b/test/stdlib/test_compactor.py new file mode 100644 index 000000000..970817387 --- /dev/null +++ b/test/stdlib/test_compactor.py @@ -0,0 +1,764 @@ +"""Tests for the ``Compactor`` protocol, ``WindowCompactor``, ``ThresholdCompactor``.""" + +from __future__ import annotations + +import pytest + +from mellea.core.base import ModelOutputThunk +from mellea.stdlib.components.chat import Message +from mellea.stdlib.context import ( + ChatContext, + Compactor, + LLMSummarizeCompactor, + PinPredicate, + ThresholdCompactor, + WindowCompactor, + pin_nothing, + pin_system, + pin_system_and_initial_user, +) +from mellea.stdlib.context.compactor import _last_usage_tokens + + +def _msg(i: int) -> Message: + return Message(role="user", content=f"m{i}") + + +def _thunk(total_tokens: int, value: str = "") -> ModelOutputThunk: + """Build a ModelOutputThunk with a populated usage dict.""" + mot = ModelOutputThunk(value=value) + mot.generation.usage = { + "prompt_tokens": total_tokens, + "completion_tokens": 0, + "total_tokens": total_tokens, + } + return mot + + +class TestChatContextDefaults: + def test_default_has_no_compactor(self): + # Compaction is opt-in: bare ChatContext() retains full history. + ctx = ChatContext() + assert ctx._compactor is None + + def test_default_keeps_full_history(self): + ctx = ChatContext() + for i in range(20): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 20 + + def test_window_size_arg_constructs_window_compactor(self): + ctx = ChatContext(window_size=3) + assert isinstance(ctx._compactor, WindowCompactor) + assert ctx._compactor.size == 3 + + def test_passing_both_args_raises(self): + with pytest.raises(ValueError): + ChatContext(compactor=WindowCompactor(size=2), window_size=3) + + def test_explicit_compactor_overrides_default(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + assert ctx._compactor is comp + + +class TestInlineCompactorGuard: + """ChatContext only accepts InlineCompactor instances.""" + + def test_rejects_llm_summarize_compactor_directly(self, scripted_summary_backend): + # Attaching LLMSummarizeCompactor would invoke the backend on every add(). + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend) + with pytest.raises(TypeError, match="requires an InlineCompactor"): + ChatContext(compactor=comp) + + def test_accepts_threshold_wrapping_window(self): + # ThresholdCompactor is an InlineCompactor regardless of inner. + wrapped = ThresholdCompactor(WindowCompactor(size=5), threshold=1000) + ctx = ChatContext(compactor=wrapped) + assert ctx._compactor is wrapped + + def test_accepts_threshold_wrapping_llm_summarize(self, scripted_summary_backend): + # Wrapped is acceptable: ThresholdCompactor gates inner by token usage, + # so backend isn't called on every add(). Inner's default_backend covers + # the actual summarisation when the gate trips. + wrapped = ThresholdCompactor( + LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=2), + threshold=1000, + ) + ctx = ChatContext(compactor=wrapped) + assert ctx._compactor is wrapped + + def test_accepts_window_compactor(self): + comp = WindowCompactor(size=5) + ctx = ChatContext(compactor=comp) + assert ctx._compactor is comp + + def test_rejects_non_inline_duck_typed_compactor(self): + class FakeCompactor: + def compact(self, ctx, *, backend=None): + return ctx + + with pytest.raises(TypeError, match="requires an InlineCompactor"): + ChatContext(compactor=FakeCompactor()) # type: ignore[arg-type] + + +class TestWindowCompactor: + def test_compact_keeps_last_n(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + items = ctx.as_list() + assert len(items) == 3 + assert [m.content for m in items] == ["m4", "m5", "m6"] + + def test_compact_does_not_mutate_original(self): + # Build with a permissive window so all 3 items are retained, then + # apply a tighter compactor manually (Pattern 2). + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) + ctx = ctx.add(_msg(1)) + ctx = ctx.add(_msg(2)) + before_compact = [m.content for m in ctx.as_list()] + compacted = WindowCompactor(size=2).compact(ctx) + # original unchanged + assert [m.content for m in ctx.as_list()] == before_compact + # compacted is shorter and a different object + assert compacted is not ctx + assert len(compacted.as_list()) == 2 + + def test_compact_preserves_compactor_on_result(self): + comp = WindowCompactor(size=2) + ctx = ChatContext(compactor=comp) + ctx = ctx.add(_msg(0)).add(_msg(1)).add(_msg(2)) + # subsequent adds keep using the same compactor + ctx = ctx.add(_msg(3)) + assert ctx._compactor is comp + assert len(ctx.as_list()) == 2 + + def test_view_for_generation_no_double_truncation(self): + ctx = ChatContext(window_size=3) + for i in range(7): + ctx = ctx.add(_msg(i)) + # add() already compacted; view should match the linear history exactly + view = ctx.view_for_generation() + assert view is not None + assert [m.content for m in view] == [m.content for m in ctx.as_list()] + + def test_negative_size_raises(self): + with pytest.raises(ValueError): + WindowCompactor(size=-1) + + def test_size_zero_clears_body(self): + # Regression: `[-0:]` evaluates to `[0:]` in Python, which would keep + # the entire body instead of nothing. size=0 must keep zero body items. + ctx = ChatContext(window_size=10_000) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=0).compact(ctx) + assert result.as_list() == [] + + def test_size_zero_keeps_pinned_prefix(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(3): + ctx = ctx.add(_msg(i)) + # Default pin_predicate=pin_system → system stays, body cleared. + result = WindowCompactor(size=0).compact(ctx) + items = result.as_list() + assert len(items) == 1 + assert items[0].content == "sys" + + def test_pins_leading_system_message(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="You are helpful.")) + for i in range(5): + ctx = ctx.add(_msg(i)) + # Apply WindowCompactor(size=2) manually — keep system + last 2 body. + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 3 + assert isinstance(items[0], Message) and items[0].role == "system" + assert [m.content for m in items[1:]] == ["m3", "m4"] + + def test_pins_multiple_leading_system_messages(self): + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys1")) + ctx = ctx.add(Message(role="system", content="sys2")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert [m.content for m in items[:2]] == ["sys1", "sys2"] + assert [m.content for m in items[2:]] == ["m3", "m4"] + + def test_does_not_pin_non_contiguous_system(self): + # System message in the middle is NOT pinned — only the contiguous prefix. + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_msg(0)) # body starts here + ctx = ctx.add(Message(role="system", content="late-sys")) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=2).compact(ctx) + items = result.as_list() + assert len(items) == 2 + assert "late-sys" not in [getattr(m, "content", None) for m in items] + + def test_no_system_message_pure_last_n(self): + # Without any system prefix, behaviour is pure last-N (matches Phase 2 semantics). + ctx = ChatContext(window_size=10_000) + for i in range(7): + ctx = ctx.add(_msg(i)) + result = WindowCompactor(size=3).compact(ctx) + items = result.as_list() + assert [m.content for m in items] == ["m4", "m5", "m6"] + + +class TestCompactorProtocol: + def test_user_class_satisfies_protocol_via_inline_marker(self): + """A user class structurally matching Compactor and inheriting InlineCompactor + is accepted by ChatContext.""" + from mellea.stdlib.context import InlineCompactor + + class Identity(InlineCompactor): + def compact(self, ctx, *, backend=None): + return ctx + + c = Identity() + ctx = ChatContext(compactor=c) + ctx = ctx.add(_msg(0)) + # Identity returns ctx unchanged, so we still see m0 + assert [m.content for m in ctx.as_list()] == ["m0"] + + def test_pattern_2_manual_compaction(self): + """Pattern 2: caller invokes compactor.compact() directly.""" + comp = WindowCompactor(size=2) + # context with no auto-compaction would be tricky to construct under the + # new defaults; instead use a window large enough that auto-compaction + # never fires, then apply comp manually. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + assert len(ctx.as_list()) == 5 + ctx2 = comp.compact(ctx) + assert len(ctx2.as_list()) == 2 + # original still untouched + assert len(ctx.as_list()) == 5 + + +class TestLastUsageTokens: + def test_no_thunk_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)) + assert _last_usage_tokens(ctx) is None + + def test_thunk_without_usage_returns_none(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(ModelOutputThunk(value="x")) + assert _last_usage_tokens(ctx) is None + + def test_reads_total_tokens(self): + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(150)) + assert _last_usage_tokens(ctx) == 150 + + def test_falls_back_to_prompt_plus_completion(self): + mot = ModelOutputThunk(value="x") + mot.generation.usage = {"prompt_tokens": 40, "completion_tokens": 20} + ctx = ChatContext(window_size=100).add(_msg(0)).add(mot) + assert _last_usage_tokens(ctx) == 60 + + def test_uses_most_recent_thunk(self): + ctx = ( + ChatContext(window_size=100).add(_thunk(100)).add(_msg(0)).add(_thunk(500)) + ) + assert _last_usage_tokens(ctx) == 500 + + +class TestThresholdCompactor: + def test_below_threshold_returns_input(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=1000) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(50)) + # 5 components but inner not invoked because token count (50) <= threshold (1000) + for i in range(1, 6): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_above_threshold_runs_inner(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + # Build a context with the last thunk reporting >threshold tokens. + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + ctx = ctx.add(_thunk(500)) + result = gated.compact(ctx) + # Inner was invoked → only last 2 components retained. + assert len(result.as_list()) == 2 + + def test_no_thunk_no_compaction(self): + """No thunk means no usage info — gate stays closed.""" + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=100) + ctx = ChatContext(window_size=100) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = gated.compact(ctx) + assert result is ctx + + def test_zero_threshold_disables_gate(self): + inner = WindowCompactor(size=2) + gated = ThresholdCompactor(inner, threshold=0) + ctx = ChatContext(window_size=100).add(_msg(0)).add(_thunk(10_000)) + result = gated.compact(ctx) + # Threshold 0 means "never trigger" — input passes through. + assert result is ctx + + +class TestPinPredicates: + def test_pin_nothing(self): + assert pin_nothing([_msg(0), _msg(1)]) == 0 + assert pin_nothing([]) == 0 + + def test_pin_system_zero_when_no_system(self): + assert pin_system([_msg(0), _msg(1)]) == 0 + + def test_pin_system_counts_contiguous(self): + components = [ + Message(role="system", content="s1"), + Message(role="system", content="s2"), + _msg(0), + Message(role="system", content="late-s"), # not pinned — non-contiguous + ] + assert pin_system(components) == 2 + + def test_pin_system_and_initial_user_with_both(self): + components = [ + Message(role="system", content="s1"), + Message(role="user", content="goal"), + Message(role="assistant", content="ack"), + ] + assert pin_system_and_initial_user(components) == 2 + + def test_pin_system_and_initial_user_no_user(self): + components = [ + Message(role="system", content="s1"), + Message(role="assistant", content="x"), + ] + # First non-system is "assistant", not "user" — not pinned beyond system. + assert pin_system_and_initial_user(components) == 1 + + def test_pin_system_and_initial_user_user_only(self): + components = [ + Message(role="user", content="goal"), + Message(role="assistant", content="ok"), + ] + assert pin_system_and_initial_user(components) == 1 + + +class TestWindowCompactorPredicate: + def test_pin_nothing_pure_last_n(self): + comp = WindowCompactor(size=2, pin_predicate=pin_nothing) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(5): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + assert len(items) == 2 + # System is dropped because predicate returned 0. + assert "sys" not in [getattr(m, "content", None) for m in items] + + def test_pin_system_and_initial_user_protects_first_user(self): + comp = WindowCompactor(size=2, pin_predicate=pin_system_and_initial_user) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + ctx = ctx.add(Message(role="user", content="goal")) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (sys + goal) + last 2 body = 4 + assert len(items) == 4 + assert items[0].content == "sys" + assert items[1].content == "goal" + + def test_custom_predicate(self): + # Predicate that pins the first 3 components unconditionally. + def pin_first_3(components): + return min(3, len(components)) + + comp = WindowCompactor(size=2, pin_predicate=pin_first_3) + ctx = ChatContext(window_size=10_000) + for i in range(8): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx) + items = result.as_list() + # prefix (m0, m1, m2) + last 2 of body (m6, m7) = 5 + assert [m.content for m in items] == ["m0", "m1", "m2", "m6", "m7"] + + +# --------------------------------------------------------------------------- # +# LLMSummarizeCompactor # +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def scripted_summary_backend(): + """Lazy-built fake backend that returns a fixed summary on each generate call.""" + from collections.abc import Sequence + + from mellea.core.backend import Backend, BaseModelSubclass + from mellea.core.base import C, GenerateLog + + class FakeBackend(Backend): + def __init__(self, summary: str = "SUMMARY-OF-OLD") -> None: + self.summary = summary + self.calls = 0 + self.last_action_content: str | None = None + self.last_model_options: dict | None = None + + async def _generate_from_context( + self, + action, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + self.calls += 1 + self.last_action_content = getattr(action, "content", str(action)) + self.last_model_options = model_options + mot = ModelOutputThunk(value=self.summary) + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw( + self, + actions, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + raise NotImplementedError + + return FakeBackend() + + +class TestLLMSummarizeCompactor: + def test_negative_keep_n_raises(self, scripted_summary_backend): + with pytest.raises(ValueError): + LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=-1) + + def test_prompt_template_must_have_placeholder(self, scripted_summary_backend): + with pytest.raises(ValueError, match="conversation"): + LLMSummarizeCompactor( + default_backend=scripted_summary_backend, + prompt_template="no placeholder here", + ) + + def test_default_backend_is_required(self): + with pytest.raises(TypeError, match="default_backend"): + LLMSummarizeCompactor() # type: ignore[call-arg] + + def test_compact_is_sync(self, scripted_summary_backend): + import inspect + + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend) + # Sync from the outside even though the implementation calls async backend code. + assert not inspect.iscoroutinefunction(comp.compact) + + def test_uses_default_backend_when_call_omits_one(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + # No backend kwarg → falls back to default_backend. + result = comp.compact(ctx) + items = result.as_list() + assert "[CONTEXT SUMMARY]" in items[0].content + assert scripted_summary_backend.calls == 1 + + def test_call_time_backend_overrides_default(self, scripted_summary_backend): + from mellea.core.backend import Backend + from mellea.core.base import GenerateLog + + class OtherBackend(Backend): + def __init__(self) -> None: + self.calls = 0 + + async def _generate_from_context( + self, + action, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + self.calls += 1 + mot = ModelOutputThunk(value="OTHER-SUMMARY") + mot._generate_log = GenerateLog(is_final_result=True) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw(self, *a, **kw): + raise NotImplementedError + + other = OtherBackend() + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=other) + items = result.as_list() + # Caller-supplied backend wins. + assert "OTHER-SUMMARY" in items[0].content + assert other.calls == 1 + assert scripted_summary_backend.calls == 0 + + def test_short_body_is_noop(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=5) + ctx = ChatContext(window_size=10_000) + for i in range(3): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + # body length (3) <= keep_n (5) → no-op, backend not called + assert result is ctx + assert scripted_summary_backend.calls == 0 + + def test_backend_failure_returns_ctx_unchanged_and_logs( + self, scripted_summary_backend, caplog + ): + """Compaction is best-effort: backend errors must not propagate.""" + import logging + + from mellea.core.backend import Backend + from mellea.core.base import GenerateLog + + class BrokenBackend(Backend): + async def _generate_from_context( + self, + action, + ctx, + *, + format=None, + model_options=None, + tool_calls: bool = False, + ): + raise RuntimeError("simulated rate limit") + + async def generate_from_raw(self, *a, **kw): + raise NotImplementedError + + comp = LLMSummarizeCompactor(default_backend=BrokenBackend(), keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + + with caplog.at_level(logging.WARNING): + result = comp.compact(ctx) + + # ctx returned unchanged — same object, original history intact. + assert result is ctx + assert [m.content for m in result.as_list()] == ["m0", "m1", "m2", "m3"] + # Warning logged with context for debugging. + assert any( + "summarisation backend call failed" in rec.message + and "RuntimeError" in rec.message + for rec in caplog.records + ) + + def test_renders_thunk_without_value_using_tool_calls( + self, scripted_summary_backend + ): + """Tool-call-only thunks (value=None) render the call name + args, not 'None'.""" + from mellea.core.base import ModelToolCall + + # The compactor's rendering only reads ``name``/``args`` off the + # ModelToolCall, never invokes ``func`` — pass None to skip + # AbstractMelleaTool's abstract-method requirements. + tool_call = ModelToolCall( + name="search", + func=None, # type: ignore[arg-type] + args={"q": "papers"}, + ) + thunk = ModelOutputThunk(value=None, tool_calls={"search": tool_call}) + + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(2): + ctx = ctx.add(_msg(i)) + ctx = ctx.add(thunk) + ctx = ctx.add(_msg(2)) # so the thunk falls into `old`, not `recent` + + comp.compact(ctx) + rendered = scripted_summary_backend.last_action_content + assert rendered is not None + assert "assistant called tools: search" in rendered + assert "'q': 'papers'" in rendered + # Old "assistant: None" failure mode must not appear. + assert "assistant: None" not in rendered + + def test_renders_thunk_with_no_value_and_no_tool_calls( + self, scripted_summary_backend + ): + """A thunk with neither value nor tool_calls is skipped entirely — no + '' marker, no 'assistant: None'.""" + thunk = ModelOutputThunk(value=None) + + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(2): + ctx = ctx.add(_msg(i)) + ctx = ctx.add(thunk) + ctx = ctx.add(_msg(2)) + + comp.compact(ctx) + rendered = scripted_summary_backend.last_action_content + assert rendered is not None + assert "" not in rendered + assert "assistant: None" not in rendered + # The other turns still made it into the prompt. + assert "user: m0" in rendered + assert "user: m1" in rendered + + def test_catchall_renders_unknown_component_as_typed_marker( + self, scripted_summary_backend + ): + """Component subclasses that aren't Message/ToolMessage/ModelOutputThunk + emit a ```` marker instead of the default object repr.""" + from mellea.core import Component + + class _CustomMarker(Component): + """Component without a ``content`` attribute.""" + + def parts(self): # type: ignore[override] + return [] + + def format_for_llm(self): # type: ignore[override] + return "" + + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(_CustomMarker()) # in `old` + ctx = ctx.add(_msg(0)) + ctx = ctx.add(_msg(99)) # in `recent` + + comp.compact(ctx) + rendered = scripted_summary_backend.last_action_content + assert rendered is not None + # Type name appears explicitly; raw repr does NOT. + assert "<_CustomMarker>" in rendered + assert "object at 0x" not in rendered + + def test_renders_message_with_attachments_as_markers( + self, scripted_summary_backend + ): + """Image/document attachments are noted by count; their contents are not reproduced.""" + from mellea.stdlib.components.docs.document import Document + + msg_with_imgs = Message(role="user", content="see these") + # Bypass the constructor to inject raw lists; the rendering path reads `_images`/`_docs`. + msg_with_imgs._images = ["IMGDATA1", "IMGDATA2"] # type: ignore[assignment] + msg_with_docs = Message(role="user", content="and these") + msg_with_docs._docs = [Document(text="doc body")] # type: ignore[assignment] + + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(msg_with_imgs) + ctx = ctx.add(msg_with_docs) + ctx = ctx.add(_msg(99)) # keeps msg_with_imgs/docs in `old`, this in `recent` + + comp.compact(ctx) + rendered = scripted_summary_backend.last_action_content + assert rendered is not None + assert "[2 image(s) attached]" in rendered + assert "[1 document(s) attached]" in rendered + # Image bytes are NOT in the rendered prompt. + assert "IMGDATA1" not in rendered + + def test_model_options_forwarded_to_backend(self, scripted_summary_backend): + """model_options set at construction reach the backend's generate call.""" + comp = LLMSummarizeCompactor( + default_backend=scripted_summary_backend, + keep_n=1, + model_options={"max_tokens": 4096, "temperature": 0.0}, + ) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + comp.compact(ctx) + assert scripted_summary_backend.last_model_options == { + "max_tokens": 4096, + "temperature": 0.0, + } + + def test_model_options_default_is_empty(self, scripted_summary_backend): + """When model_options is not set, the backend receives no caller-supplied + options (falsy: None or {}); upstream defaults govern.""" + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + comp.compact(ctx) + assert not scripted_summary_backend.last_model_options + + def test_summarises_old_keeps_recent(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=2) + ctx = ChatContext(window_size=10_000) + for i in range(6): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # summary (1) + last 2 verbatim = 3 + assert len(items) == 3 + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m4" + assert items[2].content == "m5" + assert scripted_summary_backend.calls == 1 + + def test_pin_predicate_preserves_prefix(self, scripted_summary_backend): + comp = LLMSummarizeCompactor( + default_backend=scripted_summary_backend, keep_n=1, pin_predicate=pin_system + ) + ctx = ChatContext(window_size=10_000) + ctx = ctx.add(Message(role="system", content="sys")) + for i in range(4): + ctx = ctx.add(_msg(i)) + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + # system (pinned) + summary + last 1 verbatim = 3 + assert items[0].role == "system" + assert items[0].content == "sys" + assert "[CONTEXT SUMMARY]" in items[1].content + assert items[2].content == "m3" + + def test_does_not_mutate_original(self, scripted_summary_backend): + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + before = [m.content for m in ctx.as_list()] + comp.compact(ctx, backend=scripted_summary_backend) + assert [m.content for m in ctx.as_list()] == before + + def test_satisfies_compactor_protocol(self, scripted_summary_backend): + comp: Compactor = LLMSummarizeCompactor( + default_backend=scripted_summary_backend + ) + # Just a typing-level check that the assignment is accepted. + assert callable(comp.compact) + + @pytest.mark.asyncio + async def test_works_inside_running_event_loop(self, scripted_summary_backend): + """compact() is callable from within an async function — uses worker thread.""" + comp = LLMSummarizeCompactor(default_backend=scripted_summary_backend, keep_n=1) + ctx = ChatContext(window_size=10_000) + for i in range(4): + ctx = ctx.add(_msg(i)) + # No await: this is a sync call from inside an async test. + result = comp.compact(ctx, backend=scripted_summary_backend) + items = result.as_list() + assert "[CONTEXT SUMMARY]" in items[0].content + assert items[1].content == "m3"