diff --git a/docker/.env.example b/docker/.env.example index c9d8e714e..9883ab022 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -172,6 +172,14 @@ POLAR_DB_USE_MULTI_DB=false # PolarDB connection pool size POLARDB_POOL_MAX_CONN=100 +# bytehouse endpoint/host +BYTEHOUSE_HOST=localhost +BYTEHOUSE_PORT=19000 +BYTEHOUSE_USER=bytehouse +BYTEHOUSE_PASSWORD=xxxxxx:xxxxxxxxx +BYTEHOUSE_DB_NAME=test_shared_memos_db +BYTEHOUSE_USE_MULTI_DB=false + ## Related configurations of Redis # Reddimq sends scheduling information and synchronization information for some variables MEMSCHEDULER_REDIS_HOST= # fallback keys if not using the global ones diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c68deae5a..e799eecbf 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -860,6 +860,31 @@ def get_mysql_config() -> dict[str, Any]: "charset": os.getenv("MYSQL_CHARSET", "utf8mb4"), } + @staticmethod + def get_bytehouse_config(user_id: str | None = None) -> dict[str, Any]: + """Get ByteHouse configuration.""" + use_multi_db = os.getenv("BYTEHOUSE_USE_MULTI_DB", "false").lower() == "true" + + if use_multi_db: + db_name = f"memos{user_id.replace('-', '')}" if user_id else "memos_default" + else: + db_name = os.getenv("BYTEHOUSE_DB_NAME", "shared_memos_db") + user_name = ( + f"memos{user_id.replace('-', '')}" if user_id else "memos_default" + ) + + return { + "host": os.getenv("BYTEHOUSE_HOST", "localhost"), + "port": int(os.getenv("BYTEHOUSE_PORT", "9000")), + "user": os.getenv("BYTEHOUSE_USER", "default"), + "password": os.getenv("BYTEHOUSE_PASSWORD", ""), + "db_name": db_name, + "user_name": user_name, + "use_multi_db": use_multi_db, + "auto_create": True, + "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", "1024")), + } + @staticmethod def get_scheduler_config() -> dict[str, Any]: """Get scheduler configuration.""" @@ -1132,11 +1157,13 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene else None ) postgres_config = APIConfig.get_postgres_config(user_id=user_id) + bytehouse_config = APIConfig.get_bytehouse_config(user_id=user_id) graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "polardb": polardb_config, "postgres": postgres_config, + "bytehouse": bytehouse_config, } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars graph_db_backend = os.getenv( @@ -1210,11 +1237,14 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": neo4j_config = APIConfig.get_neo4j_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") postgres_config = APIConfig.get_postgres_config(user_id="default") + bytehouse_config = APIConfig.get_bytehouse_config(user_id="default") + graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, "polardb": polardb_config, "postgres": postgres_config, + "bytehouse": bytehouse_config, } internet_config = ( APIConfig.get_internet_config() diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index d29429fc9..14e9cb14b 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -41,6 +41,7 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: "neo4j": APIConfig.get_neo4j_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), + "bytehouse": APIConfig.get_bytehouse_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 98de09812..a23911885 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -241,6 +241,56 @@ def validate_config(self): return self +class ByteHouseGraphDBConfig(BaseConfig): + """ + ByteHouse configuration for MemOS. + + Schema: + - memos_memories: Main table for memory nodes (id, memory, properties JSONB, embedding vector) + - memos_edges: Edge table for relationships (source_id, target_id, type) + + Example: + --- + host = "bytehouse" + port = 9000 + user = "n8n" + password = "secret" + db_name = "n8n" + user_name = "default" + """ + + host: str = Field(..., description="Database host") + port: int = Field(default=9000, description="Database port") + user: str = Field(..., description="Database user") + password: str = Field(..., description="Database password") + db_name: str = Field(..., description="Database name") + user_name: str = Field( + default="bytehouse", + description="Logical user/tenant ID for data isolation", + ) + use_multi_db: bool = Field( + default=False, + description="If False: use single database with logical isolation by user_name", + ) + auto_create: bool = Field( + default=False, + description="Whether to auto-create the database if it does not exist", + ) + embedding_dimension: int = Field( + default=1024, + description="Dimension of vector embedding (1024 for all-MiniLM-L6-v2)", + ) + + @model_validator(mode="after") + def validate_config(self): + """Validate config.""" + if not self.db_name: + raise ValueError("`db_name` must be provided") + if not self.use_multi_db and not self.user_name: + raise ValueError("In single-database mode, `user_name` must be provided") + return self + + class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") @@ -250,6 +300,7 @@ class GraphDBConfigFactory(BaseModel): "neo4j-community": Neo4jCommunityGraphDBConfig, "polardb": PolarDBGraphDBConfig, "postgres": PostgresGraphDBConfig, + "bytehouse": ByteHouseGraphDBConfig, } @field_validator("backend") diff --git a/src/memos/graph_dbs/bytehouse.py b/src/memos/graph_dbs/bytehouse.py new file mode 100644 index 000000000..add9745fd --- /dev/null +++ b/src/memos/graph_dbs/bytehouse.py @@ -0,0 +1,1002 @@ +import json + +from datetime import datetime +from typing import Any, Literal, Tuple + +from memos.configs.graph_db import ByteHouseGraphDBConfig +from memos.log import get_logger +from memos.dependency import require_python_package +from memos.graph_dbs.base import BaseGraphDB +from memos.utils import timed + + +logger = get_logger(__name__) + + +def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: + """Ensure metadata has proper datetime fields and normalized types.""" + now = datetime.now().isoformat() + metadata.setdefault("created_at", now) + metadata.setdefault("updated_at", now) + + # Normalize embedding type + embedding = metadata.get("embedding") + if embedding and isinstance(embedding, list): + metadata["embedding"] = [float(x) for x in embedding] + + return metadata + + +class ByteHouseGraphDB(BaseGraphDB): + @require_python_package( + import_name="clickhouse_connect", + install_command="pip install clickhouse-connect", + install_link="https://pypi.org/project/clickhouse-connect/", + ) + def __init__(self, config: ByteHouseGraphDBConfig): + """Initialize ByteHouse connection.""" + import clickhouse_connect + + self.config = config + self.db_name = config.db_name + self.user_name = config.user_name + self.embedding_dimension = config.embedding_dimension + + logger.info( + f"Connecting to ByteHouse: {config.host}:{config.port}/{config.db_name}" + ) + + # Create ClickHouse client + self.client = clickhouse_connect.get_client( + host=config.host, + port=config.port, + username="bytehouse", + password=config.password, + secure=True, + compress=False, + autogenerate_session_id=False, + ) + + # Initialize schema and tables + self._init_schema() + + @timed + def _init_schema(self): + """Create schema and tables if they don't exist.""" + try: + # Create schema + self.client.command(f"CREATE DATABASE IF NOT EXISTS {self.db_name}") + + # Create memories table + dim = self.embedding_dimension + self.client.command( + f""" + CREATE TABLE IF NOT EXISTS {self.db_name}.memories ( + user_name String, + id String, + memory String, + properties String, + embedding Array(Float32) DEFAULT arrayWithConstant({dim},0), + created_at DateTime DEFAULT now(), + updated_at DateTime DEFAULT now(), + INDEX vec_idx_embedding embedding TYPE HNSW_SQ('DIM={dim}', 'METRIC=COSINE') GRANULARITY 1, + INDEX idx_id id TYPE inverted('noop','{{"version":"v2"}}') GRANULARITY 64, + INDEX idx_memory memory TYPE inverted('standard','{{"version":"v4"}}') GRANULARITY 64 + ) ENGINE = CnchMergeTree() + PARTITION BY user_name + ORDER BY created_at + UNIQUE KEY id SETTINGS index_granularity = 128, index_granularity_bytes = 0, + enable_unique_partial_update = 1, partition_level_unique_keys = 1 + """ + ) + + # Create edges table + self.client.command( + f""" + CREATE TABLE IF NOT EXISTS {self.db_name}.edges ( + user_name String, + id String, + source_id String, + target_id String, + edge_type String, + created_at DateTime DEFAULT now(), + INDEX idx_source_id source_id TYPE inverted('noop','{{"version":"v2"}}') GRANULARITY 1, + INDEX idx_target_id target_id TYPE inverted('noop','{{"version":"v2"}}') GRANULARITY 1 + ) ENGINE = CnchMergeTree() + PARTITION BY user_name + ORDER BY created_at + UNIQUE KEY id + SETTINGS enable_unique_partial_update = 1, partition_level_unique_keys = 1 + """ + ) + + logger.info(f"Schema {self.db_name} initialized successfully") + except Exception as e: + logger.error(f"Failed to init schema: {e}") + raise + + def _parse_row(self, row, include_embedding: bool = False) -> dict[str, Any]: + """Parse database row to node dict.""" + idx = 0 + result_id = row[idx] + idx += 1 + result_memory = row[idx] or "" + idx += 1 + props = json.loads(row[idx] or "{}") + idx += 1 + props["created_at"] = row[idx].isoformat() if row[idx] else None + idx += 1 + props["updated_at"] = row[idx].isoformat() if row[idx] else None + idx += 1 + + result = { + "id": result_id, + "memory": result_memory, + "metadata": props, + } + + if include_embedding and idx < len(row): + result["metadata"]["embedding"] = row[idx] + return result + + def _build_user_name_and_kb_ids_condition( + self, user_name: str, knowledgebase_ids: list[str] + ) -> str: + """Build ClickHouse condition for user_name and knowledgebase_ids.""" + if not knowledgebase_ids: + return f"user_name = '{user_name}'" + else: + return f"user_name IN ['{'\',\''.join(knowledgebase_ids)}','{user_name}']" + + def close(self): + """Close the ClickHouse client connection.""" + if hasattr(self, "client") and self.client: + self.client.close() + logger.info("ByteHouse connection closed") + + # ========================================================================= + # Node Management + # ========================================================================= + + @timed + def add_node( + self, + id: str, + memory: str, + metadata: dict[str, Any], + user_name: str | None = None, + ) -> None: + """Add a memory node.""" + user_name = user_name or self.user_name + metadata = _prepare_node_metadata(metadata.copy()) + + # Extract embedding + embedding = metadata.pop("embedding", None) + + # Parse ISO strings to datetime objects + created_at = datetime.now() + updated_at = datetime.now() + + # If embedding is empty, fill with zeros + if not embedding: + embedding = [0.0 for _ in range(self.embedding_dimension)] + + # Serialize sources if present + if metadata.get("sources"): + metadata["sources"] = [ + json.dumps(s) if not isinstance(s, str) else s + for s in metadata["sources"] + ] + + try: + self.client.insert( + f"{self.db_name}.memories", + [ + ( + id, + memory, + json.dumps(metadata), + embedding, + user_name, + created_at, + updated_at, + ) + ], + column_names=[ + "id", + "memory", + "properties", + "embedding", + "user_name", + "created_at", + "updated_at", + ], + column_type_names=[ + "String", + "String", + "String", + "Array(Float32)", + "String", + "DateTime", + "DateTime", + ], + ) + except Exception as e: + logger.error(f"Failed to add node: {e}") + raise + + @timed + def add_nodes_batch( + self, nodes: list[dict[str, Any]], user_name: str | None = None + ) -> None: + """Batch add memory nodes.""" + for node in nodes: + self.add_node( + id=node["id"], + memory=node["memory"], + metadata=node.get("metadata", {}), + user_name=user_name, + ) + + @timed + def update_node( + self, id: str, fields: dict[str, Any], user_name: str | None = None + ) -> None: + """Update node fields using ByteHouse partial column update.""" + user_name = user_name or self.user_name + if not fields: + return + + current = self.get_node(id, user_name=user_name) + if not current: + return + + # Merge properties + props = current.get("metadata", {}).copy() + embedding = fields.pop("embedding", None) + memory = fields.pop("memory", current.get("memory", "")) + props.update(fields) + props["updated_at"] = datetime.now().isoformat() + + try: + if embedding: + self.client.insert( + f"{self.db_name}.memories", + [ + ( + id, + memory, + json.dumps(props), + embedding, + user_name, + datetime.now(), + ) + ], + column_names=[ + "id", + "memory", + "properties", + "embedding", + "user_name", + "updated_at", + ], + column_type_names=[ + "String", + "String", + "String", + "Array(Float32)", + "String", + "DateTime", + ], + ) + + else: + self.client.insert( + f"{self.db_name}.memories", + [(id, memory, json.dumps(props), user_name, datetime.now())], + column_names=[ + "id", + "memory", + "properties", + "user_name", + "updated_at", + ], + column_type_names=[ + "String", + "String", + "String", + "String", + "DateTime", + ], + ) + + except Exception as e: + logger.error(f"Failed to update node: {e}") + raise + + @timed + def delete_node(self, id: str, user_name: str | None = None) -> None: + """Delete a node and its edges using _delete_flag_.""" + user_name = user_name or self.user_name + try: + self.client.insert( + f"{self.db_name}.memories", + [(id, user_name, 1)], + column_names=["id", "user_name", "_delete_flag_"], + column_type_names=["String", "String", "UInt8"], + ) + + # Get related edges unique keys + edge_result = self.client.query( + f""" + SELECT id + FROM {self.db_name}.edges + WHERE (source_id = '{id}' OR target_id = '{id}') AND user_name = '{user_name}' + """, + ).result_set + + # Delete edges using INSERT with _delete_flag_ = 1 (only unique key needed) + if edge_result: + edge_values = [(edge_id, user_name, 1) for (edge_id) in edge_result] + self.client.insert( + f"{self.db_name}.edges", + edge_values, + column_names=["id", "user_name", "_delete_flag_"], + column_type_names=["String", "String", "UInt8"], + ) + except Exception as e: + logger.error(f"Failed to delete node: {e}") + raise + + @timed + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: + """Get a single node by ID.""" + user_name = user_name or self.user_name + try: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + result = self.client.query( + f"SELECT {cols} FROM {self.db_name}.memories WHERE id = '{id}' AND user_name = '{user_name}'", + ).result_set + if not result: + return None + return self._parse_row(result[0], include_embedding) + except Exception as e: + logger.error(f"Failed to get node: {e}") + return None + + @timed + def get_nodes( + self, + ids: list[str], + include_embedding: bool = False, + user_name: str | None = None, + **kwargs, + ) -> list[dict[str, Any]]: + """Get multiple nodes by IDs.""" + if not ids: + return [] + user_name = user_name or self.user_name + try: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + ids_str = "','".join(ids) + result = self.client.query( + f"SELECT {cols} FROM {self.db_name}.memories WHERE id IN ('{ids_str}') AND user_name = '{user_name}'", + ).result_set + return [self._parse_row(row, include_embedding) for row in result] + except Exception as e: + logger.error(f"Failed to get nodes: {e}") + return [] + + @timed + def add_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Create an edge between nodes.""" + user_name = user_name or self.user_name + edge_id = f"{source_id}_{target_id}_{type}" + + try: + self.client.insert( + f"{self.db_name}.edges", + [(edge_id, source_id, target_id, type, user_name, datetime.now())], + column_names=[ + "id", + "source_id", + "target_id", + "edge_type", + "user_name", + "created_at", + ], + column_type_names=[ + "String", + "String", + "String", + "String", + "String", + "DateTime", + ], + ) + except Exception as e: + logger.error(f"Failed to add edge: {e}") + raise + + @timed + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: + """Delete an edge using _delete_flag_.""" + user_name = user_name or self.user_name + try: + edge_id = f"{source_id}_{target_id}_{type}" + + self.client.insert( + f"{self.db_name}.edges", + [(user_name, edge_id, 1)], + column_names=["user_name", "id", "_delete_flag_"], + column_type_names=["String", "String", "UInt8"], + ) + except Exception as e: + logger.error(f"Failed to delete edge: {e}") + raise + + @timed + def edge_exists( + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, + ) -> bool: + """ + Check if an edge exists between two nodes. + Args: + source_id: ID of the source node. + target_id: ID of the target node. + type: Relationship type. Use "ANY" to match any relationship type. + direction: Direction of the edge. + Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode + Returns: + True if the edge exists, otherwise False. + """ + user_name = user_name or self.user_name + + try: + if direction == "ANY": + # Check both directions + if type == "ANY": + result = self.client.query( + f""" + SELECT 1 FROM {self.db_name}.edges + WHERE ((source_id = '{source_id}' AND target_id = '{target_id}') OR + (source_id = '{target_id}' AND target_id = '{source_id}')) + AND user_name = '{user_name}' + LIMIT 1 + """, + ).result_set + else: + result = self.client.query( + f""" + SELECT 1 FROM {self.db_name}.edges + WHERE ((source_id = '{source_id}' AND target_id = '{target_id}') OR + (source_id = '{target_id}' AND target_id = '{source_id}')) + AND edge_type = '{type}' AND user_name = '{user_name}' + LIMIT 1 + """, + ).result_set + else: + # Handle INCOMING direction by swapping source and target + if direction == "INCOMING": + source_id, target_id = target_id, source_id + + if type == "ANY": + result = self.client.query( + f""" + SELECT 1 FROM {self.db_name}.edges + WHERE source_id = '{source_id}' AND target_id = '{target_id}' + AND user_name = '{user_name}' + LIMIT 1 + """, + ).result_set + else: + result = self.client.query( + f""" + SELECT 1 FROM {self.db_name}.edges + WHERE source_id = '{source_id}' AND target_id = '{target_id}' + AND edge_type = '{type}' AND user_name = '{user_name}' + LIMIT 1 + """, + ).result_set + return len(result) > 0 + except Exception as e: + logger.error(f"Failed to check edge existence: {e}") + return False + + @timed + def get_neighbors( + self, + id: str, + type: str, + direction: Literal["in", "out", "both"] = "out", + **kwargs, + ) -> list[str]: + """Get neighboring node IDs.""" + user_name = kwargs.get("user_name") or self.user_name + try: + if direction == "out": + result = self.client.query( + f""" + SELECT + target_id + FROM {self.db_name}.edges + WHERE source_id = '{id}' AND edge_type = '{type}' AND user_name = '{user_name}' + """, + ).result_set + return [row[0] for row in result] + elif direction == "in": + result = self.client.query( + f""" + SELECT + source_id + FROM {self.db_name}.edges + WHERE target_id = '{id}' AND edge_type = '{type}' AND user_name = '{user_name}' + """, + ).result_set + return [row[0] for row in result] + else: # both + result = self.client.query( + f""" + SELECT + target_id, + source_id + FROM {self.db_name}.edges + WHERE ((source_id = '{id}' AND edge_type = '{type}') + OR (target_id = '{id}' AND edge_type = '{type}')) + AND user_name = '{user_name}' + """, + ).result_set + # Remove duplicates and self loop + result_set = set() + for row in result: + result_set.add(row[0]) + result_set.add(row[1]) + result_set.remove(id) + return list(result_set) + + except Exception as e: + logger.error(f"Failed to get neighbors: {e}") + return [] + + def get_path( + self, source_id: str, target_id: str, max_depth: int = 3, **kwargs + ) -> list[str]: + """Get the path of nodes from source to target within a limited depth""" + raise NotImplementedError + + def get_subgraph( + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, + ) -> dict[str, Any]: + """Get subgraph around center node using iterative BFS.""" + raise NotImplementedError + + def get_context_chain(self, id: str, type: str = "FOLLOWS", **kwargs) -> list[str]: + """Get ordered chain following relationship type.""" + return self.get_neighbors(id, type, "out", **kwargs) + + @timed + def search_by_embedding( + self, + vector: list[float], + top_k: int = 5, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Search nodes by vector similarity using ByteHouse vector search.""" + user_name = user_name or self.user_name + + if not user_name: + return [] + + # Build WHERE clause and parameters + conditions = [self._build_user_name_and_kb_ids_condition(user_name, knowledgebase_ids)] + + params = {"vector": vector, "top_k": top_k, "user_name": user_name} + + if scope: + conditions.append( + "JSONExtractString(properties, 'memory_type') = {scope:String}" + ) + params["scope"] = scope + + if status: + conditions.append( + "JSONExtractString(properties, 'status') = {status:String}" + ) + params["status"] = status + else: + conditions.append( + "(JSONExtractString(properties, 'status') IN ['activated', ''])" + ) + + if search_filter: + for k, v in search_filter.items(): + param_name = f"filter_{k}" + conditions.append( + f"JSONExtractString(properties, '{k}') = {{{param_name}:String}}" + ) + params[param_name] = str(v) + + where_clause = " AND ".join(conditions) + + try: + # ByteHouse vector search using cosineDistance with QueryContext + qc = self.client.create_query_context( + query=f""" + SELECT + id, + 1 - distance / 2 as score + FROM + {self.db_name}.memories + PREWHERE + {where_clause} + ORDER BY + cosineDistance(embedding, {{vector:Array(Float32)}}) AS distance ASC + LIMIT + {{top_k:UInt32}} + SETTINGS enable_new_ann=1 + """, + parameters=params, + ) + query_result = self.client.query(context=qc) + result = query_result.result_set + + results = [] + for row in result: + score = float(row[1]) + if threshold is None or score >= threshold: + results.append({"id": row[0], "score": score}) + return results + except Exception as e: + logger.error(f"Failed to search by embedding: {e}") + return [] + + @timed + def get_by_metadata( + self, + filters: list[dict[str, Any]], + status: str | None = None, + user_name: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + user_name_flag: bool = True, + ) -> list[str]: + """Get node IDs matching metadata filters.""" + user_name = user_name or self.user_name + + conditions = [] + params = {} + + conditions.append( + self._build_user_name_and_kb_ids_condition(user_name, knowledgebase_ids) + ) + + if status: + conditions.append(f"JSONExtractString(properties, 'status') = '{status}'") + + where_clause = " AND ".join(conditions) + + try: + result = self.client.query( + f"SELECT id FROM {self.db_name}.memories WHERE {where_clause}" + ).result_set + return [row[0] for row in result] + except Exception as e: + logger.error(f"Failed to get by metadata: {e}") + return [] + + def get_structure_optimization_candidates( + self, scope: str, include_embedding: bool = False, **kwargs + ) -> list[dict]: + """Find isolated nodes (no edges).""" + user_name = kwargs.get("user_name") or self.user_name + try: + cols = "m.id, m.memory, m.properties, m.created_at, m.updated_at" + if include_embedding: + cols += ", m.embedding" + result = self.client.query( + f""" + SELECT {cols} + FROM {self.db_name}.memories m + LEFT JOIN {self.db_name}.edges e1 ON m.id = e1.source_id AND m.user_name = e1.user_name + LEFT JOIN {self.db_name}.edges e2 ON m.id = e2.target_id AND m.user_name = e2.user_name + WHERE JSONExtractString(m.properties, 'memory_type') = '{scope}' + AND m.user_name = '{user_name}' + AND (JSONExtractString(m.properties, 'status') IN ['activated', '']) + AND e1.id IS NULL + AND e2.id IS NULL + """, + ).result_set + return [self._parse_row(row, include_embedding) for row in result] + except Exception as e: + logger.error(f"Failed to get structure optimization candidates: {e}") + return [] + + def deduplicate_nodes(self, **kwargs) -> None: + """Not implemented - handled at application level.""" + + def detect_conflicts(self, **kwargs) -> list[Tuple[str, str]]: + """Not implemented.""" + return [] + + def merge_nodes(self, id1: str, id2: str, **kwargs) -> str: + """Not implemented.""" + raise NotImplementedError + + def clear(self, user_name: str | None = None) -> None: + """Clear all data for user using ALTER Drop""" + user_name = user_name or self.user_name + if not user_name: + return + + try: + self.client.command( + f""" + ALTER TABLE {self.db_name}.memories DROP PARTITION '{user_name}' + """, + ) + + self.client.command( + f""" + ALTER TABLE {self.db_name}.edges DROP PARTITION '{user_name}' + """, + ) + + except Exception as e: + logger.error(f"Failed to clear data: {e}") + raise + + def export_graph(self, include_embedding: bool = False, **kwargs) -> dict[str, Any]: + """Export all data.""" + user_name = kwargs.get("user_name") or self.user_name + try: + # Get nodes + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + result = self.client.query( + f""" + SELECT {cols} FROM {self.db_name}.memories + WHERE user_name = '{user_name}' + ORDER BY created_at DESC + """, + ).result_set + nodes = [self._parse_row(row, include_embedding) for row in result] + + # Get edges + node_ids = [n["id"] for n in nodes] + edges = [] + if node_ids: + node_ids_str = "','".join(node_ids) + edge_result = self.client.query( + f""" + SELECT source_id, target_id, edge_type + FROM {self.db_name}.edges + WHERE (source_id IN ('{node_ids_str}') OR target_id IN ('{node_ids_str}')) AND user_name = '{user_name}' + """, + ).result_set + edges = [ + {"source": row[0], "target": row[1], "type": row[2]} + for row in edge_result + ] + + return { + "nodes": nodes, + "edges": edges, + "total_nodes": len(nodes), + "total_edges": len(edges), + } + except Exception as e: + logger.error(f"Failed to export graph: {e}") + return {"nodes": [], "edges": [], "total_nodes": 0, "total_edges": 0} + + def import_graph(self, data: dict[str, Any]) -> None: + """Import graph data.""" + try: + for node in data.get("nodes", []): + self.add_node( + id=node["id"], + memory=node.get("memory", ""), + metadata=node.get("metadata", {}), + ) + + for edge in data.get("edges", []): + self.add_edge( + source_id=edge["source"], + target_id=edge["target"], + type=edge["type"], + ) + except Exception as e: + logger.error(f"Failed to import graph: {e}") + raise + + def get_all_memory_items( + self, + scope: str, + include_embedding: bool = False, + status: str | None = None, + filter: dict | None = None, + knowledgebase_ids: list[str] | None = None, + **kwargs, + ) -> list[dict]: + """Get all memory items of a specific type.""" + user_name = kwargs.get("user_name") or self.user_name + + conditions = [ + f"JSONExtractString(properties, 'memory_type') = '{scope}'", + f"user_name = '{user_name}'", + ] + + if status: + conditions.append(f"JSONExtractString(properties, 'status') = '{status}'") + + if filter: + for key, value in filter.items(): + conditions.append(f"JSONExtractString(properties, '{key}') = '{value}'") + + where_clause = " AND ".join(conditions) + + try: + cols = "id, memory, properties, created_at, updated_at" + if include_embedding: + cols += ", embedding" + result = self.client.query( + f"SELECT {cols} FROM {self.db_name}.memories WHERE {where_clause}", + ).result_set + return [self._parse_row(row, include_embedding) for row in result] + except Exception as e: + logger.error(f"Failed to get all memory items: {e}") + return [] + + @timed + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: + """ + Remove all memories of a given type except the latest `keep_latest` entries. + + Args: + memory_type: Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). + keep_latest: Number of latest entries to keep. + user_name: User to filter by. + """ + user_name = user_name or self.user_name + try: + # query count + count_result = self.client.query( + f"SELECT COUNT(*) FROM {self.db_name}.memories WHERE JSONExtractString(properties, 'memory_type') = '{memory_type}' AND user_name = '{user_name}'" + ).result_set + total_count = count_result[0][0] + + if total_count <= keep_latest: + return + + # Get IDs of memories to delete (all except the latest keep_latest) + result = self.client.query( + f""" + SELECT id + FROM {self.db_name}.memories + WHERE JSONExtractString(properties, 'memory_type') = '{memory_type}' + AND user_name = '{user_name}' + ORDER BY created_at ASC + LIMIT {total_count - keep_latest} + """, + ).result_set + + # Delete memories using _delete_flag_ + if result: + memory_ids = [row[0] for row in result] + for memory_id in memory_ids: + self.client.insert( + f"{self.db_name}.memories", + [(memory_id, user_name, 1)], + column_names=["id", "user_name", "_delete_flag_"], + column_type_names=["String", "String", "UInt8"], + ) + + # Get related edges unique keys + edge_result = self.client.query( + f""" + SELECT id + FROM {self.db_name}.edges + WHERE (source_id = '{memory_id}' OR target_id = '{memory_id}') AND user_name = '{user_name}' + """, + ).result_set + + # Delete edges using INSERT with _delete_flag_ = 1 + if edge_result: + edge_values = [ + (edge_id, user_name, 1) for (edge_id,) in edge_result + ] + self.client.insert( + f"{self.db_name}.edges", + edge_values, + column_names=["id", "user_name", "_delete_flag_"], + column_type_names=["String", "String", "UInt8"], + ) + except Exception as e: + logger.error(f"Failed to remove oldest memory: {e}") + raise + + @timed + def get_grouped_counts( + self, + group_fields: list[str], + where_clause: str = "", + params: dict[str, Any] | None = None, + user_name: str | None = None, + ) -> list[dict[str, Any]]: + """ + Count nodes grouped by specified fields. + + Args: + group_fields: Fields to group by, e.g., ["memory_type", "status"] + where_clause: Extra WHERE condition + params: Parameters for WHERE clause + user_name: User to filter by + + Returns: + list[dict]: e.g., [{'memory_type': 'WorkingMemory', 'count': 10}, ...] + """ + user_name = user_name or self.user_name + try: + # Build GROUP BY clause + group_by_clause = ", ".join([f"JSONExtractString(properties, '{field}') as {field}" for field in group_fields]) + + # Build WHERE clause + where_conditions = [f"user_name = '{user_name}'"] + if where_clause: + where_conditions.append(where_clause) + where_clause_full = " WHERE " + " AND ".join(where_conditions) if where_conditions else "" + + # Execute query + query = f""" + SELECT + {group_by_clause}, + COUNT(*) as count + FROM {self.db_name}.memories + {where_clause_full} + GROUP BY {group_by_clause} + """ + + result = self.client.query(query).result_set + + # Parse results + counts = [] + for row in result: + count_dict = {} + for i, field in enumerate(group_fields): + count_dict[field] = row[i] + count_dict['count'] = row[len(group_fields)] + counts.append(count_dict) + + return counts + except Exception as e: + logger.error(f"Failed to get grouped counts: {e}") + return [] diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index 93b5971ec..94592d8b9 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -6,6 +6,7 @@ from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB from memos.graph_dbs.postgres import PostgresGraphDB +from memos.graph_dbs.bytehouse import ByteHouseGraphDB class GraphStoreFactory(BaseGraphDB): @@ -16,6 +17,7 @@ class GraphStoreFactory(BaseGraphDB): "neo4j-community": Neo4jCommunityGraphDB, "polardb": PolarDBGraphDB, "postgres": PostgresGraphDB, + "bytehouse": ByteHouseGraphDB, } @classmethod diff --git a/tests/graph_dbs/test_bytehouse.py b/tests/graph_dbs/test_bytehouse.py new file mode 100644 index 000000000..bf441ef14 --- /dev/null +++ b/tests/graph_dbs/test_bytehouse.py @@ -0,0 +1,321 @@ +import pytest + +import os +import uuid +import random + +from datetime import datetime +from dotenv import load_dotenv + +from memos.configs.graph_db import ByteHouseGraphDBConfig +from memos.graph_dbs.bytehouse import ByteHouseGraphDB, _prepare_node_metadata + + +# ────────────────────────────────────────────────────────────────────────────── +# Help functions +# ────────────────────────────────────────────────────────────────────────────── + + +def generate_random_vector(dimension: int) -> list[float]: + """Generate a random vector of the given dimension.""" + return [random.uniform(0, 1) for _ in range(dimension)] + + +# ────────────────────────────────────────────────────────────────────────────── +# Setup and configuration +# ────────────────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def config(): + return ByteHouseGraphDBConfig( + host=os.getenv("BYTEHOUSE_HOST"), + port=int(os.getenv("BYTEHOUSE_PORT", "8123")), + user=os.getenv("BYTEHOUSE_USER"), + password=os.getenv("BYTEHOUSE_PASSWORD"), + db_name=os.getenv("BYTEHOUSE_DB_NAME", "shared_memos_db"), + use_multi_db=os.getenv("BYTEHOUSE_USE_MULTI_DB", "false").lower() == "true", + user_name=f"__test_{uuid.uuid4().hex[:8]}", + embedding_dimension=2048, + ) + + +@pytest.fixture +def graph_db(config): + return ByteHouseGraphDB(config) + + +# ----------------------------- +# Test Node +# ----------------------------- + + +def test_add_and_get_node(graph_db): + node_id = str(uuid.uuid4()) + memory = "test content add and get" + metadata = { + "memory_type": "WorkingMemory", + "embedding": generate_random_vector(2048), + "tags": ["test"], + } + + graph_db.add_node(node_id, memory, metadata, user_name="test_add") + + result = graph_db.get_node(node_id, include_embedding=True, user_name="test_add") + + assert result is not None + assert result["id"] == node_id + assert result["memory"] == memory + assert result["metadata"]["memory_type"] == "WorkingMemory" + assert "embedding" in result["metadata"] + assert len(result["metadata"]["embedding"]) == len(metadata["embedding"]) + + graph_db.clear(user_name="test_add") + + +def test_update_node(graph_db): + node_id = str(uuid.uuid4()) + memory = "test content update" + metadata = { + "memory_type": "WorkingMemory", + "embedding": generate_random_vector(2048), + "tags": ["test"], + } + + graph_db.add_node(node_id, memory, metadata, user_name="test_update") + graph_db.update_node(node_id, {"tags": ["updated"]}, user_name="test_update") + + result = graph_db.get_node(node_id, user_name="test_update") + assert result["metadata"]["tags"] == ["updated"] + + graph_db.clear(user_name="test_update") + + +def test_delete_node(graph_db): + node_id = str(uuid.uuid4()) + memory = "test content delete" + metadata = { + "memory_type": "WorkingMemory", + "embedding": generate_random_vector(2048), + "tags": ["test"], + } + + graph_db.add_node(node_id, memory, metadata, user_name="test_delete") + + result = graph_db.get_node(node_id, user_name="test_delete") + assert result is not None + + graph_db.delete_node(node_id, user_name="test_delete") + + result = graph_db.get_node(node_id, user_name="test_delete") + assert result is None + + +# ----------------------------- +# Test Edge +# ----------------------------- +def test_add_and_get_edge(graph_db): + source_id = str(uuid.uuid4()) + target_id = str(uuid.uuid4()) + edge_type = "RELATED" + + # Create source node + graph_db.add_node( + source_id, + "Source node", + {"memory_type": "WorkingMemory"}, + user_name="test_edge", + ) + # Create target node + graph_db.add_node( + target_id, + "Target node", + {"memory_type": "WorkingMemory"}, + user_name="test_edge", + ) + + # Add edge between nodes + graph_db.add_edge(source_id, target_id, edge_type, user_name="test_edge") + + # Verify edge exists + assert graph_db.edge_exists(source_id, target_id, edge_type, user_name="test_edge") + assert graph_db.edge_exists(source_id, target_id, "ANY", user_name="test_edge") + + # Clean up + graph_db.clear(user_name="test_edge") + + +def test_delete_edge(graph_db): + source_id = str(uuid.uuid4()) + target_id = str(uuid.uuid4()) + edge_type = "RELATED" + + # Create source node + graph_db.add_node( + source_id, + "Source node", + {"memory_type": "WorkingMemory"}, + user_name="test_delete_edge", + ) + # Create target node + graph_db.add_node( + target_id, + "Target node", + {"memory_type": "WorkingMemory"}, + user_name="test_delete_edge", + ) + + # Add edge + graph_db.add_edge(source_id, target_id, edge_type, user_name="test_delete_edge") + assert graph_db.edge_exists( + source_id, target_id, edge_type, user_name="test_delete_edge" + ) + + # Delete edge + graph_db.delete_edge(source_id, target_id, edge_type, user_name="test_delete_edge") + + # Verify edge doesn't exist + assert not graph_db.edge_exists( + source_id, target_id, edge_type, user_name="test_delete_edge" + ) + + # Clean up + graph_db.clear(user_name="test_delete_edge") + + +def test_get_neighbors(graph_db): + node1 = str(uuid.uuid4()) + node2 = str(uuid.uuid4()) + node3 = str(uuid.uuid4()) + edge_type = "RELATED" + + # Create nodes + graph_db.add_node( + node1, "Node 1", {"memory_type": "WorkingMemory"}, user_name="test_neighbors" + ) + graph_db.add_node( + node2, "Node 2", {"memory_type": "WorkingMemory"}, user_name="test_neighbors" + ) + graph_db.add_node( + node3, "Node 3", {"memory_type": "WorkingMemory"}, user_name="test_neighbors" + ) + + # Add edges: node1 -> node2, node3 -> node1 + graph_db.add_edge(node1, node2, edge_type, user_name="test_neighbors") + graph_db.add_edge(node3, node1, edge_type, user_name="test_neighbors") + + # Test direction 'out' - should return node2 + out_neighbors = graph_db.get_neighbors( + node1, edge_type, direction="out", user_name="test_neighbors" + ) + assert len(out_neighbors) == 1 + assert node2 in out_neighbors + + # Test direction 'in' - should return node3 + in_neighbors = graph_db.get_neighbors( + node1, edge_type, direction="in", user_name="test_neighbors" + ) + assert len(in_neighbors) == 1 + assert node3 in in_neighbors + + # Test direction 'both' - should return both node2 and node3 + both_neighbors = graph_db.get_neighbors( + node1, edge_type, direction="both", user_name="test_neighbors" + ) + assert len(both_neighbors) == 2 + assert node2 in both_neighbors + assert node3 in both_neighbors + + # Clean up + graph_db.clear(user_name="test_neighbors") + + +# ----------------------------- +# Test Search +# ----------------------------- +def test_search_by_embedding(graph_db): + node_id = str(uuid.uuid4()) + memory = "test content search by embedding" + metadata = { + "memory_type": "WorkingMemory", + "embedding": generate_random_vector(2048), + "tags": ["test"], + } + + graph_db.add_node(node_id, memory, metadata, user_name="test_search") + + result = graph_db.search_by_embedding( + generate_random_vector(2048), top_k=1, user_name="test_search" + ) + assert result is not None + assert len(result) == 1 + assert result[0]["id"] == node_id + assert result[0]["score"] > 0 + + graph_db.clear(user_name="test_search") + + +# -----------------------------# Test Memory Management#----------------------------- +def test_remove_oldest_memory(graph_db): + # Create multiple memories of the same type + memory_type = "WorkingMemory" + user_name = "test_remove_oldest" + + # Create 5 test memories + for i in range(5): + node_id = str(uuid.uuid4()) + memory = f"Test memory {i}" + metadata = { + "memory_type": memory_type, + "embedding": generate_random_vector(2048), + "tags": ["test"], + } + graph_db.add_node(node_id, memory, metadata, user_name=user_name) + + # Keep only the latest 2 memories + graph_db.remove_oldest_memory(memory_type, keep_latest=2, user_name=user_name) + + # Get all remaining memories + all_memories = graph_db.get_all_memory_items(memory_type, user_name=user_name) + + # Verify only 2 memories remain + assert len(all_memories) == 2 + + # Clean up + graph_db.clear(user_name=user_name) + + +def test_get_grouped_counts(graph_db): + user_name = "test_grouped_counts" + + # Create memories with different types and statuses + memory_types = [ + "WorkingMemory", + "LongTermMemory", + "WorkingMemory", + "LongTermMemory", + ] + statuses = ["active", "inactive", "inactive", "active"] + + for i, (memory_type, status) in enumerate(zip(memory_types, statuses)): + node_id = str(uuid.uuid4()) + memory = f"Test memory {i}" + metadata = { + "memory_type": memory_type, + "status": status, + "embedding": generate_random_vector(2048), + } + graph_db.add_node(node_id, memory, metadata, user_name=user_name) + + # Test grouping by memory_type + counts_by_type = graph_db.get_grouped_counts(["memory_type"], user_name=user_name) + assert len(counts_by_type) == 2 + + # Test grouping by both memory_type and status + counts_by_type_and_status = graph_db.get_grouped_counts( + ["memory_type", "status"], user_name=user_name + ) + assert len(counts_by_type_and_status) == 4 + + # Clean up + graph_db.clear(user_name=user_name)