diff --git a/.gitignore b/.gitignore index c972d59de..51e2f7ab4 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,8 @@ report/ cov-report/ .tox/ .nox/ +report/ +cov-report/ .coverage .coverage.* .cache diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c68deae5a..69efedeb3 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -369,6 +369,26 @@ def get_memreader_config() -> dict[str, Any]: return {"backend": "openai", "config": config} + @staticmethod + def get_qwen_llm_config() -> dict[str, Any] | None: + if not os.getenv("QWEN_API_KEY"): + return None + return { + "backend": "qwen", + "config": { + "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), + "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), + "top_k": int(os.getenv("QWEN_TOP_K", "50")), + "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() + == "true", + "api_key": os.getenv("QWEN_API_KEY", ""), + "api_base": os.getenv("QWEN_API_BASE", ""), + "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), + }, + } + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. @@ -639,6 +659,7 @@ def get_oss_config() -> dict[str, Any] | None: return config + @staticmethod def get_internet_config() -> dict[str, Any]: """Get internet retriever configuration. @@ -705,8 +726,9 @@ def get_internet_config() -> dict[str, Any]: @staticmethod def get_nli_config() -> dict[str, Any]: - """Get NLI model configuration.""" + """Get relation-judge configuration for memory-version candidate matching.""" return { + "provider": os.getenv("MEM_VERSION_RELATION_JUDGE_PROVIDER", "llm"), "base_url": os.getenv("NLI_MODEL_BASE_URL", "http://localhost:32532"), } @@ -872,7 +894,7 @@ def get_scheduler_config() -> dict[str, Any]: ), "context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")), "thread_pool_max_workers": int( - os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "200") + os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "50") ), "consume_interval_seconds": float( os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") @@ -952,6 +974,7 @@ def get_product_default_config() -> dict[str, Any]: "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), + "qwen_llm": APIConfig.get_qwen_llm_config(), # General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.) "general_llm": APIConfig.get_memreader_general_llm_config(), # Image parser LLM (requires vision model) @@ -986,6 +1009,7 @@ def get_product_default_config() -> dict[str, Any]: "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" ), }, + "memory_version_switch": os.getenv("MEM_READER_MEM_VERSION_SWITCH", "off"), }, }, "enable_textual_memory": True, diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e9ed4f955..4240545f6 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -37,6 +38,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index a01fffef8..3536cee09 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -32,14 +32,14 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.memories.textual.simple_tree import SimpleTreeTextMemory -from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.plugins.component_bootstrap import build_plugin_context +from memos.plugins.manager import plugin_manager if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory -from memos.extras.nli_model.client import NLIClient from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -122,6 +122,12 @@ def init_server() -> dict[str, Any]: existing code that uses the components. """ logger.info("Initializing MemOS server components...") + logger.info( + "[INIT_SERVER] env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX=%s, env_POLAR_DB_DB_NAME=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), + os.getenv("POLAR_DB_DB_NAME"), + ) # Initialize Redis client first as it is a core dependency for features like scheduler status tracking if os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true": @@ -169,10 +175,25 @@ def init_server() -> dict[str, Any]: else None ) embedder = EmbedderFactory.from_config(embedder_config) - nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + + plugin_context = build_plugin_context( + graph_db=graph_db, + embedder=embedder, + default_cube_config=default_cube_config, + nli_client_config=nli_client_config, + mem_reader_config=mem_reader_config, + reranker_config=reranker_config, + feedback_reranker_config=feedback_reranker_config, + internet_retriever_config=internet_retriever_config, + ) + plugin_manager.discover() + plugin_manager.init_components(plugin_context) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( @@ -303,6 +324,4 @@ def init_server() -> dict[str, Any]: "feedback_server": feedback_server, "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, - "nli_client": nli_client, - "memory_history_manager": memory_history_manager, } diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index ee88ae639..c8024baa3 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -141,6 +141,7 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): sources = item.get("metadata", {}).get("sources", []) if ( item["metadata"]["memory_type"] != "RawFileMemory" + and sources and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 1f1e6ccde..a9afe554c 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -9,13 +9,21 @@ from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +from memos.plugins.manager import plugin_manager load_dotenv() +plugin_manager.discover() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +logger.info( + "[SERVER_API] load_dotenv completed. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), +) app = FastAPI( title="MemOS Server REST APIs", @@ -49,6 +57,8 @@ def health_check(): # Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) +plugin_manager.init_app(app) + if __name__ == "__main__": import argparse diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index e695d0d9a..37e943151 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -53,4 +53,5 @@ def chunk(self, text: str) -> list[str] | list[Chunk]: chunks.append(chunk) logger.debug(f"Generated {len(chunks)} chunks from input text") + return chunks diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index d4844d73f..9ed791fa8 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -76,6 +76,13 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) + + # Allow passing additional fields without raising validation errors + model_config = ConfigDict(extra="allow", strict=True) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py deleted file mode 100644 index a02dae9f6..000000000 --- a/src/memos/extras/nli_model/client.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -import requests - -from memos.extras.nli_model.types import NLIResult - - -logger = logging.getLogger(__name__) - - -class NLIClient: - """ - Client for interacting with the deployed NLI model service. - """ - - def __init__(self, base_url: str = "http://localhost:32532"): - self.base_url = base_url.rstrip("/") - self.session = requests.Session() - - def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: - """ - Compare one source text against multiple target memories using the NLI service. - - Args: - source: The new memory content. - targets: List of existing memory contents to compare against. - - Returns: - List of NLIResult corresponding to each target. - """ - if not targets: - return [] - - url = f"{self.base_url}/compare_one_to_many" - # Match schemas.CompareRequest - payload = {"source": source, "targets": targets} - - try: - response = self.session.post(url, json=payload, timeout=30) - response.raise_for_status() - data = response.json() - - # Match schemas.CompareResponse - results_str = data.get("results", []) - - results = [] - for res_str in results_str: - try: - results.append(NLIResult(res_str)) - except ValueError: - logger.warning( - f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" - ) - results.append(NLIResult.UNRELATED) - - return results - - except requests.RequestException as e: - logger.error(f"[NLIClient] Request failed: {e}") - # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. - return [NLIResult.UNRELATED] * len(targets) diff --git a/src/memos/extras/nli_model/server/README.md b/src/memos/extras/nli_model/server/README.md deleted file mode 100644 index f6886e0e4..000000000 --- a/src/memos/extras/nli_model/server/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# NLI Model Server - -This directory contains the standalone server for the Natural Language Inference (NLI) model used by MemOS. - -## Prerequisites - -- Python 3.10+ -- CUDA-capable GPU (Recommended for performance) -- `torch` and `transformers` libraries (required for the server) - -## Running the Server - -You can run the server using the module syntax from the project root to ensure imports work correctly. - -### 1. Basic Start -```bash -python -m memos.extras.nli_model.server.serve -``` - -### 2. Configuration -You can configure the server by editing config.py: - -- `HOST`: The host to bind to (default: `0.0.0.0`) -- `PORT`: The port to bind to (default: `32532`) -- `NLI_DEVICE`: The device to run the model on. - - `cuda` (Default, uses cuda:0 if available, else fallback to mps/cpu) - - `cuda:0` (Specific GPU) - - `mps` (Apple Silicon) - - `cpu` (CPU) - -## API Usage - -### Compare One to Many -**POST** `/compare_one_to_many` - -**Request Body:** -```json -{ - "source": "I just ate an apple.", - "targets": [ - "I ate a fruit.", - "I hate apples.", - "The sky is blue." - ] -} -``` - -## Testing - -An end-to-end example script is provided to verify the server's functionality. This script starts the server locally and runs a client request to verify the NLI logic. - -### End-to-End Test - -Run the example script from the project root: - -```bash -python examples/extras/nli_e2e_example.py -``` - -**Response:** -```json -{ - "results": [ - "Duplicate", // Entailment - "Contradiction", // Contradiction - "Unrelated" // Neutral - ] -} -``` diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py deleted file mode 100644 index d2e12175d..000000000 --- a/src/memos/extras/nli_model/server/config.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - - -NLI_MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" - -# Configuration -# You can set the device directly here. -# Examples: -# - "cuda" : Use default GPU (cuda:0) if available, else auto-fallback -# - "cuda:0" : Use specific GPU -# - "mps" : Use Apple Silicon GPU (if available) -# - "cpu" : Use CPU -NLI_DEVICE = "cuda" -NLI_MODEL_HOST = "0.0.0.0" -NLI_MODEL_PORT = 32532 - -# Configure logging for NLI Server -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s | %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler(), logging.FileHandler("nli_server.log")], -) -logger = logging.getLogger("nli_server") diff --git a/src/memos/extras/nli_model/server/handler.py b/src/memos/extras/nli_model/server/handler.py deleted file mode 100644 index 3e98ddeb0..000000000 --- a/src/memos/extras/nli_model/server/handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import re - -from memos.extras.nli_model.server.config import NLI_MODEL_NAME, logger -from memos.extras.nli_model.types import NLIResult - - -# Placeholder for lazy imports -torch = None -AutoModelForSequenceClassification = None -AutoTokenizer = None - - -def _map_label_to_result(raw: str) -> NLIResult: - t = raw.lower() - if "entail" in t: - return NLIResult.DUPLICATE - if "contrad" in t or "refut" in t: - return NLIResult.CONTRADICTION - # Neutral or unknown - return NLIResult.UNRELATED - - -def _clean_temporal_markers(s: str) -> str: - # Remove temporal/aspect markers that might cause contradiction - # Chinese markers (simple replace is usually okay as they are characters) - zh_markers = ["刚刚", "曾经", "正在", "目前", "现在"] - for m in zh_markers: - s = s.replace(m, "") - - # English markers (need word boundaries to avoid "snow" -> "s") - en_markers = ["just", "once", "currently", "now"] - pattern = r"\b(" + "|".join(en_markers) + r")\b" - s = re.sub(pattern, "", s, flags=re.IGNORECASE) - - # Cleanup extra spaces - s = re.sub(r"\s+", " ", s).strip() - return s - - -class NLIHandler: - """ - NLI Model Handler for inference. - Requires `torch` and `transformers` to be installed. - """ - - def __init__(self, device: str = "cpu", use_fp16: bool = True, use_compile: bool = True): - global torch, AutoModelForSequenceClassification, AutoTokenizer - try: - import torch - - from transformers import AutoModelForSequenceClassification, AutoTokenizer - except ImportError as e: - raise ImportError( - "NLIHandler requires 'torch' and 'transformers'. " - "Please install them via 'pip install torch transformers' or use the requirements.txt." - ) from e - - self.device = self._resolve_device(device) - logger.info(f"Final resolved device: {self.device}") - - # Set defaults based on device if not explicitly provided - is_cuda = "cuda" in self.device - if not is_cuda: - use_fp16 = False - use_compile = False - - self.tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_NAME) - - model_kwargs = {} - if use_fp16 and is_cuda: - model_kwargs["torch_dtype"] = torch.float16 - - self.model = AutoModelForSequenceClassification.from_pretrained( - NLI_MODEL_NAME, **model_kwargs - ).to(self.device) - self.model.eval() - - self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} - self.softmax = torch.nn.Softmax(dim=-1).to(self.device) - - if use_compile and hasattr(torch, "compile"): - logger.info("Compiling model with torch.compile...") - self.model = torch.compile(self.model) - - def _resolve_device(self, device: str) -> str: - d = device.strip().lower() - - has_cuda = torch.cuda.is_available() - has_mps = torch.backends.mps.is_available() if hasattr(torch.backends, "mps") else False - - if d == "cpu": - return "cpu" - - if d.startswith("cuda"): - if has_cuda: - if d == "cuda": - return "cuda:0" - return d - - # Fallback if CUDA not available - if has_mps: - logger.warning( - f"Device '{device}' requested but CUDA not available. Fallback to MPS." - ) - return "mps" - - logger.warning( - f"Device '{device}' requested but CUDA/MPS not available. Fallback to CPU." - ) - return "cpu" - - if d == "mps": - if has_mps: - return "mps" - - logger.warning(f"Device '{device}' requested but MPS not available. Fallback to CPU.") - return "cpu" - - # Fallback / Auto-detect for other cases (e.g. "gpu" or unknown) - if has_cuda: - return "cuda:0" - if has_mps: - return "mps" - - return "cpu" - - def predict_batch(self, premises: list[str], hypotheses: list[str]) -> list[NLIResult]: - # Clean inputs - premises = [_clean_temporal_markers(p) for p in premises] - hypotheses = [_clean_temporal_markers(h) for h in hypotheses] - - # Batch tokenize with padding - inputs = self.tokenizer( - premises, hypotheses, return_tensors="pt", truncation=True, max_length=512, padding=True - ).to(self.device) - with torch.no_grad(): - out = self.model(**inputs) - probs = self.softmax(out.logits) - - results = [] - for p in probs: - idx = int(torch.argmax(p).item()) - res = self.id2label.get(idx, str(idx)) - results.append(_map_label_to_result(res)) - return results - - def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: - """ - Compare one source text against multiple target memories efficiently using batch processing. - Performs bidirectional checks (Source <-> Target) for each pair. - """ - if not targets: - return [] - - n = len(targets) - # Construct batch: - # First n pairs: Source -> Target_i - # Next n pairs: Target_i -> Source - premises = [source] * n + targets - hypotheses = targets + [source] * n - - # Run single large batch inference - raw_results = self.predict_batch(premises, hypotheses) - - # Split results back - results_ab = raw_results[:n] - results_ba = raw_results[n:] - - final_results = [] - for i in range(n): - res_ab = results_ab[i] - res_ba = results_ba[i] - - # 1. Any Contradiction -> Contradiction (Sensitive detection, filtered by LLM later) - if res_ab == NLIResult.CONTRADICTION or res_ba == NLIResult.CONTRADICTION: - final_results.append(NLIResult.CONTRADICTION) - - # 2. Any Entailment -> Duplicate (as per user requirement) - elif res_ab == NLIResult.DUPLICATE or res_ba == NLIResult.DUPLICATE: - final_results.append(NLIResult.DUPLICATE) - - # 3. Otherwise (Both Neutral) -> Unrelated - else: - final_results.append(NLIResult.UNRELATED) - - return final_results diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py deleted file mode 100644 index 0ed9eae65..000000000 --- a/src/memos/extras/nli_model/server/serve.py +++ /dev/null @@ -1,44 +0,0 @@ -from contextlib import asynccontextmanager - -import uvicorn - -from fastapi import FastAPI, HTTPException - -from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT -from memos.extras.nli_model.server.handler import NLIHandler -from memos.extras.nli_model.types import CompareRequest, CompareResponse - - -# Global handler instance -nli_handler: NLIHandler | None = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global nli_handler - nli_handler = NLIHandler(device=NLI_DEVICE) - yield - # Clean up if needed - nli_handler = None - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/compare_one_to_many", response_model=CompareResponse) -async def compare_one_to_many(request: CompareRequest): - if nli_handler is None: - raise HTTPException(status_code=503, detail="Model not loaded") - try: - results = nli_handler.compare_one_to_many(request.source, request.targets) - return CompareResponse(results=results) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - -def start_server(host: str = "0.0.0.0", port: int = 32532): - uvicorn.run(app, host=host, port=port) - - -if __name__ == "__main__": - start_server(host=NLI_MODEL_HOST, port=NLI_MODEL_PORT) diff --git a/src/memos/extras/nli_model/types.py b/src/memos/extras/nli_model/types.py deleted file mode 100644 index 619f8f508..000000000 --- a/src/memos/extras/nli_model/types.py +++ /dev/null @@ -1,18 +0,0 @@ -from enum import Enum - -from pydantic import BaseModel - - -class NLIResult(Enum): - DUPLICATE = "Duplicate" - CONTRADICTION = "Contradiction" - UNRELATED = "Unrelated" - - -class CompareRequest(BaseModel): - source: str - targets: list[str] - - -class CompareResponse(BaseModel): - results: list[NLIResult] diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index abfae7710..856f94f2a 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -899,13 +899,7 @@ def get_node( logger.info( f" polardb [get_node] get_node completed time in {elapsed_time:.2f}s" ) - return self._parse_node( - { - "id": id, - "memory": properties.get("memory", ""), - **properties, - } - ) + return self._parse_node(properties) return None except Exception as e: @@ -976,15 +970,7 @@ def get_nodes(self, ids: list[str], user_name: str, **kwargs) -> list[dict[str, properties["embedding"] = embedding except (json.JSONDecodeError, TypeError): logger.warning(f"Failed to parse embedding for node {node_id}") - nodes.append( - self._parse_node( - { - "id": properties.get("id", node_id), - "memory": properties.get("memory", ""), - "metadata": properties, - } - ) - ) + nodes.append(self._parse_node(properties)) return nodes @timed @@ -1534,7 +1520,7 @@ def search_by_keywords_like( user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, + default_user_name=self._get_config_value("user_name"), ) # Add OR condition if we have any user_name conditions @@ -1633,7 +1619,7 @@ def search_by_keywords_tfidf( user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, + default_user_name=self._get_config_value("user_name"), ) # Add OR condition if we have any user_name conditions @@ -1751,7 +1737,7 @@ def search_by_fulltext( user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, + default_user_name=self._get_config_value("user_name"), ) if user_name_conditions: @@ -1873,7 +1859,7 @@ def search_by_embedding( user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql( user_name=user_name, knowledgebase_ids=knowledgebase_ids, - default_user_name=self.config.user_name, + default_user_name=self._get_config_value("user_name"), ) if user_name_conditions: diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index a57a40676..991541156 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -33,6 +33,8 @@ extract_working_binding_ids, ) from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager +from memos.plugins.hook_defs import H +from memos.plugins.hooks import trigger_single_hook if TYPE_CHECKING: @@ -245,6 +247,7 @@ def _single_add_operation( datetime.now().isoformat() ) to_add_memory.metadata.background = new_memory_item.metadata.background + to_add_memory.metadata.sources = new_memory_item.metadata.sources added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) @@ -288,33 +291,43 @@ def _single_update_operation( new_memory_item.memory = operation["text"] new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - if memory_type == "WorkingMemory": - fields = { - "memory": new_memory_item.memory, - "key": new_memory_item.metadata.key, - "tags": new_memory_item.metadata.tags, - "embedding": new_memory_item.metadata.embedding, - "background": new_memory_item.metadata.background, - "covered_history": old_memory_item.id, - } - self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) - item_id = old_memory_item.id - else: - done = self._single_add_operation( - old_memory_item, new_memory_item, user_id, user_name, async_mode + if getattr(self.mem_reader, "memory_version_switch", "off") != "on": + if memory_type == "WorkingMemory": + fields = { + "memory": new_memory_item.memory, + "key": new_memory_item.metadata.key, + "tags": new_memory_item.metadata.tags, + "embedding": new_memory_item.metadata.embedding, + "background": new_memory_item.metadata.background, + "covered_history": old_memory_item.id, + } + self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) + item_id = old_memory_item.id + else: + done = self._single_add_operation( + old_memory_item, new_memory_item, user_id, user_name, async_mode + ) + item_id = done.get("id") + self.graph_store.update_node( + item_id, {"covered_history": old_memory_item.id}, user_name=user_name + ) + self.graph_store.update_node( + old_memory_item.id, {"status": "archived"}, user_name=user_name + ) + + logger.info( + f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" ) - item_id = done.get("id") - self.graph_store.update_node( - item_id, {"covered_history": old_memory_item.id}, user_name=user_name + else: + item_id = self._single_update_operation_with_versions( + old_memory_item=old_memory_item, + new_memory_item=new_memory_item, + user_name=user_name, ) - self.graph_store.update_node( - old_memory_item.id, {"status": "archived"}, user_name=user_name + logger.info( + f"[Memory Feedback UPDATE] Updated:{item_id} | history appended | memory_type: {old_memory_item.metadata.memory_type}" ) - logger.info( - f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" - ) - return { "id": item_id, "text": new_memory_item.memory, @@ -323,6 +336,78 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_operation_with_versions( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_name: str, + ) -> str: + try: + updated_item, archived_item, archived_metadata, updated_fields = trigger_single_hook( + H.MEMORY_VERSION_APPLY_FEEDBACK_UPDATE, + old_item=old_memory_item, + new_item=new_memory_item, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] history fallback for %s: %s", old_memory_item.id, e + ) + updated_item = old_memory_item.model_copy(deep=True) + updated_item.memory = new_memory_item.memory + updated_item.metadata.key = new_memory_item.metadata.key + updated_item.metadata.tags = new_memory_item.metadata.tags + updated_item.metadata.background = new_memory_item.metadata.background + if getattr(new_memory_item.metadata, "sources", None) is not None: + current_sources = list(updated_item.metadata.sources or []) + updated_item.metadata.sources = ( + list(new_memory_item.metadata.sources or []) + current_sources + ) + if getattr(new_memory_item.metadata, "embedding", None) is not None: + updated_item.metadata.embedding = new_memory_item.metadata.embedding + if updated_item.metadata.memory_type == "PreferenceMemory": + updated_item.metadata.preference = updated_item.memory + updated_fields = { + "memory": updated_item.memory, + "key": updated_item.metadata.key, + "tags": updated_item.metadata.tags, + "embedding": updated_item.metadata.embedding, + "background": updated_item.metadata.background, + "sources": [ + source.model_dump(exclude_none=True) + if hasattr(source, "model_dump") + else source + for source in (updated_item.metadata.sources or []) + ], + "covered_history": old_memory_item.id, + } + archived_item = None + archived_metadata = None + + if archived_item and archived_metadata: + try: + self.graph_store.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_metadata, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] archive add failed for %s: %s", + old_memory_item.id, + e, + ) + self._retry_db_operation( + lambda: self.graph_store.update_node( + id=updated_item.id, + fields=updated_fields, + user_name=user_name, + ) + ) + self._del_working_binding(user_name, [old_memory_item]) + return updated_item.id + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -331,9 +416,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> f"[Memory Feedback UPDATE] Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" ) - delete_ids = [] - if bindings_to_delete: - delete_ids = list({bindings_to_delete}) + delete_ids = list(bindings_to_delete) for mid in delete_ids: try: @@ -346,6 +429,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> logger.warning( f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) + return bindings_to_delete def semantics_feedback( self, @@ -476,9 +560,22 @@ def semantics_feedback( f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) - if update_results: - updated_ids = [item["archived_id"] for item in update_results] - self._del_working_binding(updated_ids, user_name) + if update_results and getattr(self.mem_reader, "memory_version_switch", "off") != "on": + archived_ids = [item["archived_id"] for item in update_results] + archived_items = [] + for aid in archived_ids: + try: + node = self.graph_store.get_node(aid, user_name=user_name) + if node: + archived_items.append(TextualMemoryItem(**node)) + except Exception as e: + logger.warning( + "[Memory Feedback] Failed to fetch archived item %s for working_binding cleanup: %s", + aid, + e, + ) + if archived_items: + self._del_working_binding(user_name, archived_items) return {"record": {"add": add_results, "update": update_results}} @@ -1066,7 +1163,14 @@ def check_validity(item): tags=tags, key=key, embedding=embedding, - sources=[{"type": "chat"}], + sources=[ + { + "type": "feedback", + "role": "user", + "chat_time": feedback_time, + "content": feedback_content, + } + ], background=background, type="fine", info=info, diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index f6e33bb31..3991d0667 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1399,5 +1399,18 @@ def clean_json_response(response: str) -> str: Returns: str: Clean JSON string without markdown formatting + + Raises: + ValueError: If ``response`` is None. This is almost always an upstream + failure (LLM call returned None instead of raising) and surfacing + it here is much easier to diagnose than the implicit + ``AttributeError: 'NoneType' object has no attribute 'replace'`` + that would otherwise be thrown. """ + if response is None: + raise ValueError( + "clean_json_response received None — upstream LLM call likely " + "failed silently (check timed_with_status / generate() error " + "handling)." + ) return response.replace("```json", "").replace("```", "").strip() diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 092f29ac6..26287ff4f 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -2,6 +2,7 @@ import json import re import traceback +import uuid from typing import TYPE_CHECKING, Any @@ -15,10 +16,12 @@ from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.plugins.hook_defs import H +from memos.plugins.hooks import trigger_single_hook from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType -from memos.utils import timed +from memos.utils import timed, timed_stage if TYPE_CHECKING: @@ -58,6 +61,8 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Image parser LLM (requires vision model) # Falls back to general_llm if not configured (general_llm itself falls back to main llm) self.image_parser_llm = ( @@ -75,6 +80,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig): direct_markdown_hostnames=direct_markdown_hostnames, ) + def _embed_memory_items(self, items: list[TextualMemoryItem]) -> None: + """Compute embeddings for a list of memory items in-place. + + Attempts a single batch call first; falls back to per-item calls if the + batch fails. Errors are logged but never raised so callers always + continue normally. + """ + valid = [w for w in items if w and w.memory] + if not valid: + return + texts = [w.memory for w in valid] + try: + embeddings = self.embedder.embed(texts) + for w, emb in zip(valid, embeddings, strict=True): + w.metadata.embedding = emb + except Exception as e: + logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}") + logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts)) + for w in valid: + try: + w.metadata.embedding = self.embedder.embed([w.memory])[0] + except Exception as e2: + logger.error(f"[MultiModalStruct] Error computing embedding for item: {e2}") + def _split_large_memory_item( self, item: TextualMemoryItem, max_tokens: int ) -> list[TextualMemoryItem]: @@ -100,20 +129,32 @@ def _split_large_memory_item( try: chunks = self.chunker.chunk(item_text) split_items = [] - - def _create_chunk_item(chunk): - # Chunk objects have a 'text' attribute - chunk_text = chunk.text + source_info = dict(item.metadata.info or {}) + source_internal_info = dict(item.metadata.internal_info or {}) + ingest_batch_id = str(source_internal_info.get("ingest_batch_id") or uuid.uuid4()) + chunk_total = len(chunks) + + def _create_chunk_item(chunk_idx: int, chunk): + # Different chunkers are not fully consistent: + # some return Chunk-like objects with `.text`, while others return raw strings. + chunk_text = chunk.text if hasattr(chunk, "text") else chunk if not chunk_text or not chunk_text.strip(): return None + chunk_info = { + "user_id": item.metadata.user_id, + "session_id": item.metadata.session_id, + **source_info, + } + chunk_internal_info = { + **source_internal_info, + "ingest_batch_id": ingest_batch_id, + "chunk_index": chunk_idx, + "chunk_total": chunk_total, + } # Create a new memory item for each chunk, preserving original metadata split_item = self._make_memory_item( value=chunk_text, - info={ - "user_id": item.metadata.user_id, - "session_id": item.metadata.session_id, - **(item.metadata.info or {}), - }, + info=chunk_info, memory_type=item.metadata.memory_type, tags=item.metadata.tags or [], key=item.metadata.key, @@ -121,11 +162,15 @@ def _create_chunk_item(chunk): background=item.metadata.background or "", need_embed=False, ) + split_item.metadata.internal_info = chunk_internal_info return split_item # Use thread pool to parallel process chunks, but keep the original order with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks] + futures = [ + executor.submit(_create_chunk_item, chunk_idx, chunk) + for chunk_idx, chunk in enumerate(chunks) + ] for future in futures: split_item = future.result() if split_item is not None: @@ -203,13 +248,8 @@ def _concat_multi_modal_memories( # If only one item after processing, compute embedding and return if len(processed_items) == 1: single_item = processed_items[0] - if single_item and single_item.memory: - try: - single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0] - except Exception as e: - logger.error( - f"[MultiModalStruct] Error computing embedding for single item: {e}" - ) + with timed_stage("add", "embedding", window_count=1): + self._embed_memory_items([single_item]) return processed_items windows = [] @@ -260,31 +300,8 @@ def _concat_multi_modal_memories( windows.append(window) # Batch compute embeddings for all windows - if windows: - # Collect all valid windows that need embedding - valid_windows = [w for w in windows if w and w.memory] - - if valid_windows: - # Collect all texts that need embedding - texts_to_embed = [w.memory for w in valid_windows] - - # Batch compute all embeddings at once - try: - embeddings = self.embedder.embed(texts_to_embed) - # Fill embeddings back into memory items - for window, embedding in zip(valid_windows, embeddings, strict=True): - window.metadata.embedding = embedding - except Exception as e: - logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}") - # Fallback: compute embeddings individually - for window in valid_windows: - if window.memory: - try: - window.metadata.embedding = self.embedder.embed([window.memory])[0] - except Exception as e2: - logger.error( - f"[MultiModalStruct] Error computing embedding for item: {e2}" - ) + with timed_stage("add", "embedding", window_count=len(windows)): + self._embed_memory_items(windows) return windows @@ -309,6 +326,7 @@ def _build_window_from_items( all_sources = [] roles = set() aggregated_file_ids: list[str] = [] + ingest_batch_ids: set[str] = set() for item in items: if item.memory: @@ -337,6 +355,11 @@ def _build_window_from_items( for fid in item_file_ids: if fid and fid not in aggregated_file_ids: aggregated_file_ids.append(fid) + item_internal_info = getattr(metadata, "internal_info", None) + if isinstance(item_internal_info, dict): + batch_id = item_internal_info.get("ingest_batch_id") + if batch_id: + ingest_batch_ids.add(str(batch_id)) # Determine memory_type based on roles (same logic as simple_struct) # UserMemory if only user role, else LongTermMemory @@ -371,7 +394,6 @@ def _build_window_from_items( info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") - # Create memory item without embedding (set to None, will be filled in batch) aggregated_item = TextualMemoryItem( memory=merged_text, @@ -392,6 +414,10 @@ def _build_window_from_items( **extra_kwargs, ), ) + if len(ingest_batch_ids) == 1: + aggregated_item.metadata.internal_info = { + "ingest_batch_id": next(iter(ingest_batch_ids)) + } return aggregated_item @@ -461,6 +487,9 @@ def _get_llm_response( if self.config.remove_prompt_example and examples: prompt = prompt.replace(examples, "") + + logger.info(f"[MultiModalParser] Process String Fine Prompt: {prompt}") + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -509,6 +538,7 @@ def _get_maybe_merged_memory( sources: list, **kwargs, ) -> dict: + # TODO: delete this function """ Check if extracted memory should be merged with similar existing memories. If merge is needed, return merged memory dict with merged_from field. @@ -523,102 +553,7 @@ def _get_maybe_merged_memory( Returns: Memory dict (possibly merged) with merged_from field if merged """ - # If no graph_db or user_name, return original - if not self.graph_db or "user_name" not in kwargs: - return extracted_memory_dict - user_name = kwargs.get("user_name") - - # Detect language - lang = "en" - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(mem_text) - - # Search for similar memories - merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) - - try: - search_results = self.graph_db.search_by_embedding( - vector=self.embedder.embed(mem_text)[0], - top_k=20, - status="activated", - threshold=merge_threshold, - user_name=user_name, - ) - - if not search_results: - return extracted_memory_dict - - # Get full memory details - similar_memory_ids = [r["id"] for r in search_results if r.get("id")] - similar_memories_list = [ - self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) - for mem_id in similar_memory_ids - ] - - # Filter out None and mode:fast memories - filtered_similar = [] - for mem in similar_memories_list: - if not mem: - continue - mem_metadata = mem.get("metadata", {}) - tags = mem_metadata.get("tags", []) - if isinstance(tags, list) and "mode:fast" in tags: - continue - filtered_similar.append( - { - "id": mem.get("id"), - "memory": mem.get("memory", ""), - } - ) - logger.info( - f"Valid similar memories for {mem_text} is " - f"{len(filtered_similar)}: {filtered_similar}" - ) - - if not filtered_similar: - return extracted_memory_dict - - # Create a temporary TextualMemoryItem for merge check - temp_memory_item = TextualMemoryItem( - memory=mem_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id="", - session_id="", - memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), - status="activated", - tags=extracted_memory_dict.get("tags", []), - key=extracted_memory_dict.get("key", ""), - ), - ) - - # Try to merge with LLM - merge_result = self._merge_memories_with_llm( - temp_memory_item, filtered_similar, lang=lang - ) - - if merge_result: - # Return merged memory dict - merged_dict = extracted_memory_dict.copy() - merged_content = merge_result.get("value", mem_text) - merged_dict["value"] = merged_content - merged_from_ids = merge_result.get("merged_from", []) - merged_dict["merged_from"] = merged_from_ids - return merged_dict - else: - return extracted_memory_dict - - except Exception as e: - logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") - # On error, return original - return extracted_memory_dict + return extracted_memory_dict def _merge_memories_with_llm( self, @@ -720,6 +655,35 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory version async extraction/update pipeline ========== + if getattr(self, "memory_version_switch", "off") == "on": + try: + user_name = kwargs.get("user_name") + should_use_version_pipeline = trigger_single_hook( + H.MEMORY_VERSION_PREPARE_UPDATES, + item=fast_item, + user_name=user_name, + judge_llm=self.general_llm, + ) + if should_use_version_pipeline: + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang] + new_items = trigger_single_hook( + H.MEMORY_VERSION_APPLY_UPDATES, + item=fast_item, + user_name=user_name, + version_llm=self.qwen_llm, + merge_llm=self.general_llm, + custom_tags=custom_tags, + custom_tags_prompt_template=custom_tags_prompt_template, + timeout_sec=30, + ) + return new_items + except RuntimeError as ex: + logger.warning(f"[MultiModalFine] Memory version hook unavailable: {ex}") + except Exception as ex: + logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}") + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -730,14 +694,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if getattr(self, "memory_version_switch", "off") != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -755,8 +720,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in m_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -765,13 +732,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if getattr(self, "memory_version_switch", "off") != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -782,8 +751,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in resp_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) @@ -984,49 +955,49 @@ def _process_multi_modal_data( # must pop here, avoid add to info, only used in sync fine mode custom_tags = info.pop("custom_tags", None) if isinstance(info, dict) else None - # Use MultiModalParser to parse the scene data - # If it's a list, parse each item; otherwise parse as single message - if isinstance(scene_data_info, list): - # Pre-expand multimodal messages - expanded_messages = self._expand_multimodal_messages(scene_data_info) - - # Parse each message in the list - all_memory_items = [] - # Use thread pool to parse each message in parallel, but keep the original order - with ContextThreadPoolExecutor(max_workers=30) as executor: - # submit tasks and keep the original order - futures = [ - executor.submit( - self.multi_modal_parser.parse, - msg, - info, - mode="fast", - need_emb=False, - **kwargs, - ) - for msg in expanded_messages - ] - # collect results in original order - for future in futures: - try: - items = future.result() - all_memory_items.extend(items) - except Exception as e: - logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") - else: - # Parse as single message - all_memory_items = self.multi_modal_parser.parse( - scene_data_info, info, mode="fast", need_emb=False, **kwargs - ) - fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + # Stage: parse — parallel message parsing + sliding-window aggregation + with timed_stage("add", "parse") as ts_parse: + if isinstance(scene_data_info, list): + expanded_messages = self._expand_multimodal_messages(scene_data_info) + ts_parse.set(msg_count=len(expanded_messages)) + + all_memory_items = [] + with ContextThreadPoolExecutor(max_workers=30) as executor: + futures = [ + executor.submit( + self.multi_modal_parser.parse, + msg, + info, + mode="fast", + need_emb=False, + **kwargs, + ) + for msg in expanded_messages + ] + for future in futures: + try: + items = future.result() + all_memory_items.extend(items) + except Exception as e: + logger.error(f"[MultiModalFine] Error in parallel parsing: {e}") + else: + ts_parse.set(msg_count=1) + all_memory_items = self.multi_modal_parser.parse( + scene_data_info, info, mode="fast", need_emb=False, **kwargs + ) + + fast_memory_items = self._concat_multi_modal_memories(all_memory_items) + ts_parse.set(window_count=len(fast_memory_items)) + if mode == "fast": return fast_memory_items - else: - non_file_url_fast_items = [ - item for item in fast_memory_items if not self._is_file_url_only_item(item) - ] - # Part A: call llm in parallel using thread pool + # Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial + non_file_url_fast_items = [ + item for item in fast_memory_items if not self._is_file_url_only_item(item) + ] + + with timed_stage("add", "llm_extract") as ts_llm: fine_memory_items = [] with ContextThreadPoolExecutor(max_workers=4) as executor: @@ -1057,7 +1028,6 @@ def _process_multi_modal_data( **kwargs, ) - # Collect results fine_memory_items_string_parser = future_string.result() fine_memory_items_tool_trajectory_parser = future_tool.result() fine_memory_items_skill_memory_parser = future_skill.result() @@ -1068,21 +1038,25 @@ def _process_multi_modal_data( fine_memory_items.extend(fine_memory_items_skill_memory_parser) fine_memory_items.extend(fine_memory_items_pref_parser) - # Part B: get fine multimodal items - for fast_item in fast_memory_items: - sources = fast_item.metadata.sources - for source in sources: - lang = getattr(source, "lang", "en") - items = self.multi_modal_parser.process_transfer( - source, - context_items=[fast_item], - custom_tags=custom_tags, - info=info, - lang=lang, - user_context=kwargs.get("user_context"), - ) - fine_memory_items.extend(items) - return fine_memory_items + # Part B: per-source serial processing + with timed_stage("add", "per_source") as ts_ps: + for fast_item in fast_memory_items: + sources = fast_item.metadata.sources + for source in sources: + lang = getattr(source, "lang", "en") + items = self.multi_modal_parser.process_transfer( + source, + context_items=[fast_item], + custom_tags=custom_tags, + info=info, + lang=lang, + user_context=kwargs.get("user_context"), + ) + fine_memory_items.extend(items) + + ts_llm.set(fine_memory_count=len(fine_memory_items), per_source_ms=ts_ps.duration_ms) + + return fine_memory_items @timed def _process_transfer_multi_modal_data( diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f26be360c..3c82350df 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -11,6 +11,7 @@ from memos import log from memos.chunkers import ChunkerFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory @@ -183,6 +184,15 @@ def __init__(self, config: SimpleStructMemReaderConfig): if config.general_llm is not None else self.llm ) + self.qwen_llm = None + qwen_llm_config = getattr(config, "qwen_llm", None) + if qwen_llm_config: + try: + if isinstance(qwen_llm_config, dict): + qwen_llm_config = LLMConfigFactory.model_validate(qwen_llm_config) + self.qwen_llm = LLMFactory.from_config(qwen_llm_config) + except Exception as e: + logger.warning(f"[LLM] Qwen initialization failed: {e}") self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) self.save_rawfile = self.chunker.config.save_rawfile diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 671190e6f..56614737d 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -24,6 +24,8 @@ InternetRetrieverFactory, ) from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.plugins.component_bootstrap import build_plugin_context +from memos.plugins.manager import plugin_manager if TYPE_CHECKING: @@ -171,6 +173,16 @@ def build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) +def build_nli_client_config() -> dict[str, Any]: + """ + Build NLI client configuration. + + Returns: + NLI client configuration dictionary + """ + return APIConfig.get_nli_config() + + def _get_default_memory_size(cube_config: Any) -> dict[str, int]: """ Get default memory size configuration. @@ -246,6 +258,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = build_nli_client_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -257,8 +270,25 @@ def init_components() -> dict[str, Any]: graph_db = GraphStoreFactory.from_config(graph_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + + plugin_manager.discover() + plugin_context = build_plugin_context( + graph_db=graph_db, + embedder=embedder, + default_cube_config=default_cube_config, + nli_client_config=nli_client_config, + mem_reader_config=mem_reader_config, + reranker_config=reranker_config, + feedback_reranker_config=feedback_reranker_config, + internet_retriever_config=internet_retriever_config, + ) + plugin_manager.init_components(plugin_context) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index af0f2f233..5b52b6c22 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -75,11 +75,30 @@ class TaskPriorityLevel(Enum): # task queue -DEFAULT_STREAM_KEY_PREFIX = os.getenv( - "MEMSCHEDULER_STREAM_KEY_PREFIX", "scheduler:messages:stream:v2.0" +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v2.0" + +logger.info( + "[TASK_SCHEMAS] Module imported. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, default_stream_prefix=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + DEFAULT_STREAM_KEY_PREFIX, ) +def get_stream_key_prefix() -> str: + """Resolve the scheduler stream prefix at runtime. + + This must not be evaluated at import time because some server entrypoints + load `.env` after importing router/handler modules. + """ + resolved = os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX", DEFAULT_STREAM_KEY_PREFIX) + logger.info( + "[TASK_SCHEMAS] Resolved stream prefix at runtime. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, resolved_stream_prefix=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + resolved, + ) + return resolved + + # ============== Running Tasks ============== class RunningTaskItem(BaseModel, DictConversionMixin): """Data class for tracking running tasks in SchedulerDispatcher.""" diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 690c8d123..02cd59e8c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -548,6 +548,9 @@ def stats(self) -> dict[str, int]: running = 0 try: with self._task_lock: + done = {f for f in self._futures if f.done()} + if done: + self._futures -= done inflight = len(self._futures) except Exception: inflight = 0 diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 20dbb63b2..e0baf63ff 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -211,44 +211,48 @@ def _process_memories_with_reader( ) logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) - # Mark merged_from memories as archived when provided in memory metadata - summary_memories = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type != "RawFileMemory" - ] - if mem_reader.graph_db: - for memory in summary_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - "[Scheduler] Archived merged_from memory: %s", - old_id, - ) - except Exception as e: - logger.warning( - "[Scheduler] Failed to archive merged_from memory %s: %s", - old_id, - e, - ) - else: - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in summary_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + # fallback to simple deduplication logic when mem version switch is off + if getattr(mem_reader, "memory_version_switch", "off") != "on": + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type != "RawFileMemory" + ] + if mem_reader.graph_db: + for memory in summary_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_name, + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) cloud_env = is_cloud_env() if cloud_env: @@ -386,10 +390,34 @@ def _process_memories_with_reader( delete_ids = list(dict.fromkeys(delete_ids)) if delete_ids: try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name - ) + if getattr(mem_reader, "memory_version_switch", "off") != "on": + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) + else: + # change to soft-delete for mem versions + flattened_memories = [] + if processed_memories and len(processed_memories) > 0: + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + allowed_types = ["UserMemory", "LongTermMemory"] + text_mem.soft_delete( + delete_ids, + user_name, + [ + mem.id + for mem in flattened_memories + if mem.metadata.memory_type in allowed_types + ], + ) + logger.info( + "Soft delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) except Exception as e: logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) else: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 33d007313..beae965aa 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -13,7 +13,7 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.schemas.task_schemas import get_stream_key_prefix from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -26,7 +26,7 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, maxsize: int = 0, - stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + stream_key_prefix: str | None = None, orchestrator: SchedulerOrchestrator | None = None, status_tracker: TaskStatusTracker | None = None, ): @@ -42,7 +42,7 @@ def __init__( """ super().__init__() - self.stream_key_prefix = stream_key_prefix or "local_queue" + self.stream_key_prefix = stream_key_prefix or get_stream_key_prefix() self.max_internal_message_queue_size = maxsize @@ -56,7 +56,9 @@ def __init__( self._message_handler: Callable[[ScheduleMessageItem], None] | None = None logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" + "SchedulerLocalQueue initialized with max_internal_message_queue_size=%s, stream_prefix=%s", + self.max_internal_message_queue_size, + self.stream_key_prefix, ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1277c5465..561d7931f 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -19,9 +19,9 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS, - DEFAULT_STREAM_KEY_PREFIX, DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC, DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, + get_stream_key_prefix, ) from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -45,10 +45,7 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, - stream_key_prefix: str = os.getenv( - "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", - DEFAULT_STREAM_KEY_PREFIX, - ), + stream_key_prefix: str | None = None, orchestrator: SchedulerOrchestrator | None = None, consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", @@ -68,8 +65,12 @@ def __init__( auto_delete_acked: Whether to automatically delete acknowledged messages from stream """ super().__init__() + resolved_stream_key_prefix = stream_key_prefix or os.getenv( + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", + get_stream_key_prefix(), + ) # Stream configuration - self.stream_key_prefix = stream_key_prefix + self.stream_key_prefix = resolved_stream_key_prefix # Precompile regex for prefix filtering to reduce repeated compilation overhead self.stream_prefix_regex_pattern = re.compile(f"^{re.escape(self.stream_key_prefix)}:") self.consumer_group = consumer_group @@ -93,14 +94,26 @@ def __init__( self.task_broker_flush_bar = 10 self._refill_lock = threading.Lock() self._refill_thread: ContextThread | None = None + self._refill_in_progress = False + self._refill_thread_start: float = 0.0 + self._refill_thread_timeout: float = float( + os.getenv("MEMSCHEDULER_REDIS_REFILL_TIMEOUT_SEC", "30") or 30 + ) # Track empty streams first-seen time to avoid zombie keys self._empty_stream_seen_times: dict[str, float] = {} self._empty_stream_seen_lock = threading.Lock() logger.info( - f"[REDIS_QUEUE] Initialized with stream_prefix='{self.stream_key_prefix}', " - f"consumer_group='{self.consumer_group}', consumer_name='{self.consumer_name}'" + "[REDIS_QUEUE] Initialized with stream_prefix='%s', " + "consumer_group='%s', consumer_name='%s', " + "env_MEMSCHEDULER_STREAM_KEY_PREFIX='%s', " + "env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX='%s'", + self.stream_key_prefix, + self.consumer_group, + self.consumer_name, + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), ) # Auto-initialize Redis connection @@ -110,8 +123,11 @@ def __init__( self.seen_streams = set() - # Task Orchestrator - self.message_pack_cache = deque() + # Task Orchestrator — cap in-memory cache to avoid unbounded growth + self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque( + maxlen=self._cache_max_packs + ) self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator @@ -349,23 +365,51 @@ def task_broker( def _async_refill_cache(self, batch_size: int) -> None: """Background thread to refill message cache without blocking get_messages.""" try: - logger.debug(f"Starting async cache refill with batch_size={batch_size}") + with self._refill_lock: + remaining = self._cache_max_packs - len(self.message_pack_cache) + if remaining <= 0: + logger.debug("Async refill skipped: cache already at capacity") + return + self._refill_in_progress = True + + logger.debug( + f"Starting async cache refill with batch_size={batch_size}, remaining_capacity={remaining}" + ) new_packs = self.task_broker(consume_batch_size=batch_size) - logger.debug(f"task_broker returned {len(new_packs)} packs") + with self._refill_lock: + added = 0 for pack in new_packs: - if pack: # Only add non-empty packs + if pack: self.message_pack_cache.append(pack) - logger.debug(f"Added pack with {len(pack)} messages to cache") - logger.debug(f"Cache refill complete, cache size now: {len(self.message_pack_cache)}") + added += 1 + if added >= remaining: + break + logger.debug( + f"Cache refill complete, added={added}, cache size now: {len(self.message_pack_cache)}" + ) except Exception as e: logger.warning(f"Async cache refill failed: {e}", exc_info=True) + finally: + with self._refill_lock: + self._refill_in_progress = False + + def _is_refill_thread_available(self) -> bool: + """Check whether a new refill thread can be started.""" + if self._refill_thread is None or not self._refill_thread.is_alive(): + return True + if (time.time() - self._refill_thread_start) > self._refill_thread_timeout: + logger.warning( + f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale" + ) + return True + return False def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: if self.message_pack_cache: - # Trigger async refill if below threshold (non-blocking) - if len(self.message_pack_cache) < self.task_broker_flush_bar and ( - self._refill_thread is None or not self._refill_thread.is_alive() + if ( + len(self.message_pack_cache) < self.task_broker_flush_bar + and self._is_refill_thread_available() ): logger.debug( f"Triggering async cache refill: cache size {len(self.message_pack_cache)} < {self.task_broker_flush_bar}" @@ -373,14 +417,26 @@ def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: self._refill_thread = ContextThread( target=self._async_refill_cache, args=(batch_size,), name="redis-cache-refill" ) + self._refill_thread_start = time.time() self._refill_thread.start() else: logger.debug(f"The size of message_pack_cache is {len(self.message_pack_cache)}") else: - new_packs = self.task_broker(consume_batch_size=batch_size) - for pack in new_packs: - if pack: # Only add non-empty packs - self.message_pack_cache.append(pack) + should_fetch = False + with self._refill_lock: + if not self.message_pack_cache and not self._refill_in_progress: + self._refill_in_progress = True + should_fetch = True + if should_fetch: + try: + new_packs = self.task_broker(consume_batch_size=batch_size) + with self._refill_lock: + for pack in new_packs: + if pack: + self.message_pack_cache.append(pack) + finally: + with self._refill_lock: + self._refill_in_progress = False if len(self.message_pack_cache) == 0: return [] else: @@ -443,12 +499,17 @@ def put( with self._stream_keys_lock: if stream_key not in self.seen_streams: self.seen_streams.add(stream_key) - self._ensure_consumer_group(stream_key=stream_key) + need_create_group = True + else: + need_create_group = False if stream_key not in self._stream_keys_cache: self._stream_keys_cache.append(stream_key) self._stream_keys_last_refresh = time.time() + if need_create_group: + self._ensure_consumer_group(stream_key=stream_key) + message.stream_key = stream_key # Convert message to dictionary for Redis storage @@ -1054,14 +1115,9 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: with self._stream_keys_lock: cache_snapshot = list(self._stream_keys_cache) - # Validate that cached keys conform to the expected prefix - escaped_prefix = re.escape(effective_prefix) - regex_pattern = f"^{escaped_prefix}:" - for key in cache_snapshot: - if not re.match(regex_pattern, key): - logger.error( - f"[REDIS_QUEUE] Cached stream key '{key}' does not match prefix '{effective_prefix}:'" - ) + if effective_prefix != self.stream_key_prefix: + pattern = re.compile(f"^{re.escape(effective_prefix)}:") + cache_snapshot = [k for k in cache_snapshot if pattern.match(k)] return cache_snapshot @@ -1211,7 +1267,7 @@ def __del__(self): @property def unfinished_tasks(self) -> int: - return self.qsize() + return self.size() def _scan_candidate_stream_keys( self, @@ -1396,6 +1452,23 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) + + active_set = set(active_stream_keys) + stale = self.seen_streams - active_set + if stale: + self.seen_streams -= stale + logger.debug(f"Pruned {len(stale)} stale entries from seen_streams") + + candidate_set = set(candidate_keys) + with self._empty_stream_seen_lock: + orphaned = [k for k in self._empty_stream_seen_times if k not in candidate_set] + for k in orphaned: + del self._empty_stream_seen_times[k] + if orphaned: + logger.debug( + f"Pruned {len(orphaned)} orphaned entries from _empty_stream_seen_times" + ) + logger.debug( f"Refreshed stream keys cache: {cache_count} active keys, " f"{deleted_count} deleted, {len(candidate_keys)} candidates examined." diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 1f2e81bef..6bcf0023c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -48,6 +48,12 @@ def __init__( self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize) self.disabled_handlers = disabled_handlers + logger.info( + "[SCHEDULE_TASK_QUEUE] Initialized queue wrapper. use_redis_queue=%s, queue_type=%s, stream_prefix=%s", + self.use_redis_queue, + type(self.memos_message_queue).__name__, + getattr(self.memos_message_queue, "stream_key_prefix", None), + ) def set_status_tracker(self, status_tracker: TaskStatusTracker) -> None: """ diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index a9b2c43a4..f34cf1efd 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel): memory: str | None = Field( default_factory=lambda: "", description="The content of the archived version of the memory." ) - update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field( default="unrelated", - description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).", ) archived_memory_id: str | None = Field( default=None, @@ -106,15 +106,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) @@ -146,6 +146,10 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Arbitrary key-value pairs for additional metadata.", ) + internal_info: dict | None = Field( + default=None, + description="Internal algorithm metadata reserved for system use.", + ) model_config = ConfigDict(extra="allow") diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 8c896f538..426a2359b 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -354,6 +354,14 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem """Get a memory by its ID.""" result = self.graph_store.get_node(memory_id, user_name=user_name) if result is None: + logger.warning( + "[TreeTextMemory.get] Memory not found. memory_id=%s, lookup_user_name=%s, graph_store=%s, db_name=%s, config_user_name=%s", + memory_id, + user_name, + type(self.graph_store).__name__, + getattr(self.graph_store, "db_name", None), + getattr(getattr(self.graph_store, "config", None), "user_name", None), + ) raise ValueError(f"Memory with ID {memory_id} not found") metadata_dict = result.get("metadata", {}) return TextualMemoryItem( @@ -609,3 +617,33 @@ def add_graph_edges( future.result() except Exception as e: logger.exception("Add edge error: ", exc_info=e) + + def soft_delete( + self, + memory_ids: list[str], + user_name: str, + evolve_to_ids: list[str] | None = None, + ) -> None: + # for ruff check... + if not evolve_to_ids: + update_fields = {"status": "deleted"} + else: + update_fields = {"status": "deleted", "evolve_to": evolve_to_ids} + + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mid in memory_ids: + futures.append( + executor.submit( + self.graph_store.update_node, + id=mid, + fields=update_fields, + user_name=user_name, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py deleted file mode 100644 index 98094877c..000000000 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ /dev/null @@ -1,168 +0,0 @@ -import logging - -from typing import Literal - -from memos.context.context import ContextThreadPoolExecutor -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.types import NLIResult -from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem - - -logger = logging.getLogger(__name__) - -CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" -DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" - - -def _append_related_content( - new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] -) -> None: - """ - Append duplicate and conflict memory contents to the new item's memory text, - truncated to avoid excessive length. - """ - max_per_item_len = 200 - max_section_len = 1000 - - def _format_section(title: str, items: list[str]) -> str: - if not items: - return "" - - section_content = "" - for mem in items: - # Truncate individual item - snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem - # Check total section length - if len(section_content) + len(snippet) + 5 > max_section_len: - section_content += "\n- ... (more items truncated)" - break - section_content += f"\n- {snippet}" - - return f"\n\n{title}:{section_content}" - - append_text = "" - append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) - append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) - - if append_text: - new_item.memory += append_text - - -def _detach_related_content(new_item: TextualMemoryItem) -> None: - """ - Detach duplicate and conflict memory contents from the new item's memory text. - """ - markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] - - cut_index = -1 - for marker in markers: - idx = new_item.memory.find(marker) - if idx != -1 and (cut_index == -1 or idx < cut_index): - cut_index = idx - - if cut_index != -1: - new_item.memory = new_item.memory[:cut_index] - - return - - -class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: - """ - Initialize the MemoryHistoryManager. - - Args: - nli_client: NLIClient for conflict/duplicate detection. - graph_db: GraphDB instance for marking operations during history management. - """ - self.nli_client = nli_client - self.graph_db = graph_db - - def resolve_history_via_nli( - self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: - """ - Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, - and attach them as history to the new fast item. - - Args: - new_item: The new memory item being added. - related_items: Existing memory items that might be related. - - Returns: - List of duplicate or conflicting memory items judged by the NLI service. - """ - if not related_items: - return [] - - # 1. Call NLI - nli_results = self.nli_client.compare_one_to_many( - new_item.memory, [r.memory for r in related_items] - ) - - # 2. Process results and attach to history - duplicate_memories = [] - conflict_memories = [] - - for r_item, nli_res in zip(related_items, nli_results, strict=False): - if nli_res == NLIResult.DUPLICATE: - update_type = "duplicate" - duplicate_memories.append(r_item.memory) - elif nli_res == NLIResult.CONTRADICTION: - update_type = "conflict" - conflict_memories.append(r_item.memory) - else: - update_type = "unrelated" - - # Safely get created_at, fallback to updated_at - created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at - - archived = ArchivedTextualMemory( - version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, - memory=r_item.memory, - update_type=update_type, - archived_memory_id=r_item.id, - created_at=created_at, - ) - new_item.metadata.history.append(archived) - logger.info( - f"[Chunker: MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" - ) - - # 3. Concat duplicate/conflict memories to new_item.memory - # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. - _append_related_content(new_item, duplicate_memories, conflict_memories) - - return duplicate_memories + conflict_memories - - def mark_memory_status( - self, - memory_items: list[TextualMemoryItem], - status: Literal["activated", "resolving", "archived", "deleted"], - user_name: str | None = None, - ) -> None: - """ - Support status marking operations during history management. Common usages are: - 1. Mark conflict/duplicate old memories' status as "resolving", - to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. - """ - # Execute the actual marking operation - in db. - with ContextThreadPoolExecutor() as executor: - futures = [] - for mem in memory_items: - futures.append( - executor.submit( - self.graph_db.update_node, - id=mem.id, - fields={"status": status}, - user_name=user_name, - ) - ) - - # Wait for all tasks to complete and raise any exceptions - for future in futures: - future.result() - return diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index ecd58f309..1fe6a4811 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -191,10 +191,12 @@ def _add_memories_batch( ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() + metadata_dict["working_binding"] = working_id # Add working_binding for fast mode tags = metadata_dict.get("tags") or [] if "mode:fast" in tags: + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_id}] direct built from raw inputs" metadata_dict["background"] = ( @@ -234,6 +236,8 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: exc_info=e, ) + # TODO: working id is same with item.id, need to fix, currently stop adding WorkingMemories here. + # here used to be: _submit_batches(working_nodes, "WorkingMemory") _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: @@ -319,6 +323,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] + # TODO: working id is same with item.id, need to fix working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: @@ -395,8 +400,11 @@ def _add_to_graph_memory( node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding: + metadata_dict["working_binding"] = working_binding tags = metadata_dict.get("tags") or [] if working_binding and ("mode:fast" in tags): + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py deleted file mode 100644 index cb77d2243..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ /dev/null @@ -1,264 +0,0 @@ -import concurrent.futures -import re - -from typing import Any - -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_reader.read_multi_modal.utils import detect_lang -from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer - - -logger = get_logger(__name__) - - -class PreUpdateRetriever: - def __init__(self, graph_db, embedder): - """ - The PreUpdateRetriever is designed for the /add phase . - It serves to recall potentially duplicate/conflict memories against the new content that's being added. - - Args: - graph_db: The graph database instance (Neo4j, PolarDB, etc.) - embedder: The embedder instance for vector search - """ - self.graph_db = graph_db - self.embedder = embedder - # Use existing tokenizer for keyword extraction - self.tokenizer = FastTokenizer(use_jieba=True, use_stopwords=True) - - def _adjust_perspective(self, text: str, role: str, lang: str) -> str: - """ - For better search result, we adjust the perspective - from 1st person to 3rd person based on role and language. - "I" -> "User" (if role is user) - "I" -> "Assistant" (if role is assistant) - """ - if not role: - return text - - role = role.lower() - replacements = [] - - # Determine replacements based on language and role - if lang == "zh": - if role == "user": - replacements = [("我", "用户")] - elif role == "assistant": - replacements = [("我", "助手")] - else: # default to en - if role == "user": - replacements = [ - (r"\bI\b", "User"), - (r"\bme\b", "User"), - (r"\bmy\b", "User's"), - (r"\bmine\b", "User's"), - (r"\bmyself\b", "User himself"), - ] - elif role == "assistant": - replacements = [ - (r"\bI\b", "Assistant"), - (r"\bme\b", "Assistant"), - (r"\bmy\b", "Assistant's"), - (r"\bmine\b", "Assistant's"), - (r"\bmyself\b", "Assistant himself"), - ] - - adjusted_text = text - for pattern, repl in replacements: - if lang == "zh": - adjusted_text = adjusted_text.replace(pattern, repl) - else: - adjusted_text = re.sub(pattern, repl, adjusted_text, flags=re.IGNORECASE) - - return adjusted_text - - def _preprocess_query(self, item: TextualMemoryItem) -> str: - """ - Preprocess the query item: - 1. Extract language and role from metadata/sources - 2. Adjust perspective (I -> User/Assistant) based on role/lang - """ - raw_text = item.memory or "" - if not raw_text.strip(): - return "" - - # Extract lang/role - lang = None - role = None - sources = item.metadata.sources - - if sources: - source_list = sources if isinstance(sources, list) else [sources] - for source in source_list: - if hasattr(source, "lang") and source.lang: - lang = source.lang - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - - if hasattr(source, "role") and source.role: - role = source.role - elif isinstance(source, dict) and source.get("role"): - role = source.get("role") - - if lang and role: - break - - if lang is None: - lang = detect_lang(raw_text) - - # Adjust perspective - return self._adjust_perspective(raw_text, role, lang) - - def _get_full_memories( - self, candidate_ids: list[str], user_name: str - ) -> list[TextualMemoryItem]: - """ - Retrieve full memories for given candidate ids. - """ - full_recalled_memories = self.graph_db.get_nodes(candidate_ids, user_name=user_name) - return [TextualMemoryItem.from_dict(item) for item in full_recalled_memories] - - def vector_search( - self, - query_text: str, - query_embedding: list[float] | None, - user_name: str, - top_k: int, - search_filter: dict[str, Any] | None = None, - threshold: float = 0.5, - ) -> list[dict]: - try: - # Use pre-computed embedding if available (matches raw/clean query) - # Otherwise embed the switched query for better semantic match - q_embed = query_embedding if query_embedding else self.embedder.embed([query_text])[0] - - # Assuming graph_db.search_by_embedding returns list of dicts or items - results = self.graph_db.search_by_embedding( - vector=q_embed, - top_k=top_k, - status=None, - threshold=threshold, - user_name=user_name, - filter=search_filter, - ) - return results - except Exception as e: - logger.error(f"[PreUpdateRetriever] Vector search failed: {e}") - return [] - - def keyword_search( - self, - query_text: str, - user_name: str, - top_k: int, - search_filter: dict[str, Any] | None = None, - ) -> list[dict]: - try: - # 1. Tokenize using existing tokenizer - keywords = self.tokenizer.tokenize_mixed(query_text) - if not keywords: - return [] - - results = [] - - # 2. Try search_by_keywords_tfidf (PolarDB specific) - if hasattr(self.graph_db, "search_by_keywords_tfidf"): - try: - results = self.graph_db.search_by_keywords_tfidf( - query_words=keywords, user_name=user_name, filter=search_filter - ) - except Exception as e: - logger.warning(f"[PreUpdateRetriever] search_by_keywords_tfidf failed: {e}") - - # 3. Fallback to search_by_fulltext - if not results and hasattr(self.graph_db, "search_by_fulltext"): - try: - results = self.graph_db.search_by_fulltext( - query_words=keywords, top_k=top_k, user_name=user_name, filter=search_filter - ) - except Exception as e: - logger.warning(f"[PreUpdateRetriever] search_by_fulltext failed: {e}") - - return results[:top_k] - - except Exception as e: - logger.error(f"[PreUpdateRetriever] Keyword search failed: {e}") - return [] - - def retrieve( - self, item: TextualMemoryItem, user_name: str, top_k: int = 10, sim_threshold: float = 0.5 - ) -> list[TextualMemoryItem]: - """ - Recall related memories for a TextualMemoryItem using hybrid search (Vector + Keyword). - Might actually return top_k ~ 2top_k items. - Designed for low latency. - - Args: - item: The memory item to find related memories for - user_name: User identifier for scoping search - top_k: Max number of results to return - sim_threshold: minimal similarity threshold for vector search - - Returns: - List of TextualMemoryItem - """ - # 1. Preprocess - switched_query = self._preprocess_query(item) - - # 2. Recall - futures = [] - common_filter = { - "status": {"in": ["activated", "resolving"]}, - "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, - } - - with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: - # Task A: Vector Search (Semantic) - query_embedding = ( - item.metadata.embedding if hasattr(item.metadata, "embedding") else None - ) - futures.append( - executor.submit( - self.vector_search, - switched_query, - query_embedding, - user_name, - top_k, - common_filter, - sim_threshold, - ) - ) - - # Task B: Keyword Search - futures.append( - executor.submit( - self.keyword_search, switched_query, user_name, top_k, common_filter - ) - ) - - # 3. Collect Results - retrieved_ids = set() # for deduplicating ids - for future in concurrent.futures.as_completed(futures): - try: - res = future.result() - if not res: - continue - - for r in res: - retrieved_ids.add(r["id"]) - - except Exception as e: - logger.error(f"[PreUpdateRetriever] Search future task failed: {e}") - - retrieved_ids = list(retrieved_ids) - - if not retrieved_ids: - return [] - - # 4. Retrieve full memories to from just ids - # TODO: We should modify the db functions to support returning arbitrary fields, instead of search twice. - final_memories = self._get_full_memories(retrieved_ids, user_name) - - return final_memories diff --git a/src/memos/multi_mem_cube/composite_cube.py b/src/memos/multi_mem_cube/composite_cube.py index 0d2d460e9..11c5245ba 100644 --- a/src/memos/multi_mem_cube/composite_cube.py +++ b/src/memos/multi_mem_cube/composite_cube.py @@ -6,6 +6,7 @@ from memos.context.context import ContextThreadPoolExecutor from memos.multi_mem_cube.views import MemCubeView +from memos.utils import timed_stage if TYPE_CHECKING: @@ -27,13 +28,18 @@ class CompositeCubeView(MemCubeView): def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]: all_results: list[dict[str, Any]] = [] - - # fast mode: for each cube view, add memories - # maybe add more strategies in add_req.async_mode - for view in self.cube_views: - self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}") - results = view.add_memories(add_req) - all_results.extend(results) + cube_count = len(self.cube_views) + + with timed_stage("add", "multi_cube", cube_count=cube_count): + for idx, view in enumerate(self.cube_views): + self.logger.info( + "[CompositeCubeView] fan-out add to cube=%s (%d/%d)", + view.cube_id, + idx + 1, + cube_count, + ) + results = view.add_memories(add_req) + all_results.extend(results) return all_results diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 355cb8cee..22d3a253c 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -31,7 +31,7 @@ SearchMode, UserContext, ) -from memos.utils import timed +from memos.utils import timed, timed_stage logger = get_logger(__name__) @@ -692,25 +692,25 @@ def _process_text_mem( extract_mode, add_req.mode, ) - init_time = time.time() - # Extract memories - memories_local = self.mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - **(add_req.info or {}), - "custom_tags": add_req.custom_tags, - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode=extract_mode, - user_name=user_context.mem_cube_id, - chat_history=add_req.chat_history, - user_context=user_context, - ) - self.logger.info( - f"Time for get_memory in extract mode {extract_mode}: {time.time() - init_time}" - ) + process_start = time.perf_counter() + + # Stage 1+2: parse + embedding (logged inside get_memory via timed_stage) + with timed_stage("add", "get_memory", cube_id=self.cube_id) as ts_gm: + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + **(add_req.info or {}), + "custom_tags": add_req.custom_tags, + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode=extract_mode, + user_name=user_context.mem_cube_id, + chat_history=add_req.chat_history, + user_context=user_context, + ) + get_memory_ms = ts_gm.duration_ms flattened_local = [mm for m in memories_local for mm in m] # Explicitly set source_doc_id to metadata if present in info @@ -719,73 +719,112 @@ def _process_text_mem( for memory in flattened_local: memory.metadata.source_doc_id = source_doc_id - self.logger.info(f"Memory extraction completed for user {add_req.user_id}") - # Add memories to text_mem mem_group = [ memory for memory in flattened_local if memory.metadata.memory_type != "RawFileMemory" ] - mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( - mem_group, - user_name=user_context.mem_cube_id, - ) - - self.logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - # Add raw file nodes and edges - if self.mem_reader.save_rawfile and extract_mode == "fine": - raw_file_mem_group = [ - memory - for memory in flattened_local - if memory.metadata.memory_type == "RawFileMemory" - ] - self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges( - raw_file_mem_group, - mem_ids_local, - user_id=add_req.user_id, + # Stage 3: write_db + with timed_stage("add", "write_db", cube_id=self.cube_id) as ts_db: + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + mem_group, user_name=user_context.mem_cube_id, ) - # Schedule async/sync tasks: async process raw chunk memory | sync only send messages - self._schedule_memory_tasks( - add_req=add_req, - user_context=user_context, - mem_ids=mem_ids_local, - sync_mode=sync_mode, - ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) - # Mark merged_from memories as archived when provided in add_req.info - if sync_mode == "sync" and extract_mode == "fine": - for memory in flattened_local: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - if self.mem_reader and self.mem_reader.graph_db: - for old_id in old_ids: - try: - self.mem_reader.graph_db.update_node( - str(old_id), - {"status": "archived"}, - user_name=user_context.mem_cube_id, - ) - self.logger.info( - f"[SingleCubeView] Archived merged_from memory: {old_id}" - ) - except Exception as e: - self.logger.warning( - f"[SingleCubeView] Failed to archive merged_from memory {old_id}: {e}" - ) - else: - self.logger.warning( - "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." + # Add raw file nodes and edges + if self.mem_reader.save_rawfile and extract_mode == "fine": + raw_file_mem_group = [ + memory + for memory in flattened_local + if memory.metadata.memory_type == "RawFileMemory" + ] + self.naive_mem_cube.text_mem.add_rawfile_nodes_n_edges( + raw_file_mem_group, + mem_ids_local, + user_id=add_req.user_id, + user_name=user_context.mem_cube_id, + ) + ts_db.set(memory_count=len(mem_ids_local)) + write_db_ms = ts_db.duration_ms + + # Stage 4: schedule + with timed_stage("add", "schedule", cube_id=self.cube_id) as ts_sched: + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + # Mark merged_from memories as archived when provided in add_req.info + if ( + sync_mode == "sync" + and extract_mode == "fine" + and ( + not hasattr(self.mem_reader, "memory_version_switch") + or self.mem_reader.memory_version_switch != "on" + ) + ): + for memory in flattened_local: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] ) + if self.mem_reader and self.mem_reader.graph_db: + for old_id in old_ids: + try: + self.mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"[SingleCubeView] Archived merged_from memory: {old_id}" + ) + except Exception as e: + self.logger.warning( + f"[SingleCubeView] Failed to archive merged_from memory {old_id}: {e}" + ) + else: + self.logger.warning( + "[SingleCubeView] merged_from provided but graph_db is unavailable; skip archiving." + ) + schedule_ms = ts_sched.duration_ms + + # Summary rollup — total_ms is the outer wall-clock, not a new stage + total_ms = int((time.perf_counter() - process_start) * 1000) + input_msg_count = len(add_req.messages) if add_req.messages else 0 + memory_count = len(mem_ids_local) + est_input_tokens = ( + sum( + len(str(m.get("content", ""))) if isinstance(m, dict) else len(str(m)) + for m in (add_req.messages or []) + ) + // 4 + ) + timed_stage.emit_now( + "add", + "summary", + cube_id=self.cube_id, + sync_mode=sync_mode, + extract_mode=extract_mode, + input_msg_count=input_msg_count, + est_input_tokens=est_input_tokens, + memory_count=memory_count, + get_memory_ms=get_memory_ms, + write_db_ms=write_db_ms, + schedule_ms=schedule_ms, + total_ms=total_ms, + per_item_ms=total_ms // max(memory_count, 1), + ) # Format results uniformly text_memories = [ diff --git a/src/memos/plugins/__init__.py b/src/memos/plugins/__init__.py new file mode 100644 index 000000000..0a0f8cde3 --- /dev/null +++ b/src/memos/plugins/__init__.py @@ -0,0 +1,20 @@ +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec +from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook +from memos.plugins.manager import PluginManager, plugin_manager + + +__all__ = [ + "H", + "HookSpec", + "MemOSPlugin", + "PluginManager", + "all_hook_specs", + "define_hook", + "get_hook_spec", + "hookable", + "plugin_manager", + "register_hook", + "register_hooks", + "trigger_hook", +] diff --git a/src/memos/plugins/base.py b/src/memos/plugins/base.py new file mode 100644 index 000000000..e486cfc72 --- /dev/null +++ b/src/memos/plugins/base.py @@ -0,0 +1,75 @@ +"""MemOS plugin base class — all plugins must inherit from this class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastapi import FastAPI + from starlette.middleware.base import BaseHTTPMiddleware + + +class MemOSPlugin: + """MemOS plugin base class. + + Provides three unified registration methods. Plugin developers need only + inherit from this class and register capabilities via self.register_* + in init_app. + """ + + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + + _app: FastAPI | None = None + + # ------------------------------------------------------------------ # + # Registration methods — called by plugins in init_app + # ------------------------------------------------------------------ # + + def register_router(self, router, **kwargs) -> None: + """Register a router.""" + self._app.include_router(router, **kwargs) + + def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None: + """Register middleware.""" + self._app.add_middleware(middleware_cls, **kwargs) + + def register_hook(self, name: str, callback: Callable) -> None: + """Register a single Hook callback.""" + from memos.plugins.hooks import register_hook + + register_hook(name, callback) + + def register_hooks(self, names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple Hook points.""" + from memos.plugins.hooks import register_hooks + + register_hooks(names, callback) + + # ------------------------------------------------------------------ # + # Internal methods — called by PluginManager, plugin developers need not care + # ------------------------------------------------------------------ # + + def _bind_app(self, app: FastAPI) -> None: + """Bind FastAPI instance so that register_* methods are available.""" + self._app = app + + # ------------------------------------------------------------------ # + # Lifecycle methods — override in subclasses + # ------------------------------------------------------------------ # + + def on_load(self) -> None: + """Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config.""" + + def init_app(self) -> None: + """Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here.""" + + def init_components(self, context: dict) -> None: + """Called during server bootstrap to contribute runtime components.""" + + def on_shutdown(self) -> None: + """Called when the service shuts down. Used for resource cleanup.""" diff --git a/src/memos/plugins/component_bootstrap.py b/src/memos/plugins/component_bootstrap.py new file mode 100644 index 000000000..151e37215 --- /dev/null +++ b/src/memos/plugins/component_bootstrap.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + + +def build_plugin_context( + *, + graph_db: Any, + embedder: Any, + default_cube_config: Any, + nli_client_config: dict[str, Any], + mem_reader_config: Any, + reranker_config: Any, + feedback_reranker_config: Any, + internet_retriever_config: Any, +) -> dict[str, Any]: + return { + "shared": { + "graph_db": graph_db, + "embedder": embedder, + }, + "configs": { + "default_cube_config": default_cube_config, + "nli_client_config": nli_client_config, + "mem_reader_config": mem_reader_config, + "reranker_config": reranker_config, + "feedback_reranker_config": feedback_reranker_config, + "internet_retriever_config": internet_retriever_config, + }, + } diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py new file mode 100644 index 000000000..030d5292f --- /dev/null +++ b/src/memos/plugins/hook_defs.py @@ -0,0 +1,132 @@ +"""Hook declaration registry — single source of truth for CE repo Hook points. + +The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook. +Hooks triggered by custom trigger_hook must be explicitly declared in this file. + +Plugin-owned Hooks should be declared within each plugin package, not in this file. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +_specs: dict[str, HookSpec] = {} + + +@dataclass(frozen=True) +class HookSpec: + """Hook spec definition.""" + + name: str + description: str + params: list[str] + pipe_key: str | None = None + + +def define_hook( + name: str, + *, + description: str, + params: list[str], + pipe_key: str | None = None, +) -> None: + """Declare a Hook point. Skips if already exists (idempotent).""" + if name in _specs: + return + _specs[name] = HookSpec( + name=name, + description=description, + params=params, + pipe_key=pipe_key, + ) + logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key) + + +def get_hook_spec(name: str) -> HookSpec | None: + return _specs.get(name) + + +def all_hook_specs() -> dict[str, HookSpec]: + """Return all declared Hooks (including @hookable auto-declared + plugin-declared).""" + return dict(_specs) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE Hook name constants +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +class H: + """CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package.""" + + # @hookable("add") — AddHandler.handle_add_memories + ADD_BEFORE = "add.before" + ADD_AFTER = "add.after" + + # @hookable("search") — SearchHandler.handle_search_memories + SEARCH_BEFORE = "search.before" + SEARCH_AFTER = "search.after" + + # Custom Hook (manually triggered via trigger_hook) + ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + + # mem_reader — generic extension point before LLM extraction + MEM_READER_PRE_EXTRACT = "mem_reader.pre_extract" + + # memory version — single-provider business hooks + MEMORY_VERSION_PREPARE_UPDATES = "memory_version.prepare_updates" + MEMORY_VERSION_APPLY_UPDATES = "memory_version.apply_updates" + MEMORY_VERSION_APPLY_FEEDBACK_UPDATE = "memory_version.apply_feedback_update" + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE custom Hook declarations (@hookable-generated ones need not be declared here) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +define_hook( + H.ADD_MEMORIES_POST_PROCESS, + description="Post-process result after add_memories returns, before constructing Response", + params=["request", "result"], + pipe_key="result", +) + +define_hook( + H.MEM_READER_PRE_EXTRACT, + description="Customize prompt before mem_reader LLM extraction", + params=["prompt", "prompt_type", "mem_str", "lang", "sources"], + pipe_key="prompt", +) + +define_hook( + H.MEMORY_VERSION_PREPARE_UPDATES, + description=( + "Prepare memory-version candidates and decide whether extraction should continue " + "through the version pipeline" + ), + params=["item", "user_name", "judge_llm"], +) + +define_hook( + H.MEMORY_VERSION_APPLY_UPDATES, + description="Apply memory-version updates during mem_reader extraction", + params=[ + "item", + "user_name", + "version_llm", + "merge_llm", + "custom_tags", + "custom_tags_prompt_template", + "timeout_sec", + ], +) + +define_hook( + H.MEMORY_VERSION_APPLY_FEEDBACK_UPDATE, + description="Apply memory-version update semantics during feedback update", + params=["old_item", "new_item", "user_name"], +) diff --git a/src/memos/plugins/hooks.py b/src/memos/plugins/hooks.py new file mode 100644 index 000000000..d5b64ce82 --- /dev/null +++ b/src/memos/plugins/hooks.py @@ -0,0 +1,150 @@ +"""Hook runtime — registration, triggering, and @hookable decorator.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections import defaultdict +from functools import wraps +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +_hooks: dict[str, list[Callable]] = defaultdict(list) + + +def register_hook(name: str, callback: Callable) -> None: + """Register a hook callback. Undeclared hook names will log a warning.""" + from memos.plugins.hook_defs import get_hook_spec + + if get_hook_spec(name) is None: + logger.warning( + "Registering callback for undeclared hook: %s (callback=%s)", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + _hooks[name].append(callback) + logger.debug( + "Hook registered: %s -> %s", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + + +def register_hooks(names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple hook points.""" + for name in names: + register_hook(name, callback) + + +def trigger_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook, invoking all registered callbacks in order. + + - Zero overhead when no callbacks are registered + - Undeclared hook names will log a warning and be skipped + - pipe_key is auto-fetched from HookSpec, supports piped return value passing + """ + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + logger.warning("Undeclared hook triggered: %s — ignored", name) + return None + + pipe_key = spec.pipe_key + + for cb in _hooks.get(name, []): + try: + rv = cb(**kwargs) + if pipe_key is not None and rv is not None: + kwargs[pipe_key] = rv + except Exception: + logger.exception( + "Hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + + return kwargs.get(pipe_key) if pipe_key else None + + +def trigger_single_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook that must be implemented by exactly one callback.""" + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + raise RuntimeError(f"Undeclared hook triggered: {name}") + + callbacks = _hooks.get(name, []) + if not callbacks: + raise RuntimeError(f"No plugin registered required hook: {name}") + if len(callbacks) > 1: + raise RuntimeError(f"Multiple plugins registered single-provider hook: {name}") + + cb = callbacks[0] + try: + return cb(**kwargs) + except Exception: + logger.exception( + "Single hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + raise + + +def hookable(name: str): + """Decorator: automatically triggers name.before / name.after hook before and after the method. + + Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py. + Supports piped return values: before can modify request, after can modify result. + Compatible with both sync and async methods. + """ + from memos.plugins.hook_defs import define_hook + + define_hook( + f"{name}.before", + description=f"Before {name} executes; can modify request", + params=["request"], + pipe_key="request", + ) + define_hook( + f"{name}.after", + description=f"After {name} executes; can modify result", + params=["request", "result"], + pipe_key="result", + ) + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = await func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return sync_wrapper + + return decorator diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py new file mode 100644 index 000000000..5f0397dc1 --- /dev/null +++ b/src/memos/plugins/manager.py @@ -0,0 +1,90 @@ +"""Plugin manager — discover, load, and manage MemOS plugins.""" + +from __future__ import annotations + +import importlib.metadata +import logging + +from typing import TYPE_CHECKING + +from memos.plugins.base import MemOSPlugin + + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "memos.plugins" + + +class PluginManager: + """Discover, load, and manage MemOS plugins.""" + + def __init__(self): + self._plugins: dict[str, MemOSPlugin] = {} + self._discovered = False + + @property + def plugins(self) -> dict[str, MemOSPlugin]: + return dict(self._plugins) + + def discover(self) -> None: + """Discover and load all installed plugins via entry_points.""" + if self._discovered: + return + + try: + eps = importlib.metadata.entry_points() + if hasattr(eps, "select"): + plugin_eps = eps.select(group=ENTRY_POINT_GROUP) + else: + plugin_eps = eps.get(ENTRY_POINT_GROUP, []) + except Exception: + logger.exception("Failed to query entry_points") + return + + for ep in plugin_eps: + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not isinstance(plugin, MemOSPlugin): + logger.warning("Plugin %s does not extend MemOSPlugin, skipped", ep.name) + continue + plugin.on_load() + self._plugins[plugin.name] = plugin + logger.info("Plugin discovered: %s v%s", plugin.name, plugin.version) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + + self._discovered = True + + def init_components(self, context: dict) -> None: + """Initialize runtime components contributed by loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin.init_components(context) + logger.info("Plugin components initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin components: %s", plugin.name) + + def init_app(self, app: FastAPI) -> None: + """Bind app and initialize all loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin._bind_app(app) + plugin.init_app() + logger.info("Plugin initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin: %s", plugin.name) + + def shutdown(self) -> None: + """Shut down all plugins and release resources.""" + for plugin in self._plugins.values(): + try: + plugin.on_shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) + + +plugin_manager = PluginManager() diff --git a/src/memos/utils.py b/src/memos/utils.py index fd6d4eaf9..f7111f8ad 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -2,12 +2,138 @@ import time import traceback +from contextlib import ContextDecorator +from typing import Any + from memos.log import get_logger logger = get_logger(__name__) +class timed_stage(ContextDecorator): # noqa: N801 + """Unified timing helper for business-stage instrumentation. + + Works as **both** a context-manager and a decorator - one tool for all + timing needs. + + Context-manager (when the stage is a *code block* inside a function):: + + with timed_stage("add", "parse", cube_id=cube_id) as ts: + items = self._parse(...) + ts.set(msg_count=10, window_count=len(windows)) + + Decorator (when the stage is *an entire function*):: + + @timed_stage("add", "write_db") + def _write_to_db(self, ...): + ... + + Decorator with dynamic fields extracted from arguments:: + + @timed_stage("search", "recall", + extra=lambda self, req, **kw: {"cube_id": self.cube_id}) + def _vector_recall(self, req, ...): + ... + + Output format (SLS-friendly, one-line structured log):: + + [STAGE] biz=add stage=parse cube_id=xxx duration_ms=150 msg_count=10 + """ + + def __init__( + self, + biz: str = "", + stage: str = "", + *, + extra: dict[str, Any] | None = None, + level: str = "info", + **fields: Any, + ): + self._biz = biz + self._stage = stage + self._extra_factory = extra if callable(extra) else None + self._static_extra = extra if isinstance(extra, dict) else None + self._level = level + self._fields: dict[str, Any] = dict(fields) + self._start: float = 0.0 + self.duration_ms: int = 0 + + # -- context-manager protocol ------------------------------------------ + + def __enter__(self): + self._start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.duration_ms = int((time.perf_counter() - self._start) * 1000) + self._emit(self.duration_ms, exc_type) + return False + + # -- decorator protocol (extends ContextDecorator) --------------------- + + def __call__(self, func=None): + if func is None: + return super().__call__(func) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if self._extra_factory is not None: + try: + dynamic = self._extra_factory(*args, **kwargs) + if dynamic: + self._fields.update(dynamic) + except Exception as e: + logger.warning("[STAGE] extra callback error: %r", e) + + stage_name = self._stage or func.__name__ + self._stage = stage_name + self._start = time.perf_counter() + try: + return func(*args, **kwargs) + finally: + self.duration_ms = int((time.perf_counter() - self._start) * 1000) + self._emit(self.duration_ms) + + return wrapper + + # -- public API -------------------------------------------------------- + + def set(self, **fields: Any): + """Add / overwrite fields after execution (e.g. counts only known after the block runs).""" + self._fields.update(fields) + + @staticmethod + def emit_now(biz: str, stage: str, **fields: Any): + """Fire a one-shot structured log without timing (e.g. summary rollups).""" + parts = [f"biz={biz}", f"stage={stage}"] + for k, v in fields.items(): + parts.append(f"{k}={v}") + logger.info("[STAGE] " + " ".join(parts)) + + # -- internals --------------------------------------------------------- + + def _emit(self, duration_ms: int, exc_type=None): + parts: list[str] = [] + if self._biz: + parts.append(f"biz={self._biz}") + if self._stage: + parts.append(f"stage={self._stage}") + parts.append(f"duration_ms={duration_ms}") + + if self._static_extra: + self._fields.update(self._static_extra) + + for k, v in self._fields.items(): + parts.append(f"{k}={v}") + + if exc_type is not None: + parts.append(f"error={exc_type.__name__}") + + msg = "[STAGE] " + " ".join(parts) + getattr(logger, self._level, logger.info)(msg) + + def timed_with_status( func=None, *, diff --git a/tests/extras/nli_model/test_client_integration.py b/tests/extras/nli_model/test_client_integration.py deleted file mode 100644 index 5beff14a0..000000000 --- a/tests/extras/nli_model/test_client_integration.py +++ /dev/null @@ -1,129 +0,0 @@ -import threading -import time -import unittest - -from unittest.mock import MagicMock, patch - -import requests -import uvicorn - -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.server.serve import app -from memos.extras.nli_model.types import NLIResult - - -# We need to mock the NLIHandler to avoid loading the heavy model -# but we want to run the real FastAPI server. -class TestNLIClientIntegration(unittest.TestCase): - server_thread = None - stop_server = False - port = 32533 # Use a different port for testing - - @classmethod - def setUpClass(cls): - # Patch the lifespan to inject a mock handler instead of real NLIHandler - cls.mock_handler = MagicMock() - cls.mock_handler.compare_one_to_many.return_value = [ - NLIResult.DUPLICATE, - NLIResult.CONTRADICTION, - ] - - # We need to patch the module where lifespan is defined/used or modify the global variable - # Since 'app' is already imported, we can patch the global nli_handler in serve.py - # But lifespan sets it on startup. - - # Let's patch NLIHandler class in serve.py so when lifespan instantiates it, it gets our mock - cls.handler_patcher = patch("memos.extras.nli_model.server.serve.NLIHandler") - cls.MockHandlerClass = cls.handler_patcher.start() - cls.MockHandlerClass.return_value = cls.mock_handler - - # Start server in a thread - def run_server(): - # Disable logs for uvicorn to keep test output clean - config = uvicorn.Config(app, host="127.0.0.1", port=cls.port, log_level="error") - cls.server = uvicorn.Server(config) - cls.server.run() - - cls.server_thread = threading.Thread(target=run_server, daemon=True) - cls.server_thread.start() - - # Wait for server to be ready - cls._wait_for_server() - - @classmethod - def tearDownClass(cls): - # Stop the server - if hasattr(cls, "server"): - cls.server.should_exit = True - if cls.server_thread: - cls.server_thread.join(timeout=5) - - cls.handler_patcher.stop() - - @classmethod - def _wait_for_server(cls): - url = f"http://127.0.0.1:{cls.port}/docs" - retries = 20 - for _ in range(retries): - try: - response = requests.get(url) - if response.status_code == 200: - return - except requests.ConnectionError: - pass - time.sleep(0.1) - raise RuntimeError("Server failed to start") - - def setUp(self): - self.client = NLIClient(base_url=f"http://127.0.0.1:{self.port}") - # Reset mock calls before each test - self.mock_handler.reset_mock() - # Ensure default behavior - self.mock_handler.compare_one_to_many.return_value = [ - NLIResult.DUPLICATE, - NLIResult.CONTRADICTION, - ] - - def test_real_server_compare_one_to_many(self): - source = "I like apples." - targets = ["I love fruit.", "I hate apples."] - - results = self.client.compare_one_to_many(source, targets) - - # Verify result - self.assertEqual(len(results), 2) - self.assertEqual(results[0], NLIResult.DUPLICATE) - self.assertEqual(results[1], NLIResult.CONTRADICTION) - - # Verify server received the request - self.mock_handler.compare_one_to_many.assert_called_once() - args, _ = self.mock_handler.compare_one_to_many.call_args - self.assertEqual(args[0], source) - self.assertEqual(args[1], targets) - - def test_real_server_empty_targets(self): - source = "I like apples." - targets = [] - - results = self.client.compare_one_to_many(source, targets) - - self.assertEqual(results, []) - # Should not call handler because client handles empty list - self.mock_handler.compare_one_to_many.assert_not_called() - - def test_real_server_handler_error(self): - # Simulate handler error - self.mock_handler.compare_one_to_many.side_effect = ValueError("Something went wrong") - - source = "I like apples." - targets = ["I love fruit."] - - # Client should catch 500 and return UNRELATED - results = self.client.compare_one_to_many(source, targets) - - self.assertEqual(len(results), 1) - self.assertEqual(results[0], NLIResult.UNRELATED) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/mem_reader/test_project_id_propagation.py b/tests/mem_reader/test_project_id_propagation.py index 5a17910ca..bf55aca46 100644 --- a/tests/mem_reader/test_project_id_propagation.py +++ b/tests/mem_reader/test_project_id_propagation.py @@ -53,6 +53,7 @@ def _make_fast_item( manager_user_id: str | None = MANAGER_USER_ID, project_id: str | None = PROJECT_ID, role: str = "user", + internal_info: dict | None = None, ) -> TextualMemoryItem: return TextualMemoryItem( memory=memory, @@ -63,6 +64,7 @@ def _make_fast_item( sources=[SourceMessage(type="chat", role=role, content=memory)], manager_user_id=manager_user_id, project_id=project_id, + internal_info=internal_info, ), ) @@ -216,6 +218,8 @@ def setUp(self): self.reader.graph_db = MagicMock() self.reader.oss_config = None self.reader.skills_dir_config = None + self.reader.memory_version_switch = "off" + self.reader.qwen_llm = MagicMock() # -- _build_window_from_items -------------------------------------------- def test_build_window_propagates_project_id(self): @@ -255,6 +259,50 @@ def test_build_window_picks_first_nonempty(self): self.assertIsNotNone(result) _assert_fields(self, result) + def test_split_large_memory_item_assigns_shared_ingest_batch_id(self): + self.reader._count_tokens = MagicMock(return_value=999) + self.reader.chunker.chunk.return_value = ["chunk one", "chunk two"] + + def fake_make_memory_item( + *, + value, + info, + memory_type, + tags, + key, + sources, + background, + need_embed, + ): + return TextualMemoryItem( + memory=value, + metadata=TreeNodeTextualMemoryMetadata( + user_id=info["user_id"], + session_id=info["session_id"], + memory_type=memory_type, + tags=tags, + key=key, + sources=sources, + background=background, + ), + ) + + self.reader._make_memory_item = MagicMock(side_effect=fake_make_memory_item) + source_item = _make_fast_item("very long source", internal_info={"origin": "doc"}) + + result = self.reader._split_large_memory_item(source_item, max_tokens=10) + + self.assertEqual(len(result), 2) + batch_ids = { + item.metadata.internal_info["ingest_batch_id"] + for item in result + if item.metadata.internal_info and item.metadata.internal_info.get("ingest_batch_id") + } + self.assertEqual(len(batch_ids), 1) + self.assertEqual({item.metadata.internal_info["chunk_index"] for item in result}, {0, 1}) + self.assertEqual({item.metadata.internal_info["chunk_total"] for item in result}, {2}) + self.assertEqual({item.metadata.internal_info["origin"] for item in result}, {"doc"}) + # -- _process_string_fine ------------------------------------------------ def test_process_string_fine_propagates_fields(self): """Fine string extraction must carry project_id/manager_user_id diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py deleted file mode 100644 index a6ac186b7..000000000 --- a/tests/memories/textual/test_history_manager.py +++ /dev/null @@ -1,137 +0,0 @@ -import uuid - -from unittest.mock import MagicMock - -import pytest - -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.types import NLIResult -from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ( - TextualMemoryItem, - TextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.organize.history_manager import ( - MemoryHistoryManager, - _append_related_content, - _detach_related_content, -) - - -@pytest.fixture -def mock_nli_client(): - client = MagicMock(spec=NLIClient) - return client - - -@pytest.fixture -def mock_graph_db(): - return MagicMock(spec=BaseGraphDB) - - -@pytest.fixture -def history_manager(mock_nli_client, mock_graph_db): - return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) - - -def test_detach_related_content(): - original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate 1", "Duplicate 2"] - conflicts = ["Conflict 1", "Conflict 2"] - - # 1. Append content - _append_related_content(item, duplicates, conflicts) - - # Verify content was appended - assert item.memory != original_memory - assert "[possibly conflicting memories]" in item.memory - assert "[possibly duplicate memories]" in item.memory - assert "Duplicate 1" in item.memory - assert "Conflict 1" in item.memory - - # 2. Detach content - _detach_related_content(item) - - # 3. Verify content is restored - assert item.memory == original_memory - - -def test_detach_only_conflicts(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = [] - conflicts = ["Conflict A"] - - _append_related_content(item, duplicates, conflicts) - assert "Conflict A" in item.memory - assert "Duplicate" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_detach_only_duplicates(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate A"] - conflicts = [] - - _append_related_content(item, duplicates, conflicts) - assert "Duplicate A" in item.memory - assert "Conflict" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_truncation(history_manager, mock_nli_client): - # Setup - new_item = TextualMemoryItem(memory="Test") - long_memory = "A" * 300 - related_item = TextualMemoryItem(memory=long_memory) - - mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] - - # Action - history_manager.resolve_history_via_nli(new_item, [related_item]) - - # Assert - assert "possibly duplicate memories" in new_item.memory - assert "..." in new_item.memory # Should be truncated - assert len(new_item.memory) < 1000 # Ensure reasonable length - - -def test_empty_related_items(history_manager, mock_nli_client): - new_item = TextualMemoryItem(memory="Test") - history_manager.resolve_history_via_nli(new_item, []) - - mock_nli_client.compare_one_to_many.assert_not_called() - assert new_item.metadata.history is None or len(new_item.metadata.history) == 0 - - -def test_mark_memory_status(history_manager, mock_graph_db): - # Setup - id1 = uuid.uuid4().hex - id2 = uuid.uuid4().hex - id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] - status = "resolving" - - # Action - history_manager.mark_memory_status(items, status) - - # Assert - assert mock_graph_db.update_node.call_count == 3 - - # Verify we called it correctly (user_name=None is passed by mark_memory_status) - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) diff --git a/tests/memories/textual/test_pre_update_retriever.py b/tests/memories/textual/test_pre_update_retriever.py deleted file mode 100644 index 6bed90abb..000000000 --- a/tests/memories/textual/test_pre_update_retriever.py +++ /dev/null @@ -1,150 +0,0 @@ -import unittest -import uuid - -from dotenv import load_dotenv - -from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.item import ( - SourceMessage, - TextualMemoryItem, - TreeNodeTextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever - - -# Load environment variables -load_dotenv() - - -class TestPreUpdateRecaller(unittest.TestCase): - @classmethod - def setUpClass(cls): - # Initialize graph_db and embedder using factories - # We assume environment variables are set for these to work - try: - cls.graph_db_config = build_graph_db_config() - cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) - - cls.embedder_config = build_embedder_config() - cls.embedder = EmbedderFactory.from_config(cls.embedder_config) - except Exception as e: - raise unittest.SkipTest( - f"Skipping test because initialization failed (likely missing env vars): {e}" - ) from e - - cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) - - # Use a unique user name to isolate tests - cls.user_name = "test_pre_update_recaller_user_" + str(uuid.uuid4())[:8] - - def setUp(self): - # Add some data to the db - self.added_ids = [] - - # Create a memory item to add - self.memory_text = "The user likes to eat apples." - self.embedding = self.embedder.embed([self.memory_text])[0] - - # We use dictionary for metadata to simulate what might be passed or stored - # But wait, add_node expects metadata as a dict usually. - metadata = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": self.embedding, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": ["food", "fruit"], - "key": "user_preference", - "sources": [], - } - - node_id = str(uuid.uuid4()) - self.graph_db.add_node(node_id, self.memory_text, metadata, user_name=self.user_name) - self.added_ids.append(node_id) - - # Add another one - self.memory_text_2 = "The user has a dog named Rex." - self.embedding_2 = self.embedder.embed([self.memory_text_2])[0] - metadata_2 = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": self.embedding_2, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": ["pet", "dog"], - "key": "user_pet", - "sources": [], - } - node_id_2 = str(uuid.uuid4()) - self.graph_db.add_node(node_id_2, self.memory_text_2, metadata_2, user_name=self.user_name) - self.added_ids.append(node_id_2) - - def tearDown(self): - """Clean up test data.""" - for node_id in self.added_ids: - try: - self.graph_db.delete_node(node_id, user_name=self.user_name) - except Exception as e: - print(f"Error deleting node {node_id}: {e}") - - def test_recall_vector_search(self): - """Test recalling using vector search (implicit in recall method).""" - # "I like apples" -> perspective adjustment should match "The user likes to eat apples" - query_text = "I like apples" - - # Create metadata with source to trigger perspective adjustment - # role="user" means "I" -> "User" - source = SourceMessage(role="user", lang="en") - metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") - - item = TextualMemoryItem(memory=query_text, metadata=metadata) - - # The recall method does both vector and keyword search - results = self.recaller.retrieve(item, self.user_name, top_k=5) - - # Verify we got results - self.assertTrue(len(results) > 0, "Should return at least one result") - found_texts = [r.memory for r in results] - - # Check if the relevant memory is found - # "The user likes to eat apples." should be found. - # We check for "apples" to be safe - self.assertTrue( - any("apples" in t for t in found_texts), - f"Expected 'apples' in results, got: {found_texts}", - ) - - def test_recall_keyword_search(self): - """Test recalling where keyword search might be more relevant.""" - # "Rex" is a specific name - query_text = "What is the name of my dog?" - source = SourceMessage(role="user", lang="en") - metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") - - item = TextualMemoryItem(memory=query_text, metadata=metadata) - - results = self.recaller.retrieve(item, self.user_name, top_k=5) - - found_texts = [r.memory for r in results] - self.assertTrue( - any("Rex" in t for t in found_texts), f"Expected 'Rex' in results, got: {found_texts}" - ) - - def test_perspective_adjustment(self): - """Unit test for the _adjust_perspective method specifically.""" - text = "I went to the store myself." - adjusted = self.recaller._adjust_perspective(text, "user", "en") - # I -> User, myself -> User himself - self.assertIn("User", adjusted) - self.assertIn("User himself", adjusted) - - text_zh = "我喜欢吃苹果" - adjusted_zh = self.recaller._adjust_perspective(text_zh, "user", "zh") - # 我 -> 用户 - self.assertIn("用户", adjusted_zh) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/memories/textual/test_pre_update_retriever_latency.py b/tests/memories/textual/test_pre_update_retriever_latency.py deleted file mode 100644 index f4a359de9..000000000 --- a/tests/memories/textual/test_pre_update_retriever_latency.py +++ /dev/null @@ -1,183 +0,0 @@ -import time -import unittest -import uuid - -import numpy as np - -from dotenv import load_dotenv - -from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.item import ( - SourceMessage, - TextualMemoryItem, - TreeNodeTextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever - - -# Load environment variables -load_dotenv() - - -class TestPreUpdateRecallerLatency(unittest.TestCase): - """ - Performance and latency tests for PreUpdateRetriever. - These tests are designed to measure latency and might take longer to run. - """ - - @classmethod - def setUpClass(cls): - # Initialize graph_db and embedder using factories - try: - cls.graph_db_config = build_graph_db_config() - cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) - - cls.embedder_config = build_embedder_config() - cls.embedder = EmbedderFactory.from_config(cls.embedder_config) - except Exception as e: - raise unittest.SkipTest( - f"Skipping test because initialization failed (likely missing env vars): {e}" - ) from e - - cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) - - # Use a unique user name to isolate tests - cls.user_name = "test_pre_update_recaller_latency_user_" + str(uuid.uuid4())[:8] - - def setUp(self): - # Add a substantial amount of data for latency testing - self.added_ids = [] - self.num_items = 20 - - print(f"\nPopulating database with {self.num_items} items for latency test...") - for i in range(self.num_items): - text = f"This is memory item number {i}. The user might enjoy topic {i % 5}." - embedding = self.embedder.embed([text])[0] - metadata = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": embedding, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": [f"tag_{i}"], - "key": f"key_{i}", - "sources": [], - } - node_id = str(uuid.uuid4()) - self.graph_db.add_node(node_id, text, metadata, user_name=self.user_name) - self.added_ids.append(node_id) - - def tearDown(self): - """Clean up test data.""" - print("Cleaning up test data...") - for node_id in self.added_ids: - try: - self.graph_db.delete_node(node_id, user_name=self.user_name) - except Exception as e: - print(f"Error deleting node {node_id}: {e}") - - def measure_network_rtt(self, trials=10): - """Measure average network round-trip time.""" - print(f"Measuring Network RTT (using {trials} probes)...") - latencies = [] - - # Try to use raw driver for minimal overhead if available (Neo4j specific) - if hasattr(self.graph_db, "driver") and hasattr(self.graph_db, "db_name"): - print("Using Neo4j driver for direct ping...") - try: - with self.graph_db.driver.session(database=self.graph_db.db_name) as session: - # Warmup - session.run("RETURN 1").single() - - for _ in range(trials): - start = time.time() - session.run("RETURN 1").single() - latencies.append((time.time() - start) * 1000) - except Exception as e: - print(f"Direct driver ping failed: {e}. Falling back to get_node.") - latencies = [] - - if not latencies: - # Fallback to get_node with non-existent ID - print("Using get_node for ping...") - for _ in range(trials): - probe_id = str(uuid.uuid4()) - start = time.time() - self.graph_db.get_node(probe_id, user_name=self.user_name) - latencies.append((time.time() - start) * 1000) - - avg_rtt = np.mean(latencies) - print(f"Average Network RTT: {avg_rtt:.2f} ms") - return avg_rtt - - def test_recall_latency(self): - """Test and report recall latency statistics.""" - avg_rtt = self.measure_network_rtt() - - queries = [ - "I enjoy topic 1", - "What about topic 3?", - "Do I have any preferences?", - "Tell me about memory item 5", - ] - - latencies = [] - - # Warmup - print("Warming up...") - warmup_item = TextualMemoryItem( - memory="warmup query", - metadata=TreeNodeTextualMemoryMetadata( - sources=[SourceMessage(role="user", lang="en")], memory_type="WorkingMemory" - ), - ) - self.recaller.retrieve(warmup_item, self.user_name, top_k=5) - - print(f"Running {len(queries)} queries...") - for q in queries: - # Pre-calculate embedding to exclude from latency measurement - q_embedding = self.embedder.embed([q])[0] - - item = TextualMemoryItem( - memory=q, - metadata=TreeNodeTextualMemoryMetadata( - sources=[SourceMessage(role="user", lang="en")], - memory_type="WorkingMemory", - embedding=q_embedding, - ), - ) - - start_time = time.time() - results = self.recaller.retrieve(item, self.user_name, top_k=5) - end_time = time.time() - - duration_ms = (end_time - start_time) * 1000 - latencies.append(duration_ms) - print(f"Query: '{q}' -> Found {len(results)} results in {duration_ms:.2f} ms") - - # Assert that we actually found results (sanity check) - if "preferences" not in q: # The preferences query might return 0 - self.assertTrue(len(results) > 0, f"Expected results for query: {q}") - - # Report Results - avg_latency = np.mean(latencies) - p95_latency = np.percentile(latencies, 95) - min_latency = np.min(latencies) - max_latency = np.max(latencies) - internal_processing = avg_latency - avg_rtt - - print("\n--- Latency Results ---") - print(f"Average Network RTT: {avg_rtt:.2f} ms") - print(f"Average Total Latency: {avg_latency:.2f} ms") - print(f"Estimated Internal Processing: {internal_processing:.2f} ms") - print(f"95th Percentile: {p95_latency:.2f} ms") - print(f"Min Latency: {min_latency:.2f} ms") - print(f"Max Latency: {max_latency:.2f} ms") - - self.assertLess(internal_processing, 200, "Internal processing should be under 200ms") - - -if __name__ == "__main__": - unittest.main() diff --git a/src/memos/extras/nli_model/__init__.py b/tests/plugins/__init__.py similarity index 100% rename from src/memos/extras/nli_model/__init__.py rename to tests/plugins/__init__.py diff --git a/tests/plugins/conftest.py b/tests/plugins/conftest.py new file mode 100644 index 000000000..6a1a16b68 --- /dev/null +++ b/tests/plugins/conftest.py @@ -0,0 +1,17 @@ +"""Ensure @hookable-generated hooks are declared for core framework tests. + +In production, @hookable("add") runs at import time of add_handler.py, +declaring add.before / add.after. Core framework tests don't import handler +modules (to avoid heavy dependencies), so we trigger declarations here. + +Plugin-specific hooks are declared in each plugin's own tests/conftest.py. +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("chat") +hookable("feedback") +hookable("memory.get") diff --git a/src/memos/extras/nli_model/server/__init__.py b/tests/plugins/run_plugin_server.py similarity index 100% rename from src/memos/extras/nli_model/server/__init__.py rename to tests/plugins/run_plugin_server.py diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..0eb65b208 --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,474 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + def test_trigger_single_hook_returns_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_single_hook + + define_hook("test.single", description="single", params=["value"]) + + def handler(*, value, **kwargs): + return value + 1 + + register_hook("test.single", handler) + + assert trigger_single_hook("test.single", value=1) == 2 + + def test_trigger_single_hook_requires_exactly_one_callback(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_single_hook + + define_hook("test.single.count", description="single", params=["value"]) + + def handler_a(*, value, **kwargs): + return value + 1 + + def handler_b(*, value, **kwargs): + return value + 2 + + register_hook("test.single.count", handler_a) + register_hook("test.single.count", handler_b) + + try: + trigger_single_hook("test.single.count", value=1) + except RuntimeError as exc: + assert "Multiple plugins" in str(exc) + else: + raise AssertionError("Expected RuntimeError for multiple callbacks") + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.run(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2 diff --git a/tests/test_add_stage_logging.py b/tests/test_add_stage_logging.py new file mode 100644 index 000000000..c41948493 --- /dev/null +++ b/tests/test_add_stage_logging.py @@ -0,0 +1,415 @@ +"""Integration tests for the add-memories call chain after timed_stage refactoring. + +Validates: + 1. SingleCubeView._process_text_mem returns correct business results (regression). + 2. Each stage emits a [STAGE] log with expected biz/stage/fields. + 3. Summary rollup emits all aggregated fields. + 4. CompositeCubeView.add_memories emits multi_cube stage for >1 cubes. + 5. Exceptions in stages do not swallow errors or corrupt results. + +NOTE: SingleCubeView / CompositeCubeView are imported lazily inside fixtures +to work around a known circular import in memos.api.handlers.__init__. +""" + +import logging +import uuid + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _stage_logs(caplog) -> list[str]: + return [r.message for r in caplog.records if r.message.startswith("[STAGE]")] + + +def _make_add_req(**overrides): + from memos.api.product_models import APIADDRequest + + defaults = { + "user_id": "test_user", + "messages": [ + {"role": "user", "content": "remember this"}, + {"role": "assistant", "content": "ok"}, + ], + } + defaults.update(overrides) + return APIADDRequest(**defaults) + + +def _make_memory_item(memory_text: str = "hello world"): + from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata + + return TextualMemoryItem( + id=str(uuid.uuid4()), + memory=memory_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id="u1", + session_id="s1", + memory_type="WorkingMemory", + sources=[], + info={}, + ), + ) + + +@dataclass +class _FakeSingleCube: + """Minimal stub that records calls for CompositeCubeView tests.""" + + cube_id: str + result: list[dict[str, Any]] + + def add_memories(self, add_req): + return list(self.result) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def single_cube_view(): + """Build a SingleCubeView with fully-mocked dependencies.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + mem_item = _make_memory_item() + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [[mem_item]] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [mem_item.id] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + mock_scheduler = MagicMock() + + view = SingleCubeView( + cube_id="cube_test", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=mock_scheduler, + logger=logging.getLogger("test.single_cube"), + searcher=None, + feedback_server=None, + ) + return view, mem_item + + +# =========================================================================== +# SingleCubeView — async + fast (the most common path) +# =========================================================================== + + +class TestSingleCubeAddAsyncFast: + def test_returns_correct_business_result(self, single_cube_view): + view, mem_item = single_cube_view + add_req = _make_add_req(async_mode="async") + + results = view.add_memories(add_req) + + assert len(results) == 1 + assert results[0]["memory_id"] == mem_item.id + assert results[0]["cube_id"] == "cube_test" + assert results[0]["memory_type"] == "WorkingMemory" + + def test_mem_reader_called_with_fast_mode(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.mem_reader.get_memory.assert_called_once() + call_kwargs = view.mem_reader.get_memory.call_args + assert call_kwargs.kwargs["mode"] == "fast" + + def test_text_mem_add_called(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.naive_mem_cube.text_mem.add.assert_called_once() + + def test_scheduler_called(self, single_cube_view): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + view.add_memories(add_req) + + view.mem_scheduler.submit_messages.assert_called_once() + + def test_stage_logs_emitted(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + logs = _stage_logs(caplog) + + stage_names = [] + for log_line in logs: + for part in log_line.split(): + if part.startswith("stage="): + stage_names.append(part.split("=", 1)[1]) + + assert "get_memory" in stage_names + assert "write_db" in stage_names + assert "schedule" in stage_names + assert "summary" in stage_names + + def test_summary_contains_all_fields(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + s = summary[0] + for field in [ + "cube_id=", + "sync_mode=", + "extract_mode=", + "input_msg_count=", + "est_input_tokens=", + "memory_count=", + "get_memory_ms=", + "write_db_ms=", + "schedule_ms=", + "total_ms=", + "per_item_ms=", + ]: + assert field in s, f"Missing field '{field}' in summary: {s}" + + def test_summary_values_are_consistent(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = next(log for log in _stage_logs(caplog) if "stage=summary" in log) + fields = {} + for part in summary.split(): + if "=" in part: + k, v = part.split("=", 1) + fields[k] = v + + assert fields["sync_mode"] == "async" + assert fields["extract_mode"] == "fast" + assert fields["input_msg_count"] == "2" + assert fields["memory_count"] == "1" + assert int(fields["total_ms"]) >= 0 + assert int(fields["get_memory_ms"]) >= 0 + assert int(fields["write_db_ms"]) >= 0 + assert int(fields["schedule_ms"]) >= 0 + + def test_write_db_reports_memory_count(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + write_db = [log for log in _stage_logs(caplog) if "stage=write_db" in log] + assert len(write_db) == 1 + assert "memory_count=1" in write_db[0] + + def test_get_memory_has_cube_id(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + gm = [log for log in _stage_logs(caplog) if "stage=get_memory" in log] + assert len(gm) == 1 + assert "cube_id=cube_test" in gm[0] + + def test_schedule_has_cube_id(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="async") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + sched = [log for log in _stage_logs(caplog) if "stage=schedule" in log] + assert len(sched) == 1 + assert "cube_id=cube_test" in sched[0] + + +# =========================================================================== +# SingleCubeView — sync + fast +# =========================================================================== + + +class TestSingleCubeAddSyncFast: + def test_sync_fast_returns_result(self, single_cube_view): + view, mem_item = single_cube_view + add_req = _make_add_req(async_mode="sync", mode="fast") + + results = view.add_memories(add_req) + + assert len(results) == 1 + assert results[0]["memory_id"] == mem_item.id + + def test_sync_fast_summary_fields(self, single_cube_view, caplog): + view, _ = single_cube_view + add_req = _make_add_req(async_mode="sync", mode="fast") + + with caplog.at_level(logging.INFO): + view.add_memories(add_req) + + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + assert "sync_mode=sync" in summary[0] + assert "extract_mode=fast" in summary[0] + + +# =========================================================================== +# SingleCubeView — zero memories edge case +# =========================================================================== + + +class TestSingleCubeEdgeCases: + def test_zero_memories_does_not_crash(self, caplog): + """When get_memory returns no items, process should still complete.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [[]] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + view = SingleCubeView( + cube_id="cube_empty", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=MagicMock(), + logger=logging.getLogger("test.edge"), + searcher=None, + ) + + add_req = _make_add_req(async_mode="async") + with caplog.at_level(logging.INFO): + results = view.add_memories(add_req) + + assert results == [] + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert len(summary) == 1 + assert "memory_count=0" in summary[0] + + def test_multiple_memories_returned(self, caplog): + """Multiple memory items should all appear in results.""" + from memos.multi_mem_cube.single_cube import SingleCubeView + + items = [_make_memory_item(f"mem_{i}") for i in range(3)] + + mock_mem_reader = MagicMock() + mock_mem_reader.get_memory.return_value = [items] + mock_mem_reader.save_rawfile = False + + mock_text_mem = MagicMock() + mock_text_mem.add.return_value = [it.id for it in items] + mock_text_mem.mode = "async" + + mock_naive_cube = MagicMock() + mock_naive_cube.text_mem = mock_text_mem + + view = SingleCubeView( + cube_id="cube_multi", + naive_mem_cube=mock_naive_cube, + mem_reader=mock_mem_reader, + mem_scheduler=MagicMock(), + logger=logging.getLogger("test.multi"), + searcher=None, + ) + + add_req = _make_add_req(async_mode="async") + with caplog.at_level(logging.INFO): + results = view.add_memories(add_req) + + assert len(results) == 3 + summary = [log for log in _stage_logs(caplog) if "stage=summary" in log] + assert "memory_count=3" in summary[0] + + +# =========================================================================== +# CompositeCubeView — multi_cube stage +# =========================================================================== + + +class TestCompositeCubeAdd: + def test_single_cube_emits_multi_cube_log(self, caplog): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake = _FakeSingleCube(cube_id="c1", result=[{"m": 1}]) + composite = CompositeCubeView( + cube_views=[fake], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + with caplog.at_level(logging.INFO): + results = composite.add_memories(add_req) + + assert len(results) == 1 + multi = [log for log in _stage_logs(caplog) if "stage=multi_cube" in log] + assert len(multi) == 1 + assert "cube_count=1" in multi[0] + + def test_multi_cube_emits_stage_with_duration(self, caplog): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake1 = _FakeSingleCube(cube_id="c1", result=[{"m": 1}]) + fake2 = _FakeSingleCube(cube_id="c2", result=[{"m": 2}]) + composite = CompositeCubeView( + cube_views=[fake1, fake2], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + with caplog.at_level(logging.INFO): + results = composite.add_memories(add_req) + + assert len(results) == 2 + multi = [log for log in _stage_logs(caplog) if "stage=multi_cube" in log] + assert len(multi) == 1 + assert "cube_count=2" in multi[0] + assert "duration_ms=" in multi[0] + + def test_fan_out_results_aggregated(self): + from memos.multi_mem_cube.composite_cube import CompositeCubeView + + fake1 = _FakeSingleCube(cube_id="c1", result=[{"a": 1}, {"a": 2}]) + fake2 = _FakeSingleCube(cube_id="c2", result=[{"b": 3}]) + composite = CompositeCubeView( + cube_views=[fake1, fake2], + logger=logging.getLogger("test.composite"), + ) + + add_req = _make_add_req() + results = composite.add_memories(add_req) + + assert len(results) == 3 diff --git a/tests/test_utils_timing.py b/tests/test_utils_timing.py new file mode 100644 index 000000000..b4d5cb989 --- /dev/null +++ b/tests/test_utils_timing.py @@ -0,0 +1,390 @@ +"""Tests for memos.utils timing utilities: timed_stage, timed, timed_with_status. + +Covers: + - timed_stage: context-manager, decorator, emit_now, duration_ms propagation, + set(), error logging, extra callback, static extra dict + - timed: regression (return value, threshold, log_prefix, log=False) + - timed_with_status: regression (success, failure+fallback, log_args, log_extra_args) +""" + +import logging +import time + +import pytest + +from memos.utils import timed, timed_stage, timed_with_status + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _collect_stage_logs(caplog): + """Return all log messages that start with '[STAGE]'.""" + return [r.message for r in caplog.records if r.message.startswith("[STAGE]")] + + +def _collect_timer_logs(caplog): + """Return all log messages that start with '[TIMER]'.""" + return [r.message for r in caplog.records if r.message.startswith("[TIMER]")] + + +def _collect_timer_with_status_logs(caplog): + """Return all log messages that start with '[TIMER_WITH_STATUS]'.""" + return [r.message for r in caplog.records if r.message.startswith("[TIMER_WITH_STATUS]")] + + +# =========================================================================== +# timed_stage — context manager +# =========================================================================== + + +class TestTimedStageContextManager: + def test_basic_log_output(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "parse", cube_id="c1"): + pass + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=add" in logs[0] + assert "stage=parse" in logs[0] + assert "cube_id=c1" in logs[0] + assert "duration_ms=" in logs[0] + + def test_duration_ms_is_populated(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "embedding") as ts: + time.sleep(0.05) + assert ts.duration_ms >= 40 # at least ~50ms minus jitter + + def test_set_adds_fields(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "write_db") as ts: + ts.set(memory_count=7) + logs = _collect_stage_logs(caplog) + assert "memory_count=7" in logs[0] + + def test_set_overwrites_fields(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "write_db", memory_count=0) as ts: + ts.set(memory_count=5) + logs = _collect_stage_logs(caplog) + assert "memory_count=5" in logs[0] + assert "memory_count=0" not in logs[0] + + def test_exception_logged_but_propagated(self, caplog): + with ( + caplog.at_level(logging.INFO), + pytest.raises(ValueError, match="boom"), + timed_stage("add", "parse"), + ): + raise ValueError("boom") + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "error=ValueError" in logs[0] + + def test_duration_ms_available_after_exception(self, caplog): + with caplog.at_level(logging.INFO): + ts = timed_stage("add", "parse") + with pytest.raises(RuntimeError), ts: + time.sleep(0.02) + raise RuntimeError("fail") + assert ts.duration_ms >= 15 + + def test_no_biz_no_stage(self, caplog): + """Empty biz/stage should not emit those fields.""" + with caplog.at_level(logging.INFO), timed_stage(x=1): + pass + logs = _collect_stage_logs(caplog) + assert "biz=" not in logs[0] + assert "stage=" not in logs[0] + assert "x=1" in logs[0] + + def test_duration_ms_readable_downstream(self): + """Downstream code can reference ts.duration_ms for summary rollup.""" + with timed_stage("add", "get_memory") as ts: + time.sleep(0.01) + get_memory_ms = ts.duration_ms + assert isinstance(get_memory_ms, int) + assert get_memory_ms >= 5 + + +# =========================================================================== +# timed_stage — decorator +# =========================================================================== + + +class TestTimedStageDecorator: + def test_decorator_basic(self, caplog): + @timed_stage("search", "recall") + def do_recall(): + return [1, 2, 3] + + with caplog.at_level(logging.INFO): + result = do_recall() + assert result == [1, 2, 3] + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=search" in logs[0] + assert "stage=recall" in logs[0] + + def test_decorator_uses_func_name_when_no_stage(self, caplog): + @timed_stage("add") + def my_custom_func(): + pass + + with caplog.at_level(logging.INFO): + my_custom_func() + logs = _collect_stage_logs(caplog) + assert "stage=my_custom_func" in logs[0] + + def test_decorator_preserves_function_metadata(self): + @timed_stage("add", "test") + def documented_func(): + """I have a docstring.""" + return 42 + + assert documented_func.__name__ == "documented_func" + assert documented_func.__doc__ == "I have a docstring." + + def test_decorator_with_extra_callback(self, caplog): + class Service: + cube_id = "cube_abc" + + @timed_stage("add", "write_db", extra=lambda self: {"cube_id": self.cube_id}) + def write(self): + return "ok" + + svc = Service() + with caplog.at_level(logging.INFO): + result = svc.write() + assert result == "ok" + logs = _collect_stage_logs(caplog) + assert "cube_id=cube_abc" in logs[0] + + def test_decorator_extra_callback_error_does_not_break(self, caplog): + @timed_stage("add", "risky", extra=lambda: (_ for _ in ()).throw(RuntimeError("oops"))) + def risky_func(): + return "still works" + + with caplog.at_level(logging.WARNING): + result = risky_func() + assert result == "still works" + + def test_decorator_exception_propagated(self, caplog): + @timed_stage("add", "crash") + def crasher(): + raise TypeError("type error") + + with caplog.at_level(logging.INFO), pytest.raises(TypeError, match="type error"): + crasher() + logs = _collect_stage_logs(caplog) + assert "error=TypeError" not in logs[0] # decorator path doesn't pass exc_type + + +# =========================================================================== +# timed_stage.emit_now +# =========================================================================== + + +class TestTimedStageEmitNow: + def test_emit_now_basic(self, caplog): + with caplog.at_level(logging.INFO): + timed_stage.emit_now("add", "summary", total_ms=1200, per_item_ms=240) + logs = _collect_stage_logs(caplog) + assert len(logs) == 1 + assert "biz=add" in logs[0] + assert "stage=summary" in logs[0] + assert "total_ms=1200" in logs[0] + assert "per_item_ms=240" in logs[0] + assert "duration_ms" not in logs[0] + + def test_emit_now_no_extra_fields(self, caplog): + with caplog.at_level(logging.INFO): + timed_stage.emit_now("search", "summary") + logs = _collect_stage_logs(caplog) + assert logs[0] == "[STAGE] biz=search stage=summary" + + +# =========================================================================== +# timed_stage — static extra dict +# =========================================================================== + + +class TestTimedStageStaticExtra: + def test_static_extra_dict(self, caplog): + with caplog.at_level(logging.INFO), timed_stage("add", "parse", extra={"env": "prod"}): + pass + logs = _collect_stage_logs(caplog) + assert "env=prod" in logs[0] + + +# =========================================================================== +# timed — regression tests (original behavior must not change) +# =========================================================================== + + +class TestTimedRegression: + def test_return_value_preserved(self): + @timed + def add(a, b): + return a + b + + assert add(1, 2) == 3 + + def test_no_log_below_threshold(self, caplog): + """@timed only logs when elapsed >= 100ms.""" + + @timed + def fast_func(): + return "fast" + + with caplog.at_level(logging.INFO): + result = fast_func() + assert result == "fast" + logs = _collect_timer_logs(caplog) + assert len(logs) == 0 + + def test_log_above_threshold(self, caplog): + @timed + def slow_func(): + time.sleep(0.12) + return "slow" + + with caplog.at_level(logging.INFO): + result = slow_func() + assert result == "slow" + logs = _collect_timer_logs(caplog) + assert len(logs) == 1 + assert "slow_func" in logs[0] + + def test_log_false_disables(self, caplog): + @timed(log=False) + def no_log_func(): + time.sleep(0.12) + return 99 + + with caplog.at_level(logging.INFO): + result = no_log_func() + assert result == 99 + logs = _collect_timer_logs(caplog) + assert len(logs) == 0 + + def test_log_prefix(self, caplog): + @timed(log_prefix="MY_PREFIX") + def prefixed(): + time.sleep(0.12) + return True + + with caplog.at_level(logging.INFO): + prefixed() + logs = _collect_timer_logs(caplog) + assert "MY_PREFIX" in logs[0] + + def test_both_decorator_forms(self): + """@timed and @timed() should both work.""" + + @timed + def bare(): + return 1 + + @timed() + def parens(): + return 2 + + assert bare() == 1 + assert parens() == 2 + + +# =========================================================================== +# timed_with_status — regression tests +# =========================================================================== + + +class TestTimedWithStatusRegression: + def test_success_logging(self, caplog): + @timed_with_status + def ok_func(): + return "hello" + + with caplog.at_level(logging.INFO): + result = ok_func() + assert result == "hello" + logs = _collect_timer_with_status_logs(caplog) + assert len(logs) == 1 + assert "status: SUCCESS" in logs[0] + assert "ok_func" in logs[0] + + def test_failure_logging_no_fallback(self, caplog): + @timed_with_status + def fail_func(): + raise RuntimeError("bad") + + with caplog.at_level(logging.INFO): + fail_func() + logs = _collect_timer_with_status_logs(caplog) + assert len(logs) == 1 + assert "status: FAILED" in logs[0] + assert "RuntimeError" in logs[0] + + def test_failure_with_fallback(self, caplog): + @timed_with_status(fallback=lambda e, *a, **kw: "fallback_val") + def fail_func(): + raise RuntimeError("bad") + + with caplog.at_level(logging.INFO): + result = fail_func() + assert result == "fallback_val" + logs = _collect_timer_with_status_logs(caplog) + assert "status: FAILED" in logs[0] + + def test_log_prefix(self, caplog): + @timed_with_status(log_prefix="CUSTOM") + def prefixed(): + return 1 + + with caplog.at_level(logging.INFO): + prefixed() + logs = _collect_timer_with_status_logs(caplog) + assert "CUSTOM" in logs[0] + + def test_log_args(self, caplog): + @timed_with_status(log_args=["user_id"]) + def with_args(user_id="u1"): + return user_id + + with caplog.at_level(logging.INFO): + with_args(user_id="u42") + logs = _collect_timer_with_status_logs(caplog) + assert "user_id=u42" in logs[0] + + def test_log_extra_args_dict(self, caplog): + @timed_with_status(log_extra_args={"region": "us-west"}) + def with_extra(): + return True + + with caplog.at_level(logging.INFO): + with_extra() + logs = _collect_timer_with_status_logs(caplog) + assert "region=us-west" in logs[0] + + def test_log_extra_args_callable(self, caplog): + @timed_with_status(log_extra_args=lambda *a, **kw: {"dynamic": "yes"}) + def with_dynamic(): + return True + + with caplog.at_level(logging.INFO): + with_dynamic() + logs = _collect_timer_with_status_logs(caplog) + assert "dynamic=yes" in logs[0] + + def test_both_decorator_forms(self): + """@timed_with_status and @timed_with_status() should both work.""" + + @timed_with_status + def bare(): + return 1 + + @timed_with_status() + def parens(): + return 2 + + assert bare() == 1 + assert parens() == 2