From 788e51f53a1812e42ec5cc4c845894bce24ba2ab Mon Sep 17 00:00:00 2001 From: root Date: Sat, 18 Apr 2026 16:49:30 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E5=91=8A=E8=AD=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memos/configs/chunker.py | 15 ++++++++++++--- src/memos/configs/embedder.py | 15 ++++++++++++--- src/memos/configs/graph_db.py | 15 ++++++++++++--- src/memos/configs/internet_retriever.py | 12 +++++++++--- src/memos/configs/llm.py | 15 ++++++++++++--- src/memos/configs/mem_reader.py | 15 ++++++++++++--- src/memos/configs/parser.py | 15 ++++++++++++--- src/memos/configs/vec_db.py | 15 ++++++++++++--- 8 files changed, 93 insertions(+), 24 deletions(-) diff --git a/src/memos/configs/chunker.py b/src/memos/configs/chunker.py index f9a738415..0c896f65d 100644 --- a/src/memos/configs/chunker.py +++ b/src/memos/configs/chunker.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -38,7 +38,9 @@ class ChunkerConfigFactory(BaseConfig): """Factory class for creating chunker configurations.""" backend: str = Field(..., description="Backend for chunker") - config: dict[str, Any] = Field(..., description="Configuration for the chunker backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the chunker backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "sentence": SentenceChunkerConfig, @@ -56,5 +58,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "ChunkerConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 050043ab0..070776431 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -81,7 +81,9 @@ class EmbedderConfigFactory(BaseConfig): """Factory class for creating embedder configurations.""" backend: str = Field(..., description="Backend for embedding model") - config: dict[str, Any] = Field(..., description="Configuration for the embedding model backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the embedding model backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "ollama": OllamaEmbedderConfig, @@ -101,5 +103,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "EmbedderConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 9b1ce7f9d..6722c9f20 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.configs.vec_db import VectorDBConfigFactory @@ -267,7 +267,9 @@ def validate_config(self): class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") - config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the graph database backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDBConfig, @@ -287,5 +289,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def instantiate_config(self): config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/internet_retriever.py b/src/memos/configs/internet_retriever.py index 1c5e2b8ad..717a7ebaa 100644 --- a/src/memos/configs/internet_retriever.py +++ b/src/memos/configs/internet_retriever.py @@ -2,7 +2,7 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.exceptions import ConfigurationError @@ -73,7 +73,7 @@ class InternetRetrieverConfigFactory(BaseConfig): backend: str | None = Field( None, description="Backend for internet retriever (google, bing, etc.)" ) - config: dict[str, Any] | None = Field( + config: SerializeAsAny[BaseConfig | dict[str, Any] | None] = Field( None, description="Configuration for the internet retriever backend" ) @@ -94,7 +94,13 @@ def validate_backend(cls, backend: str | None) -> str | None: @model_validator(mode="after") def create_config(self) -> "InternetRetrieverConfigFactory": - if self.backend is not None: + if self.backend is not None and isinstance(self.config, dict): config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 5487d117c..fc9c4ee2b 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -123,7 +123,9 @@ class LLMConfigFactory(BaseConfig): """Factory class for creating LLM configurations.""" backend: str = Field(..., description="Backend for LLM") - config: dict[str, Any] = Field(..., description="Configuration for the LLM backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the LLM backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "openai": OpenAILLMConfig, @@ -148,5 +150,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "LLMConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index 4bd7953c0..ab53700ff 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, ClassVar -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import ConfigDict, Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.configs.chunker import ChunkerConfigFactory @@ -77,7 +77,9 @@ class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" backend: str = Field(..., description="Backend for MemReader") - config: dict[str, Any] = Field(..., description="Configuration for the MemReader backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the MemReader backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, @@ -96,5 +98,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "MemReaderConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/parser.py b/src/memos/configs/parser.py index 22c1fe1b5..a52434a13 100644 --- a/src/memos/configs/parser.py +++ b/src/memos/configs/parser.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -17,7 +17,9 @@ class ParserConfigFactory(BaseConfig): """Factory class for creating Parser configurations.""" backend: str = Field(..., description="Backend for parser") - config: dict[str, Any] = Field(..., description="Configuration for the parser backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the parser backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "markitdown": MarkItDownParserConfig, @@ -34,5 +36,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "ParserConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index 9fdb83a35..433279117 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Literal -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos import settings from memos.configs.base import BaseConfig @@ -58,7 +58,9 @@ class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" backend: str = Field(..., description="Backend for vector database") - config: dict[str, Any] = Field(..., description="Configuration for the vector database backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the vector database backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDBConfig, @@ -76,5 +78,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "VectorDBConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value