Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/memos/configs/chunker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig

Expand Down Expand Up @@ -38,7 +38,9 @@ class ChunkerConfigFactory(BaseConfig):
"""Factory class for creating chunker configurations."""

backend: str = Field(..., description="Backend for chunker")
config: dict[str, Any] = Field(..., description="Configuration for the chunker backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the chunker backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"sentence": SentenceChunkerConfig,
Expand All @@ -56,5 +58,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "ChunkerConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/embedder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig

Expand Down Expand Up @@ -81,7 +81,9 @@ class EmbedderConfigFactory(BaseConfig):
"""Factory class for creating embedder configurations."""

backend: str = Field(..., description="Backend for embedding model")
config: dict[str, Any] = Field(..., description="Configuration for the embedding model backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the embedding model backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"ollama": OllamaEmbedderConfig,
Expand All @@ -101,5 +103,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "EmbedderConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/graph_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig
from memos.configs.vec_db import VectorDBConfigFactory
Expand Down Expand Up @@ -243,7 +243,9 @@ def validate_config(self):

class GraphDBConfigFactory(BaseModel):
backend: str = Field(..., description="Backend for graph database")
config: dict[str, Any] = Field(..., description="Configuration for the graph database backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the graph database backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDBConfig,
Expand All @@ -262,5 +264,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def instantiate_config(self):
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
12 changes: 9 additions & 3 deletions src/memos/configs/internet_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, ClassVar

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig
from memos.exceptions import ConfigurationError
Expand Down Expand Up @@ -86,7 +86,7 @@ class InternetRetrieverConfigFactory(BaseConfig):
backend: str | None = Field(
None, description="Backend for internet retriever (google, bing, etc.)"
)
config: dict[str, Any] | None = Field(
config: SerializeAsAny[BaseConfig | dict[str, Any] | None] = Field(
None, description="Configuration for the internet retriever backend"
)

Expand All @@ -108,7 +108,13 @@ def validate_backend(cls, backend: str | None) -> str | None:

@model_validator(mode="after")
def create_config(self) -> "InternetRetrieverConfigFactory":
if self.backend is not None:
if self.backend is not None and isinstance(self.config, dict):
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig

Expand Down Expand Up @@ -144,7 +144,9 @@ class LLMConfigFactory(BaseConfig):
"""Factory class for creating LLM configurations."""

backend: str = Field(..., description="Backend for LLM")
config: dict[str, Any] = Field(..., description="Configuration for the LLM backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the LLM backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"openai": OpenAILLMConfig,
Expand All @@ -170,5 +172,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "LLMConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/mem_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, ClassVar

from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic import ConfigDict, Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig
from memos.configs.chunker import ChunkerConfigFactory
Expand Down Expand Up @@ -88,7 +88,9 @@ class MemReaderConfigFactory(BaseConfig):
"""Factory class for creating MemReader configurations."""

backend: str = Field(..., description="Backend for MemReader")
config: dict[str, Any] = Field(..., description="Configuration for the MemReader backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the MemReader backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"simple_struct": SimpleStructMemReaderConfig,
Expand All @@ -107,5 +109,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "MemReaderConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos.configs.base import BaseConfig

Expand All @@ -17,7 +17,9 @@ class ParserConfigFactory(BaseConfig):
"""Factory class for creating Parser configurations."""

backend: str = Field(..., description="Backend for parser")
config: dict[str, Any] = Field(..., description="Configuration for the parser backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the parser backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"markitdown": MarkItDownParserConfig,
Expand All @@ -34,5 +36,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "ParserConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value
15 changes: 12 additions & 3 deletions src/memos/configs/vec_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, ClassVar, Literal

from pydantic import Field, field_validator, model_validator
from pydantic import Field, SerializeAsAny, field_serializer, field_validator, model_validator

from memos import settings
from memos.configs.base import BaseConfig
Expand Down Expand Up @@ -58,7 +58,9 @@ class VectorDBConfigFactory(BaseConfig):
"""Factory class for creating vector database configurations."""

backend: str = Field(..., description="Backend for vector database")
config: dict[str, Any] = Field(..., description="Configuration for the vector database backend")
config: SerializeAsAny[BaseConfig | dict[str, Any]] = Field(
..., description="Configuration for the vector database backend"
)

backend_to_class: ClassVar[dict[str, Any]] = {
"qdrant": QdrantVecDBConfig,
Expand All @@ -76,5 +78,12 @@ def validate_backend(cls, backend: str) -> str:
@model_validator(mode="after")
def create_config(self) -> "VectorDBConfigFactory":
config_class = self.backend_to_class[self.backend]
self.config = config_class(**self.config)
if isinstance(self.config, dict):
self.config = config_class(**self.config)
return self

@field_serializer("config", mode="plain")
def serialize_config(self, value):
if isinstance(value, BaseConfig):
return value.model_dump(mode="python")
return value