Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/memos/chunkers/sentence_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def chunk(self, text: str) -> list[str] | list[Chunk]:
chunks.append(chunk)

logger.debug(f"Generated {len(chunks)} chunks from input text")

return chunks
177 changes: 88 additions & 89 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
from memos.utils import timed
from memos.utils import timed, timed_stage


if TYPE_CHECKING:
Expand Down Expand Up @@ -75,6 +75,30 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
direct_markdown_hostnames=direct_markdown_hostnames,
)

def _embed_memory_items(self, items: list[TextualMemoryItem]) -> None:
"""Compute embeddings for a list of memory items in-place.

Attempts a single batch call first; falls back to per-item calls if the
batch fails. Errors are logged but never raised so callers always
continue normally.
"""
valid = [w for w in items if w and w.memory]
if not valid:
return
texts = [w.memory for w in valid]
try:
embeddings = self.embedder.embed(texts)
for w, emb in zip(valid, embeddings, strict=True):
w.metadata.embedding = emb
except Exception as e:
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
logger.warning("[EMBED_FALLBACK] batch_size=%d", len(texts))
for w in valid:
try:
w.metadata.embedding = self.embedder.embed([w.memory])[0]
except Exception as e2:
logger.error(f"[MultiModalStruct] Error computing embedding for item: {e2}")

def _split_large_memory_item(
self, item: TextualMemoryItem, max_tokens: int
) -> list[TextualMemoryItem]:
Expand Down Expand Up @@ -203,13 +227,8 @@ def _concat_multi_modal_memories(
# If only one item after processing, compute embedding and return
if len(processed_items) == 1:
single_item = processed_items[0]
if single_item and single_item.memory:
try:
single_item.metadata.embedding = self.embedder.embed([single_item.memory])[0]
except Exception as e:
logger.error(
f"[MultiModalStruct] Error computing embedding for single item: {e}"
)
with timed_stage("add", "embedding", window_count=1):
self._embed_memory_items([single_item])
return processed_items

windows = []
Expand Down Expand Up @@ -260,31 +279,8 @@ def _concat_multi_modal_memories(
windows.append(window)

# Batch compute embeddings for all windows
if windows:
# Collect all valid windows that need embedding
valid_windows = [w for w in windows if w and w.memory]

if valid_windows:
# Collect all texts that need embedding
texts_to_embed = [w.memory for w in valid_windows]

# Batch compute all embeddings at once
try:
embeddings = self.embedder.embed(texts_to_embed)
# Fill embeddings back into memory items
for window, embedding in zip(valid_windows, embeddings, strict=True):
window.metadata.embedding = embedding
except Exception as e:
logger.error(f"[MultiModalStruct] Error batch computing embeddings: {e}")
# Fallback: compute embeddings individually
for window in valid_windows:
if window.memory:
try:
window.metadata.embedding = self.embedder.embed([window.memory])[0]
except Exception as e2:
logger.error(
f"[MultiModalStruct] Error computing embedding for item: {e2}"
)
with timed_stage("add", "embedding", window_count=len(windows)):
self._embed_memory_items(windows)

return windows

Expand Down Expand Up @@ -984,49 +980,49 @@ def _process_multi_modal_data(
# must pop here, avoid add to info, only used in sync fine mode
custom_tags = info.pop("custom_tags", None) if isinstance(info, dict) else None

# Use MultiModalParser to parse the scene data
# If it's a list, parse each item; otherwise parse as single message
if isinstance(scene_data_info, list):
# Pre-expand multimodal messages
expanded_messages = self._expand_multimodal_messages(scene_data_info)

# Parse each message in the list
all_memory_items = []
# Use thread pool to parse each message in parallel, but keep the original order
with ContextThreadPoolExecutor(max_workers=30) as executor:
# submit tasks and keep the original order
futures = [
executor.submit(
self.multi_modal_parser.parse,
msg,
info,
mode="fast",
need_emb=False,
**kwargs,
)
for msg in expanded_messages
]
# collect results in original order
for future in futures:
try:
items = future.result()
all_memory_items.extend(items)
except Exception as e:
logger.error(f"[MultiModalFine] Error in parallel parsing: {e}")
else:
# Parse as single message
all_memory_items = self.multi_modal_parser.parse(
scene_data_info, info, mode="fast", need_emb=False, **kwargs
)
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
# Stage: parse — parallel message parsing + sliding-window aggregation
with timed_stage("add", "parse") as ts_parse:
if isinstance(scene_data_info, list):
expanded_messages = self._expand_multimodal_messages(scene_data_info)
ts_parse.set(msg_count=len(expanded_messages))

all_memory_items = []
with ContextThreadPoolExecutor(max_workers=30) as executor:
futures = [
executor.submit(
self.multi_modal_parser.parse,
msg,
info,
mode="fast",
need_emb=False,
**kwargs,
)
for msg in expanded_messages
]
for future in futures:
try:
items = future.result()
all_memory_items.extend(items)
except Exception as e:
logger.error(f"[MultiModalFine] Error in parallel parsing: {e}")
else:
ts_parse.set(msg_count=1)
all_memory_items = self.multi_modal_parser.parse(
scene_data_info, info, mode="fast", need_emb=False, **kwargs
)

fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
ts_parse.set(window_count=len(fast_memory_items))

if mode == "fast":
return fast_memory_items
else:
non_file_url_fast_items = [
item for item in fast_memory_items if not self._is_file_url_only_item(item)
]

# Part A: call llm in parallel using thread pool
# Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial
non_file_url_fast_items = [
item for item in fast_memory_items if not self._is_file_url_only_item(item)
]

with timed_stage("add", "llm_extract") as ts_llm:
fine_memory_items = []

with ContextThreadPoolExecutor(max_workers=4) as executor:
Expand Down Expand Up @@ -1057,7 +1053,6 @@ def _process_multi_modal_data(
**kwargs,
)

# Collect results
fine_memory_items_string_parser = future_string.result()
fine_memory_items_tool_trajectory_parser = future_tool.result()
fine_memory_items_skill_memory_parser = future_skill.result()
Expand All @@ -1068,21 +1063,25 @@ def _process_multi_modal_data(
fine_memory_items.extend(fine_memory_items_skill_memory_parser)
fine_memory_items.extend(fine_memory_items_pref_parser)

# Part B: get fine multimodal items
for fast_item in fast_memory_items:
sources = fast_item.metadata.sources
for source in sources:
lang = getattr(source, "lang", "en")
items = self.multi_modal_parser.process_transfer(
source,
context_items=[fast_item],
custom_tags=custom_tags,
info=info,
lang=lang,
user_context=kwargs.get("user_context"),
)
fine_memory_items.extend(items)
return fine_memory_items
# Part B: per-source serial processing
with timed_stage("add", "per_source") as ts_ps:
for fast_item in fast_memory_items:
sources = fast_item.metadata.sources
for source in sources:
lang = getattr(source, "lang", "en")
items = self.multi_modal_parser.process_transfer(
source,
context_items=[fast_item],
custom_tags=custom_tags,
info=info,
lang=lang,
user_context=kwargs.get("user_context"),
)
fine_memory_items.extend(items)

ts_llm.set(fine_memory_count=len(fine_memory_items), per_source_ms=ts_ps.duration_ms)

return fine_memory_items

@timed
def _process_transfer_multi_modal_data(
Expand Down
20 changes: 13 additions & 7 deletions src/memos/multi_mem_cube/composite_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from memos.context.context import ContextThreadPoolExecutor
from memos.multi_mem_cube.views import MemCubeView
from memos.utils import timed_stage


if TYPE_CHECKING:
Expand All @@ -27,13 +28,18 @@ class CompositeCubeView(MemCubeView):

def add_memories(self, add_req: APIADDRequest) -> list[dict[str, Any]]:
all_results: list[dict[str, Any]] = []

# fast mode: for each cube view, add memories
# maybe add more strategies in add_req.async_mode
for view in self.cube_views:
self.logger.info(f"[CompositeCubeView] fan-out add to cube={view.cube_id}")
results = view.add_memories(add_req)
all_results.extend(results)
cube_count = len(self.cube_views)

with timed_stage("add", "multi_cube", cube_count=cube_count):
for idx, view in enumerate(self.cube_views):
self.logger.info(
"[CompositeCubeView] fan-out add to cube=%s (%d/%d)",
view.cube_id,
idx + 1,
cube_count,
)
results = view.add_memories(add_req)
all_results.extend(results)

return all_results

Expand Down
Loading
Loading