Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 160 additions & 6 deletions airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import ipaddress
import json
import logging
import os
import os.path
import socket
import sys
import tempfile
import threading
import time
import traceback
from collections import defaultdict
from functools import wraps
from typing import Any, DefaultDict, Iterable, List, Mapping, Optional
from typing import Any, Callable, DefaultDict, Iterable, List, Mapping, Optional
from urllib.parse import urlparse

import orjson
Expand Down Expand Up @@ -51,6 +55,144 @@
_HAS_LOGGED_FOR_SERIALIZATION_ERROR = False


class _DeadlockDiagnostics:
def __init__(
self,
*,
interval_seconds: float = 30.0,
stall_interval_count: int = 3,
dump_repeat_interval_count: int = 10,
time_fn: Callable[[], float] = time.monotonic,
write_fn: Callable[[bytes], None] | None = None,
) -> None:
self._interval_seconds = interval_seconds
self._stall_interval_count = stall_interval_count
self._dump_repeat_interval_count = dump_repeat_interval_count
self._time_fn = time_fn
self._write_fn = write_fn or self._write_to_stderr
self._stop = threading.Event()
self._lock = threading.Lock()
self._thread: threading.Thread | None = None
self._start_time = self._time_fn()
self._messages_written = 0
self._bytes_written = 0
self._print_blocked = False
self._print_blocked_since = 0.0
self._last_messages_written = 0
self._stall_count = 0

def start(self) -> None:
self._thread = threading.Thread(
target=self._run,
name="airbyte-deadlock-diagnostics",
daemon=True,
)
self._thread.start()

def stop(self) -> None:
self._stop.set()

def mark_print_started(self) -> None:
with self._lock:
self._print_blocked = True
self._print_blocked_since = self._time_fn()

def mark_print_finished(self) -> None:
with self._lock:
self._print_blocked = False

def record_message(self, data: str) -> None:
with self._lock:
self._messages_written += 1
self._bytes_written += len(data.encode())

def emit_heartbeat(self) -> None:
now = self._time_fn()
with self._lock:
messages_written = self._messages_written
bytes_written = self._bytes_written
print_blocked = self._print_blocked
print_blocked_since = self._print_blocked_since

if messages_written == self._last_messages_written and messages_written > 0:
self._stall_count += 1
else:
self._stall_count = 0
self._last_messages_written = messages_written
stall_count = self._stall_count

line = self._heartbeat_line(
now,
messages_written,
bytes_written,
print_blocked,
print_blocked_since,
)
if self._should_dump_threads(stall_count):
line += self._thread_dump()

try:
self._write_fn(line.encode())
except OSError:
return

def _run(self) -> None:
while not self._stop.wait(timeout=self._interval_seconds):
self.emit_heartbeat()

def _heartbeat_line(
self,
now: float,
messages_written: int,
bytes_written: int,
print_blocked: bool,
print_blocked_since: float,
) -> str:
blocked = "YES" if print_blocked else "NO"
blocked_duration = (
f" blocked_since={now - print_blocked_since:.0f}s" if print_blocked else ""
)
return (
f"STDOUT_HEARTBEAT: t={now - self._start_time:.0f}s "
f"msgs={messages_written} bytes={bytes_written} "
f"print_blocked={blocked}{blocked_duration}"
f"{self._queue_stats()}\n"
)

def _queue_stats(self) -> str:
from airbyte_cdk.sources.concurrent_source.queue_registry import get_queue

queue = get_queue()
if queue is None:
return ""

try:
return f" queue_size={queue.qsize()} queue_full={queue.full()}"
except NotImplementedError:
return ""

def _should_dump_threads(self, stall_count: int) -> bool:
if stall_count == self._stall_interval_count:
return True
return (
stall_count > self._stall_interval_count
and (stall_count - self._stall_interval_count) % self._dump_repeat_interval_count == 0
)

def _thread_dump(self) -> str:
thread_names = {thread.ident: thread.name for thread in threading.enumerate()}
lines = ["=== THREAD DUMP (stall detected) ===\n"]
for thread_id, frame in sys._current_frames().items():
lines.append(f"\nThread {thread_names.get(thread_id, 'unknown')} ({thread_id}):\n")
lines.extend(traceback.format_stack(frame))
lines.append("=== END THREAD DUMP ===\n")
return "".join(lines)

@staticmethod
def _write_to_stderr(data: bytes) -> None:
os.write(2, data)


class AirbyteEntrypoint(object):
def __init__(self, source: Source):
init_uncaught_exception_handler(logger)
Expand Down Expand Up @@ -391,13 +533,25 @@ def _emit_queued_messages(self, source: Source) -> Iterable[AirbyteMessage]:
def launch(source: Source, args: List[str]) -> None:
source_entrypoint = AirbyteEntrypoint(source)
parsed_args = source_entrypoint.parse_args(args)
diagnostics = _DeadlockDiagnostics()
diagnostics.start()

# temporarily removes the PrintBuffer because we're seeing weird print behavior for concurrent syncs
# Refer to: https://github.com/airbytehq/oncall/issues/6235
with PRINT_BUFFER:
for message in source_entrypoint.run(parsed_args):
# simply printing is creating issues for concurrent CDK as Python uses different two instructions to print: one for the message and
# the other for the break line. Adding `\n` to the message ensure that both are printed at the same time
print(f"{message}\n", end="")
try:
with PRINT_BUFFER:
for message in source_entrypoint.run(parsed_args):
# simply printing is creating issues for concurrent CDK as Python uses different two instructions to print: one for the message and
# the other for the break line. Adding `\n` to the message ensure that both are printed at the same time
data = f"{message}\n"
diagnostics.mark_print_started()
try:
print(data, end="")
finally:
diagnostics.mark_print_finished()
diagnostics.record_message(data)
finally:
diagnostics.stop()


def _init_internal_request_filter() -> None:
Expand Down
49 changes: 27 additions & 22 deletions airbyte_cdk/sources/concurrent_source/concurrent_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import (
PartitionGenerationCompletedSentinel,
)
from airbyte_cdk.sources.concurrent_source.queue_registry import register_queue, unregister_queue
from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException
from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager
from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository
Expand Down Expand Up @@ -110,30 +111,34 @@ def read(
streams: List[AbstractStream],
) -> Iterator[AirbyteMessage]:
self._logger.info("Starting syncing")
concurrent_stream_processor = ConcurrentReadProcessor(
streams,
PartitionEnqueuer(self._queue, self._threadpool),
self._threadpool,
self._logger,
self._slice_logger,
self._message_repository,
PartitionReader(
self._queue,
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
),
max_concurrent_partition_generators=self._initial_number_partitions_to_generate,
)
register_queue(self._queue)
try:
concurrent_stream_processor = ConcurrentReadProcessor(
streams,
PartitionEnqueuer(self._queue, self._threadpool),
self._threadpool,
self._logger,
self._slice_logger,
self._message_repository,
PartitionReader(
self._queue,
PartitionLogger(self._slice_logger, self._logger, self._message_repository),
),
max_concurrent_partition_generators=self._initial_number_partitions_to_generate,
)

# Enqueue initial partition generation tasks
yield from self._submit_initial_partition_generators(concurrent_stream_processor)
# Enqueue initial partition generation tasks
yield from self._submit_initial_partition_generators(concurrent_stream_processor)

# Read from the queue until all partitions were generated and read
yield from self._consume_from_queue(
self._queue,
concurrent_stream_processor,
)
self._threadpool.check_for_errors_and_shutdown()
self._logger.info("Finished syncing")
# Read from the queue until all partitions were generated and read
yield from self._consume_from_queue(
self._queue,
concurrent_stream_processor,
)
self._threadpool.check_for_errors_and_shutdown()
self._logger.info("Finished syncing")
finally:
unregister_queue()

def _submit_initial_partition_generators(
self, concurrent_stream_processor: ConcurrentReadProcessor
Expand Down
24 changes: 24 additions & 0 deletions airbyte_cdk/sources/concurrent_source/queue_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (c) 2026 Airbyte, Inc., all rights reserved.
#

from queue import Queue
from typing import Optional

from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem

_queue: Optional[Queue[QueueItem]] = None


def register_queue(queue: Queue[QueueItem]) -> None:
global _queue
_queue = queue


def get_queue() -> Optional[Queue[QueueItem]]:
return _queue


def unregister_queue() -> None:
global _queue
_queue = None
73 changes: 73 additions & 0 deletions unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from argparse import Namespace
from collections import defaultdict
from copy import deepcopy
from queue import Queue
from typing import Any, List, Mapping, MutableMapping, Union
from unittest import mock
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -944,3 +945,75 @@ def test_memory_failfast_flushes_queued_state_before_raising(mocker):
with pytest.raises(AirbyteTracedException) as exc_info:
next(gen)
assert exc_info.value is fail_fast_exc


def test_deadlock_diagnostics_heartbeat_includes_stall_thread_dump():
emitted: list[bytes] = []
timestamps = iter([100.0, 130.0, 160.0, 190.0, 220.0])
diagnostics = entrypoint_module._DeadlockDiagnostics(
interval_seconds=30.0,
stall_interval_count=3,
time_fn=lambda: next(timestamps),
write_fn=emitted.append,
)

diagnostics.record_message("record\n")
diagnostics.emit_heartbeat()
diagnostics.emit_heartbeat()
diagnostics.emit_heartbeat()
diagnostics.emit_heartbeat()

first_heartbeat, _, _, stalled_heartbeat = [item.decode() for item in emitted]
assert "STDOUT_HEARTBEAT: t=30s msgs=1 bytes=7 print_blocked=NO" in first_heartbeat
assert "=== THREAD DUMP (stall detected) ===" in stalled_heartbeat
assert "=== END THREAD DUMP ===" in stalled_heartbeat


def test_deadlock_diagnostics_heartbeat_reports_queue_stats():
emitted: list[bytes] = []
queue: Queue = Queue(maxsize=1)
queue.put("record")

from airbyte_cdk.sources.concurrent_source.queue_registry import (
register_queue,
unregister_queue,
)

register_queue(queue)
try:
diagnostics = entrypoint_module._DeadlockDiagnostics(
interval_seconds=30.0,
time_fn=lambda: 30.0,
write_fn=emitted.append,
)

diagnostics.emit_heartbeat()
finally:
unregister_queue()

assert "queue_size=1 queue_full=True" in emitted[0].decode()


def test_launch_starts_deadlock_diagnostics(mocker):
diagnostics = MagicMock()
mocker.patch.object(entrypoint_module, "_DeadlockDiagnostics", return_value=diagnostics)
mocker.patch.object(AirbyteEntrypoint, "parse_args", return_value=Namespace(command="spec"))
mocker.patch.object(
AirbyteEntrypoint,
"run",
return_value=iter(
[
AirbyteMessage(
type=Type.SPEC, spec=ConnectorSpecification(connectionSpecification={})
)
]
),
)

entrypoint_module.launch(MockSource(), ["spec"])

diagnostics.start.assert_called_once_with()
diagnostics.mark_print_started.assert_called_once_with()
diagnostics.mark_print_finished.assert_called_once_with()
diagnostics.record_message.assert_called_once()
diagnostics.stop.assert_called_once_with()
Loading