diff --git a/src/memos/configs/chunker.py b/src/memos/configs/chunker.py index f9a738415..0c896f65d 100644 --- a/src/memos/configs/chunker.py +++ b/src/memos/configs/chunker.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -38,7 +38,9 @@ class ChunkerConfigFactory(BaseConfig): """Factory class for creating chunker configurations.""" backend: str = Field(..., description="Backend for chunker") - config: dict[str, Any] = Field(..., description="Configuration for the chunker backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the chunker backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "sentence": SentenceChunkerConfig, @@ -56,5 +58,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "ChunkerConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/embedder.py b/src/memos/configs/embedder.py index 050043ab0..070776431 100644 --- a/src/memos/configs/embedder.py +++ b/src/memos/configs/embedder.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -81,7 +81,9 @@ class EmbedderConfigFactory(BaseConfig): """Factory class for creating embedder configurations.""" backend: str = Field(..., description="Backend for embedding model") - config: dict[str, Any] = Field(..., description="Configuration for the embedding model backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the embedding model backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "ollama": OllamaEmbedderConfig, @@ -101,5 +103,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "EmbedderConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 98de09812..0791571c4 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.configs.vec_db import VectorDBConfigFactory @@ -243,7 +243,9 @@ def validate_config(self): class GraphDBConfigFactory(BaseModel): backend: str = Field(..., description="Backend for graph database") - config: dict[str, Any] = Field(..., description="Configuration for the graph database backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the graph database backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDBConfig, @@ -262,5 +264,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def instantiate_config(self): config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/internet_retriever.py b/src/memos/configs/internet_retriever.py index 562cfdd1f..bdc82d3af 100644 --- a/src/memos/configs/internet_retriever.py +++ b/src/memos/configs/internet_retriever.py @@ -2,7 +2,7 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.exceptions import ConfigurationError @@ -86,7 +86,7 @@ class InternetRetrieverConfigFactory(BaseConfig): backend: str | None = Field( None, description="Backend for internet retriever (google, bing, etc.)" ) - config: dict[str, Any] | None = Field( + config: SerializeAsAny[BaseConfig | dict[str, Any] | None] = Field( None, description="Configuration for the internet retriever backend" ) @@ -108,7 +108,13 @@ def validate_backend(cls, backend: str | None) -> str | None: @model_validator(mode="after") def create_config(self) -> "InternetRetrieverConfigFactory": - if self.backend is not None: + if self.backend is not None and isinstance(self.config, dict): config_class = self.backend_to_class[self.backend] self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 81f7038fa..f8b418734 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -144,7 +144,9 @@ class LLMConfigFactory(BaseConfig): """Factory class for creating LLM configurations.""" backend: str = Field(..., description="Backend for LLM") - config: dict[str, Any] = Field(..., description="Configuration for the LLM backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the LLM backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "openai": OpenAILLMConfig, @@ -170,5 +172,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "LLMConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index d4844d73f..045b2f97f 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, ClassVar -from pydantic import ConfigDict, Field, field_validator, model_validator +from pydantic import ConfigDict, Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig from memos.configs.chunker import ChunkerConfigFactory @@ -88,7 +88,9 @@ class MemReaderConfigFactory(BaseConfig): """Factory class for creating MemReader configurations.""" backend: str = Field(..., description="Backend for MemReader") - config: dict[str, Any] = Field(..., description="Configuration for the MemReader backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the MemReader backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, @@ -107,5 +109,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "MemReaderConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/parser.py b/src/memos/configs/parser.py index 22c1fe1b5..a52434a13 100644 --- a/src/memos/configs/parser.py +++ b/src/memos/configs/parser.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos.configs.base import BaseConfig @@ -17,7 +17,9 @@ class ParserConfigFactory(BaseConfig): """Factory class for creating Parser configurations.""" backend: str = Field(..., description="Backend for parser") - config: dict[str, Any] = Field(..., description="Configuration for the parser backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the parser backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "markitdown": MarkItDownParserConfig, @@ -34,5 +36,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "ParserConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value diff --git a/src/memos/configs/vec_db.py b/src/memos/configs/vec_db.py index 9fdb83a35..433279117 100644 --- a/src/memos/configs/vec_db.py +++ b/src/memos/configs/vec_db.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Literal -from pydantic import Field, field_validator, model_validator +from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator from memos import settings from memos.configs.base import BaseConfig @@ -58,7 +58,9 @@ class VectorDBConfigFactory(BaseConfig): """Factory class for creating vector database configurations.""" backend: str = Field(..., description="Backend for vector database") - config: dict[str, Any] = Field(..., description="Configuration for the vector database backend") + config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field( + ..., description="Configuration for the vector database backend" + ) backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDBConfig, @@ -76,5 +78,12 @@ def validate_backend(cls, backend: str) -> str: @model_validator(mode="after") def create_config(self) -> "VectorDBConfigFactory": config_class = self.backend_to_class[self.backend] - self.config = config_class(**self.config) + if isinstance(self.config, dict): + self.config = config_class(**self.config) return self + + @field_serializer("config", mode="plain") + def serialize_config(self, value): + if isinstance(value, BaseConfig): + return value.model_dump(mode="python") + return value