diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index 57820f005..d7049b9e3 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -7,13 +7,17 @@ import ipaddress import json import logging +import os import os.path import socket import sys import tempfile +import threading +import time +import traceback from collections import defaultdict from functools import wraps -from typing import Any, DefaultDict, Iterable, List, Mapping, Optional +from typing import Any, Callable, DefaultDict, Iterable, List, Mapping, Optional from urllib.parse import urlparse import orjson @@ -51,6 +55,144 @@ _HAS_LOGGED_FOR_SERIALIZATION_ERROR = False +class _DeadlockDiagnostics: + def __init__( + self, + *, + interval_seconds: float = 30.0, + stall_interval_count: int = 3, + dump_repeat_interval_count: int = 10, + time_fn: Callable[[], float] = time.monotonic, + write_fn: Callable[[bytes], None] | None = None, + ) -> None: + self._interval_seconds = interval_seconds + self._stall_interval_count = stall_interval_count + self._dump_repeat_interval_count = dump_repeat_interval_count + self._time_fn = time_fn + self._write_fn = write_fn or self._write_to_stderr + self._stop = threading.Event() + self._lock = threading.Lock() + self._thread: threading.Thread | None = None + self._start_time = self._time_fn() + self._messages_written = 0 + self._bytes_written = 0 + self._print_blocked = False + self._print_blocked_since = 0.0 + self._last_messages_written = 0 + self._stall_count = 0 + + def start(self) -> None: + self._thread = threading.Thread( + target=self._run, + name="airbyte-deadlock-diagnostics", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + self._stop.set() + + def mark_print_started(self) -> None: + with self._lock: + self._print_blocked = True + self._print_blocked_since = self._time_fn() + + def mark_print_finished(self) -> None: + with self._lock: + self._print_blocked = False + + def record_message(self, data: str) -> None: + with self._lock: + self._messages_written += 1 + self._bytes_written += len(data.encode()) + + def emit_heartbeat(self) -> None: + now = self._time_fn() + with self._lock: + messages_written = self._messages_written + bytes_written = self._bytes_written + print_blocked = self._print_blocked + print_blocked_since = self._print_blocked_since + + if messages_written == self._last_messages_written and messages_written > 0: + self._stall_count += 1 + else: + self._stall_count = 0 + self._last_messages_written = messages_written + stall_count = self._stall_count + + line = self._heartbeat_line( + now, + messages_written, + bytes_written, + print_blocked, + print_blocked_since, + ) + if self._should_dump_threads(stall_count): + line += self._thread_dump() + + try: + self._write_fn(line.encode()) + except OSError: + return + + def _run(self) -> None: + while not self._stop.wait(timeout=self._interval_seconds): + self.emit_heartbeat() + + def _heartbeat_line( + self, + now: float, + messages_written: int, + bytes_written: int, + print_blocked: bool, + print_blocked_since: float, + ) -> str: + blocked = "YES" if print_blocked else "NO" + blocked_duration = ( + f" blocked_since={now - print_blocked_since:.0f}s" if print_blocked else "" + ) + return ( + f"STDOUT_HEARTBEAT: t={now - self._start_time:.0f}s " + f"msgs={messages_written} bytes={bytes_written} " + f"print_blocked={blocked}{blocked_duration}" + f"{self._queue_stats()}\n" + ) + + def _queue_stats(self) -> str: + from airbyte_cdk.sources.concurrent_source.queue_registry import get_queue + + queue = get_queue() + if queue is None: + return "" + + try: + return f" queue_size={queue.qsize()} queue_full={queue.full()}" + except NotImplementedError: + return "" + + def _should_dump_threads(self, stall_count: int) -> bool: + if stall_count == self._stall_interval_count: + return True + return ( + stall_count > self._stall_interval_count + and (stall_count - self._stall_interval_count) % self._dump_repeat_interval_count == 0 + ) + + def _thread_dump(self) -> str: + thread_names = {thread.ident: thread.name for thread in threading.enumerate()} + lines = ["=== THREAD DUMP (stall detected) ===\n"] + for thread_id, frame in sys._current_frames().items(): + lines.append(f"\nThread {thread_names.get(thread_id, 'unknown')} ({thread_id}):\n") + lines.extend(traceback.format_stack(frame)) + lines.append("=== END THREAD DUMP ===\n") + return "".join(lines) + + @staticmethod + def _write_to_stderr(data: bytes) -> None: + os.write(2, data) + + class AirbyteEntrypoint(object): def __init__(self, source: Source): init_uncaught_exception_handler(logger) @@ -391,13 +533,25 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]: def launch(source: Source, args: List[str]) -> None: source_entrypoint = AirbyteEntrypoint(source) parsed_args = source_entrypoint.parse_args(args) + diagnostics = _DeadlockDiagnostics() + diagnostics.start() + # temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs # Refer to: https://github.com/airbytehq/oncall/issues/6235 - with PRINT_BUFFER: - for message in source_entrypoint.run(parsed_args): - # simply printing is creating issues for concurrent CDK as Python uses different two instructions to print: one for the message and - # the other for the break line. Adding `\n` to the message ensure that both are printed at the same time - print(f"{message}\n", end="") + try: + with PRINT_BUFFER: + for message in source_entrypoint.run(parsed_args): + # simply printing is creating issues for concurrent CDK as Python uses different two instructions to print: one for the message and + # the other for the break line. Adding `\n` to the message ensure that both are printed at the same time + data = f"{message}\n" + diagnostics.mark_print_started() + try: + print(data, end="") + finally: + diagnostics.mark_print_finished() + diagnostics.record_message(data) + finally: + diagnostics.stop() def _init_internal_request_filter() -> None: diff --git a/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte_cdk/sources/concurrent_source/concurrent_source.py index 474780bcc..f3355326e 100644 --- a/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -12,6 +12,7 @@ from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import ( PartitionGenerationCompletedSentinel, ) +from airbyte_cdk.sources.concurrent_source.queue_registry import register_queue, unregister_queue from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository @@ -110,30 +111,34 @@ def read( streams: List[AbstractStream], ) -> Iterator[AirbyteMessage]: self._logger.info("Starting syncing") - concurrent_stream_processor = ConcurrentReadProcessor( - streams, - PartitionEnqueuer(self._queue, self._threadpool), - self._threadpool, - self._logger, - self._slice_logger, - self._message_repository, - PartitionReader( - self._queue, - PartitionLogger(self._slice_logger, self._logger, self._message_repository), - ), - max_concurrent_partition_generators=self._initial_number_partitions_to_generate, - ) + register_queue(self._queue) + try: + concurrent_stream_processor = ConcurrentReadProcessor( + streams, + PartitionEnqueuer(self._queue, self._threadpool), + self._threadpool, + self._logger, + self._slice_logger, + self._message_repository, + PartitionReader( + self._queue, + PartitionLogger(self._slice_logger, self._logger, self._message_repository), + ), + max_concurrent_partition_generators=self._initial_number_partitions_to_generate, + ) - # Enqueue initial partition generation tasks - yield from self._submit_initial_partition_generators(concurrent_stream_processor) + # Enqueue initial partition generation tasks + yield from self._submit_initial_partition_generators(concurrent_stream_processor) - # Read from the queue until all partitions were generated and read - yield from self._consume_from_queue( - self._queue, - concurrent_stream_processor, - ) - self._threadpool.check_for_errors_and_shutdown() - self._logger.info("Finished syncing") + # Read from the queue until all partitions were generated and read + yield from self._consume_from_queue( + self._queue, + concurrent_stream_processor, + ) + self._threadpool.check_for_errors_and_shutdown() + self._logger.info("Finished syncing") + finally: + unregister_queue() def _submit_initial_partition_generators( self, concurrent_stream_processor: ConcurrentReadProcessor diff --git a/airbyte_cdk/sources/concurrent_source/queue_registry.py b/airbyte_cdk/sources/concurrent_source/queue_registry.py new file mode 100644 index 000000000..97345eee5 --- /dev/null +++ b/airbyte_cdk/sources/concurrent_source/queue_registry.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2026 Airbyte, Inc., all rights reserved. +# + +from queue import Queue +from typing import Optional + +from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem + +_queue: Optional[Queue[QueueItem]] = None + + +def register_queue(queue: Queue[QueueItem]) -> None: + global _queue + _queue = queue + + +def get_queue() -> Optional[Queue[QueueItem]]: + return _queue + + +def unregister_queue() -> None: + global _queue + _queue = None diff --git a/unit_tests/test_entrypoint.py b/unit_tests/test_entrypoint.py index fcfb44915..7ab948535 100644 --- a/unit_tests/test_entrypoint.py +++ b/unit_tests/test_entrypoint.py @@ -6,6 +6,7 @@ from argparse import Namespace from collections import defaultdict from copy import deepcopy +from queue import Queue from typing import Any, List, Mapping, MutableMapping, Union from unittest import mock from unittest.mock import MagicMock, patch @@ -944,3 +945,75 @@ def test_memory_failfast_flushes_queued_state_before_raising(mocker): with pytest.raises(AirbyteTracedException) as exc_info: next(gen) assert exc_info.value is fail_fast_exc + + +def test_deadlock_diagnostics_heartbeat_includes_stall_thread_dump(): + emitted: list[bytes] = [] + timestamps = iter([100.0, 130.0, 160.0, 190.0, 220.0]) + diagnostics = entrypoint_module._DeadlockDiagnostics( + interval_seconds=30.0, + stall_interval_count=3, + time_fn=lambda: next(timestamps), + write_fn=emitted.append, + ) + + diagnostics.record_message("record\n") + diagnostics.emit_heartbeat() + diagnostics.emit_heartbeat() + diagnostics.emit_heartbeat() + diagnostics.emit_heartbeat() + + first_heartbeat, _, _, stalled_heartbeat = [item.decode() for item in emitted] + assert "STDOUT_HEARTBEAT: t=30s msgs=1 bytes=7 print_blocked=NO" in first_heartbeat + assert "=== THREAD DUMP (stall detected) ===" in stalled_heartbeat + assert "=== END THREAD DUMP ===" in stalled_heartbeat + + +def test_deadlock_diagnostics_heartbeat_reports_queue_stats(): + emitted: list[bytes] = [] + queue: Queue = Queue(maxsize=1) + queue.put("record") + + from airbyte_cdk.sources.concurrent_source.queue_registry import ( + register_queue, + unregister_queue, + ) + + register_queue(queue) + try: + diagnostics = entrypoint_module._DeadlockDiagnostics( + interval_seconds=30.0, + time_fn=lambda: 30.0, + write_fn=emitted.append, + ) + + diagnostics.emit_heartbeat() + finally: + unregister_queue() + + assert "queue_size=1 queue_full=True" in emitted[0].decode() + + +def test_launch_starts_deadlock_diagnostics(mocker): + diagnostics = MagicMock() + mocker.patch.object(entrypoint_module, "_DeadlockDiagnostics", return_value=diagnostics) + mocker.patch.object(AirbyteEntrypoint, "parse_args", return_value=Namespace(command="spec")) + mocker.patch.object( + AirbyteEntrypoint, + "run", + return_value=iter( + [ + AirbyteMessage( + type=Type.SPEC, spec=ConnectorSpecification(connectionSpecification={}) + ) + ] + ), + ) + + entrypoint_module.launch(MockSource(), ["spec"]) + + diagnostics.start.assert_called_once_with() + diagnostics.mark_print_started.assert_called_once_with() + diagnostics.mark_print_finished.assert_called_once_with() + diagnostics.record_message.assert_called_once() + diagnostics.stop.assert_called_once_with()