diff --git a/.gitignore b/.gitignore index c972d59de..51e2f7ab4 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,8 @@ report/ cov-report/ .tox/ .nox/ +report/ +cov-report/ .coverage .coverage.* .cache diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 01f832079..69efedeb3 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -369,6 +369,26 @@ def get_memreader_config() -> dict[str, Any]: return {"backend": "openai", "config": config} + @staticmethod + def get_qwen_llm_config() -> dict[str, Any] | None: + if not os.getenv("QWEN_API_KEY"): + return None + return { + "backend": "qwen", + "config": { + "model_name_or_path": os.getenv("QWEN_MODEL", "qwen-flash"), + "temperature": float(os.getenv("QWEN_TEMPERATURE", "0.8")), + "max_tokens": int(os.getenv("QWEN_MAX_TOKENS", "8000")), + "top_p": float(os.getenv("QWEN_TOP_P", "0.9")), + "top_k": int(os.getenv("QWEN_TOP_K", "50")), + "remove_think_prefix": os.getenv("QWEN_REMOVE_THINK_PREFIX", "true").lower() + == "true", + "api_key": os.getenv("QWEN_API_KEY", ""), + "api_base": os.getenv("QWEN_API_BASE", ""), + "model_schema": os.getenv("QWEN_MODEL_SCHEMA", "memos.configs.llm.QwenLLMConfig"), + }, + } + @staticmethod def get_memreader_general_llm_config() -> dict[str, Any]: """Get general LLM configuration for non-chat/doc tasks. @@ -639,6 +659,7 @@ def get_oss_config() -> dict[str, Any] | None: return config + @staticmethod def get_internet_config() -> dict[str, Any]: """Get internet retriever configuration. @@ -705,8 +726,9 @@ def get_internet_config() -> dict[str, Any]: @staticmethod def get_nli_config() -> dict[str, Any]: - """Get NLI model configuration.""" + """Get relation-judge configuration for memory-version candidate matching.""" return { + "provider": os.getenv("MEM_VERSION_RELATION_JUDGE_PROVIDER", "llm"), "base_url": os.getenv("NLI_MODEL_BASE_URL", "http://localhost:32532"), } @@ -952,6 +974,7 @@ def get_product_default_config() -> dict[str, Any]: "backend": reader_config["backend"], "config": { "llm": APIConfig.get_memreader_config(), + "qwen_llm": APIConfig.get_qwen_llm_config(), # General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.) "general_llm": APIConfig.get_memreader_general_llm_config(), # Image parser LLM (requires vision model) @@ -986,6 +1009,7 @@ def get_product_default_config() -> dict[str, Any]: "SKILLS_LOCAL_DIR", "/tmp/upload_skill_memory/" ), }, + "memory_version_switch": os.getenv("MEM_READER_MEM_VERSION_SWITCH", "off"), }, }, "enable_textual_memory": True, diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e9ed4f955..4240545f6 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -15,6 +15,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable from memos.types import MessageList @@ -37,6 +38,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_reader", "mem_scheduler", "feedback_server" ) + @hookable("add") def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: """ Main handler for add memories endpoint. diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index a01fffef8..3536cee09 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -32,14 +32,14 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.memories.textual.simple_tree import SimpleTreeTextMemory -from memos.memories.textual.tree_text_memory.organize.history_manager import MemoryHistoryManager from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.plugins.component_bootstrap import build_plugin_context +from memos.plugins.manager import plugin_manager if TYPE_CHECKING: from memos.memories.textual.tree import TreeTextMemory -from memos.extras.nli_model.client import NLIClient from memos.mem_agent.deepsearch_agent import DeepSearchMemAgent from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -122,6 +122,12 @@ def init_server() -> dict[str, Any]: existing code that uses the components. """ logger.info("Initializing MemOS server components...") + logger.info( + "[INIT_SERVER] env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX=%s, env_POLAR_DB_DB_NAME=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), + os.getenv("POLAR_DB_DB_NAME"), + ) # Initialize Redis client first as it is a core dependency for features like scheduler status tracking if os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true": @@ -169,10 +175,25 @@ def init_server() -> dict[str, Any]: else None ) embedder = EmbedderFactory.from_config(embedder_config) - nli_client = NLIClient(base_url=nli_client_config["base_url"]) - memory_history_manager = MemoryHistoryManager(nli_client=nli_client, graph_db=graph_db) + + plugin_context = build_plugin_context( + graph_db=graph_db, + embedder=embedder, + default_cube_config=default_cube_config, + nli_client_config=nli_client_config, + mem_reader_config=mem_reader_config, + reranker_config=reranker_config, + feedback_reranker_config=feedback_reranker_config, + internet_retriever_config=internet_retriever_config, + ) + plugin_manager.discover() + plugin_manager.init_components(plugin_context) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( @@ -303,6 +324,4 @@ def init_server() -> dict[str, Any]: "feedback_server": feedback_server, "redis_client": redis_client, "deepsearch_agent": deepsearch_agent, - "nli_client": nli_client, - "memory_history_manager": memory_history_manager, } diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index ee88ae639..c8024baa3 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -141,6 +141,7 @@ def separate_knowledge_and_conversation_mem(memories: list[dict[str, Any]]): sources = item.get("metadata", {}).get("sources", []) if ( item["metadata"]["memory_type"] != "RawFileMemory" + and sources and len(sources) > 0 and "type" in sources[0] and sources[0]["type"] == "file" diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 1f1e6ccde..a9afe554c 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -9,13 +9,21 @@ from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware from memos.api.routers.server_router import router as server_router +from memos.plugins.manager import plugin_manager load_dotenv() +plugin_manager.discover() + # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +logger.info( + "[SERVER_API] load_dotenv completed. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), +) app = FastAPI( title="MemOS Server REST APIs", @@ -49,6 +57,8 @@ def health_check(): # Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) +plugin_manager.init_app(app) + if __name__ == "__main__": import argparse diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index d4844d73f..9ed791fa8 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator, model_validator @@ -76,6 +76,13 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig): default=None, description="Skills directory for the MemReader", ) + memory_version_switch: Literal["on", "off"] = Field( + default="off", + description="Turn on memory version or off", + ) + + # Allow passing additional fields without raising validation errors + model_config = ConfigDict(extra="allow", strict=True) class StrategyStructMemReaderConfig(BaseMemReaderConfig): diff --git a/src/memos/extras/nli_model/client.py b/src/memos/extras/nli_model/client.py deleted file mode 100644 index a02dae9f6..000000000 --- a/src/memos/extras/nli_model/client.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging - -import requests - -from memos.extras.nli_model.types import NLIResult - - -logger = logging.getLogger(__name__) - - -class NLIClient: - """ - Client for interacting with the deployed NLI model service. - """ - - def __init__(self, base_url: str = "http://localhost:32532"): - self.base_url = base_url.rstrip("/") - self.session = requests.Session() - - def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: - """ - Compare one source text against multiple target memories using the NLI service. - - Args: - source: The new memory content. - targets: List of existing memory contents to compare against. - - Returns: - List of NLIResult corresponding to each target. - """ - if not targets: - return [] - - url = f"{self.base_url}/compare_one_to_many" - # Match schemas.CompareRequest - payload = {"source": source, "targets": targets} - - try: - response = self.session.post(url, json=payload, timeout=30) - response.raise_for_status() - data = response.json() - - # Match schemas.CompareResponse - results_str = data.get("results", []) - - results = [] - for res_str in results_str: - try: - results.append(NLIResult(res_str)) - except ValueError: - logger.warning( - f"[NLIClient] Unknown result: {res_str}, defaulting to UNRELATED" - ) - results.append(NLIResult.UNRELATED) - - return results - - except requests.RequestException as e: - logger.error(f"[NLIClient] Request failed: {e}") - # Fallback: if NLI fails, assume all are Unrelated to avoid blocking the flow. - return [NLIResult.UNRELATED] * len(targets) diff --git a/src/memos/extras/nli_model/server/README.md b/src/memos/extras/nli_model/server/README.md deleted file mode 100644 index f6886e0e4..000000000 --- a/src/memos/extras/nli_model/server/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# NLI Model Server - -This directory contains the standalone server for the Natural Language Inference (NLI) model used by MemOS. - -## Prerequisites - -- Python 3.10+ -- CUDA-capable GPU (Recommended for performance) -- `torch` and `transformers` libraries (required for the server) - -## Running the Server - -You can run the server using the module syntax from the project root to ensure imports work correctly. - -### 1. Basic Start -```bash -python -m memos.extras.nli_model.server.serve -``` - -### 2. Configuration -You can configure the server by editing config.py: - -- `HOST`: The host to bind to (default: `0.0.0.0`) -- `PORT`: The port to bind to (default: `32532`) -- `NLI_DEVICE`: The device to run the model on. - - `cuda` (Default, uses cuda:0 if available, else fallback to mps/cpu) - - `cuda:0` (Specific GPU) - - `mps` (Apple Silicon) - - `cpu` (CPU) - -## API Usage - -### Compare One to Many -**POST** `/compare_one_to_many` - -**Request Body:** -```json -{ - "source": "I just ate an apple.", - "targets": [ - "I ate a fruit.", - "I hate apples.", - "The sky is blue." - ] -} -``` - -## Testing - -An end-to-end example script is provided to verify the server's functionality. This script starts the server locally and runs a client request to verify the NLI logic. - -### End-to-End Test - -Run the example script from the project root: - -```bash -python examples/extras/nli_e2e_example.py -``` - -**Response:** -```json -{ - "results": [ - "Duplicate", // Entailment - "Contradiction", // Contradiction - "Unrelated" // Neutral - ] -} -``` diff --git a/src/memos/extras/nli_model/server/config.py b/src/memos/extras/nli_model/server/config.py deleted file mode 100644 index d2e12175d..000000000 --- a/src/memos/extras/nli_model/server/config.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - - -NLI_MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli" - -# Configuration -# You can set the device directly here. -# Examples: -# - "cuda" : Use default GPU (cuda:0) if available, else auto-fallback -# - "cuda:0" : Use specific GPU -# - "mps" : Use Apple Silicon GPU (if available) -# - "cpu" : Use CPU -NLI_DEVICE = "cuda" -NLI_MODEL_HOST = "0.0.0.0" -NLI_MODEL_PORT = 32532 - -# Configure logging for NLI Server -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s | %(name)s - %(levelname)s - %(message)s", - handlers=[logging.StreamHandler(), logging.FileHandler("nli_server.log")], -) -logger = logging.getLogger("nli_server") diff --git a/src/memos/extras/nli_model/server/handler.py b/src/memos/extras/nli_model/server/handler.py deleted file mode 100644 index 3e98ddeb0..000000000 --- a/src/memos/extras/nli_model/server/handler.py +++ /dev/null @@ -1,186 +0,0 @@ -import re - -from memos.extras.nli_model.server.config import NLI_MODEL_NAME, logger -from memos.extras.nli_model.types import NLIResult - - -# Placeholder for lazy imports -torch = None -AutoModelForSequenceClassification = None -AutoTokenizer = None - - -def _map_label_to_result(raw: str) -> NLIResult: - t = raw.lower() - if "entail" in t: - return NLIResult.DUPLICATE - if "contrad" in t or "refut" in t: - return NLIResult.CONTRADICTION - # Neutral or unknown - return NLIResult.UNRELATED - - -def _clean_temporal_markers(s: str) -> str: - # Remove temporal/aspect markers that might cause contradiction - # Chinese markers (simple replace is usually okay as they are characters) - zh_markers = ["刚刚", "曾经", "正在", "目前", "现在"] - for m in zh_markers: - s = s.replace(m, "") - - # English markers (need word boundaries to avoid "snow" -> "s") - en_markers = ["just", "once", "currently", "now"] - pattern = r"\b(" + "|".join(en_markers) + r")\b" - s = re.sub(pattern, "", s, flags=re.IGNORECASE) - - # Cleanup extra spaces - s = re.sub(r"\s+", " ", s).strip() - return s - - -class NLIHandler: - """ - NLI Model Handler for inference. - Requires `torch` and `transformers` to be installed. - """ - - def __init__(self, device: str = "cpu", use_fp16: bool = True, use_compile: bool = True): - global torch, AutoModelForSequenceClassification, AutoTokenizer - try: - import torch - - from transformers import AutoModelForSequenceClassification, AutoTokenizer - except ImportError as e: - raise ImportError( - "NLIHandler requires 'torch' and 'transformers'. " - "Please install them via 'pip install torch transformers' or use the requirements.txt." - ) from e - - self.device = self._resolve_device(device) - logger.info(f"Final resolved device: {self.device}") - - # Set defaults based on device if not explicitly provided - is_cuda = "cuda" in self.device - if not is_cuda: - use_fp16 = False - use_compile = False - - self.tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL_NAME) - - model_kwargs = {} - if use_fp16 and is_cuda: - model_kwargs["torch_dtype"] = torch.float16 - - self.model = AutoModelForSequenceClassification.from_pretrained( - NLI_MODEL_NAME, **model_kwargs - ).to(self.device) - self.model.eval() - - self.id2label = {int(k): v for k, v in self.model.config.id2label.items()} - self.softmax = torch.nn.Softmax(dim=-1).to(self.device) - - if use_compile and hasattr(torch, "compile"): - logger.info("Compiling model with torch.compile...") - self.model = torch.compile(self.model) - - def _resolve_device(self, device: str) -> str: - d = device.strip().lower() - - has_cuda = torch.cuda.is_available() - has_mps = torch.backends.mps.is_available() if hasattr(torch.backends, "mps") else False - - if d == "cpu": - return "cpu" - - if d.startswith("cuda"): - if has_cuda: - if d == "cuda": - return "cuda:0" - return d - - # Fallback if CUDA not available - if has_mps: - logger.warning( - f"Device '{device}' requested but CUDA not available. Fallback to MPS." - ) - return "mps" - - logger.warning( - f"Device '{device}' requested but CUDA/MPS not available. Fallback to CPU." - ) - return "cpu" - - if d == "mps": - if has_mps: - return "mps" - - logger.warning(f"Device '{device}' requested but MPS not available. Fallback to CPU.") - return "cpu" - - # Fallback / Auto-detect for other cases (e.g. "gpu" or unknown) - if has_cuda: - return "cuda:0" - if has_mps: - return "mps" - - return "cpu" - - def predict_batch(self, premises: list[str], hypotheses: list[str]) -> list[NLIResult]: - # Clean inputs - premises = [_clean_temporal_markers(p) for p in premises] - hypotheses = [_clean_temporal_markers(h) for h in hypotheses] - - # Batch tokenize with padding - inputs = self.tokenizer( - premises, hypotheses, return_tensors="pt", truncation=True, max_length=512, padding=True - ).to(self.device) - with torch.no_grad(): - out = self.model(**inputs) - probs = self.softmax(out.logits) - - results = [] - for p in probs: - idx = int(torch.argmax(p).item()) - res = self.id2label.get(idx, str(idx)) - results.append(_map_label_to_result(res)) - return results - - def compare_one_to_many(self, source: str, targets: list[str]) -> list[NLIResult]: - """ - Compare one source text against multiple target memories efficiently using batch processing. - Performs bidirectional checks (Source <-> Target) for each pair. - """ - if not targets: - return [] - - n = len(targets) - # Construct batch: - # First n pairs: Source -> Target_i - # Next n pairs: Target_i -> Source - premises = [source] * n + targets - hypotheses = targets + [source] * n - - # Run single large batch inference - raw_results = self.predict_batch(premises, hypotheses) - - # Split results back - results_ab = raw_results[:n] - results_ba = raw_results[n:] - - final_results = [] - for i in range(n): - res_ab = results_ab[i] - res_ba = results_ba[i] - - # 1. Any Contradiction -> Contradiction (Sensitive detection, filtered by LLM later) - if res_ab == NLIResult.CONTRADICTION or res_ba == NLIResult.CONTRADICTION: - final_results.append(NLIResult.CONTRADICTION) - - # 2. Any Entailment -> Duplicate (as per user requirement) - elif res_ab == NLIResult.DUPLICATE or res_ba == NLIResult.DUPLICATE: - final_results.append(NLIResult.DUPLICATE) - - # 3. Otherwise (Both Neutral) -> Unrelated - else: - final_results.append(NLIResult.UNRELATED) - - return final_results diff --git a/src/memos/extras/nli_model/server/serve.py b/src/memos/extras/nli_model/server/serve.py deleted file mode 100644 index 0ed9eae65..000000000 --- a/src/memos/extras/nli_model/server/serve.py +++ /dev/null @@ -1,44 +0,0 @@ -from contextlib import asynccontextmanager - -import uvicorn - -from fastapi import FastAPI, HTTPException - -from memos.extras.nli_model.server.config import NLI_DEVICE, NLI_MODEL_HOST, NLI_MODEL_PORT -from memos.extras.nli_model.server.handler import NLIHandler -from memos.extras.nli_model.types import CompareRequest, CompareResponse - - -# Global handler instance -nli_handler: NLIHandler | None = None - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global nli_handler - nli_handler = NLIHandler(device=NLI_DEVICE) - yield - # Clean up if needed - nli_handler = None - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/compare_one_to_many", response_model=CompareResponse) -async def compare_one_to_many(request: CompareRequest): - if nli_handler is None: - raise HTTPException(status_code=503, detail="Model not loaded") - try: - results = nli_handler.compare_one_to_many(request.source, request.targets) - return CompareResponse(results=results) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - -def start_server(host: str = "0.0.0.0", port: int = 32532): - uvicorn.run(app, host=host, port=port) - - -if __name__ == "__main__": - start_server(host=NLI_MODEL_HOST, port=NLI_MODEL_PORT) diff --git a/src/memos/extras/nli_model/types.py b/src/memos/extras/nli_model/types.py deleted file mode 100644 index 619f8f508..000000000 --- a/src/memos/extras/nli_model/types.py +++ /dev/null @@ -1,18 +0,0 @@ -from enum import Enum - -from pydantic import BaseModel - - -class NLIResult(Enum): - DUPLICATE = "Duplicate" - CONTRADICTION = "Contradiction" - UNRELATED = "Unrelated" - - -class CompareRequest(BaseModel): - source: str - targets: list[str] - - -class CompareResponse(BaseModel): - results: list[NLIResult] diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index a57a40676..991541156 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -33,6 +33,8 @@ extract_working_binding_ids, ) from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager +from memos.plugins.hook_defs import H +from memos.plugins.hooks import trigger_single_hook if TYPE_CHECKING: @@ -245,6 +247,7 @@ def _single_add_operation( datetime.now().isoformat() ) to_add_memory.metadata.background = new_memory_item.metadata.background + to_add_memory.metadata.sources = new_memory_item.metadata.sources added_ids = self._retry_db_operation( lambda: self.memory_manager.add([to_add_memory], user_name=user_name, use_batch=False) @@ -288,33 +291,43 @@ def _single_update_operation( new_memory_item.memory = operation["text"] new_memory_item.metadata.embedding = self._batch_embed([operation["text"]])[0] - if memory_type == "WorkingMemory": - fields = { - "memory": new_memory_item.memory, - "key": new_memory_item.metadata.key, - "tags": new_memory_item.metadata.tags, - "embedding": new_memory_item.metadata.embedding, - "background": new_memory_item.metadata.background, - "covered_history": old_memory_item.id, - } - self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) - item_id = old_memory_item.id - else: - done = self._single_add_operation( - old_memory_item, new_memory_item, user_id, user_name, async_mode + if getattr(self.mem_reader, "memory_version_switch", "off") != "on": + if memory_type == "WorkingMemory": + fields = { + "memory": new_memory_item.memory, + "key": new_memory_item.metadata.key, + "tags": new_memory_item.metadata.tags, + "embedding": new_memory_item.metadata.embedding, + "background": new_memory_item.metadata.background, + "covered_history": old_memory_item.id, + } + self.graph_store.update_node(old_memory_item.id, fields=fields, user_name=user_name) + item_id = old_memory_item.id + else: + done = self._single_add_operation( + old_memory_item, new_memory_item, user_id, user_name, async_mode + ) + item_id = done.get("id") + self.graph_store.update_node( + item_id, {"covered_history": old_memory_item.id}, user_name=user_name + ) + self.graph_store.update_node( + old_memory_item.id, {"status": "archived"}, user_name=user_name + ) + + logger.info( + f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" ) - item_id = done.get("id") - self.graph_store.update_node( - item_id, {"covered_history": old_memory_item.id}, user_name=user_name + else: + item_id = self._single_update_operation_with_versions( + old_memory_item=old_memory_item, + new_memory_item=new_memory_item, + user_name=user_name, ) - self.graph_store.update_node( - old_memory_item.id, {"status": "archived"}, user_name=user_name + logger.info( + f"[Memory Feedback UPDATE] Updated:{item_id} | history appended | memory_type: {old_memory_item.metadata.memory_type}" ) - logger.info( - f"[Memory Feedback UPDATE] New Add:{item_id} | Set archived:{old_memory_item.id} | memory_type: {memory_type}" - ) - return { "id": item_id, "text": new_memory_item.memory, @@ -323,6 +336,78 @@ def _single_update_operation( "origin_memory": old_memory_item.memory, } + def _single_update_operation_with_versions( + self, + old_memory_item: TextualMemoryItem, + new_memory_item: TextualMemoryItem, + user_name: str, + ) -> str: + try: + updated_item, archived_item, archived_metadata, updated_fields = trigger_single_hook( + H.MEMORY_VERSION_APPLY_FEEDBACK_UPDATE, + old_item=old_memory_item, + new_item=new_memory_item, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] history fallback for %s: %s", old_memory_item.id, e + ) + updated_item = old_memory_item.model_copy(deep=True) + updated_item.memory = new_memory_item.memory + updated_item.metadata.key = new_memory_item.metadata.key + updated_item.metadata.tags = new_memory_item.metadata.tags + updated_item.metadata.background = new_memory_item.metadata.background + if getattr(new_memory_item.metadata, "sources", None) is not None: + current_sources = list(updated_item.metadata.sources or []) + updated_item.metadata.sources = ( + list(new_memory_item.metadata.sources or []) + current_sources + ) + if getattr(new_memory_item.metadata, "embedding", None) is not None: + updated_item.metadata.embedding = new_memory_item.metadata.embedding + if updated_item.metadata.memory_type == "PreferenceMemory": + updated_item.metadata.preference = updated_item.memory + updated_fields = { + "memory": updated_item.memory, + "key": updated_item.metadata.key, + "tags": updated_item.metadata.tags, + "embedding": updated_item.metadata.embedding, + "background": updated_item.metadata.background, + "sources": [ + source.model_dump(exclude_none=True) + if hasattr(source, "model_dump") + else source + for source in (updated_item.metadata.sources or []) + ], + "covered_history": old_memory_item.id, + } + archived_item = None + archived_metadata = None + + if archived_item and archived_metadata: + try: + self.graph_store.add_node( + id=archived_item.id, + memory=archived_item.memory, + metadata=archived_metadata, + user_name=user_name, + ) + except Exception as e: + logger.warning( + "[Memory Feedback UPDATE] archive add failed for %s: %s", + old_memory_item.id, + e, + ) + self._retry_db_operation( + lambda: self.graph_store.update_node( + id=updated_item.id, + fields=updated_fields, + user_name=user_name, + ) + ) + self._del_working_binding(user_name, [old_memory_item]) + return updated_item.id + def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> set[str]: """Delete working memory bindings""" bindings_to_delete = extract_working_binding_ids(mem_items) @@ -331,9 +416,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> f"[Memory Feedback UPDATE] Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" ) - delete_ids = [] - if bindings_to_delete: - delete_ids = list({bindings_to_delete}) + delete_ids = list(bindings_to_delete) for mid in delete_ids: try: @@ -346,6 +429,7 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> logger.warning( f"[0107 Feedback Core:_del_working_binding] TreeTextMemory.delete_hard: failed to delete {mid}: {e}" ) + return bindings_to_delete def semantics_feedback( self, @@ -476,9 +560,22 @@ def semantics_feedback( f"[0107 Feedback Core: semantics_feedback] Operation failed for {original_op}: {e}", exc_info=True, ) - if update_results: - updated_ids = [item["archived_id"] for item in update_results] - self._del_working_binding(updated_ids, user_name) + if update_results and getattr(self.mem_reader, "memory_version_switch", "off") != "on": + archived_ids = [item["archived_id"] for item in update_results] + archived_items = [] + for aid in archived_ids: + try: + node = self.graph_store.get_node(aid, user_name=user_name) + if node: + archived_items.append(TextualMemoryItem(**node)) + except Exception as e: + logger.warning( + "[Memory Feedback] Failed to fetch archived item %s for working_binding cleanup: %s", + aid, + e, + ) + if archived_items: + self._del_working_binding(user_name, archived_items) return {"record": {"add": add_results, "update": update_results}} @@ -1066,7 +1163,14 @@ def check_validity(item): tags=tags, key=key, embedding=embedding, - sources=[{"type": "chat"}], + sources=[ + { + "type": "feedback", + "role": "user", + "chat_time": feedback_time, + "content": feedback_content, + } + ], background=background, type="fine", info=info, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 2edede76d..c1d65380f 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -15,6 +15,8 @@ from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.plugins.hook_defs import H +from memos.plugins.hooks import trigger_single_hook from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -58,6 +60,8 @@ def __init__(self, config: MultiModalStructMemReaderConfig): simple_config = SimpleStructMemReaderConfig(**config_dict) super().__init__(simple_config) + self.memory_version_switch = getattr(config, "memory_version_switch", "off") + # Image parser LLM (requires vision model) # Falls back to general_llm if not configured (general_llm itself falls back to main llm) self.image_parser_llm = ( @@ -458,6 +462,9 @@ def _get_llm_response( if self.config.remove_prompt_example and examples: prompt = prompt.replace(examples, "") + + logger.info(f"[MultiModalParser] Process String Fine Prompt: {prompt}") + messages = [{"role": "user", "content": prompt}] try: response_text = self.llm.generate(messages) @@ -506,6 +513,7 @@ def _get_maybe_merged_memory( sources: list, **kwargs, ) -> dict: + # TODO: delete this function """ Check if extracted memory should be merged with similar existing memories. If merge is needed, return merged memory dict with merged_from field. @@ -520,102 +528,7 @@ def _get_maybe_merged_memory( Returns: Memory dict (possibly merged) with merged_from field if merged """ - # If no graph_db or user_name, return original - if not self.graph_db or "user_name" not in kwargs: - return extracted_memory_dict - user_name = kwargs.get("user_name") - - # Detect language - lang = "en" - if sources: - for source in sources: - if hasattr(source, "lang") and source.lang: - lang = source.lang - break - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - break - if lang is None: - lang = detect_lang(mem_text) - - # Search for similar memories - merge_threshold = kwargs.get("merge_similarity_threshold", 0.3) - - try: - search_results = self.graph_db.search_by_embedding( - vector=self.embedder.embed(mem_text)[0], - top_k=20, - status="activated", - threshold=merge_threshold, - user_name=user_name, - ) - - if not search_results: - return extracted_memory_dict - - # Get full memory details - similar_memory_ids = [r["id"] for r in search_results if r.get("id")] - similar_memories_list = [ - self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name) - for mem_id in similar_memory_ids - ] - - # Filter out None and mode:fast memories - filtered_similar = [] - for mem in similar_memories_list: - if not mem: - continue - mem_metadata = mem.get("metadata", {}) - tags = mem_metadata.get("tags", []) - if isinstance(tags, list) and "mode:fast" in tags: - continue - filtered_similar.append( - { - "id": mem.get("id"), - "memory": mem.get("memory", ""), - } - ) - logger.info( - f"Valid similar memories for {mem_text} is " - f"{len(filtered_similar)}: {filtered_similar}" - ) - - if not filtered_similar: - return extracted_memory_dict - - # Create a temporary TextualMemoryItem for merge check - temp_memory_item = TextualMemoryItem( - memory=mem_text, - metadata=TreeNodeTextualMemoryMetadata( - user_id="", - session_id="", - memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"), - status="activated", - tags=extracted_memory_dict.get("tags", []), - key=extracted_memory_dict.get("key", ""), - ), - ) - - # Try to merge with LLM - merge_result = self._merge_memories_with_llm( - temp_memory_item, filtered_similar, lang=lang - ) - - if merge_result: - # Return merged memory dict - merged_dict = extracted_memory_dict.copy() - merged_content = merge_result.get("value", mem_text) - merged_dict["value"] = merged_content - merged_from_ids = merge_result.get("merged_from", []) - merged_dict["merged_from"] = merged_from_ids - return merged_dict - else: - return extracted_memory_dict - - except Exception as e: - logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}") - # On error, return original - return extracted_memory_dict + return extracted_memory_dict def _merge_memories_with_llm( self, @@ -717,6 +630,35 @@ def _process_one_item( # Determine prompt type based on sources prompt_type = self._determine_prompt_type(sources) + # ========== Stage 0: Memory version async extraction/update pipeline ========== + if getattr(self, "memory_version_switch", "off") == "on": + try: + user_name = kwargs.get("user_name") + should_use_version_pipeline = trigger_single_hook( + H.MEMORY_VERSION_PREPARE_UPDATES, + item=fast_item, + user_name=user_name, + judge_llm=self.general_llm, + ) + if should_use_version_pipeline: + lang = detect_lang(kwargs.get("chat_history") or mem_str) + custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang] + new_items = trigger_single_hook( + H.MEMORY_VERSION_APPLY_UPDATES, + item=fast_item, + user_name=user_name, + version_llm=self.qwen_llm, + merge_llm=self.general_llm, + custom_tags=custom_tags, + custom_tags_prompt_template=custom_tags_prompt_template, + timeout_sec=30, + ) + return new_items + except RuntimeError as ex: + logger.warning(f"[MultiModalFine] Memory version hook unavailable: {ex}") + except Exception as ex: + logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}") + # ========== Stage 1: Normal extraction (without reference) ========== try: resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type) @@ -727,14 +669,15 @@ def _process_one_item( if resp.get("memory list", []): for m in resp.get("memory list", []): try: - # Check and merge with similar memories if needed - m_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=m, - mem_text=m.get("value", ""), - sources=sources, - original_query=mem_str, - **kwargs, - ) + m_maybe_merged = m + if getattr(self, "memory_version_switch", "off") != "on": + m_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=m, + mem_text=m.get("value", ""), + sources=sources, + original_query=mem_str, + **kwargs, + ) # Normalize memory_type (same as simple_struct) memory_type = ( m_maybe_merged.get("memory_type", "LongTermMemory") @@ -752,8 +695,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in m_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in m_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = m_maybe_merged["merged_from"] fine_items.append(node) @@ -762,13 +707,15 @@ def _process_one_item( elif resp.get("value") and resp.get("key"): try: # Check and merge with similar memories if needed - resp_maybe_merged = self._get_maybe_merged_memory( - extracted_memory_dict=resp, - mem_text=resp.get("value", "").strip(), - sources=sources, - original_query=mem_str, - **kwargs, - ) + resp_maybe_merged = resp + if getattr(self, "memory_version_switch", "off") != "on": + resp_maybe_merged = self._get_maybe_merged_memory( + extracted_memory_dict=resp, + mem_text=resp.get("value", "").strip(), + sources=sources, + original_query=mem_str, + **kwargs, + ) node = self._make_memory_item( value=resp_maybe_merged.get("value", "").strip(), info=info_per_item, @@ -779,8 +726,10 @@ def _process_one_item( background=resp.get("summary", ""), **extra_kwargs, ) - # Add merged_from to info if present - if "merged_from" in resp_maybe_merged: + if ( + getattr(self, "memory_version_switch", "off") != "on" + and "merged_from" in resp_maybe_merged + ): node.metadata.info = node.metadata.info or {} node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"] fine_items.append(node) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index f26be360c..3c82350df 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -11,6 +11,7 @@ from memos import log from memos.chunkers import ChunkerFactory +from memos.configs.llm import LLMConfigFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory @@ -183,6 +184,15 @@ def __init__(self, config: SimpleStructMemReaderConfig): if config.general_llm is not None else self.llm ) + self.qwen_llm = None + qwen_llm_config = getattr(config, "qwen_llm", None) + if qwen_llm_config: + try: + if isinstance(qwen_llm_config, dict): + qwen_llm_config = LLMConfigFactory.model_validate(qwen_llm_config) + self.qwen_llm = LLMFactory.from_config(qwen_llm_config) + except Exception as e: + logger.warning(f"[LLM] Qwen initialization failed: {e}") self.embedder = EmbedderFactory.from_config(config.embedder) self.chunker = ChunkerFactory.from_config(config.chunker) self.save_rawfile = self.chunker.config.save_rawfile diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index 671190e6f..56614737d 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -24,6 +24,8 @@ InternetRetrieverFactory, ) from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer +from memos.plugins.component_bootstrap import build_plugin_context +from memos.plugins.manager import plugin_manager if TYPE_CHECKING: @@ -171,6 +173,16 @@ def build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) +def build_nli_client_config() -> dict[str, Any]: + """ + Build NLI client configuration. + + Returns: + NLI client configuration dictionary + """ + return APIConfig.get_nli_config() + + def _get_default_memory_size(cube_config: Any) -> dict[str, int]: """ Get default memory size configuration. @@ -246,6 +258,7 @@ def init_components() -> dict[str, Any]: graph_db_config = build_graph_db_config() llm_config = build_llm_config() embedder_config = build_embedder_config() + nli_client_config = build_nli_client_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() feedback_reranker_config = build_feedback_reranker_config() @@ -257,8 +270,25 @@ def init_components() -> dict[str, Any]: graph_db = GraphStoreFactory.from_config(graph_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) + + plugin_manager.discover() + plugin_context = build_plugin_context( + graph_db=graph_db, + embedder=embedder, + default_cube_config=default_cube_config, + nli_client_config=nli_client_config, + mem_reader_config=mem_reader_config, + reranker_config=reranker_config, + feedback_reranker_config=feedback_reranker_config, + internet_retriever_config=internet_retriever_config, + ) + plugin_manager.init_components(plugin_context) + # Pass graph_db to mem_reader for recall operations (deduplication, conflict detection) - mem_reader = MemReaderFactory.from_config(mem_reader_config, graph_db=graph_db) + mem_reader = MemReaderFactory.from_config( + mem_reader_config, + graph_db=graph_db, + ) reranker = RerankerFactory.from_config(reranker_config) feedback_reranker = RerankerFactory.from_config(feedback_reranker_config) internet_retriever = InternetRetrieverFactory.from_config( diff --git a/src/memos/mem_scheduler/schemas/task_schemas.py b/src/memos/mem_scheduler/schemas/task_schemas.py index af0f2f233..5b52b6c22 100644 --- a/src/memos/mem_scheduler/schemas/task_schemas.py +++ b/src/memos/mem_scheduler/schemas/task_schemas.py @@ -75,11 +75,30 @@ class TaskPriorityLevel(Enum): # task queue -DEFAULT_STREAM_KEY_PREFIX = os.getenv( - "MEMSCHEDULER_STREAM_KEY_PREFIX", "scheduler:messages:stream:v2.0" +DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v2.0" + +logger.info( + "[TASK_SCHEMAS] Module imported. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, default_stream_prefix=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + DEFAULT_STREAM_KEY_PREFIX, ) +def get_stream_key_prefix() -> str: + """Resolve the scheduler stream prefix at runtime. + + This must not be evaluated at import time because some server entrypoints + load `.env` after importing router/handler modules. + """ + resolved = os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX", DEFAULT_STREAM_KEY_PREFIX) + logger.info( + "[TASK_SCHEMAS] Resolved stream prefix at runtime. env_MEMSCHEDULER_STREAM_KEY_PREFIX=%s, resolved_stream_prefix=%s", + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + resolved, + ) + return resolved + + # ============== Running Tasks ============== class RunningTaskItem(BaseModel, DictConversionMixin): """Data class for tracking running tasks in SchedulerDispatcher.""" diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index 20dbb63b2..e0baf63ff 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -211,44 +211,48 @@ def _process_memories_with_reader( ) logger.info("Added %s Rawfile memories.", len(raw_file_mem_group)) - # Mark merged_from memories as archived when provided in memory metadata - summary_memories = [ - memory - for memory in flattened_memories - if memory.metadata.memory_type != "RawFileMemory" - ] - if mem_reader.graph_db: - for memory in summary_memories: - merged_from = (memory.metadata.info or {}).get("merged_from") - if merged_from: - old_ids = ( - merged_from - if isinstance(merged_from, (list | tuple | set)) - else [merged_from] - ) - for old_id in old_ids: - try: - mem_reader.graph_db.update_node( - str(old_id), {"status": "archived"}, user_name=user_name - ) - logger.info( - "[Scheduler] Archived merged_from memory: %s", - old_id, - ) - except Exception as e: - logger.warning( - "[Scheduler] Failed to archive merged_from memory %s: %s", - old_id, - e, - ) - else: - has_merged_from = any( - (m.metadata.info or {}).get("merged_from") for m in summary_memories - ) - if has_merged_from: - logger.warning( - "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + # fallback to simple deduplication logic when mem version switch is off + if getattr(mem_reader, "memory_version_switch", "off") != "on": + # Mark merged_from memories as archived when provided in memory metadata + summary_memories = [ + memory + for memory in flattened_memories + if memory.metadata.memory_type != "RawFileMemory" + ] + if mem_reader.graph_db: + for memory in summary_memories: + merged_from = (memory.metadata.info or {}).get("merged_from") + if merged_from: + old_ids = ( + merged_from + if isinstance(merged_from, (list | tuple | set)) + else [merged_from] + ) + for old_id in old_ids: + try: + mem_reader.graph_db.update_node( + str(old_id), + {"status": "archived"}, + user_name=user_name, + ) + logger.info( + "[Scheduler] Archived merged_from memory: %s", + old_id, + ) + except Exception as e: + logger.warning( + "[Scheduler] Failed to archive merged_from memory %s: %s", + old_id, + e, + ) + else: + has_merged_from = any( + (m.metadata.info or {}).get("merged_from") for m in summary_memories ) + if has_merged_from: + logger.warning( + "[Scheduler] merged_from provided but graph_db is unavailable; skip archiving." + ) cloud_env = is_cloud_env() if cloud_env: @@ -386,10 +390,34 @@ def _process_memories_with_reader( delete_ids = list(dict.fromkeys(delete_ids)) if delete_ids: try: - text_mem.delete(delete_ids, user_name=user_name) - logger.info( - "Delete raw/working mem_ids: %s for user_name: %s", delete_ids, user_name - ) + if getattr(mem_reader, "memory_version_switch", "off") != "on": + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + "Delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) + else: + # change to soft-delete for mem versions + flattened_memories = [] + if processed_memories and len(processed_memories) > 0: + for memory_list in processed_memories: + flattened_memories.extend(memory_list) + allowed_types = ["UserMemory", "LongTermMemory"] + text_mem.soft_delete( + delete_ids, + user_name, + [ + mem.id + for mem in flattened_memories + if mem.metadata.memory_type in allowed_types + ], + ) + logger.info( + "Soft delete raw/working mem_ids: %s for user_name: %s", + delete_ids, + user_name, + ) except Exception as e: logger.warning("Failed to delete some mem_ids %s: %s", delete_ids, e) else: diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 33d007313..beae965aa 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -13,7 +13,7 @@ from memos.log import get_logger from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import DEFAULT_STREAM_KEY_PREFIX +from memos.mem_scheduler.schemas.task_schemas import get_stream_key_prefix from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule @@ -26,7 +26,7 @@ class SchedulerLocalQueue(RedisSchedulerModule): def __init__( self, maxsize: int = 0, - stream_key_prefix: str = DEFAULT_STREAM_KEY_PREFIX, + stream_key_prefix: str | None = None, orchestrator: SchedulerOrchestrator | None = None, status_tracker: TaskStatusTracker | None = None, ): @@ -42,7 +42,7 @@ def __init__( """ super().__init__() - self.stream_key_prefix = stream_key_prefix or "local_queue" + self.stream_key_prefix = stream_key_prefix or get_stream_key_prefix() self.max_internal_message_queue_size = maxsize @@ -56,7 +56,9 @@ def __init__( self._message_handler: Callable[[ScheduleMessageItem], None] | None = None logger.info( - f"SchedulerLocalQueue initialized with max_internal_message_queue_size={self.max_internal_message_queue_size}" + "SchedulerLocalQueue initialized with max_internal_message_queue_size=%s, stream_prefix=%s", + self.max_internal_message_queue_size, + self.stream_key_prefix, ) def get_stream_key(self, user_id: str, mem_cube_id: str, task_label: str) -> str: diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index a23e33c55..561d7931f 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -19,9 +19,9 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import ( DEFAULT_STREAM_INACTIVITY_DELETE_SECONDS, - DEFAULT_STREAM_KEY_PREFIX, DEFAULT_STREAM_KEYS_REFRESH_INTERVAL_SEC, DEFAULT_STREAM_RECENT_ACTIVE_SECONDS, + get_stream_key_prefix, ) from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker @@ -45,10 +45,7 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, - stream_key_prefix: str = os.getenv( - "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", - DEFAULT_STREAM_KEY_PREFIX, - ), + stream_key_prefix: str | None = None, orchestrator: SchedulerOrchestrator | None = None, consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", @@ -68,8 +65,12 @@ def __init__( auto_delete_acked: Whether to automatically delete acknowledged messages from stream """ super().__init__() + resolved_stream_key_prefix = stream_key_prefix or os.getenv( + "MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX", + get_stream_key_prefix(), + ) # Stream configuration - self.stream_key_prefix = stream_key_prefix + self.stream_key_prefix = resolved_stream_key_prefix # Precompile regex for prefix filtering to reduce repeated compilation overhead self.stream_prefix_regex_pattern = re.compile(f"^{re.escape(self.stream_key_prefix)}:") self.consumer_group = consumer_group @@ -104,8 +105,15 @@ def __init__( self._empty_stream_seen_lock = threading.Lock() logger.info( - f"[REDIS_QUEUE] Initialized with stream_prefix='{self.stream_key_prefix}', " - f"consumer_group='{self.consumer_group}', consumer_name='{self.consumer_name}'" + "[REDIS_QUEUE] Initialized with stream_prefix='%s', " + "consumer_group='%s', consumer_name='%s', " + "env_MEMSCHEDULER_STREAM_KEY_PREFIX='%s', " + "env_MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX='%s'", + self.stream_key_prefix, + self.consumer_group, + self.consumer_name, + os.getenv("MEMSCHEDULER_STREAM_KEY_PREFIX"), + os.getenv("MEMSCHEDULER_REDIS_STREAM_KEY_PREFIX"), ) # Auto-initialize Redis connection diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 1f2e81bef..6bcf0023c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -48,6 +48,12 @@ def __init__( self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize) self.disabled_handlers = disabled_handlers + logger.info( + "[SCHEDULE_TASK_QUEUE] Initialized queue wrapper. use_redis_queue=%s, queue_type=%s, stream_prefix=%s", + self.use_redis_queue, + type(self.memos_message_queue).__name__, + getattr(self.memos_message_queue, "stream_key_prefix", None), + ) def set_status_tracker(self, status_tracker: TaskStatusTracker) -> None: """ diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index a9b2c43a4..b0f90b537 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel): memory: str | None = Field( default_factory=lambda: "", description="The content of the archived version of the memory." ) - update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field( + update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field( default="unrelated", - description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).", + description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).", ) archived_memory_id: str | None = Field( default=None, @@ -106,15 +106,15 @@ class TextualMemoryMetadata(BaseModel): default=None, description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.", ) - evolve_to: list[str] | None = Field( + evolve_to: list[str] = Field( default_factory=list, - description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.", + description="Recording which new memory nodes it 'evolves' to after llm extraction.", ) - version: int | None = Field( - default=None, + version: int = Field( + default=1, description="The version of the memory. Will be incremented when the memory is updated.", ) - history: list[ArchivedTextualMemory] | None = Field( + history: list[ArchivedTextualMemory] = Field( default_factory=list, description="Storing the archived versions of the memory. Only preserving core information of each version.", ) diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 8c896f538..426a2359b 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -354,6 +354,14 @@ def get(self, memory_id: str, user_name: str | None = None) -> TextualMemoryItem """Get a memory by its ID.""" result = self.graph_store.get_node(memory_id, user_name=user_name) if result is None: + logger.warning( + "[TreeTextMemory.get] Memory not found. memory_id=%s, lookup_user_name=%s, graph_store=%s, db_name=%s, config_user_name=%s", + memory_id, + user_name, + type(self.graph_store).__name__, + getattr(self.graph_store, "db_name", None), + getattr(getattr(self.graph_store, "config", None), "user_name", None), + ) raise ValueError(f"Memory with ID {memory_id} not found") metadata_dict = result.get("metadata", {}) return TextualMemoryItem( @@ -609,3 +617,33 @@ def add_graph_edges( future.result() except Exception as e: logger.exception("Add edge error: ", exc_info=e) + + def soft_delete( + self, + memory_ids: list[str], + user_name: str, + evolve_to_ids: list[str] | None = None, + ) -> None: + # for ruff check... + if not evolve_to_ids: + update_fields = {"status": "deleted"} + else: + update_fields = {"status": "deleted", "evolve_to": evolve_to_ids} + + # Execute the actual marking operation - in db. + with ContextThreadPoolExecutor() as executor: + futures = [] + for mid in memory_ids: + futures.append( + executor.submit( + self.graph_store.update_node, + id=mid, + fields=update_fields, + user_name=user_name, + ) + ) + + # Wait for all tasks to complete and raise any exceptions + for future in futures: + future.result() + return diff --git a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py b/src/memos/memories/textual/tree_text_memory/organize/history_manager.py deleted file mode 100644 index 98094877c..000000000 --- a/src/memos/memories/textual/tree_text_memory/organize/history_manager.py +++ /dev/null @@ -1,168 +0,0 @@ -import logging - -from typing import Literal - -from memos.context.context import ContextThreadPoolExecutor -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.types import NLIResult -from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ArchivedTextualMemory, TextualMemoryItem - - -logger = logging.getLogger(__name__) - -CONFLICT_MEMORY_TITLE = "[possibly conflicting memories]" -DUPLICATE_MEMORY_TITLE = "[possibly duplicate memories]" - - -def _append_related_content( - new_item: TextualMemoryItem, duplicates: list[str], conflicts: list[str] -) -> None: - """ - Append duplicate and conflict memory contents to the new item's memory text, - truncated to avoid excessive length. - """ - max_per_item_len = 200 - max_section_len = 1000 - - def _format_section(title: str, items: list[str]) -> str: - if not items: - return "" - - section_content = "" - for mem in items: - # Truncate individual item - snippet = mem[:max_per_item_len] + "..." if len(mem) > max_per_item_len else mem - # Check total section length - if len(section_content) + len(snippet) + 5 > max_section_len: - section_content += "\n- ... (more items truncated)" - break - section_content += f"\n- {snippet}" - - return f"\n\n{title}:{section_content}" - - append_text = "" - append_text += _format_section(CONFLICT_MEMORY_TITLE, conflicts) - append_text += _format_section(DUPLICATE_MEMORY_TITLE, duplicates) - - if append_text: - new_item.memory += append_text - - -def _detach_related_content(new_item: TextualMemoryItem) -> None: - """ - Detach duplicate and conflict memory contents from the new item's memory text. - """ - markers = [f"\n\n{CONFLICT_MEMORY_TITLE}:", f"\n\n{DUPLICATE_MEMORY_TITLE}:"] - - cut_index = -1 - for marker in markers: - idx = new_item.memory.find(marker) - if idx != -1 and (cut_index == -1 or idx < cut_index): - cut_index = idx - - if cut_index != -1: - new_item.memory = new_item.memory[:cut_index] - - return - - -class MemoryHistoryManager: - def __init__(self, nli_client: NLIClient, graph_db: BaseGraphDB) -> None: - """ - Initialize the MemoryHistoryManager. - - Args: - nli_client: NLIClient for conflict/duplicate detection. - graph_db: GraphDB instance for marking operations during history management. - """ - self.nli_client = nli_client - self.graph_db = graph_db - - def resolve_history_via_nli( - self, new_item: TextualMemoryItem, related_items: list[TextualMemoryItem] - ) -> list[TextualMemoryItem]: - """ - Detect relationships (Duplicate/Conflict) between the new item and related items using NLI, - and attach them as history to the new fast item. - - Args: - new_item: The new memory item being added. - related_items: Existing memory items that might be related. - - Returns: - List of duplicate or conflicting memory items judged by the NLI service. - """ - if not related_items: - return [] - - # 1. Call NLI - nli_results = self.nli_client.compare_one_to_many( - new_item.memory, [r.memory for r in related_items] - ) - - # 2. Process results and attach to history - duplicate_memories = [] - conflict_memories = [] - - for r_item, nli_res in zip(related_items, nli_results, strict=False): - if nli_res == NLIResult.DUPLICATE: - update_type = "duplicate" - duplicate_memories.append(r_item.memory) - elif nli_res == NLIResult.CONTRADICTION: - update_type = "conflict" - conflict_memories.append(r_item.memory) - else: - update_type = "unrelated" - - # Safely get created_at, fallback to updated_at - created_at = getattr(r_item.metadata, "created_at", None) or r_item.metadata.updated_at - - archived = ArchivedTextualMemory( - version=r_item.metadata.version or 1, - is_fast=r_item.metadata.is_fast or False, - memory=r_item.memory, - update_type=update_type, - archived_memory_id=r_item.id, - created_at=created_at, - ) - new_item.metadata.history.append(archived) - logger.info( - f"[Chunker: MemoryHistoryManager] Archived related memory {r_item.id} as {update_type} for new item {new_item.id}" - ) - - # 3. Concat duplicate/conflict memories to new_item.memory - # We will mark those old memories as invisible during fine processing, this op helps to avoid information loss. - _append_related_content(new_item, duplicate_memories, conflict_memories) - - return duplicate_memories + conflict_memories - - def mark_memory_status( - self, - memory_items: list[TextualMemoryItem], - status: Literal["activated", "resolving", "archived", "deleted"], - user_name: str | None = None, - ) -> None: - """ - Support status marking operations during history management. Common usages are: - 1. Mark conflict/duplicate old memories' status as "resolving", - to make them invisible to /search api, but still visible for PreUpdateRetriever. - 2. Mark resolved memories' status as "activated", to restore their visibility. - """ - # Execute the actual marking operation - in db. - with ContextThreadPoolExecutor() as executor: - futures = [] - for mem in memory_items: - futures.append( - executor.submit( - self.graph_db.update_node, - id=mem.id, - fields={"status": status}, - user_name=user_name, - ) - ) - - # Wait for all tasks to complete and raise any exceptions - for future in futures: - future.result() - return diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index ecd58f309..1fe6a4811 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -191,10 +191,12 @@ def _add_memories_batch( ) metadata_dict = memory.metadata.model_dump(exclude_none=True) metadata_dict["updated_at"] = datetime.now().isoformat() + metadata_dict["working_binding"] = working_id # Add working_binding for fast mode tags = metadata_dict.get("tags") or [] if "mode:fast" in tags: + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_id}] direct built from raw inputs" metadata_dict["background"] = ( @@ -234,6 +236,8 @@ def _submit_batches(nodes: list[dict], node_kind: str) -> None: exc_info=e, ) + # TODO: working id is same with item.id, need to fix, currently stop adding WorkingMemories here. + # here used to be: _submit_batches(working_nodes, "WorkingMemory") _submit_batches(graph_nodes, "graph memory") if graph_node_ids and self.is_reorganize: @@ -319,6 +323,7 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non ids: list[str] = [] futures = [] + # TODO: working id is same with item.id, need to fix working_id = memory.id if hasattr(memory, "id") else memory.id or str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: @@ -395,8 +400,11 @@ def _add_to_graph_memory( node_id = memory.id if hasattr(memory, "id") else str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding: + metadata_dict["working_binding"] = working_binding tags = metadata_dict.get("tags") or [] if working_binding and ("mode:fast" in tags): + metadata_dict["is_fast"] = True # Temporal fix prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py b/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py deleted file mode 100644 index cb77d2243..000000000 --- a/src/memos/memories/textual/tree_text_memory/retrieve/pre_update.py +++ /dev/null @@ -1,264 +0,0 @@ -import concurrent.futures -import re - -from typing import Any - -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_reader.read_multi_modal.utils import detect_lang -from memos.memories.textual.item import TextualMemoryItem -from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer - - -logger = get_logger(__name__) - - -class PreUpdateRetriever: - def __init__(self, graph_db, embedder): - """ - The PreUpdateRetriever is designed for the /add phase . - It serves to recall potentially duplicate/conflict memories against the new content that's being added. - - Args: - graph_db: The graph database instance (Neo4j, PolarDB, etc.) - embedder: The embedder instance for vector search - """ - self.graph_db = graph_db - self.embedder = embedder - # Use existing tokenizer for keyword extraction - self.tokenizer = FastTokenizer(use_jieba=True, use_stopwords=True) - - def _adjust_perspective(self, text: str, role: str, lang: str) -> str: - """ - For better search result, we adjust the perspective - from 1st person to 3rd person based on role and language. - "I" -> "User" (if role is user) - "I" -> "Assistant" (if role is assistant) - """ - if not role: - return text - - role = role.lower() - replacements = [] - - # Determine replacements based on language and role - if lang == "zh": - if role == "user": - replacements = [("我", "用户")] - elif role == "assistant": - replacements = [("我", "助手")] - else: # default to en - if role == "user": - replacements = [ - (r"\bI\b", "User"), - (r"\bme\b", "User"), - (r"\bmy\b", "User's"), - (r"\bmine\b", "User's"), - (r"\bmyself\b", "User himself"), - ] - elif role == "assistant": - replacements = [ - (r"\bI\b", "Assistant"), - (r"\bme\b", "Assistant"), - (r"\bmy\b", "Assistant's"), - (r"\bmine\b", "Assistant's"), - (r"\bmyself\b", "Assistant himself"), - ] - - adjusted_text = text - for pattern, repl in replacements: - if lang == "zh": - adjusted_text = adjusted_text.replace(pattern, repl) - else: - adjusted_text = re.sub(pattern, repl, adjusted_text, flags=re.IGNORECASE) - - return adjusted_text - - def _preprocess_query(self, item: TextualMemoryItem) -> str: - """ - Preprocess the query item: - 1. Extract language and role from metadata/sources - 2. Adjust perspective (I -> User/Assistant) based on role/lang - """ - raw_text = item.memory or "" - if not raw_text.strip(): - return "" - - # Extract lang/role - lang = None - role = None - sources = item.metadata.sources - - if sources: - source_list = sources if isinstance(sources, list) else [sources] - for source in source_list: - if hasattr(source, "lang") and source.lang: - lang = source.lang - elif isinstance(source, dict) and source.get("lang"): - lang = source.get("lang") - - if hasattr(source, "role") and source.role: - role = source.role - elif isinstance(source, dict) and source.get("role"): - role = source.get("role") - - if lang and role: - break - - if lang is None: - lang = detect_lang(raw_text) - - # Adjust perspective - return self._adjust_perspective(raw_text, role, lang) - - def _get_full_memories( - self, candidate_ids: list[str], user_name: str - ) -> list[TextualMemoryItem]: - """ - Retrieve full memories for given candidate ids. - """ - full_recalled_memories = self.graph_db.get_nodes(candidate_ids, user_name=user_name) - return [TextualMemoryItem.from_dict(item) for item in full_recalled_memories] - - def vector_search( - self, - query_text: str, - query_embedding: list[float] | None, - user_name: str, - top_k: int, - search_filter: dict[str, Any] | None = None, - threshold: float = 0.5, - ) -> list[dict]: - try: - # Use pre-computed embedding if available (matches raw/clean query) - # Otherwise embed the switched query for better semantic match - q_embed = query_embedding if query_embedding else self.embedder.embed([query_text])[0] - - # Assuming graph_db.search_by_embedding returns list of dicts or items - results = self.graph_db.search_by_embedding( - vector=q_embed, - top_k=top_k, - status=None, - threshold=threshold, - user_name=user_name, - filter=search_filter, - ) - return results - except Exception as e: - logger.error(f"[PreUpdateRetriever] Vector search failed: {e}") - return [] - - def keyword_search( - self, - query_text: str, - user_name: str, - top_k: int, - search_filter: dict[str, Any] | None = None, - ) -> list[dict]: - try: - # 1. Tokenize using existing tokenizer - keywords = self.tokenizer.tokenize_mixed(query_text) - if not keywords: - return [] - - results = [] - - # 2. Try search_by_keywords_tfidf (PolarDB specific) - if hasattr(self.graph_db, "search_by_keywords_tfidf"): - try: - results = self.graph_db.search_by_keywords_tfidf( - query_words=keywords, user_name=user_name, filter=search_filter - ) - except Exception as e: - logger.warning(f"[PreUpdateRetriever] search_by_keywords_tfidf failed: {e}") - - # 3. Fallback to search_by_fulltext - if not results and hasattr(self.graph_db, "search_by_fulltext"): - try: - results = self.graph_db.search_by_fulltext( - query_words=keywords, top_k=top_k, user_name=user_name, filter=search_filter - ) - except Exception as e: - logger.warning(f"[PreUpdateRetriever] search_by_fulltext failed: {e}") - - return results[:top_k] - - except Exception as e: - logger.error(f"[PreUpdateRetriever] Keyword search failed: {e}") - return [] - - def retrieve( - self, item: TextualMemoryItem, user_name: str, top_k: int = 10, sim_threshold: float = 0.5 - ) -> list[TextualMemoryItem]: - """ - Recall related memories for a TextualMemoryItem using hybrid search (Vector + Keyword). - Might actually return top_k ~ 2top_k items. - Designed for low latency. - - Args: - item: The memory item to find related memories for - user_name: User identifier for scoping search - top_k: Max number of results to return - sim_threshold: minimal similarity threshold for vector search - - Returns: - List of TextualMemoryItem - """ - # 1. Preprocess - switched_query = self._preprocess_query(item) - - # 2. Recall - futures = [] - common_filter = { - "status": {"in": ["activated", "resolving"]}, - "memory_type": {"in": ["LongTermMemory", "UserMemory", "WorkingMemory"]}, - } - - with ContextThreadPoolExecutor(max_workers=3, thread_name_prefix="fast_recall") as executor: - # Task A: Vector Search (Semantic) - query_embedding = ( - item.metadata.embedding if hasattr(item.metadata, "embedding") else None - ) - futures.append( - executor.submit( - self.vector_search, - switched_query, - query_embedding, - user_name, - top_k, - common_filter, - sim_threshold, - ) - ) - - # Task B: Keyword Search - futures.append( - executor.submit( - self.keyword_search, switched_query, user_name, top_k, common_filter - ) - ) - - # 3. Collect Results - retrieved_ids = set() # for deduplicating ids - for future in concurrent.futures.as_completed(futures): - try: - res = future.result() - if not res: - continue - - for r in res: - retrieved_ids.add(r["id"]) - - except Exception as e: - logger.error(f"[PreUpdateRetriever] Search future task failed: {e}") - - retrieved_ids = list(retrieved_ids) - - if not retrieved_ids: - return [] - - # 4. Retrieve full memories to from just ids - # TODO: We should modify the db functions to support returning arbitrary fields, instead of search twice. - final_memories = self._get_full_memories(retrieved_ids, user_name) - - return final_memories diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 1a8b7092a..22d3a253c 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -762,7 +762,14 @@ def _process_text_mem( ) # Mark merged_from memories as archived when provided in add_req.info - if sync_mode == "sync" and extract_mode == "fine": + if ( + sync_mode == "sync" + and extract_mode == "fine" + and ( + not hasattr(self.mem_reader, "memory_version_switch") + or self.mem_reader.memory_version_switch != "on" + ) + ): for memory in flattened_local: merged_from = (memory.metadata.info or {}).get("merged_from") if merged_from: diff --git a/src/memos/plugins/__init__.py b/src/memos/plugins/__init__.py new file mode 100644 index 000000000..0a0f8cde3 --- /dev/null +++ b/src/memos/plugins/__init__.py @@ -0,0 +1,20 @@ +from memos.plugins.base import MemOSPlugin +from memos.plugins.hook_defs import H, HookSpec, all_hook_specs, define_hook, get_hook_spec +from memos.plugins.hooks import hookable, register_hook, register_hooks, trigger_hook +from memos.plugins.manager import PluginManager, plugin_manager + + +__all__ = [ + "H", + "HookSpec", + "MemOSPlugin", + "PluginManager", + "all_hook_specs", + "define_hook", + "get_hook_spec", + "hookable", + "plugin_manager", + "register_hook", + "register_hooks", + "trigger_hook", +] diff --git a/src/memos/plugins/base.py b/src/memos/plugins/base.py new file mode 100644 index 000000000..e486cfc72 --- /dev/null +++ b/src/memos/plugins/base.py @@ -0,0 +1,75 @@ +"""MemOS plugin base class — all plugins must inherit from this class.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastapi import FastAPI + from starlette.middleware.base import BaseHTTPMiddleware + + +class MemOSPlugin: + """MemOS plugin base class. + + Provides three unified registration methods. Plugin developers need only + inherit from this class and register capabilities via self.register_* + in init_app. + """ + + name: str = "unnamed" + version: str = "0.0.0" + description: str = "" + + _app: FastAPI | None = None + + # ------------------------------------------------------------------ # + # Registration methods — called by plugins in init_app + # ------------------------------------------------------------------ # + + def register_router(self, router, **kwargs) -> None: + """Register a router.""" + self._app.include_router(router, **kwargs) + + def register_middleware(self, middleware_cls: type[BaseHTTPMiddleware], **kwargs) -> None: + """Register middleware.""" + self._app.add_middleware(middleware_cls, **kwargs) + + def register_hook(self, name: str, callback: Callable) -> None: + """Register a single Hook callback.""" + from memos.plugins.hooks import register_hook + + register_hook(name, callback) + + def register_hooks(self, names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple Hook points.""" + from memos.plugins.hooks import register_hooks + + register_hooks(names, callback) + + # ------------------------------------------------------------------ # + # Internal methods — called by PluginManager, plugin developers need not care + # ------------------------------------------------------------------ # + + def _bind_app(self, app: FastAPI) -> None: + """Bind FastAPI instance so that register_* methods are available.""" + self._app = app + + # ------------------------------------------------------------------ # + # Lifecycle methods — override in subclasses + # ------------------------------------------------------------------ # + + def on_load(self) -> None: + """Called after the plugin is discovered. Used for initialization logic, e.g. checking dependencies, reading config.""" + + def init_app(self) -> None: + """Called after FastAPI app is bound. Register routes, middleware, and Hooks via self.register_* here.""" + + def init_components(self, context: dict) -> None: + """Called during server bootstrap to contribute runtime components.""" + + def on_shutdown(self) -> None: + """Called when the service shuts down. Used for resource cleanup.""" diff --git a/src/memos/plugins/component_bootstrap.py b/src/memos/plugins/component_bootstrap.py new file mode 100644 index 000000000..151e37215 --- /dev/null +++ b/src/memos/plugins/component_bootstrap.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Any + + +def build_plugin_context( + *, + graph_db: Any, + embedder: Any, + default_cube_config: Any, + nli_client_config: dict[str, Any], + mem_reader_config: Any, + reranker_config: Any, + feedback_reranker_config: Any, + internet_retriever_config: Any, +) -> dict[str, Any]: + return { + "shared": { + "graph_db": graph_db, + "embedder": embedder, + }, + "configs": { + "default_cube_config": default_cube_config, + "nli_client_config": nli_client_config, + "mem_reader_config": mem_reader_config, + "reranker_config": reranker_config, + "feedback_reranker_config": feedback_reranker_config, + "internet_retriever_config": internet_retriever_config, + }, + } diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py new file mode 100644 index 000000000..030d5292f --- /dev/null +++ b/src/memos/plugins/hook_defs.py @@ -0,0 +1,132 @@ +"""Hook declaration registry — single source of truth for CE repo Hook points. + +The @hookable decorator automatically declares its before/after Hooks; no need to manually define_hook. +Hooks triggered by custom trigger_hook must be explicitly declared in this file. + +Plugin-owned Hooks should be declared within each plugin package, not in this file. +""" + +from __future__ import annotations + +import logging + +from dataclasses import dataclass + + +logger = logging.getLogger(__name__) + +_specs: dict[str, HookSpec] = {} + + +@dataclass(frozen=True) +class HookSpec: + """Hook spec definition.""" + + name: str + description: str + params: list[str] + pipe_key: str | None = None + + +def define_hook( + name: str, + *, + description: str, + params: list[str], + pipe_key: str | None = None, +) -> None: + """Declare a Hook point. Skips if already exists (idempotent).""" + if name in _specs: + return + _specs[name] = HookSpec( + name=name, + description=description, + params=params, + pipe_key=pipe_key, + ) + logger.debug("Hook defined: %s (pipe_key=%s)", name, pipe_key) + + +def get_hook_spec(name: str) -> HookSpec | None: + return _specs.get(name) + + +def all_hook_specs() -> dict[str, HookSpec]: + """Return all declared Hooks (including @hookable auto-declared + plugin-declared).""" + return dict(_specs) + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE Hook name constants +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + +class H: + """CE Hook name constants. Plugin-owned Hook constants should be defined within the plugin package.""" + + # @hookable("add") — AddHandler.handle_add_memories + ADD_BEFORE = "add.before" + ADD_AFTER = "add.after" + + # @hookable("search") — SearchHandler.handle_search_memories + SEARCH_BEFORE = "search.before" + SEARCH_AFTER = "search.after" + + # Custom Hook (manually triggered via trigger_hook) + ADD_MEMORIES_POST_PROCESS = "add.memories.post_process" + + # mem_reader — generic extension point before LLM extraction + MEM_READER_PRE_EXTRACT = "mem_reader.pre_extract" + + # memory version — single-provider business hooks + MEMORY_VERSION_PREPARE_UPDATES = "memory_version.prepare_updates" + MEMORY_VERSION_APPLY_UPDATES = "memory_version.apply_updates" + MEMORY_VERSION_APPLY_FEEDBACK_UPDATE = "memory_version.apply_feedback_update" + + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# CE custom Hook declarations (@hookable-generated ones need not be declared here) +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +define_hook( + H.ADD_MEMORIES_POST_PROCESS, + description="Post-process result after add_memories returns, before constructing Response", + params=["request", "result"], + pipe_key="result", +) + +define_hook( + H.MEM_READER_PRE_EXTRACT, + description="Customize prompt before mem_reader LLM extraction", + params=["prompt", "prompt_type", "mem_str", "lang", "sources"], + pipe_key="prompt", +) + +define_hook( + H.MEMORY_VERSION_PREPARE_UPDATES, + description=( + "Prepare memory-version candidates and decide whether extraction should continue " + "through the version pipeline" + ), + params=["item", "user_name", "judge_llm"], +) + +define_hook( + H.MEMORY_VERSION_APPLY_UPDATES, + description="Apply memory-version updates during mem_reader extraction", + params=[ + "item", + "user_name", + "version_llm", + "merge_llm", + "custom_tags", + "custom_tags_prompt_template", + "timeout_sec", + ], +) + +define_hook( + H.MEMORY_VERSION_APPLY_FEEDBACK_UPDATE, + description="Apply memory-version update semantics during feedback update", + params=["old_item", "new_item", "user_name"], +) diff --git a/src/memos/plugins/hooks.py b/src/memos/plugins/hooks.py new file mode 100644 index 000000000..d5b64ce82 --- /dev/null +++ b/src/memos/plugins/hooks.py @@ -0,0 +1,150 @@ +"""Hook runtime — registration, triggering, and @hookable decorator.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections import defaultdict +from functools import wraps +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +_hooks: dict[str, list[Callable]] = defaultdict(list) + + +def register_hook(name: str, callback: Callable) -> None: + """Register a hook callback. Undeclared hook names will log a warning.""" + from memos.plugins.hook_defs import get_hook_spec + + if get_hook_spec(name) is None: + logger.warning( + "Registering callback for undeclared hook: %s (callback=%s)", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + _hooks[name].append(callback) + logger.debug( + "Hook registered: %s -> %s", + name, + getattr(callback, "__qualname__", repr(callback)), + ) + + +def register_hooks(names: list[str], callback: Callable) -> None: + """Batch-register the same callback to multiple hook points.""" + for name in names: + register_hook(name, callback) + + +def trigger_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook, invoking all registered callbacks in order. + + - Zero overhead when no callbacks are registered + - Undeclared hook names will log a warning and be skipped + - pipe_key is auto-fetched from HookSpec, supports piped return value passing + """ + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + logger.warning("Undeclared hook triggered: %s — ignored", name) + return None + + pipe_key = spec.pipe_key + + for cb in _hooks.get(name, []): + try: + rv = cb(**kwargs) + if pipe_key is not None and rv is not None: + kwargs[pipe_key] = rv + except Exception: + logger.exception( + "Hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + + return kwargs.get(pipe_key) if pipe_key else None + + +def trigger_single_hook(name: str, **kwargs: Any) -> Any: + """Trigger a hook that must be implemented by exactly one callback.""" + from memos.plugins.hook_defs import get_hook_spec + + spec = get_hook_spec(name) + if spec is None: + raise RuntimeError(f"Undeclared hook triggered: {name}") + + callbacks = _hooks.get(name, []) + if not callbacks: + raise RuntimeError(f"No plugin registered required hook: {name}") + if len(callbacks) > 1: + raise RuntimeError(f"Multiple plugins registered single-provider hook: {name}") + + cb = callbacks[0] + try: + return cb(**kwargs) + except Exception: + logger.exception( + "Single hook %s callback %s failed", + name, + getattr(cb, "__qualname__", repr(cb)), + ) + raise + + +def hookable(name: str): + """Decorator: automatically triggers name.before / name.after hook before and after the method. + + Auto-declares before/after Hooks (idempotent); no need to manually define_hook in hook_defs.py. + Supports piped return values: before can modify request, after can modify result. + Compatible with both sync and async methods. + """ + from memos.plugins.hook_defs import define_hook + + define_hook( + f"{name}.before", + description=f"Before {name} executes; can modify request", + params=["request"], + pipe_key="request", + ) + define_hook( + f"{name}.after", + description=f"After {name} executes; can modify result", + params=["request", "result"], + pipe_key="result", + ) + + def decorator(func): + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = await func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return async_wrapper + + @wraps(func) + def sync_wrapper(self, request, *args, **kwargs): + rv = trigger_hook(f"{name}.before", request=request) + request = rv if rv is not None else request + result = func(self, request, *args, **kwargs) + rv = trigger_hook(f"{name}.after", request=request, result=result) + result = rv if rv is not None else result + return result + + return sync_wrapper + + return decorator diff --git a/src/memos/plugins/manager.py b/src/memos/plugins/manager.py new file mode 100644 index 000000000..5f0397dc1 --- /dev/null +++ b/src/memos/plugins/manager.py @@ -0,0 +1,90 @@ +"""Plugin manager — discover, load, and manage MemOS plugins.""" + +from __future__ import annotations + +import importlib.metadata +import logging + +from typing import TYPE_CHECKING + +from memos.plugins.base import MemOSPlugin + + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "memos.plugins" + + +class PluginManager: + """Discover, load, and manage MemOS plugins.""" + + def __init__(self): + self._plugins: dict[str, MemOSPlugin] = {} + self._discovered = False + + @property + def plugins(self) -> dict[str, MemOSPlugin]: + return dict(self._plugins) + + def discover(self) -> None: + """Discover and load all installed plugins via entry_points.""" + if self._discovered: + return + + try: + eps = importlib.metadata.entry_points() + if hasattr(eps, "select"): + plugin_eps = eps.select(group=ENTRY_POINT_GROUP) + else: + plugin_eps = eps.get(ENTRY_POINT_GROUP, []) + except Exception: + logger.exception("Failed to query entry_points") + return + + for ep in plugin_eps: + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not isinstance(plugin, MemOSPlugin): + logger.warning("Plugin %s does not extend MemOSPlugin, skipped", ep.name) + continue + plugin.on_load() + self._plugins[plugin.name] = plugin + logger.info("Plugin discovered: %s v%s", plugin.name, plugin.version) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + + self._discovered = True + + def init_components(self, context: dict) -> None: + """Initialize runtime components contributed by loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin.init_components(context) + logger.info("Plugin components initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin components: %s", plugin.name) + + def init_app(self, app: FastAPI) -> None: + """Bind app and initialize all loaded plugins.""" + for plugin in self._plugins.values(): + try: + plugin._bind_app(app) + plugin.init_app() + logger.info("Plugin initialized: %s", plugin.name) + except Exception: + logger.exception("Failed to init plugin: %s", plugin.name) + + def shutdown(self) -> None: + """Shut down all plugins and release resources.""" + for plugin in self._plugins.values(): + try: + plugin.on_shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) + + +plugin_manager = PluginManager() diff --git a/tests/extras/nli_model/test_client_integration.py b/tests/extras/nli_model/test_client_integration.py deleted file mode 100644 index 5beff14a0..000000000 --- a/tests/extras/nli_model/test_client_integration.py +++ /dev/null @@ -1,129 +0,0 @@ -import threading -import time -import unittest - -from unittest.mock import MagicMock, patch - -import requests -import uvicorn - -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.server.serve import app -from memos.extras.nli_model.types import NLIResult - - -# We need to mock the NLIHandler to avoid loading the heavy model -# but we want to run the real FastAPI server. -class TestNLIClientIntegration(unittest.TestCase): - server_thread = None - stop_server = False - port = 32533 # Use a different port for testing - - @classmethod - def setUpClass(cls): - # Patch the lifespan to inject a mock handler instead of real NLIHandler - cls.mock_handler = MagicMock() - cls.mock_handler.compare_one_to_many.return_value = [ - NLIResult.DUPLICATE, - NLIResult.CONTRADICTION, - ] - - # We need to patch the module where lifespan is defined/used or modify the global variable - # Since 'app' is already imported, we can patch the global nli_handler in serve.py - # But lifespan sets it on startup. - - # Let's patch NLIHandler class in serve.py so when lifespan instantiates it, it gets our mock - cls.handler_patcher = patch("memos.extras.nli_model.server.serve.NLIHandler") - cls.MockHandlerClass = cls.handler_patcher.start() - cls.MockHandlerClass.return_value = cls.mock_handler - - # Start server in a thread - def run_server(): - # Disable logs for uvicorn to keep test output clean - config = uvicorn.Config(app, host="127.0.0.1", port=cls.port, log_level="error") - cls.server = uvicorn.Server(config) - cls.server.run() - - cls.server_thread = threading.Thread(target=run_server, daemon=True) - cls.server_thread.start() - - # Wait for server to be ready - cls._wait_for_server() - - @classmethod - def tearDownClass(cls): - # Stop the server - if hasattr(cls, "server"): - cls.server.should_exit = True - if cls.server_thread: - cls.server_thread.join(timeout=5) - - cls.handler_patcher.stop() - - @classmethod - def _wait_for_server(cls): - url = f"http://127.0.0.1:{cls.port}/docs" - retries = 20 - for _ in range(retries): - try: - response = requests.get(url) - if response.status_code == 200: - return - except requests.ConnectionError: - pass - time.sleep(0.1) - raise RuntimeError("Server failed to start") - - def setUp(self): - self.client = NLIClient(base_url=f"http://127.0.0.1:{self.port}") - # Reset mock calls before each test - self.mock_handler.reset_mock() - # Ensure default behavior - self.mock_handler.compare_one_to_many.return_value = [ - NLIResult.DUPLICATE, - NLIResult.CONTRADICTION, - ] - - def test_real_server_compare_one_to_many(self): - source = "I like apples." - targets = ["I love fruit.", "I hate apples."] - - results = self.client.compare_one_to_many(source, targets) - - # Verify result - self.assertEqual(len(results), 2) - self.assertEqual(results[0], NLIResult.DUPLICATE) - self.assertEqual(results[1], NLIResult.CONTRADICTION) - - # Verify server received the request - self.mock_handler.compare_one_to_many.assert_called_once() - args, _ = self.mock_handler.compare_one_to_many.call_args - self.assertEqual(args[0], source) - self.assertEqual(args[1], targets) - - def test_real_server_empty_targets(self): - source = "I like apples." - targets = [] - - results = self.client.compare_one_to_many(source, targets) - - self.assertEqual(results, []) - # Should not call handler because client handles empty list - self.mock_handler.compare_one_to_many.assert_not_called() - - def test_real_server_handler_error(self): - # Simulate handler error - self.mock_handler.compare_one_to_many.side_effect = ValueError("Something went wrong") - - source = "I like apples." - targets = ["I love fruit."] - - # Client should catch 500 and return UNRELATED - results = self.client.compare_one_to_many(source, targets) - - self.assertEqual(len(results), 1) - self.assertEqual(results[0], NLIResult.UNRELATED) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/mem_reader/test_project_id_propagation.py b/tests/mem_reader/test_project_id_propagation.py index 5a17910ca..65e276be6 100644 --- a/tests/mem_reader/test_project_id_propagation.py +++ b/tests/mem_reader/test_project_id_propagation.py @@ -216,6 +216,8 @@ def setUp(self): self.reader.graph_db = MagicMock() self.reader.oss_config = None self.reader.skills_dir_config = None + self.reader.memory_version_switch = "off" + self.reader.qwen_llm = MagicMock() # -- _build_window_from_items -------------------------------------------- def test_build_window_propagates_project_id(self): diff --git a/tests/memories/textual/test_history_manager.py b/tests/memories/textual/test_history_manager.py deleted file mode 100644 index a6ac186b7..000000000 --- a/tests/memories/textual/test_history_manager.py +++ /dev/null @@ -1,137 +0,0 @@ -import uuid - -from unittest.mock import MagicMock - -import pytest - -from memos.extras.nli_model.client import NLIClient -from memos.extras.nli_model.types import NLIResult -from memos.graph_dbs.base import BaseGraphDB -from memos.memories.textual.item import ( - TextualMemoryItem, - TextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.organize.history_manager import ( - MemoryHistoryManager, - _append_related_content, - _detach_related_content, -) - - -@pytest.fixture -def mock_nli_client(): - client = MagicMock(spec=NLIClient) - return client - - -@pytest.fixture -def mock_graph_db(): - return MagicMock(spec=BaseGraphDB) - - -@pytest.fixture -def history_manager(mock_nli_client, mock_graph_db): - return MemoryHistoryManager(nli_client=mock_nli_client, graph_db=mock_graph_db) - - -def test_detach_related_content(): - original_memory = "This is the original memory content." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate 1", "Duplicate 2"] - conflicts = ["Conflict 1", "Conflict 2"] - - # 1. Append content - _append_related_content(item, duplicates, conflicts) - - # Verify content was appended - assert item.memory != original_memory - assert "[possibly conflicting memories]" in item.memory - assert "[possibly duplicate memories]" in item.memory - assert "Duplicate 1" in item.memory - assert "Conflict 1" in item.memory - - # 2. Detach content - _detach_related_content(item) - - # 3. Verify content is restored - assert item.memory == original_memory - - -def test_detach_only_conflicts(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = [] - conflicts = ["Conflict A"] - - _append_related_content(item, duplicates, conflicts) - assert "Conflict A" in item.memory - assert "Duplicate" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_detach_only_duplicates(): - original_memory = "Original memory." - item = TextualMemoryItem(memory=original_memory, metadata=TextualMemoryMetadata()) - - duplicates = ["Duplicate A"] - conflicts = [] - - _append_related_content(item, duplicates, conflicts) - assert "Duplicate A" in item.memory - assert "Conflict" not in item.memory - - _detach_related_content(item) - assert item.memory == original_memory - - -def test_truncation(history_manager, mock_nli_client): - # Setup - new_item = TextualMemoryItem(memory="Test") - long_memory = "A" * 300 - related_item = TextualMemoryItem(memory=long_memory) - - mock_nli_client.compare_one_to_many.return_value = [NLIResult.DUPLICATE] - - # Action - history_manager.resolve_history_via_nli(new_item, [related_item]) - - # Assert - assert "possibly duplicate memories" in new_item.memory - assert "..." in new_item.memory # Should be truncated - assert len(new_item.memory) < 1000 # Ensure reasonable length - - -def test_empty_related_items(history_manager, mock_nli_client): - new_item = TextualMemoryItem(memory="Test") - history_manager.resolve_history_via_nli(new_item, []) - - mock_nli_client.compare_one_to_many.assert_not_called() - assert new_item.metadata.history is None or len(new_item.metadata.history) == 0 - - -def test_mark_memory_status(history_manager, mock_graph_db): - # Setup - id1 = uuid.uuid4().hex - id2 = uuid.uuid4().hex - id3 = uuid.uuid4().hex - items = [ - TextualMemoryItem(memory="M1", id=id1), - TextualMemoryItem(memory="M2", id=id2), - TextualMemoryItem(memory="M3", id=id3), - ] - status = "resolving" - - # Action - history_manager.mark_memory_status(items, status) - - # Assert - assert mock_graph_db.update_node.call_count == 3 - - # Verify we called it correctly (user_name=None is passed by mark_memory_status) - mock_graph_db.update_node.assert_any_call(id=id1, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id2, fields={"status": status}, user_name=None) - mock_graph_db.update_node.assert_any_call(id=id3, fields={"status": status}, user_name=None) diff --git a/tests/memories/textual/test_pre_update_retriever.py b/tests/memories/textual/test_pre_update_retriever.py deleted file mode 100644 index 6bed90abb..000000000 --- a/tests/memories/textual/test_pre_update_retriever.py +++ /dev/null @@ -1,150 +0,0 @@ -import unittest -import uuid - -from dotenv import load_dotenv - -from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.item import ( - SourceMessage, - TextualMemoryItem, - TreeNodeTextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever - - -# Load environment variables -load_dotenv() - - -class TestPreUpdateRecaller(unittest.TestCase): - @classmethod - def setUpClass(cls): - # Initialize graph_db and embedder using factories - # We assume environment variables are set for these to work - try: - cls.graph_db_config = build_graph_db_config() - cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) - - cls.embedder_config = build_embedder_config() - cls.embedder = EmbedderFactory.from_config(cls.embedder_config) - except Exception as e: - raise unittest.SkipTest( - f"Skipping test because initialization failed (likely missing env vars): {e}" - ) from e - - cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) - - # Use a unique user name to isolate tests - cls.user_name = "test_pre_update_recaller_user_" + str(uuid.uuid4())[:8] - - def setUp(self): - # Add some data to the db - self.added_ids = [] - - # Create a memory item to add - self.memory_text = "The user likes to eat apples." - self.embedding = self.embedder.embed([self.memory_text])[0] - - # We use dictionary for metadata to simulate what might be passed or stored - # But wait, add_node expects metadata as a dict usually. - metadata = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": self.embedding, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": ["food", "fruit"], - "key": "user_preference", - "sources": [], - } - - node_id = str(uuid.uuid4()) - self.graph_db.add_node(node_id, self.memory_text, metadata, user_name=self.user_name) - self.added_ids.append(node_id) - - # Add another one - self.memory_text_2 = "The user has a dog named Rex." - self.embedding_2 = self.embedder.embed([self.memory_text_2])[0] - metadata_2 = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": self.embedding_2, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": ["pet", "dog"], - "key": "user_pet", - "sources": [], - } - node_id_2 = str(uuid.uuid4()) - self.graph_db.add_node(node_id_2, self.memory_text_2, metadata_2, user_name=self.user_name) - self.added_ids.append(node_id_2) - - def tearDown(self): - """Clean up test data.""" - for node_id in self.added_ids: - try: - self.graph_db.delete_node(node_id, user_name=self.user_name) - except Exception as e: - print(f"Error deleting node {node_id}: {e}") - - def test_recall_vector_search(self): - """Test recalling using vector search (implicit in recall method).""" - # "I like apples" -> perspective adjustment should match "The user likes to eat apples" - query_text = "I like apples" - - # Create metadata with source to trigger perspective adjustment - # role="user" means "I" -> "User" - source = SourceMessage(role="user", lang="en") - metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") - - item = TextualMemoryItem(memory=query_text, metadata=metadata) - - # The recall method does both vector and keyword search - results = self.recaller.retrieve(item, self.user_name, top_k=5) - - # Verify we got results - self.assertTrue(len(results) > 0, "Should return at least one result") - found_texts = [r.memory for r in results] - - # Check if the relevant memory is found - # "The user likes to eat apples." should be found. - # We check for "apples" to be safe - self.assertTrue( - any("apples" in t for t in found_texts), - f"Expected 'apples' in results, got: {found_texts}", - ) - - def test_recall_keyword_search(self): - """Test recalling where keyword search might be more relevant.""" - # "Rex" is a specific name - query_text = "What is the name of my dog?" - source = SourceMessage(role="user", lang="en") - metadata = TreeNodeTextualMemoryMetadata(sources=[source], memory_type="WorkingMemory") - - item = TextualMemoryItem(memory=query_text, metadata=metadata) - - results = self.recaller.retrieve(item, self.user_name, top_k=5) - - found_texts = [r.memory for r in results] - self.assertTrue( - any("Rex" in t for t in found_texts), f"Expected 'Rex' in results, got: {found_texts}" - ) - - def test_perspective_adjustment(self): - """Unit test for the _adjust_perspective method specifically.""" - text = "I went to the store myself." - adjusted = self.recaller._adjust_perspective(text, "user", "en") - # I -> User, myself -> User himself - self.assertIn("User", adjusted) - self.assertIn("User himself", adjusted) - - text_zh = "我喜欢吃苹果" - adjusted_zh = self.recaller._adjust_perspective(text_zh, "user", "zh") - # 我 -> 用户 - self.assertIn("用户", adjusted_zh) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/memories/textual/test_pre_update_retriever_latency.py b/tests/memories/textual/test_pre_update_retriever_latency.py deleted file mode 100644 index f4a359de9..000000000 --- a/tests/memories/textual/test_pre_update_retriever_latency.py +++ /dev/null @@ -1,183 +0,0 @@ -import time -import unittest -import uuid - -import numpy as np - -from dotenv import load_dotenv - -from memos.api.handlers.config_builders import build_embedder_config, build_graph_db_config -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.memories.textual.item import ( - SourceMessage, - TextualMemoryItem, - TreeNodeTextualMemoryMetadata, -) -from memos.memories.textual.tree_text_memory.retrieve.pre_update import PreUpdateRetriever - - -# Load environment variables -load_dotenv() - - -class TestPreUpdateRecallerLatency(unittest.TestCase): - """ - Performance and latency tests for PreUpdateRetriever. - These tests are designed to measure latency and might take longer to run. - """ - - @classmethod - def setUpClass(cls): - # Initialize graph_db and embedder using factories - try: - cls.graph_db_config = build_graph_db_config() - cls.graph_db = GraphStoreFactory.from_config(cls.graph_db_config) - - cls.embedder_config = build_embedder_config() - cls.embedder = EmbedderFactory.from_config(cls.embedder_config) - except Exception as e: - raise unittest.SkipTest( - f"Skipping test because initialization failed (likely missing env vars): {e}" - ) from e - - cls.recaller = PreUpdateRetriever(cls.graph_db, cls.embedder) - - # Use a unique user name to isolate tests - cls.user_name = "test_pre_update_recaller_latency_user_" + str(uuid.uuid4())[:8] - - def setUp(self): - # Add a substantial amount of data for latency testing - self.added_ids = [] - self.num_items = 20 - - print(f"\nPopulating database with {self.num_items} items for latency test...") - for i in range(self.num_items): - text = f"This is memory item number {i}. The user might enjoy topic {i % 5}." - embedding = self.embedder.embed([text])[0] - metadata = { - "memory_type": "LongTermMemory", - "status": "activated", - "embedding": embedding, - "created_at": "2023-01-01T00:00:00", - "updated_at": "2023-01-01T00:00:00", - "tags": [f"tag_{i}"], - "key": f"key_{i}", - "sources": [], - } - node_id = str(uuid.uuid4()) - self.graph_db.add_node(node_id, text, metadata, user_name=self.user_name) - self.added_ids.append(node_id) - - def tearDown(self): - """Clean up test data.""" - print("Cleaning up test data...") - for node_id in self.added_ids: - try: - self.graph_db.delete_node(node_id, user_name=self.user_name) - except Exception as e: - print(f"Error deleting node {node_id}: {e}") - - def measure_network_rtt(self, trials=10): - """Measure average network round-trip time.""" - print(f"Measuring Network RTT (using {trials} probes)...") - latencies = [] - - # Try to use raw driver for minimal overhead if available (Neo4j specific) - if hasattr(self.graph_db, "driver") and hasattr(self.graph_db, "db_name"): - print("Using Neo4j driver for direct ping...") - try: - with self.graph_db.driver.session(database=self.graph_db.db_name) as session: - # Warmup - session.run("RETURN 1").single() - - for _ in range(trials): - start = time.time() - session.run("RETURN 1").single() - latencies.append((time.time() - start) * 1000) - except Exception as e: - print(f"Direct driver ping failed: {e}. Falling back to get_node.") - latencies = [] - - if not latencies: - # Fallback to get_node with non-existent ID - print("Using get_node for ping...") - for _ in range(trials): - probe_id = str(uuid.uuid4()) - start = time.time() - self.graph_db.get_node(probe_id, user_name=self.user_name) - latencies.append((time.time() - start) * 1000) - - avg_rtt = np.mean(latencies) - print(f"Average Network RTT: {avg_rtt:.2f} ms") - return avg_rtt - - def test_recall_latency(self): - """Test and report recall latency statistics.""" - avg_rtt = self.measure_network_rtt() - - queries = [ - "I enjoy topic 1", - "What about topic 3?", - "Do I have any preferences?", - "Tell me about memory item 5", - ] - - latencies = [] - - # Warmup - print("Warming up...") - warmup_item = TextualMemoryItem( - memory="warmup query", - metadata=TreeNodeTextualMemoryMetadata( - sources=[SourceMessage(role="user", lang="en")], memory_type="WorkingMemory" - ), - ) - self.recaller.retrieve(warmup_item, self.user_name, top_k=5) - - print(f"Running {len(queries)} queries...") - for q in queries: - # Pre-calculate embedding to exclude from latency measurement - q_embedding = self.embedder.embed([q])[0] - - item = TextualMemoryItem( - memory=q, - metadata=TreeNodeTextualMemoryMetadata( - sources=[SourceMessage(role="user", lang="en")], - memory_type="WorkingMemory", - embedding=q_embedding, - ), - ) - - start_time = time.time() - results = self.recaller.retrieve(item, self.user_name, top_k=5) - end_time = time.time() - - duration_ms = (end_time - start_time) * 1000 - latencies.append(duration_ms) - print(f"Query: '{q}' -> Found {len(results)} results in {duration_ms:.2f} ms") - - # Assert that we actually found results (sanity check) - if "preferences" not in q: # The preferences query might return 0 - self.assertTrue(len(results) > 0, f"Expected results for query: {q}") - - # Report Results - avg_latency = np.mean(latencies) - p95_latency = np.percentile(latencies, 95) - min_latency = np.min(latencies) - max_latency = np.max(latencies) - internal_processing = avg_latency - avg_rtt - - print("\n--- Latency Results ---") - print(f"Average Network RTT: {avg_rtt:.2f} ms") - print(f"Average Total Latency: {avg_latency:.2f} ms") - print(f"Estimated Internal Processing: {internal_processing:.2f} ms") - print(f"95th Percentile: {p95_latency:.2f} ms") - print(f"Min Latency: {min_latency:.2f} ms") - print(f"Max Latency: {max_latency:.2f} ms") - - self.assertLess(internal_processing, 200, "Internal processing should be under 200ms") - - -if __name__ == "__main__": - unittest.main() diff --git a/src/memos/extras/nli_model/__init__.py b/tests/plugins/__init__.py similarity index 100% rename from src/memos/extras/nli_model/__init__.py rename to tests/plugins/__init__.py diff --git a/tests/plugins/conftest.py b/tests/plugins/conftest.py new file mode 100644 index 000000000..6a1a16b68 --- /dev/null +++ b/tests/plugins/conftest.py @@ -0,0 +1,17 @@ +"""Ensure @hookable-generated hooks are declared for core framework tests. + +In production, @hookable("add") runs at import time of add_handler.py, +declaring add.before / add.after. Core framework tests don't import handler +modules (to avoid heavy dependencies), so we trigger declarations here. + +Plugin-specific hooks are declared in each plugin's own tests/conftest.py. +""" + +from memos.plugins.hooks import hookable + + +hookable("add") +hookable("search") +hookable("chat") +hookable("feedback") +hookable("memory.get") diff --git a/src/memos/extras/nli_model/server/__init__.py b/tests/plugins/run_plugin_server.py similarity index 100% rename from src/memos/extras/nli_model/server/__init__.py rename to tests/plugins/run_plugin_server.py diff --git a/tests/plugins/test_plugin_demo.py b/tests/plugins/test_plugin_demo.py new file mode 100644 index 000000000..0eb65b208 --- /dev/null +++ b/tests/plugins/test_plugin_demo.py @@ -0,0 +1,474 @@ +""" +Plugin system core framework tests. + +Covers generic capabilities of the memos.plugins package (independent of specific plugin implementations): +1. Hook declaration registry (hook_defs) +2. Hook registration and triggering / pipe_key pipeline return value +3. @hookable decorator (sync + async + auto-declaration + pipeline return value) +4. MemOSPlugin base class register_* methods + +Plugin-specific functional tests are located in each plugin package: + extensions/memos_demo_plugin/tests/ +""" + +import asyncio +import logging + +from fastapi import FastAPI +from fastapi.testclient import TestClient + + +logging.basicConfig(level=logging.DEBUG) + + +# ========================================================================= # +# 1. Hook declaration registry (hook_defs) +# ========================================================================= # + + +class TestHookDefs: + def test_define_hook_and_get_spec(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook( + "test.custom.hook", + description="test hook", + params=["request", "result"], + pipe_key="result", + ) + + spec = get_hook_spec("test.custom.hook") + assert spec is not None + assert spec.name == "test.custom.hook" + assert spec.params == ["request", "result"] + assert spec.pipe_key == "result" + + def test_define_hook_is_idempotent(self): + from memos.plugins.hook_defs import define_hook, get_hook_spec + + define_hook("test.idempotent", description="first", params=["a"], pipe_key="a") + define_hook("test.idempotent", description="second", params=["b"], pipe_key="b") + + spec = get_hook_spec("test.idempotent") + assert spec.description == "first" + + def test_get_hook_spec_returns_none_for_unknown(self): + from memos.plugins.hook_defs import get_hook_spec + + assert get_hook_spec("definitely.does.not.exist") is None + + def test_all_hook_specs_includes_custom(self): + from memos.plugins.hook_defs import H, all_hook_specs + + specs = all_hook_specs() + assert H.ADD_MEMORIES_POST_PROCESS in specs + + def test_h_constants(self): + from memos.plugins.hook_defs import H + + assert H.ADD_BEFORE == "add.before" + assert H.ADD_AFTER == "add.after" + assert H.SEARCH_BEFORE == "search.before" + assert H.SEARCH_AFTER == "search.after" + assert H.ADD_MEMORIES_POST_PROCESS == "add.memories.post_process" + + +# ========================================================================= # +# 2. Hook registration and triggering / pipe_key pipeline return value +# ========================================================================= # + + +class TestHookMechanism: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_and_trigger(self): + from memos.plugins.hooks import register_hook, trigger_hook + + captured = {} + + def my_callback(*, request, **kwargs): + captured["request"] = request + + register_hook("add.before", my_callback) + trigger_hook("add.before", request="test_request") + + assert captured["request"] == "test_request" + + def test_register_hooks_batch(self): + from memos.plugins.hooks import register_hooks, trigger_hook + + call_count = 0 + + def my_callback(**kwargs): + nonlocal call_count + call_count += 1 + + register_hooks(["add.before", "search.before"], my_callback) + trigger_hook("add.before") + trigger_hook("search.before") + + assert call_count == 2 + + def test_trigger_undeclared_hook_returns_none(self): + from memos.plugins.hooks import trigger_hook + + result = trigger_hook("nonexistent.undeclared.hook", request="anything") + assert result is None + + def test_hook_exception_does_not_propagate(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook("test.exception", description="test", params=["x"]) + + results = [] + + def bad_callback(**kwargs): + raise ValueError("intentional error") + + def good_callback(**kwargs): + results.append("ok") + + register_hook("test.exception", bad_callback) + register_hook("test.exception", good_callback) + trigger_hook("test.exception", x=1) + + assert results == ["ok"] + + def test_trigger_hook_pipe_key_returns_modified_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.pipe", + description="pipe test", + params=["request", "result"], + pipe_key="result", + ) + + def double_result(*, request, result, **kwargs): + return result * 2 + + register_hook("test.pipe", double_result) + rv = trigger_hook("test.pipe", request="req", result=5) + + assert rv == 10 + + def test_trigger_hook_pipe_key_chains_callbacks(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.chain", + description="chain test", + params=["result"], + pipe_key="result", + ) + + def add_one(*, result, **kwargs): + return result + 1 + + def add_ten(*, result, **kwargs): + return result + 10 + + register_hook("test.chain", add_one) + register_hook("test.chain", add_ten) + + rv = trigger_hook("test.chain", result=0) + assert rv == 11 + + def test_trigger_hook_pipe_key_none_callback_no_modify(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.noop", + description="noop test", + params=["result"], + pipe_key="result", + ) + + def noop(*, result, **kwargs): + return None # explicitly return None — should not modify + + register_hook("test.noop", noop) + rv = trigger_hook("test.noop", result="original") + + assert rv == "original" + + def test_trigger_hook_notification_mode(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_hook + + define_hook( + "test.notify", + description="notification test", + params=["data"], + pipe_key=None, + ) + + captured = [] + + def observer(*, data, **kwargs): + captured.append(data) + + register_hook("test.notify", observer) + rv = trigger_hook("test.notify", data="hello") + + assert rv is None + assert captured == ["hello"] + + def test_trigger_single_hook_returns_value(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_single_hook + + define_hook("test.single", description="single", params=["value"]) + + def handler(*, value, **kwargs): + return value + 1 + + register_hook("test.single", handler) + + assert trigger_single_hook("test.single", value=1) == 2 + + def test_trigger_single_hook_requires_exactly_one_callback(self): + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import register_hook, trigger_single_hook + + define_hook("test.single.count", description="single", params=["value"]) + + def handler_a(*, value, **kwargs): + return value + 1 + + def handler_b(*, value, **kwargs): + return value + 2 + + register_hook("test.single.count", handler_a) + register_hook("test.single.count", handler_b) + + try: + trigger_single_hook("test.single.count", value=1) + except RuntimeError as exc: + assert "Multiple plugins" in str(exc) + else: + raise AssertionError("Expected RuntimeError for multiple callbacks") + + +# ========================================================================= # +# 3. @hookable decorator +# ========================================================================= # + + +class TestHookableDecorator: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_hookable_auto_declares_specs(self): + from memos.plugins.hook_defs import get_hook_spec + from memos.plugins.hooks import hookable + + @hookable("auto_test") + def dummy(self, request): + return request + + before_spec = get_hook_spec("auto_test.before") + after_spec = get_hook_spec("auto_test.after") + + assert before_spec is not None + assert before_spec.pipe_key == "request" + assert after_spec is not None + assert after_spec.pipe_key == "result" + + def test_hookable_sync(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append(("before", request)) + + def on_after(*, request, result, **kwargs): + events.append(("after", result)) + + register_hook("sync_demo.before", on_before) + register_hook("sync_demo.after", on_after) + + class FakeHandler: + @hookable("sync_demo") + def do_work(self, request): + return f"processed:{request}" + + result = FakeHandler().do_work("my_input") + + assert result == "processed:my_input" + assert events == [("before", "my_input"), ("after", "processed:my_input")] + + def test_hookable_async(self): + from memos.plugins.hooks import hookable, register_hook + + events = [] + + def on_before(*, request, **kwargs): + events.append("before") + + def on_after(*, request, result, **kwargs): + events.append("after") + + register_hook("async_demo.before", on_before) + register_hook("async_demo.after", on_after) + + class FakeHandler: + @hookable("async_demo") + async def do_work(self, request): + return "async_result" + + result = asyncio.run(FakeHandler().do_work("req")) + + assert result == "async_result" + assert events == ["before", "after"] + + def test_hookable_before_can_modify_request(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_request(*, request, **kwargs): + return "modified_request" + + register_hook("modify_req.before", rewrite_request) + + class FakeHandler: + @hookable("modify_req") + def do_work(self, request): + return f"got:{request}" + + result = FakeHandler().do_work("original") + assert result == "got:modified_request" + + def test_hookable_after_can_modify_result(self): + from memos.plugins.hooks import hookable, register_hook + + def rewrite_result(*, request, result, **kwargs): + return f"{result}+modified" + + register_hook("modify_res.after", rewrite_result) + + class FakeHandler: + @hookable("modify_res") + def do_work(self, request): + return "original_result" + + result = FakeHandler().do_work("req") + assert result == "original_result+modified" + + def test_hookable_falsy_return_preserved(self): + """ensure empty list / 0 / empty string are not treated as None""" + from memos.plugins.hooks import hookable, register_hook + + def return_empty_list(*, request, result, **kwargs): + return [] + + register_hook("falsy_test.after", return_empty_list) + + class FakeHandler: + @hookable("falsy_test") + def do_work(self, request): + return [1, 2, 3] + + result = FakeHandler().do_work("req") + assert result == [] + + +# ========================================================================= # +# 4. Base class register_* methods +# ========================================================================= # + + +class TestBaseClassRegisterMethods: + def setup_method(self): + from memos.plugins.hooks import _hooks + + _hooks.clear() + + def test_register_router(self): + from fastapi import APIRouter + + from memos.plugins.base import MemOSPlugin + + app = FastAPI() + plugin = MemOSPlugin() + plugin._bind_app(app) + + router = APIRouter(prefix="/test") + + @router.get("/ping") + async def ping(): + return {"pong": True} + + plugin.register_router(router) + + paths = [r.path for r in app.routes] + assert "/test/ping" in paths + + def test_register_middleware(self): + from starlette.middleware.base import BaseHTTPMiddleware + + from memos.plugins.base import MemOSPlugin + + class NoopMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = FastAPI() + + @app.get("/x") + async def x(): + return {} + + plugin = MemOSPlugin() + plugin._bind_app(app) + plugin.register_middleware(NoopMiddleware) + + client = TestClient(app) + resp = client.get("/x") + assert resp.status_code == 200 + + def test_register_hook(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("test.reg.event", description="test", params=["x"]) + + called = [] + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hook("test.reg.event", lambda **kw: called.append(True)) + + trigger_hook("test.reg.event", x=1) + assert called == [True] + + def test_register_hooks_batch(self): + from memos.plugins.base import MemOSPlugin + from memos.plugins.hook_defs import define_hook + from memos.plugins.hooks import trigger_hook + + define_hook("batch.a", description="a", params=["x"]) + define_hook("batch.b", description="b", params=["x"]) + + count = 0 + + def cb(**kw): + nonlocal count + count += 1 + + plugin = MemOSPlugin() + plugin._bind_app(FastAPI()) + plugin.register_hooks(["batch.a", "batch.b"], cb) + + trigger_hook("batch.a", x=1) + trigger_hook("batch.b", x=2) + assert count == 2