diff --git a/src/memos/api/config.py b/src/memos/api/config.py index c68deae5a..01f832079 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -872,7 +872,7 @@ def get_scheduler_config() -> dict[str, Any]: ), "context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")), "thread_pool_max_workers": int( - os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "200") + os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "50") ), "consume_interval_seconds": float( os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index 690c8d123..02cd59e8c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -548,6 +548,9 @@ def stats(self) -> dict[str, int]: running = 0 try: with self._task_lock: + done = {f for f in self._futures if f.done()} + if done: + self._futures -= done inflight = len(self._futures) except Exception: inflight = 0 diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index 1277c5465..a23e33c55 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -93,6 +93,11 @@ def __init__( self.task_broker_flush_bar = 10 self._refill_lock = threading.Lock() self._refill_thread: ContextThread | None = None + self._refill_in_progress = False + self._refill_thread_start: float = 0.0 + self._refill_thread_timeout: float = float( + os.getenv("MEMSCHEDULER_REDIS_REFILL_TIMEOUT_SEC", "30") or 30 + ) # Track empty streams first-seen time to avoid zombie keys self._empty_stream_seen_times: dict[str, float] = {} @@ -110,8 +115,11 @@ def __init__( self.seen_streams = set() - # Task Orchestrator - self.message_pack_cache = deque() + # Task Orchestrator — cap in-memory cache to avoid unbounded growth + self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50) + self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque( + maxlen=self._cache_max_packs + ) self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator @@ -349,23 +357,51 @@ def task_broker( def _async_refill_cache(self, batch_size: int) -> None: """Background thread to refill message cache without blocking get_messages.""" try: - logger.debug(f"Starting async cache refill with batch_size={batch_size}") + with self._refill_lock: + remaining = self._cache_max_packs - len(self.message_pack_cache) + if remaining <= 0: + logger.debug("Async refill skipped: cache already at capacity") + return + self._refill_in_progress = True + + logger.debug( + f"Starting async cache refill with batch_size={batch_size}, remaining_capacity={remaining}" + ) new_packs = self.task_broker(consume_batch_size=batch_size) - logger.debug(f"task_broker returned {len(new_packs)} packs") + with self._refill_lock: + added = 0 for pack in new_packs: - if pack: # Only add non-empty packs + if pack: self.message_pack_cache.append(pack) - logger.debug(f"Added pack with {len(pack)} messages to cache") - logger.debug(f"Cache refill complete, cache size now: {len(self.message_pack_cache)}") + added += 1 + if added >= remaining: + break + logger.debug( + f"Cache refill complete, added={added}, cache size now: {len(self.message_pack_cache)}" + ) except Exception as e: logger.warning(f"Async cache refill failed: {e}", exc_info=True) + finally: + with self._refill_lock: + self._refill_in_progress = False + + def _is_refill_thread_available(self) -> bool: + """Check whether a new refill thread can be started.""" + if self._refill_thread is None or not self._refill_thread.is_alive(): + return True + if (time.time() - self._refill_thread_start) > self._refill_thread_timeout: + logger.warning( + f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale" + ) + return True + return False def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: if self.message_pack_cache: - # Trigger async refill if below threshold (non-blocking) - if len(self.message_pack_cache) < self.task_broker_flush_bar and ( - self._refill_thread is None or not self._refill_thread.is_alive() + if ( + len(self.message_pack_cache) < self.task_broker_flush_bar + and self._is_refill_thread_available() ): logger.debug( f"Triggering async cache refill: cache size {len(self.message_pack_cache)} < {self.task_broker_flush_bar}" @@ -373,14 +409,26 @@ def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: self._refill_thread = ContextThread( target=self._async_refill_cache, args=(batch_size,), name="redis-cache-refill" ) + self._refill_thread_start = time.time() self._refill_thread.start() else: logger.debug(f"The size of message_pack_cache is {len(self.message_pack_cache)}") else: - new_packs = self.task_broker(consume_batch_size=batch_size) - for pack in new_packs: - if pack: # Only add non-empty packs - self.message_pack_cache.append(pack) + should_fetch = False + with self._refill_lock: + if not self.message_pack_cache and not self._refill_in_progress: + self._refill_in_progress = True + should_fetch = True + if should_fetch: + try: + new_packs = self.task_broker(consume_batch_size=batch_size) + with self._refill_lock: + for pack in new_packs: + if pack: + self.message_pack_cache.append(pack) + finally: + with self._refill_lock: + self._refill_in_progress = False if len(self.message_pack_cache) == 0: return [] else: @@ -443,12 +491,17 @@ def put( with self._stream_keys_lock: if stream_key not in self.seen_streams: self.seen_streams.add(stream_key) - self._ensure_consumer_group(stream_key=stream_key) + need_create_group = True + else: + need_create_group = False if stream_key not in self._stream_keys_cache: self._stream_keys_cache.append(stream_key) self._stream_keys_last_refresh = time.time() + if need_create_group: + self._ensure_consumer_group(stream_key=stream_key) + message.stream_key = stream_key # Convert message to dictionary for Redis storage @@ -1054,14 +1107,9 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]: with self._stream_keys_lock: cache_snapshot = list(self._stream_keys_cache) - # Validate that cached keys conform to the expected prefix - escaped_prefix = re.escape(effective_prefix) - regex_pattern = f"^{escaped_prefix}:" - for key in cache_snapshot: - if not re.match(regex_pattern, key): - logger.error( - f"[REDIS_QUEUE] Cached stream key '{key}' does not match prefix '{effective_prefix}:'" - ) + if effective_prefix != self.stream_key_prefix: + pattern = re.compile(f"^{re.escape(effective_prefix)}:") + cache_snapshot = [k for k in cache_snapshot if pattern.match(k)] return cache_snapshot @@ -1211,7 +1259,7 @@ def __del__(self): @property def unfinished_tasks(self) -> int: - return self.qsize() + return self.size() def _scan_candidate_stream_keys( self, @@ -1396,6 +1444,23 @@ def _update_stream_cache_with_log( self._stream_keys_cache = active_stream_keys self._stream_keys_last_refresh = time.time() cache_count = len(self._stream_keys_cache) + + active_set = set(active_stream_keys) + stale = self.seen_streams - active_set + if stale: + self.seen_streams -= stale + logger.debug(f"Pruned {len(stale)} stale entries from seen_streams") + + candidate_set = set(candidate_keys) + with self._empty_stream_seen_lock: + orphaned = [k for k in self._empty_stream_seen_times if k not in candidate_set] + for k in orphaned: + del self._empty_stream_seen_times[k] + if orphaned: + logger.debug( + f"Pruned {len(orphaned)} orphaned entries from _empty_stream_seen_times" + ) + logger.debug( f"Refreshed stream keys cache: {cache_count} active keys, " f"{deleted_count} deleted, {len(candidate_keys)} candidates examined."