Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/conductor/client/automator/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def __init__(
self._monitor_thread: Optional[threading.Thread] = None
self._restart_counts: List[int] = [0 for _ in self.workers]
self._next_restart_at: List[float] = [0.0 for _ in self.workers]
# Lock to protect process list during concurrent access (monitor thread vs main thread)
self._process_lock = threading.Lock()
logger.info("TaskHandler initialized")

def __enter__(self):
Expand All @@ -280,8 +282,10 @@ def stop_processes(self) -> None:
self._monitor_stop_event.set()
if self._monitor_thread is not None and self._monitor_thread.is_alive():
self._monitor_thread.join(timeout=2.0)
self.__stop_task_runner_processes()
self.__stop_metrics_provider_process()
# Lock to prevent race conditions with monitor thread
with self._process_lock:
self.__stop_task_runner_processes()
self.__stop_metrics_provider_process()
logger.info("Stopped worker processes...")
self.queue.put(None)
self.logger_process.terminate()
Expand Down Expand Up @@ -381,20 +385,22 @@ def __monitor_loop(self) -> None:
def __check_and_restart_processes(self) -> None:
if self._monitor_stop_event.is_set():
return
for i, process in enumerate(list(self.task_runner_processes)):
if process is None:
continue
if process.is_alive():
continue
exitcode = process.exitcode
if exitcode is None:
continue
worker = self.workers[i] if i < len(self.workers) else None
worker_name = worker.get_task_definition_name() if worker is not None else f"worker[{i}]"
logger.warning("Worker process exited (worker=%s, pid=%s, exitcode=%s)", worker_name, process.pid, exitcode)
if not self.restart_on_failure:
continue
self.__restart_worker_process(i)
# Lock to prevent race conditions with stop_processes
with self._process_lock:
for i, process in enumerate(list(self.task_runner_processes)):
if process is None:
continue
if process.is_alive():
continue
exitcode = process.exitcode
if exitcode is None:
continue
worker = self.workers[i] if i < len(self.workers) else None
worker_name = worker.get_task_definition_name() if worker is not None else f"worker[{i}]"
logger.warning("Worker process exited (worker=%s, pid=%s, exitcode=%s)", worker_name, process.pid, exitcode)
if not self.restart_on_failure:
continue
self.__restart_worker_process(i)

def __restart_worker_process(self, index: int) -> None:
if self._monitor_stop_event.is_set():
Expand Down Expand Up @@ -522,7 +528,6 @@ def __start_task_runner_processes(self):
n = 0
for i, task_runner_process in enumerate(self.task_runner_processes):
task_runner_process.start()
print(f'task runner process {task_runner_process.name} started')
worker = self.workers[i]
paused_status = "PAUSED" if getattr(worker, "paused", False) else "ACTIVE"
logger.debug("Started worker '%s' [%s]", worker.get_task_definition_name(), paused_status)
Expand Down
129 changes: 72 additions & 57 deletions src/conductor/client/telemetry/metrics_collector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import threading
import time
from collections import deque
from typing import Any, ClassVar, Dict, List, Tuple
Expand Down Expand Up @@ -84,6 +85,11 @@ class MetricsCollector:
dispatcher.register(PollStarted, metrics.on_poll_started)
dispatcher.publish(PollStarted(...))

Thread Safety:
This class is thread-safe. Internal dictionaries (counters, gauges, histograms, etc.)
are protected by a lock to prevent race conditions when accessed from multiple threads
(e.g., worker threads and monitor threads).

Note: Uses Python's Protocol for structural subtyping rather than explicit
inheritance to avoid circular imports and maintain backward compatibility.
"""
Expand All @@ -99,6 +105,7 @@ def __init__(self, settings: MetricsSettings):
self.quantile_data: Dict[str, deque] = {} # metric_name+labels -> deque of values
self.registry = None
self.must_collect_metrics = False
self._lock = threading.RLock() # Reentrant lock for thread-safe access to internal dictionaries

if settings is None:
return
Expand Down Expand Up @@ -504,12 +511,13 @@ def __increment_counter(
) -> None:
if not self.must_collect_metrics:
return
counter = self.__get_counter(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
counter.labels(*labels.values()).inc()
with self._lock:
counter = self.__get_counter(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
counter.labels(*labels.values()).inc()

def __record_gauge(
self,
Expand All @@ -520,12 +528,13 @@ def __record_gauge(
) -> None:
if not self.must_collect_metrics:
return
gauge = self.__get_gauge(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
gauge.labels(*labels.values()).set(value)
with self._lock:
gauge = self.__get_gauge(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
gauge.labels(*labels.values()).set(value)

def __get_counter(
self,
Expand Down Expand Up @@ -587,12 +596,13 @@ def __observe_histogram(
) -> None:
if not self.must_collect_metrics:
return
histogram = self.__get_histogram(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
histogram.labels(*labels.values()).observe(value)
with self._lock:
histogram = self.__get_histogram(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
histogram.labels(*labels.values()).observe(value)

def __get_histogram(
self,
Expand Down Expand Up @@ -630,12 +640,13 @@ def __observe_summary(
) -> None:
if not self.must_collect_metrics:
return
summary = self.__get_summary(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
summary.labels(*labels.values()).observe(value)
with self._lock:
summary = self.__get_summary(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()]
)
summary.labels(*labels.values()).observe(value)

def __get_summary(
self,
Expand Down Expand Up @@ -681,45 +692,46 @@ def __record_quantiles(
if not self.must_collect_metrics:
return

# Create a key for this metric+labels combination
label_values = tuple(labels.values())
data_key = f"{name}_{label_values}"

# Initialize data window if needed
if data_key not in self.quantile_data:
self.quantile_data[data_key] = deque(maxlen=self.QUANTILE_WINDOW_SIZE)

# Add new observation
self.quantile_data[data_key].append(value)

# Calculate and update quantiles
observations = sorted(self.quantile_data[data_key])
n = len(observations)
with self._lock:
# Create a key for this metric+labels combination
label_values = tuple(labels.values())
data_key = f"{name}_{label_values}"

# Initialize data window if needed
if data_key not in self.quantile_data:
self.quantile_data[data_key] = deque(maxlen=self.QUANTILE_WINDOW_SIZE)

# Add new observation
self.quantile_data[data_key].append(value)

# Calculate and update quantiles
observations = sorted(self.quantile_data[data_key])
n = len(observations)

if n > 0:
quantiles = [0.5, 0.75, 0.9, 0.95, 0.99]
for q in quantiles:
quantile_value = self.__calculate_quantile(observations, q)

# Get or create gauge for this quantile
gauge = self.__get_quantile_gauge(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()] + ["quantile"],
quantile=q
)

if n > 0:
quantiles = [0.5, 0.75, 0.9, 0.95, 0.99]
for q in quantiles:
quantile_value = self.__calculate_quantile(observations, q)
# Set gauge value with labels + quantile
gauge.labels(*labels.values(), str(q)).set(quantile_value)

# Get or create gauge for this quantile
gauge = self.__get_quantile_gauge(
# Also publish _count and _sum for proper summary metrics
self.__update_summary_aggregates(
name=name,
documentation=documentation,
labelnames=[label.value for label in labels.keys()] + ["quantile"],
quantile=q
labels=labels,
observations=list(self.quantile_data[data_key])
)

# Set gauge value with labels + quantile
gauge.labels(*labels.values(), str(q)).set(quantile_value)

# Also publish _count and _sum for proper summary metrics
self.__update_summary_aggregates(
name=name,
documentation=documentation,
labels=labels,
observations=list(self.quantile_data[data_key])
)

def __calculate_quantile(self, sorted_values: List[float], quantile: float) -> float:
"""Calculate quantile from sorted list of values."""
if not sorted_values:
Expand Down Expand Up @@ -770,6 +782,9 @@ def __update_summary_aggregates(
"""
Update _count and _sum gauges for proper summary metric format.
This makes the metrics compatible with Prometheus summary type.

Note: This method should only be called while holding self._lock
(called from __record_quantiles which already holds the lock).
"""
if not observations:
return
Expand Down
125 changes: 125 additions & 0 deletions tests/unit/telemetry/test_metrics_collector_thread_safety.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import tempfile
import threading
import time
import unittest
from pathlib import Path

from conductor.client.configuration.settings.metrics_settings import MetricsSettings
from conductor.client.telemetry.metrics_collector import MetricsCollector


class TestMetricsCollectorThreadSafety(unittest.TestCase):
"""Test thread safety of MetricsCollector."""

def setUp(self):
"""Create temporary directory for metrics."""
self.temp_dir = tempfile.mkdtemp()
self.metrics_settings = MetricsSettings(directory=self.temp_dir)

def tearDown(self):
"""Clean up temporary directory."""
import shutil
try:
shutil.rmtree(self.temp_dir)
except Exception:
pass

def test_concurrent_counter_increments(self):
"""Test that concurrent counter increments from multiple threads work correctly."""
collector = MetricsCollector(self.metrics_settings)

# Number of threads and increments per thread
num_threads = 10
increments_per_thread = 100

# Track exceptions from threads
exceptions = []

def increment_task_poll():
try:
for i in range(increments_per_thread):
collector.increment_task_poll(f"task_type_{threading.current_thread().name}")
except Exception as e:
exceptions.append(e)

# Create and start threads
threads = []
for i in range(num_threads):
thread = threading.Thread(target=increment_task_poll, name=f"thread_{i}")
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join(timeout=5.0)

# Verify no exceptions occurred
self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}")

def test_concurrent_mixed_metrics(self):
"""Test that concurrent mixed metric operations (counters, gauges, quantiles) work correctly."""
collector = MetricsCollector(self.metrics_settings)

num_threads = 5
operations_per_thread = 50
exceptions = []

def mixed_operations():
try:
for i in range(operations_per_thread):
# Mix different metric types
collector.increment_task_poll("task_a")
collector.record_task_result_payload_size("task_a", 1024)
collector.record_task_execute_time("task_a", 0.123)
collector.increment_worker_restart("task_a")
except Exception as e:
exceptions.append(e)

# Create and start threads
threads = []
for i in range(num_threads):
thread = threading.Thread(target=mixed_operations, name=f"thread_{i}")
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join(timeout=10.0)

# Verify no exceptions occurred
self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}")

def test_concurrent_quantile_recording(self):
"""Test that concurrent quantile recording works correctly."""
collector = MetricsCollector(self.metrics_settings)

num_threads = 5
observations_per_thread = 50
exceptions = []

def record_quantiles():
try:
for i in range(observations_per_thread):
# Record execution time (which uses quantiles)
collector.record_task_execute_time("task_b", float(i) / 100.0)
except Exception as e:
exceptions.append(e)

# Create and start threads
threads = []
for i in range(num_threads):
thread = threading.Thread(target=record_quantiles, name=f"thread_{i}")
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join(timeout=10.0)

# Verify no exceptions occurred
self.assertEqual(len(exceptions), 0, f"Exceptions occurred during concurrent access: {exceptions}")


if __name__ == '__main__':
unittest.main()
Loading