diff --git a/src/conductor/client/automator/task_handler.py b/src/conductor/client/automator/task_handler.py index 342ec3211..08ef6961b 100644 --- a/src/conductor/client/automator/task_handler.py +++ b/src/conductor/client/automator/task_handler.py @@ -268,6 +268,8 @@ def __init__( self._monitor_thread: Optional[threading.Thread] = None self._restart_counts: List[int] = [0 for _ in self.workers] self._next_restart_at: List[float] = [0.0 for _ in self.workers] + # Lock to protect process list during concurrent access (monitor thread vs main thread) + self._process_lock = threading.Lock() logger.info("TaskHandler initialized") def __enter__(self): @@ -280,8 +282,10 @@ def stop_processes(self) -> None: self._monitor_stop_event.set() if self._monitor_thread is not None and self._monitor_thread.is_alive(): self._monitor_thread.join(timeout=2.0) - self.__stop_task_runner_processes() - self.__stop_metrics_provider_process() + # Lock to prevent race conditions with monitor thread + with self._process_lock: + self.__stop_task_runner_processes() + self.__stop_metrics_provider_process() logger.info("Stopped worker processes...") self.queue.put(None) self.logger_process.terminate() @@ -381,20 +385,22 @@ def __monitor_loop(self) -> None: def __check_and_restart_processes(self) -> None: if self._monitor_stop_event.is_set(): return - for i, process in enumerate(list(self.task_runner_processes)): - if process is None: - continue - if process.is_alive(): - continue - exitcode = process.exitcode - if exitcode is None: - continue - worker = self.workers[i] if i < len(self.workers) else None - worker_name = worker.get_task_definition_name() if worker is not None else f"worker[{i}]" - logger.warning("Worker process exited (worker=%s, pid=%s, exitcode=%s)", worker_name, process.pid, exitcode) - if not self.restart_on_failure: - continue - self.__restart_worker_process(i) + # Lock to prevent race conditions with stop_processes + with self._process_lock: + for i, process in enumerate(list(self.task_runner_processes)): + if process is None: + continue + if process.is_alive(): + continue + exitcode = process.exitcode + if exitcode is None: + continue + worker = self.workers[i] if i < len(self.workers) else None + worker_name = worker.get_task_definition_name() if worker is not None else f"worker[{i}]" + logger.warning("Worker process exited (worker=%s, pid=%s, exitcode=%s)", worker_name, process.pid, exitcode) + if not self.restart_on_failure: + continue + self.__restart_worker_process(i) def __restart_worker_process(self, index: int) -> None: if self._monitor_stop_event.is_set(): @@ -522,7 +528,6 @@ def __start_task_runner_processes(self): n = 0 for i, task_runner_process in enumerate(self.task_runner_processes): task_runner_process.start() - print(f'task runner process {task_runner_process.name} started') worker = self.workers[i] paused_status = "PAUSED" if getattr(worker, "paused", False) else "ACTIVE" logger.debug("Started worker '%s' [%s]", worker.get_task_definition_name(), paused_status) diff --git a/src/conductor/client/telemetry/metrics_collector.py b/src/conductor/client/telemetry/metrics_collector.py index 93677edb4..7e3b6c579 100644 --- a/src/conductor/client/telemetry/metrics_collector.py +++ b/src/conductor/client/telemetry/metrics_collector.py @@ -1,5 +1,6 @@ import logging import os +import threading import time from collections import deque from typing import Any, ClassVar, Dict, List, Tuple @@ -84,6 +85,11 @@ class MetricsCollector: dispatcher.register(PollStarted, metrics.on_poll_started) dispatcher.publish(PollStarted(...)) + Thread Safety: + This class is thread-safe. Internal dictionaries (counters, gauges, histograms, etc.) + are protected by a lock to prevent race conditions when accessed from multiple threads + (e.g., worker threads and monitor threads). + Note: Uses Python's Protocol for structural subtyping rather than explicit inheritance to avoid circular imports and maintain backward compatibility. """ @@ -99,6 +105,7 @@ def __init__(self, settings: MetricsSettings): self.quantile_data: Dict[str, deque] = {} # metric_name+labels -> deque of values self.registry = None self.must_collect_metrics = False + self._lock = threading.RLock() # Reentrant lock for thread-safe access to internal dictionaries if settings is None: return @@ -504,12 +511,13 @@ def __increment_counter( ) -> None: if not self.must_collect_metrics: return - counter = self.__get_counter( - name=name, - documentation=documentation, - labelnames=[label.value for label in labels.keys()] - ) - counter.labels(*labels.values()).inc() + with self._lock: + counter = self.__get_counter( + name=name, + documentation=documentation, + labelnames=[label.value for label in labels.keys()] + ) + counter.labels(*labels.values()).inc() def __record_gauge( self, @@ -520,12 +528,13 @@ def __record_gauge( ) -> None: if not self.must_collect_metrics: return - gauge = self.__get_gauge( - name=name, - documentation=documentation, - labelnames=[label.value for label in labels.keys()] - ) - gauge.labels(*labels.values()).set(value) + with self._lock: + gauge = self.__get_gauge( + name=name, + documentation=documentation, + labelnames=[label.value for label in labels.keys()] + ) + gauge.labels(*labels.values()).set(value) def __get_counter( self, @@ -587,12 +596,13 @@ def __observe_histogram( ) -> None: if not self.must_collect_metrics: return - histogram = self.__get_histogram( - name=name, - documentation=documentation, - labelnames=[label.value for label in labels.keys()] - ) - histogram.labels(*labels.values()).observe(value) + with self._lock: + histogram = self.__get_histogram( + name=name, + documentation=documentation, + labelnames=[label.value for label in labels.keys()] + ) + histogram.labels(*labels.values()).observe(value) def __get_histogram( self, @@ -630,12 +640,13 @@ def __observe_summary( ) -> None: if not self.must_collect_metrics: return - summary = self.__get_summary( - name=name, - documentation=documentation, - labelnames=[label.value for label in labels.keys()] - ) - summary.labels(*labels.values()).observe(value) + with self._lock: + summary = self.__get_summary( + name=name, + documentation=documentation, + labelnames=[label.value for label in labels.keys()] + ) + summary.labels(*labels.values()).observe(value) def __get_summary( self, @@ -681,45 +692,46 @@ def __record_quantiles( if not self.must_collect_metrics: return - # Create a key for this metric+labels combination - label_values = tuple(labels.values()) - data_key = f"{name}_{label_values}" - - # Initialize data window if needed - if data_key not in self.quantile_data: - self.quantile_data[data_key] = deque(maxlen=self.QUANTILE_WINDOW_SIZE) - - # Add new observation - self.quantile_data[data_key].append(value) - - # Calculate and update quantiles - observations = sorted(self.quantile_data[data_key]) - n = len(observations) + with self._lock: + # Create a key for this metric+labels combination + label_values = tuple(labels.values()) + data_key = f"{name}_{label_values}" + + # Initialize data window if needed + if data_key not in self.quantile_data: + self.quantile_data[data_key] = deque(maxlen=self.QUANTILE_WINDOW_SIZE) + + # Add new observation + self.quantile_data[data_key].append(value) + + # Calculate and update quantiles + observations = sorted(self.quantile_data[data_key]) + n = len(observations) + + if n > 0: + quantiles = [0.5, 0.75, 0.9, 0.95, 0.99] + for q in quantiles: + quantile_value = self.__calculate_quantile(observations, q) + + # Get or create gauge for this quantile + gauge = self.__get_quantile_gauge( + name=name, + documentation=documentation, + labelnames=[label.value for label in labels.keys()] + ["quantile"], + quantile=q + ) - if n > 0: - quantiles = [0.5, 0.75, 0.9, 0.95, 0.99] - for q in quantiles: - quantile_value = self.__calculate_quantile(observations, q) + # Set gauge value with labels + quantile + gauge.labels(*labels.values(), str(q)).set(quantile_value) - # Get or create gauge for this quantile - gauge = self.__get_quantile_gauge( + # Also publish _count and _sum for proper summary metrics + self.__update_summary_aggregates( name=name, documentation=documentation, - labelnames=[label.value for label in labels.keys()] + ["quantile"], - quantile=q + labels=labels, + observations=list(self.quantile_data[data_key]) ) - # Set gauge value with labels + quantile - gauge.labels(*labels.values(), str(q)).set(quantile_value) - - # Also publish _count and _sum for proper summary metrics - self.__update_summary_aggregates( - name=name, - documentation=documentation, - labels=labels, - observations=list(self.quantile_data[data_key]) - ) - def __calculate_quantile(self, sorted_values: List[float], quantile: float) -> float: """Calculate quantile from sorted list of values.""" if not sorted_values: @@ -770,6 +782,9 @@ def __update_summary_aggregates( """ Update _count and _sum gauges for proper summary metric format. This makes the metrics compatible with Prometheus summary type. + + Note: This method should only be called while holding self._lock + (called from __record_quantiles which already holds the lock). """ if not observations: return diff --git a/tests/unit/telemetry/test_metrics_collector_thread_safety.py b/tests/unit/telemetry/test_metrics_collector_thread_safety.py new file mode 100644 index 000000000..7fa61c5aa --- /dev/null +++ b/tests/unit/telemetry/test_metrics_collector_thread_safety.py @@ -0,0 +1,125 @@ +import os +import tempfile +import threading +import time +import unittest +from pathlib import Path + +from conductor.client.configuration.settings.metrics_settings import MetricsSettings +from conductor.client.telemetry.metrics_collector import MetricsCollector + + +class TestMetricsCollectorThreadSafety(unittest.TestCase): + """Test thread safety of MetricsCollector.""" + + def setUp(self): + """Create temporary directory for metrics.""" + self.temp_dir = tempfile.mkdtemp() + self.metrics_settings = MetricsSettings(directory=self.temp_dir) + + def tearDown(self): + """Clean up temporary directory.""" + import shutil + try: + shutil.rmtree(self.temp_dir) + except Exception: + pass + + def test_concurrent_counter_increments(self): + """Test that concurrent counter increments from multiple threads work correctly.""" + collector = MetricsCollector(self.metrics_settings) + + # Number of threads and increments per thread + num_threads = 10 + increments_per_thread = 100 + + # Track exceptions from threads + exceptions = [] + + def increment_task_poll(): + try: + for i in range(increments_per_thread): + collector.increment_task_poll(f"task_type_{threading.current_thread().name}") + except Exception as e: + exceptions.append(e) + + # Create and start threads + threads = [] + for i in range(num_threads): + thread = threading.Thread(target=increment_task_poll, name=f"thread_{i}") + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=5.0) + + # Verify no exceptions occurred + self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}") + + def test_concurrent_mixed_metrics(self): + """Test that concurrent mixed metric operations (counters, gauges, quantiles) work correctly.""" + collector = MetricsCollector(self.metrics_settings) + + num_threads = 5 + operations_per_thread = 50 + exceptions = [] + + def mixed_operations(): + try: + for i in range(operations_per_thread): + # Mix different metric types + collector.increment_task_poll("task_a") + collector.record_task_result_payload_size("task_a", 1024) + collector.record_task_execute_time("task_a", 0.123) + collector.increment_worker_restart("task_a") + except Exception as e: + exceptions.append(e) + + # Create and start threads + threads = [] + for i in range(num_threads): + thread = threading.Thread(target=mixed_operations, name=f"thread_{i}") + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=10.0) + + # Verify no exceptions occurred + self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}") + + def test_concurrent_quantile_recording(self): + """Test that concurrent quantile recording works correctly.""" + collector = MetricsCollector(self.metrics_settings) + + num_threads = 5 + observations_per_thread = 50 + exceptions = [] + + def record_quantiles(): + try: + for i in range(observations_per_thread): + # Record execution time (which uses quantiles) + collector.record_task_execute_time("task_b", float(i) / 100.0) + except Exception as e: + exceptions.append(e) + + # Create and start threads + threads = [] + for i in range(num_threads): + thread = threading.Thread(target=record_quantiles, name=f"thread_{i}") + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=10.0) + + # Verify no exceptions occurred + self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}") + + +if __name__ == '__main__': + unittest.main()