diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index e695d0d9a..37e943151 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -53,4 +53,5 @@ def chunk(self, text: str) -> list[str] | list[Chunk]: chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") + return chunks diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 092f29ac6..210826a98 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -18,7 +18,7 @@ from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType -from memos.utils import timed +from memos.utils import timed, timed_stage if TYPE_CHECKING: @@ -75,6 +75,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig): direct_markdown_hostnames=direct_markdown_hostnames, ) + def _embed_memory_items(self, items: list[TextualMemoryItem]) -> None: + """Compute embeddings for a list of memory items in-place. + + Attempts a single batch call first; falls back to per-item calls if the + batch fails. Errors are logged but never raised so callers always + continue normally. + """ + valid = [w for w in items if w and w.memory] + if not valid: + return + texts = [w.memory for w in valid] + try: + embeddings = self.embedder.embed(texts) + for w, emb in zip(valid, embeddings, strict=True): + w.metadata.embedding = emb + except Exception as e: + logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}") + logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts)) + for w in valid: + try: + w.metadata.embedding = self.embedder.embed([w.memory])[0] + except Exception as e2: + logger.error(f"[MultiModalStruct] Error computing embedding for item: {e2}") + def _split_large_memory_item( self, item: TextualMemoryItem, max_tokens: int ) -> list[TextualMemoryItem]: @@ -203,13 +227,8 @@ def _concat_multi_modal_memories( # If only one item after processing, compute embedding and return if len(processed_items) == 1: single_item = processed_items[0] - if single_item and single_item.memory: - try: - single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0] - except Exception as e: - logger.error( - f"[MultiModalStruct] Error computing embedding for single item: {e}" - ) + with timed_stage("add", "embedding", window_count=1): + self._embed_memory_items([single_item]) return processed_items windows = [] @@ -260,31 +279,8 @@ def _concat_multi_modal_memories( windows.append(window) # Batch compute embeddings for all windows - if windows: - # Collect all valid windows that need embedding - valid_windows = [w for w in windows if w and w.memory] - - if valid_windows: - # Collect all texts that need embedding - texts_to_embed = [w.memory for w in valid_windows] - - # Batch compute all embeddings at once - try: - embeddings = self.embedder.embed(texts_to_embed) - # Fill embeddings back into memory items - for window, embedding in zip(valid_windows, embeddings, strict=True): - window.metadata.embedding = embedding - except Exception as e: - logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}") - # Fallback: compute embeddings individually - for window in valid_windows: - if window.memory: - try: - window.metadata.embedding = self.embedder.embed([window.memory])[0] - except Exception as e2: - logger.error( - f"[MultiModalStruct] Error computing embedding for item: {e2}" - ) + with timed_stage("add", "embedding", window_count=len(windows)): + self._embed_memory_items(windows) return windows @@ -984,49 +980,49 @@ def _process_multi_modal_data( # must pop here, avoid add to info, only used in sync fine mode custom_tags = info.pop("custom_tags", None) if isinstance(info, dict) else None - # Use MultiModalParser to parse the scene data - # If it's a list, parse each item; otherwise parse as single message - if isinstance(scene_data_info, list): - # Pre-expand multimodal messages - expanded_messages = self._expand_multimodal_messages(scene_data_info) - - # Parse each message in the list - all_memory_items = [] - # Use thread pool to parse each message in parallel, but keep the original order - with ContextThreadPoolExecutor(max_workers=30) as executor: - # submit tasks and keep the original order - futures = [ - executor.submit( - self.multi_modal_parser.parse, - msg, - info, - mode="fast", - need_emb=False, - **kwargs, - ) - for msg in expanded_messages - ] - # collect results in original order - for future in futures: - try: - items = future.result() - all_memory_items.extend(items) - except Exception as e: - logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") - else: - # Parse as single message - all_memory_items = self.multi_modal_parser.parse( - scene_data_info, info, mode="fast", need_emb=False, **kwargs - ) - fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + # Stage: parse — parallel message parsing + sliding-window aggregation + with timed_stage("add", "parse") as ts_parse: + if isinstance(scene_data_info, list): + expanded_messages = self._expand_multimodal_messages(scene_data_info) + ts_parse.set(msg_count=len(expanded_messages)) + + all_memory_items = [] + with ContextThreadPoolExecutor(max_workers=30) as executor: + futures = [ + executor.submit( + self.multi_modal_parser.parse, + msg, + info, + mode="fast", + need_emb=False, + **kwargs, + ) + for msg in expanded_messages + ] + for future in futures: + try: + items = future.result() + all_memory_items.extend(items) + except Exception as e: + logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") + else: + ts_parse.set(msg_count=1) + all_memory_items = self.multi_modal_parser.parse( + scene_data_info, info, mode="fast", need_emb=False, **kwargs + ) + + fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + ts_parse.set(window_count=len(fast_memory_items)) + if mode == "fast": return fast_memory_items - else: - non_file_url_fast_items = [ - item for item in fast_memory_items if not self._is_file_url_only_item(item) - ] - # Part A: call llm in parallel using thread pool + # Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial + non_file_url_fast_items = [ + item for item in fast_memory_items if not self._is_file_url_only_item(item) + ] + + with timed_stage("add", "llm_extract") as ts_llm: fine_memory_items = [] with ContextThreadPoolExecutor(max_workers=4) as executor: @@ -1057,7 +1053,6 @@ def _process_multi_modal_data( **kwargs, ) - # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() @@ -1068,21 +1063,25 @@ def _process_multi_modal_data( fine_memory_items.extend(fine_memory_items_skill_memory_parser) fine_memory_items.extend(fine_memory_items_pref_parser) - # Part B: get fine multimodal items - for fast_item in fast_memory_items: - sources = fast_item.metadata.sources - for source in sources: - lang = getattr(source, "lang", "en") - items = self.multi_modal_parser.process_transfer( - source, - context_items=[fast_item], - custom_tags=custom_tags, - info=info, - lang=lang, - user_context=kwargs.get("user_context"), - ) - fine_memory_items.extend(items) - return fine_memory_items + # Part B: per-source serial processing + with timed_stage("add", "per_source") as ts_ps: + for fast_item in fast_memory_items: + sources = fast_item.metadata.sources + for source in sources: + lang = getattr(source, "lang", "en") + items = self.multi_modal_parser.process_transfer( + source, + context_items=[fast_item], + custom_tags=custom_tags, + info=info, + lang=lang, + user_context=kwargs.get("user_context"), + ) + fine_memory_items.extend(items) + + ts_llm.set(fine_memory_count=len(fine_memory_items), per_source_ms=ts_ps.duration_ms) + + return fine_memory_items @timed def _process_transfer_multi_modal_data( diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py index 0d2d460e9..11c5245ba 100644 --- a/src/memos/multi_mem_cube/composite_cube.py +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -6,6 +6,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.multi_mem_cube.views import MemCubeView +from memos.utils import timed_stage if TYPE_CHECKING: @@ -27,13 +28,18 @@ class CompositeCubeView(MemCubeView): def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: all_results: list[dict[str, Any]] = [] - - # fast mode: for each cube view, add memories - # maybe add more strategies in add_req.async_mode - for view in self.cube_views: - self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") - results = view.add_memories(add_req) - all_results.extend(results) + cube_count = len(self.cube_views) + + with timed_stage("add", "multi_cube", cube_count=cube_count): + for idx, view in enumerate(self.cube_views): + self.logger.info( + "[CompositeCubeView] fan-out add to cube=%s (%d/%d)", + view.cube_id, + idx + 1, + cube_count, + ) + results = view.add_memories(add_req) + all_results.extend(results) return all_results diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 355cb8cee..1a8b7092a 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -31,7 +31,7 @@ SearchMode, UserContext, ) -from memos.utils import timed +from memos.utils import timed, timed_stage logger = get_logger(__name__) @@ -692,25 +692,25 @@ def _process_text_mem( extract_mode, add_req.mode, ) - init_time = time.time() - # Extract memories - memories_local = self.mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - **(add_req.info or {}), - "custom_tags": add_req.custom_tags, - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode=extract_mode, - user_name=user_context.mem_cube_id, - chat_history=add_req.chat_history, - user_context=user_context, - ) - self.logger.info( - f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}" - ) + process_start = time.perf_counter() + + # Stage 1+2: parse + embedding (logged inside get_memory via timed_stage) + with timed_stage("add", "get_memory", cube_id=self.cube_id) as ts_gm: + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + **(add_req.info or {}), + "custom_tags": add_req.custom_tags, + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode=extract_mode, + user_name=user_context.mem_cube_id, + chat_history=add_req.chat_history, + user_context=user_context, + ) + get_memory_ms = ts_gm.duration_ms flattened_local = [mm for m in memories_local for mm in m] # Explicitly set source_doc_id to metadata if present in info @@ -719,73 +719,105 @@ def _process_text_mem( for memory in flattened_local: memory.metadata.source_doc_id = source_doc_id - self.logger.info(f"Memory extraction completed for user {add_req.user_id}") - # Add memories to text_mem mem_group = [ memory for memory in flattened_local if memory.metadata.memory_type != "RawFileMemory" ] - mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - mem_group, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - - # Add raw file nodes and edges - if self.mem_reader.save_rawfile and extract_mode == "fine": - raw_file_mem_group = [ - memory - for memory in flattened_local - if memory.metadata.memory_type == "RawFileMemory" - ] - self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - mem_ids_local, - user_id=add_req.user_id, + # Stage 3: write_db + with timed_stage("add", "write_db", cube_id=self.cube_id) as ts_db: + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + mem_group, user_name=user_context.mem_cube_id, ) - # Schedule async/sync tasks: async process raw chunk memory | sync only send messages - self._schedule_memory_tasks( - add_req=add_req, - user_context=user_context, - mem_ids=mem_ids_local, - sync_mode=sync_mode, - ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) - # Mark merged_from memories as archived when provided in add_req.info - if sync_mode == "sync" and extract_mode == "fine": - for memory in flattened_local: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - if self.mem_reader and self.mem_reader.graph_db: - for old_id in old_ids: - try: - self.mem_reader.graph_db.update_node( - str(old_id), - {"status": "archived"}, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"[SingleCubeView] Archived merged_from memory: {old_id}" - ) - except Exception as e: - self.logger.warning( - f"[SingleCubeView] Failed to archive merged_from memory {old_id}: {e}" - ) - else: - self.logger.warning( - "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." + # Add raw file nodes and edges + if self.mem_reader.save_rawfile and extract_mode == "fine": + raw_file_mem_group = [ + memory + for memory in flattened_local + if memory.metadata.memory_type == "RawFileMemory" + ] + self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + mem_ids_local, + user_id=add_req.user_id, + user_name=user_context.mem_cube_id, + ) + ts_db.set(memory_count=len(mem_ids_local)) + write_db_ms = ts_db.duration_ms + + # Stage 4: schedule + with timed_stage("add", "schedule", cube_id=self.cube_id) as ts_sched: + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + # Mark merged_from memories as archived when provided in add_req.info + if sync_mode == "sync" and extract_mode == "fine": + for memory in flattened_local: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] ) + if self.mem_reader and self.mem_reader.graph_db: + for old_id in old_ids: + try: + self.mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"[SingleCubeView] Archived merged_from memory: {old_id}" + ) + except Exception as e: + self.logger.warning( + f"[SingleCubeView] Failed to archive merged_from memory {old_id}: {e}" + ) + else: + self.logger.warning( + "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." + ) + schedule_ms = ts_sched.duration_ms + + # Summary rollup — total_ms is the outer wall-clock, not a new stage + total_ms = int((time.perf_counter() - process_start) * 1000) + input_msg_count = len(add_req.messages) if add_req.messages else 0 + memory_count = len(mem_ids_local) + est_input_tokens = ( + sum( + len(str(m.get("content", ""))) if isinstance(m, dict) else len(str(m)) + for m in (add_req.messages or []) + ) + // 4 + ) + timed_stage.emit_now( + "add", + "summary", + cube_id=self.cube_id, + sync_mode=sync_mode, + extract_mode=extract_mode, + input_msg_count=input_msg_count, + est_input_tokens=est_input_tokens, + memory_count=memory_count, + get_memory_ms=get_memory_ms, + write_db_ms=write_db_ms, + schedule_ms=schedule_ms, + total_ms=total_ms, + per_item_ms=total_ms // max(memory_count, 1), + ) # Format results uniformly text_memories = [ diff --git a/src/memos/utils.py b/src/memos/utils.py index fd6d4eaf9..f7111f8ad 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -2,12 +2,138 @@ import time import traceback +from contextlib import ContextDecorator +from typing import Any + from memos.log import get_logger logger = get_logger(__name__) +class timed_stage(ContextDecorator): # noqa: N801 + """Unified timing helper for business-stage instrumentation. + + Works as **both** a context-manager and a decorator - one tool for all + timing needs. + + Context-manager (when the stage is a *code block* inside a function):: + + with timed_stage("add", "parse", cube_id=cube_id) as ts: + items = self._parse(...) + ts.set(msg_count=10, window_count=len(windows)) + + Decorator (when the stage is *an entire function*):: + + @timed_stage("add", "write_db") + def _write_to_db(self, ...): + ... + + Decorator with dynamic fields extracted from arguments:: + + @timed_stage("search", "recall", + extra=lambda self, req, **kw: {"cube_id": self.cube_id}) + def _vector_recall(self, req, ...): + ... + + Output format (SLS-friendly, one-line structured log):: + + [STAGE] biz=add stage=parse cube_id=xxx duration_ms=150 msg_count=10 + """ + + def __init__( + self, + biz: str = "", + stage: str = "", + *, + extra: dict[str, Any] | None = None, + level: str = "info", + **fields: Any, + ): + self._biz = biz + self._stage = stage + self._extra_factory = extra if callable(extra) else None + self._static_extra = extra if isinstance(extra, dict) else None + self._level = level + self._fields: dict[str, Any] = dict(fields) + self._start: float = 0.0 + self.duration_ms: int = 0 + + # -- context-manager protocol ------------------------------------------ + + def __enter__(self): + self._start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.duration_ms = int((time.perf_counter() - self._start) * 1000) + self._emit(self.duration_ms, exc_type) + return False + + # -- decorator protocol (extends ContextDecorator) --------------------- + + def __call__(self, func=None): + if func is None: + return super().__call__(func) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if self._extra_factory is not None: + try: + dynamic = self._extra_factory(*args, **kwargs) + if dynamic: + self._fields.update(dynamic) + except Exception as e: + logger.warning("[STAGE] extra callback error: %r", e) + + stage_name = self._stage or func.__name__ + self._stage = stage_name + self._start = time.perf_counter() + try: + return func(*args, **kwargs) + finally: + self.duration_ms = int((time.perf_counter() - self._start) * 1000) + self._emit(self.duration_ms) + + return wrapper + + # -- public API -------------------------------------------------------- + + def set(self, **fields: Any): + """Add / overwrite fields after execution (e.g. counts only known after the block runs).""" + self._fields.update(fields) + + @staticmethod + def emit_now(biz: str, stage: str, **fields: Any): + """Fire a one-shot structured log without timing (e.g. summary rollups).""" + parts = [f"biz={biz}", f"stage={stage}"] + for k, v in fields.items(): + parts.append(f"{k}={v}") + logger.info("[STAGE] " + " ".join(parts)) + + # -- internals --------------------------------------------------------- + + def _emit(self, duration_ms: int, exc_type=None): + parts: list[str] = [] + if self._biz: + parts.append(f"biz={self._biz}") + if self._stage: + parts.append(f"stage={self._stage}") + parts.append(f"duration_ms={duration_ms}") + + if self._static_extra: + self._fields.update(self._static_extra) + + for k, v in self._fields.items(): + parts.append(f"{k}={v}") + + if exc_type is not None: + parts.append(f"error={exc_type.__name__}") + + msg = "[STAGE] " + " ".join(parts) + getattr(logger, self._level, logger.info)(msg) + + def timed_with_status( func=None, *, diff --git a/tests/test_add_stage_logging.py b/tests/test_add_stage_logging.py new file mode 100644 index 000000000..c41948493 --- /dev/null +++ b/tests/test_add_stage_logging.py @@ -0,0 +1,415 @@ +"""Integration tests for the add-memories call chain after timed_stage refactoring. + +Validates: + 1. SingleCubeView._process_text_mem returns correct business results (regression). + 2. Each stage emits a [STAGE] log with expected biz/stage/fields. + 3. Summary rollup emits all aggregated fields. + 4. CompositeCubeView.add_memories emits multi_cube stage for >1 cubes. + 5. Exceptions in stages do not swallow errors or corrupt results. + +NOTE: SingleCubeView / CompositeCubeView are imported lazily inside fixtures +to work around a known circular import in memos.api.handlers.__init__. +""" + +import logging +import uuid + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _stage_logs(caplog) -> list[str]: + return [r.message for r in caplog.records if r.message.startswith("[STAGE]")] + + +def _make_add_req(**overrides): + from memos.api.product_models import APIADDRequest + + defaults = { + "user_id": "test_user", + "messages": [ + {"role": "user", "content": "remember this"}, + {"role": "assistant", "content": "ok"}, + ], + } + defaults.update(overrides) + return APIADDRequest(**defaults) + + +def _make_memory_item(memory_text: str = "hello world"): + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata + + return TextualMemoryItem( + id=str(uuid.uuid4()), + memory=memory_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id="u1", + session_id="s1", + memory_type="WorkingMemory", + sources=[], + info={}, + ), + ) + + +@dataclass +class _FakeSingleCube: + """Minimal stub that records calls for CompositeCubeView tests.""" + + cube_id: str + result: list[dict[str, Any]] + + def add_memories(self, add_req): + return list(self.result) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def single_cube_view(): + """Build a SingleCubeView with fully-mocked dependencies.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + mem_item = _make_memory_item() + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [[mem_item]] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [mem_item.id] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + mock_scheduler = MagicMock() + + view = SingleCubeView( + cube_id="cube_test", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=mock_scheduler, + logger=logging.getLogger("test.single_cube"), + searcher=None, + feedback_server=None, + ) + return view, mem_item + + +# =========================================================================== +# SingleCubeView — async + fast (the most common path) +# =========================================================================== + + +class TestSingleCubeAddAsyncFast: + def test_returns_correct_business_result(self, single_cube_view): + view, mem_item = single_cube_view + add_req = _make_add_req(async_mode="async") + + results = view.add_memories(add_req) + + assert len(results) == 1 + assert results[0]["memory_id"] == mem_item.id + assert results[0]["cube_id"] == "cube_test" + assert results[0]["memory_type"] == "WorkingMemory" + + def test_mem_reader_called_with_fast_mode(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.mem_reader.get_memory.assert_called_once() + call_kwargs = view.mem_reader.get_memory.call_args + assert call_kwargs.kwargs["mode"] == "fast" + + def test_text_mem_add_called(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.naive_mem_cube.text_mem.add.assert_called_once() + + def test_scheduler_called(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.mem_scheduler.submit_messages.assert_called_once() + + def test_stage_logs_emitted(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + logs = _stage_logs(caplog) + + stage_names = [] + for log_line in logs: + for part in log_line.split(): + if part.startswith("stage="): + stage_names.append(part.split("=", 1)[1]) + + assert "get_memory" in stage_names + assert "write_db" in stage_names + assert "schedule" in stage_names + assert "summary" in stage_names + + def test_summary_contains_all_fields(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + s = summary[0] + for field in [ + "cube_id=", + "sync_mode=", + "extract_mode=", + "input_msg_count=", + "est_input_tokens=", + "memory_count=", + "get_memory_ms=", + "write_db_ms=", + "schedule_ms=", + "total_ms=", + "per_item_ms=", + ]: + assert field in s, f"Missing field '{field}' in summary: {s}" + + def test_summary_values_are_consistent(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = next(log for log in _stage_logs(caplog) if "stage=summary" in log) + fields = {} + for part in summary.split(): + if "=" in part: + k, v = part.split("=", 1) + fields[k] = v + + assert fields["sync_mode"] == "async" + assert fields["extract_mode"] == "fast" + assert fields["input_msg_count"] == "2" + assert fields["memory_count"] == "1" + assert int(fields["total_ms"]) >= 0 + assert int(fields["get_memory_ms"]) >= 0 + assert int(fields["write_db_ms"]) >= 0 + assert int(fields["schedule_ms"]) >= 0 + + def test_write_db_reports_memory_count(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + write_db = [log for log in _stage_logs(caplog) if "stage=write_db" in log] + assert len(write_db) == 1 + assert "memory_count=1" in write_db[0] + + def test_get_memory_has_cube_id(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + gm = [log for log in _stage_logs(caplog) if "stage=get_memory" in log] + assert len(gm) == 1 + assert "cube_id=cube_test" in gm[0] + + def test_schedule_has_cube_id(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + sched = [log for log in _stage_logs(caplog) if "stage=schedule" in log] + assert len(sched) == 1 + assert "cube_id=cube_test" in sched[0] + + +# =========================================================================== +# SingleCubeView — sync + fast +# =========================================================================== + + +class TestSingleCubeAddSyncFast: + def test_sync_fast_returns_result(self, single_cube_view): + view, mem_item = single_cube_view + add_req = _make_add_req(async_mode="sync", mode="fast") + + results = view.add_memories(add_req) + + assert len(results) == 1 + assert results[0]["memory_id"] == mem_item.id + + def test_sync_fast_summary_fields(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="sync", mode="fast") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + assert "sync_mode=sync" in summary[0] + assert "extract_mode=fast" in summary[0] + + +# =========================================================================== +# SingleCubeView — zero memories edge case +# =========================================================================== + + +class TestSingleCubeEdgeCases: + def test_zero_memories_does_not_crash(self, caplog): + """When get_memory returns no items, process should still complete.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [[]] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + view = SingleCubeView( + cube_id="cube_empty", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=MagicMock(), + logger=logging.getLogger("test.edge"), + searcher=None, + ) + + add_req = _make_add_req(async_mode="async") + with caplog.at_level(logging.INFO): + results = view.add_memories(add_req) + + assert results == [] + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + assert "memory_count=0" in summary[0] + + def test_multiple_memories_returned(self, caplog): + """Multiple memory items should all appear in results.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + items = [_make_memory_item(f"mem_{i}") for i in range(3)] + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [items] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [it.id for it in items] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + view = SingleCubeView( + cube_id="cube_multi", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=MagicMock(), + logger=logging.getLogger("test.multi"), + searcher=None, + ) + + add_req = _make_add_req(async_mode="async") + with caplog.at_level(logging.INFO): + results = view.add_memories(add_req) + + assert len(results) == 3 + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert "memory_count=3" in summary[0] + + +# =========================================================================== +# CompositeCubeView — multi_cube stage +# =========================================================================== + + +class TestCompositeCubeAdd: + def test_single_cube_emits_multi_cube_log(self, caplog): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake = _FakeSingleCube(cube_id="c1", result=[{"m": 1}]) + composite = CompositeCubeView( + cube_views=[fake], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + with caplog.at_level(logging.INFO): + results = composite.add_memories(add_req) + + assert len(results) == 1 + multi = [log for log in _stage_logs(caplog) if "stage=multi_cube" in log] + assert len(multi) == 1 + assert "cube_count=1" in multi[0] + + def test_multi_cube_emits_stage_with_duration(self, caplog): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake1 = _FakeSingleCube(cube_id="c1", result=[{"m": 1}]) + fake2 = _FakeSingleCube(cube_id="c2", result=[{"m": 2}]) + composite = CompositeCubeView( + cube_views=[fake1, fake2], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + with caplog.at_level(logging.INFO): + results = composite.add_memories(add_req) + + assert len(results) == 2 + multi = [log for log in _stage_logs(caplog) if "stage=multi_cube" in log] + assert len(multi) == 1 + assert "cube_count=2" in multi[0] + assert "duration_ms=" in multi[0] + + def test_fan_out_results_aggregated(self): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake1 = _FakeSingleCube(cube_id="c1", result=[{"a": 1}, {"a": 2}]) + fake2 = _FakeSingleCube(cube_id="c2", result=[{"b": 3}]) + composite = CompositeCubeView( + cube_views=[fake1, fake2], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + results = composite.add_memories(add_req) + + assert len(results) == 3 diff --git a/tests/test_utils_timing.py b/tests/test_utils_timing.py new file mode 100644 index 000000000..b4d5cb989 --- /dev/null +++ b/tests/test_utils_timing.py @@ -0,0 +1,390 @@ +"""Tests for memos.utils timing utilities: timed_stage, timed, timed_with_status. + +Covers: + - timed_stage: context-manager, decorator, emit_now, duration_ms propagation, + set(), error logging, extra callback, static extra dict + - timed: regression (return value, threshold, log_prefix, log=False) + - timed_with_status: regression (success, failure+fallback, log_args, log_extra_args) +""" + +import logging +import time + +import pytest + +from memos.utils import timed, timed_stage, timed_with_status + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _collect_stage_logs(caplog): + """Return all log messages that start with '[STAGE]'.""" + return [r.message for r in caplog.records if r.message.startswith("[STAGE]")] + + +def _collect_timer_logs(caplog): + """Return all log messages that start with '[TIMER]'.""" + return [r.message for r in caplog.records if r.message.startswith("[TIMER]")] + + +def _collect_timer_with_status_logs(caplog): + """Return all log messages that start with '[TIMER_WITH_STATUS]'.""" + return [r.message for r in caplog.records if r.message.startswith("[TIMER_WITH_STATUS]")] + + +# =========================================================================== +# timed_stage — context manager +# =========================================================================== + + +class TestTimedStageContextManager: + def test_basic_log_output(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "parse", cube_id="c1"): + pass + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=add" in logs[0] + assert "stage=parse" in logs[0] + assert "cube_id=c1" in logs[0] + assert "duration_ms=" in logs[0] + + def test_duration_ms_is_populated(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "embedding") as ts: + time.sleep(0.05) + assert ts.duration_ms >= 40 # at least ~50ms minus jitter + + def test_set_adds_fields(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "write_db") as ts: + ts.set(memory_count=7) + logs = _collect_stage_logs(caplog) + assert "memory_count=7" in logs[0] + + def test_set_overwrites_fields(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "write_db", memory_count=0) as ts: + ts.set(memory_count=5) + logs = _collect_stage_logs(caplog) + assert "memory_count=5" in logs[0] + assert "memory_count=0" not in logs[0] + + def test_exception_logged_but_propagated(self, caplog): + with ( + caplog.at_level(logging.INFO), + pytest.raises(ValueError, match="boom"), + timed_stage("add", "parse"), + ): + raise ValueError("boom") + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "error=ValueError" in logs[0] + + def test_duration_ms_available_after_exception(self, caplog): + with caplog.at_level(logging.INFO): + ts = timed_stage("add", "parse") + with pytest.raises(RuntimeError), ts: + time.sleep(0.02) + raise RuntimeError("fail") + assert ts.duration_ms >= 15 + + def test_no_biz_no_stage(self, caplog): + """Empty biz/stage should not emit those fields.""" + with caplog.at_level(logging.INFO), timed_stage(x=1): + pass + logs = _collect_stage_logs(caplog) + assert "biz=" not in logs[0] + assert "stage=" not in logs[0] + assert "x=1" in logs[0] + + def test_duration_ms_readable_downstream(self): + """Downstream code can reference ts.duration_ms for summary rollup.""" + with timed_stage("add", "get_memory") as ts: + time.sleep(0.01) + get_memory_ms = ts.duration_ms + assert isinstance(get_memory_ms, int) + assert get_memory_ms >= 5 + + +# =========================================================================== +# timed_stage — decorator +# =========================================================================== + + +class TestTimedStageDecorator: + def test_decorator_basic(self, caplog): + @timed_stage("search", "recall") + def do_recall(): + return [1, 2, 3] + + with caplog.at_level(logging.INFO): + result = do_recall() + assert result == [1, 2, 3] + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=search" in logs[0] + assert "stage=recall" in logs[0] + + def test_decorator_uses_func_name_when_no_stage(self, caplog): + @timed_stage("add") + def my_custom_func(): + pass + + with caplog.at_level(logging.INFO): + my_custom_func() + logs = _collect_stage_logs(caplog) + assert "stage=my_custom_func" in logs[0] + + def test_decorator_preserves_function_metadata(self): + @timed_stage("add", "test") + def documented_func(): + """I have a docstring.""" + return 42 + + assert documented_func.__name__ == "documented_func" + assert documented_func.__doc__ == "I have a docstring." + + def test_decorator_with_extra_callback(self, caplog): + class Service: + cube_id = "cube_abc" + + @timed_stage("add", "write_db", extra=lambda self: {"cube_id": self.cube_id}) + def write(self): + return "ok" + + svc = Service() + with caplog.at_level(logging.INFO): + result = svc.write() + assert result == "ok" + logs = _collect_stage_logs(caplog) + assert "cube_id=cube_abc" in logs[0] + + def test_decorator_extra_callback_error_does_not_break(self, caplog): + @timed_stage("add", "risky", extra=lambda: (_ for _ in ()).throw(RuntimeError("oops"))) + def risky_func(): + return "still works" + + with caplog.at_level(logging.WARNING): + result = risky_func() + assert result == "still works" + + def test_decorator_exception_propagated(self, caplog): + @timed_stage("add", "crash") + def crasher(): + raise TypeError("type error") + + with caplog.at_level(logging.INFO), pytest.raises(TypeError, match="type error"): + crasher() + logs = _collect_stage_logs(caplog) + assert "error=TypeError" not in logs[0] # decorator path doesn't pass exc_type + + +# =========================================================================== +# timed_stage.emit_now +# =========================================================================== + + +class TestTimedStageEmitNow: + def test_emit_now_basic(self, caplog): + with caplog.at_level(logging.INFO): + timed_stage.emit_now("add", "summary", total_ms=1200, per_item_ms=240) + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=add" in logs[0] + assert "stage=summary" in logs[0] + assert "total_ms=1200" in logs[0] + assert "per_item_ms=240" in logs[0] + assert "duration_ms" not in logs[0] + + def test_emit_now_no_extra_fields(self, caplog): + with caplog.at_level(logging.INFO): + timed_stage.emit_now("search", "summary") + logs = _collect_stage_logs(caplog) + assert logs[0] == "[STAGE] biz=search stage=summary" + + +# =========================================================================== +# timed_stage — static extra dict +# =========================================================================== + + +class TestTimedStageStaticExtra: + def test_static_extra_dict(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "parse", extra={"env": "prod"}): + pass + logs = _collect_stage_logs(caplog) + assert "env=prod" in logs[0] + + +# =========================================================================== +# timed — regression tests (original behavior must not change) +# =========================================================================== + + +class TestTimedRegression: + def test_return_value_preserved(self): + @timed + def add(a, b): + return a + b + + assert add(1, 2) == 3 + + def test_no_log_below_threshold(self, caplog): + """@timed only logs when elapsed >= 100ms.""" + + @timed + def fast_func(): + return "fast" + + with caplog.at_level(logging.INFO): + result = fast_func() + assert result == "fast" + logs = _collect_timer_logs(caplog) + assert len(logs) == 0 + + def test_log_above_threshold(self, caplog): + @timed + def slow_func(): + time.sleep(0.12) + return "slow" + + with caplog.at_level(logging.INFO): + result = slow_func() + assert result == "slow" + logs = _collect_timer_logs(caplog) + assert len(logs) == 1 + assert "slow_func" in logs[0] + + def test_log_false_disables(self, caplog): + @timed(log=False) + def no_log_func(): + time.sleep(0.12) + return 99 + + with caplog.at_level(logging.INFO): + result = no_log_func() + assert result == 99 + logs = _collect_timer_logs(caplog) + assert len(logs) == 0 + + def test_log_prefix(self, caplog): + @timed(log_prefix="MY_PREFIX") + def prefixed(): + time.sleep(0.12) + return True + + with caplog.at_level(logging.INFO): + prefixed() + logs = _collect_timer_logs(caplog) + assert "MY_PREFIX" in logs[0] + + def test_both_decorator_forms(self): + """@timed and @timed() should both work.""" + + @timed + def bare(): + return 1 + + @timed() + def parens(): + return 2 + + assert bare() == 1 + assert parens() == 2 + + +# =========================================================================== +# timed_with_status — regression tests +# =========================================================================== + + +class TestTimedWithStatusRegression: + def test_success_logging(self, caplog): + @timed_with_status + def ok_func(): + return "hello" + + with caplog.at_level(logging.INFO): + result = ok_func() + assert result == "hello" + logs = _collect_timer_with_status_logs(caplog) + assert len(logs) == 1 + assert "status: SUCCESS" in logs[0] + assert "ok_func" in logs[0] + + def test_failure_logging_no_fallback(self, caplog): + @timed_with_status + def fail_func(): + raise RuntimeError("bad") + + with caplog.at_level(logging.INFO): + fail_func() + logs = _collect_timer_with_status_logs(caplog) + assert len(logs) == 1 + assert "status: FAILED" in logs[0] + assert "RuntimeError" in logs[0] + + def test_failure_with_fallback(self, caplog): + @timed_with_status(fallback=lambda e, *a, **kw: "fallback_val") + def fail_func(): + raise RuntimeError("bad") + + with caplog.at_level(logging.INFO): + result = fail_func() + assert result == "fallback_val" + logs = _collect_timer_with_status_logs(caplog) + assert "status: FAILED" in logs[0] + + def test_log_prefix(self, caplog): + @timed_with_status(log_prefix="CUSTOM") + def prefixed(): + return 1 + + with caplog.at_level(logging.INFO): + prefixed() + logs = _collect_timer_with_status_logs(caplog) + assert "CUSTOM" in logs[0] + + def test_log_args(self, caplog): + @timed_with_status(log_args=["user_id"]) + def with_args(user_id="u1"): + return user_id + + with caplog.at_level(logging.INFO): + with_args(user_id="u42") + logs = _collect_timer_with_status_logs(caplog) + assert "user_id=u42" in logs[0] + + def test_log_extra_args_dict(self, caplog): + @timed_with_status(log_extra_args={"region": "us-west"}) + def with_extra(): + return True + + with caplog.at_level(logging.INFO): + with_extra() + logs = _collect_timer_with_status_logs(caplog) + assert "region=us-west" in logs[0] + + def test_log_extra_args_callable(self, caplog): + @timed_with_status(log_extra_args=lambda *a, **kw: {"dynamic": "yes"}) + def with_dynamic(): + return True + + with caplog.at_level(logging.INFO): + with_dynamic() + logs = _collect_timer_with_status_logs(caplog) + assert "dynamic=yes" in logs[0] + + def test_both_decorator_forms(self): + """@timed_with_status and @timed_with_status() should both work.""" + + @timed_with_status + def bare(): + return 1 + + @timed_with_status() + def parens(): + return 2 + + assert bare() == 1 + assert parens() == 2