diff --git a/cloud_pipelines_backend/launchers/interfaces.py b/cloud_pipelines_backend/launchers/interfaces.py index 2c1309c8..848c459b 100644 --- a/cloud_pipelines_backend/launchers/interfaces.py +++ b/cloud_pipelines_backend/launchers/interfaces.py @@ -125,6 +125,16 @@ def get_refreshed( ) -> typing_extensions.Self: raise NotImplementedError() + def transient_infra_failure_reason(self) -> str | None: + """Short human-readable reason if this container is wedged by a transient + infrastructure fault that warrants relaunching the task, else None. + + Default: None (no transient fault). Launchers override this to detect + launcher-specific transient faults (e.g. a wedged CSI sidecar). The + orchestrator relaunches the task in place when a reason is returned. + """ + return None + def get_log(self) -> str: raise NotImplementedError() diff --git a/cloud_pipelines_backend/launchers/kubernetes_launchers.py b/cloud_pipelines_backend/launchers/kubernetes_launchers.py index 25cd267b..d29c76a1 100644 --- a/cloud_pipelines_backend/launchers/kubernetes_launchers.py +++ b/cloud_pipelines_backend/launchers/kubernetes_launchers.py @@ -31,6 +31,18 @@ _MAX_INPUT_VALUE_SIZE = 10000 _MAIN_CONTAINER_NAME = "main" +# Detection of pods wedged by a transient GKE gcsfuse-sidecar failure. +# The GKE-injected gcsfuse sidecar runs a bucket-access-check pre-flight that +# resolves Workload Identity via the metadata server. When that endpoint is +# degraded the call times out, the sidecar fatals (exit 255), the GCS volume +# never mounts, and the main container wedges in CreateContainerConfigError +# until the run-level timeout cancels it. We surface that signature so the +# orchestrator can relaunch the task in place (usually on a healthier node). +_GCSFUSE_SIDECAR_CONTAINER_NAME = "gke-gcsfuse-sidecar" +_GCSFUSE_SIDECAR_FATAL_EXIT_CODE = 255 +_GCSFUSE_WEDGE_MESSAGE_SUBSTRING = "bucket access check" +_GCSFUSE_WEDGE_REASONS = frozenset({"ContainerStatusUnknown", "Error", "StartError"}) + # Kubernetes annotation keys. (Has strict naming policy. Single slash only etc.) _CLOUD_PIPELINES_KUBERNETES_ANNOTATION_KEY = "cloud-pipelines.net" @@ -919,6 +931,64 @@ def get_refreshed(self) -> "LaunchedKubernetesContainer": new_launched_container._debug_pod = pod return new_launched_container + def _detect_wedged_gcsfuse_sidecar(self) -> bool: + """Return True if the pod is wedged by a fatal gcsfuse-sidecar pre-flight. + + Signature: the gke-gcsfuse-sidecar init container terminated with exit + 255, a transient reason, and a bucket-access-check message, while the + main container has not started. See the module-level constants. + """ + pod_status: k8s_client_lib.V1PodStatus | None = self._debug_pod.status + if pod_status is None: + return False + # Never relaunch a pod whose main container has already started or ended; + # the wedge only matters while the main container is still blocked. + main_container_state = self._get_main_container_state() + if main_container_state is not None and ( + main_container_state.running is not None + or main_container_state.terminated is not None + ): + return False + + init_container_statuses: list[k8s_client_lib.V1ContainerStatus] = ( + pod_status.init_container_statuses or [] + ) + for container_status in init_container_statuses: + if container_status.name != _GCSFUSE_SIDECAR_CONTAINER_NAME: + continue + state: k8s_client_lib.V1ContainerState | None = container_status.state + last_state: k8s_client_lib.V1ContainerState | None = ( + container_status.last_state + ) + terminated_states = [ + s.terminated + for s in (state, last_state) + if s is not None and s.terminated is not None + ] + for terminated in terminated_states: + message = terminated.message or "" + if ( + terminated.exit_code == _GCSFUSE_SIDECAR_FATAL_EXIT_CODE + and terminated.reason in _GCSFUSE_WEDGE_REASONS + and _GCSFUSE_WEDGE_MESSAGE_SUBSTRING in message.lower() + ): + return True + return False + + def transient_infra_failure_reason(self) -> str | None: + """Reason string when the pod is wedged by a transient gcsfuse-sidecar + bucket-access-check failure, else None. + + The orchestrator relaunches the task in place (same run/node) when this + returns a reason, so we do not mutate the pod here. + """ + if self._detect_wedged_gcsfuse_sidecar(): + return ( + "gke-gcsfuse-sidecar bucket-access-check timeout " + "(GKE metadata server unreachable)" + ) + return None + def get_log(self) -> str: launcher = self._get_launcher() core_api_client = k8s_client_lib.CoreV1Api(api_client=launcher._api_client) diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index 1a81f456..324c22ea 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -32,6 +32,12 @@ DYNAMIC_DATA_SECRET_KEY = "secret" DYNAMIC_DATA_SECRET_NAME_KEY = "name" +# Per-node retry budget for transient infrastructure failures (e.g. a wedged +# gcsfuse sidecar). Each attempt is one fresh pod; the count is tracked on the +# ExecutionNode's extra_data so it survives across orchestrator polls/relaunches. +_MAX_TRANSIENT_INFRA_RETRIES = 2 +_TRANSIENT_INFRA_RETRY_COUNT_KEY = "transient_infra_retry_count" + class OrchestratorError(RuntimeError): pass @@ -753,6 +759,20 @@ def internal_process_one_running_execution( reloaded_launched_container: launcher_interfaces.LaunchedContainer = ( self._launcher.get_refreshed_launched_container_from_dict(launcher_data) ) + # Self-heal transient infra failures (e.g. a wedged gcsfuse sidecar) by + # relaunching the task in place: error out this attempt and re-queue the + # execution node so the normal launch path builds a fresh pod. + transient_infra_failure_reason = ( + reloaded_launched_container.transient_infra_failure_reason() + ) + if transient_infra_failure_reason is not None: + self._handle_transient_infra_failure( + session=session, + container_execution=container_execution, + launched_container=reloaded_launched_container, + reason=transient_infra_failure_reason, + ) + return current_time = _get_current_time() # Saving the updated launcher data reloaded_launcher_data = reloaded_launched_container.to_dict() @@ -968,6 +988,91 @@ def _maybe_preload_value( ) session.commit() + def _handle_transient_infra_failure( + self, + *, + session: orm.Session, + container_execution: bts.ContainerExecution, + launched_container: launcher_interfaces.LaunchedContainer, + reason: str, + ): + """Relaunch a task wedged by a transient infrastructure failure. + + The current pod attempt is terminated and its ContainerExecution is + marked SYSTEM_ERROR (which both records the failed attempt and excludes + it from cache reuse). Each linked ExecutionNode is then re-queued so the + normal launch path builds a fresh pod, up to `_MAX_TRANSIENT_INFRA_RETRIES` + attempts; beyond that the node is failed (SYSTEM_ERROR) and its downstream + skipped so the run fails fast instead of hanging until its timeout. + """ + # Best-effort delete of the wedged pod; it may already be gone. + try: + launched_container.terminate() + except Exception: + _logger.exception( + f"Failed to terminate wedged container execution {container_execution.id}; " + "continuing with re-queue." + ) + + session.rollback() + with session.begin(): + container_execution.status = bts.ContainerExecutionStatus.SYSTEM_ERROR + container_execution.ended_at = _get_current_time() + + execution_nodes = container_execution.execution_nodes + for execution_node in execution_nodes: + if execution_node.extra_data is None: + execution_node.extra_data = {} + retry_count = int( + execution_node.extra_data.get(_TRANSIENT_INFRA_RETRY_COUNT_KEY, 0) + ) + + # Record the reason on the node either way, for observability. + _record_orchestration_error_message( + container_execution=container_execution, + execution_nodes=[execution_node], + message=( + f"Transient infrastructure failure: {reason}. " + f"Attempt {retry_count + 1}/{_MAX_TRANSIENT_INFRA_RETRIES + 1}." + ), + ) + + if retry_count < _MAX_TRANSIENT_INFRA_RETRIES: + execution_node.extra_data[_TRANSIENT_INFRA_RETRY_COUNT_KEY] = ( + retry_count + 1 + ) + # Re-queue: the queued processor will launch a fresh pod. + # The wedged (now SYSTEM_ERROR) execution is not a cache + # reuse candidate, so a new ContainerExecution is created. + execution_node.container_execution_status = ( + bts.ContainerExecutionStatus.QUEUED + ) + _logger.info( + f"Re-queuing execution {execution_node.id} after transient " + f"infra failure ({reason}); retry " + f"{retry_count + 1}/{_MAX_TRANSIENT_INFRA_RETRIES}." + ) + else: + execution_node.container_execution_status = ( + bts.ContainerExecutionStatus.SYSTEM_ERROR + ) + record_system_error_exception( + execution=execution_node, + exception=OrchestratorError( + f"Transient infrastructure failure persisted after " + f"{_MAX_TRANSIENT_INFRA_RETRIES} retries: {reason}" + ), + ) + _mark_all_downstream_executions_as_skipped( + session=session, execution=execution_node + ) + _logger.warning( + f"Execution {execution_node.id} still hitting transient infra " + f"failure ({reason}) after {retry_count} retries " + f"(max {_MAX_TRANSIENT_INFRA_RETRIES}); failing it and skipping " + "downstream." + ) + def _get_direct_downstream_executions( session: orm.Session, execution: bts.ExecutionNode diff --git a/tests/test_kubernetes_launchers.py b/tests/test_kubernetes_launchers.py new file mode 100644 index 00000000..48f89f30 --- /dev/null +++ b/tests/test_kubernetes_launchers.py @@ -0,0 +1,165 @@ +"""Tests for cloud_pipelines_backend.launchers.kubernetes_launchers. + +Focused on `LaunchedKubernetesContainer.transient_infra_failure_reason`, which +detects pods wedged by a transient GKE gcsfuse-sidecar bucket-access-check +failure so the orchestrator can relaunch the task in place. Detection is pure +(reads the cached pod status), so these tests run offline with no API client. +""" + +from __future__ import annotations + +from kubernetes import client as k8s_client_lib + +from cloud_pipelines_backend.launchers import kubernetes_launchers +from cloud_pipelines_backend.launchers.kubernetes_launchers import ( + LaunchedKubernetesContainer, +) + + +def _make_main_container_status( + *, running: bool = False, terminated: bool = False +) -> k8s_client_lib.V1ContainerStatus: + state = k8s_client_lib.V1ContainerState() + if running: + state.running = k8s_client_lib.V1ContainerStateRunning() + elif terminated: + state.terminated = k8s_client_lib.V1ContainerStateTerminated(exit_code=0) + else: + state.waiting = k8s_client_lib.V1ContainerStateWaiting( + reason="CreateContainerConfigError" + ) + return k8s_client_lib.V1ContainerStatus( + name="main", ready=False, restart_count=0, image="img", image_id="", state=state + ) + + +def _make_sidecar_status( + *, + exit_code: int = 255, + reason: str = "Error", + message: str = "Bucket access check failed for stg-oasis-tmp", + in_last_state: bool = True, +) -> k8s_client_lib.V1ContainerStatus: + terminated = k8s_client_lib.V1ContainerStateTerminated( + exit_code=exit_code, reason=reason, message=message + ) + state = k8s_client_lib.V1ContainerState() + last_state = k8s_client_lib.V1ContainerState() + if in_last_state: + last_state.terminated = terminated + else: + state.terminated = terminated + return k8s_client_lib.V1ContainerStatus( + name=kubernetes_launchers._GCSFUSE_SIDECAR_CONTAINER_NAME, + ready=False, + restart_count=1, + image="gcsfuse", + image_id="", + state=state, + last_state=last_state, + ) + + +def _make_pod( + *, + sidecar_status: k8s_client_lib.V1ContainerStatus | None, + main_status: k8s_client_lib.V1ContainerStatus | None = None, +) -> k8s_client_lib.V1Pod: + if main_status is None: + main_status = _make_main_container_status() + init_statuses = [sidecar_status] if sidecar_status is not None else [] + return k8s_client_lib.V1Pod( + api_version="v1", + kind="Pod", + metadata=k8s_client_lib.V1ObjectMeta( + name="task-abc-orig", + generate_name="task-abc-", + namespace="kueue-jobs-staging", + annotations={"gke-gcsfuse/volumes": "true"}, + ), + spec=k8s_client_lib.V1PodSpec( + restart_policy="Never", + containers=[k8s_client_lib.V1Container(name="main")], + init_containers=[ + k8s_client_lib.V1Container( + name=kubernetes_launchers._GCSFUSE_SIDECAR_CONTAINER_NAME + ) + ], + ), + status=k8s_client_lib.V1PodStatus( + phase="Pending", + container_statuses=[main_status], + init_container_statuses=init_statuses, + ), + ) + + +def _make_container(pod: k8s_client_lib.V1Pod) -> LaunchedKubernetesContainer: + return LaunchedKubernetesContainer( + pod_name=pod.metadata.name, + namespace=pod.metadata.namespace, + output_uris={"out": "gs://bucket/out"}, + log_uri="gs://bucket/log", + debug_pod=pod, + launcher=None, + ) + + +def test_wedged_sidecar_returns_reason(): + pod = _make_pod(sidecar_status=_make_sidecar_status()) + container = _make_container(pod) + + reason = container.transient_infra_failure_reason() + + assert reason is not None + assert "gke-gcsfuse-sidecar" in reason + + +def test_wedge_in_current_state_is_detected(): + pod = _make_pod(sidecar_status=_make_sidecar_status(in_last_state=False)) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is not None + + +def test_running_main_container_is_not_a_wedge(): + pod = _make_pod( + sidecar_status=_make_sidecar_status(), + main_status=_make_main_container_status(running=True), + ) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is None + + +def test_terminated_main_container_is_not_a_wedge(): + pod = _make_pod( + sidecar_status=_make_sidecar_status(), + main_status=_make_main_container_status(terminated=True), + ) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is None + + +def test_sidecar_clean_exit_is_not_a_wedge(): + pod = _make_pod( + sidecar_status=_make_sidecar_status(exit_code=0, reason="Completed") + ) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is None + + +def test_unrelated_failure_message_is_not_a_wedge(): + pod = _make_pod(sidecar_status=_make_sidecar_status(message="some other failure")) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is None + + +def test_no_sidecar_is_not_a_wedge(): + pod = _make_pod(sidecar_status=None) + container = _make_container(pod) + assert container.transient_infra_failure_reason() is None + + +def test_default_interface_reason_is_none(): + # The base interface default reports no transient failure. + sentinel = kubernetes_launchers.interfaces.LaunchedContainer() + assert sentinel.transient_infra_failure_reason() is None diff --git a/tests/test_orchestrator_transient_infra.py b/tests/test_orchestrator_transient_infra.py new file mode 100644 index 00000000..ec0cdf0d --- /dev/null +++ b/tests/test_orchestrator_transient_infra.py @@ -0,0 +1,196 @@ +"""Tests for the orchestrator's transient-infra self-heal (re-queue in place). + +When a launched container reports a transient infrastructure failure (e.g. a +wedged gcsfuse sidecar), the orchestrator terminates the wedged pod, marks its +ContainerExecution SYSTEM_ERROR, and re-queues the same ExecutionNode within the +same run — up to `_MAX_TRANSIENT_INFRA_RETRIES` attempts, after which the node is +failed and its downstream skipped. The launcher and storage are faked so the +tests run offline. +""" + +from __future__ import annotations + +from typing import Callable +from unittest import mock + +import sqlalchemy as sql +from sqlalchemy import orm + +from cloud_pipelines_backend import api_server_sql +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend import component_structures +from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend import orchestrator_sql +from cloud_pipelines_backend.launchers import interfaces as launcher_interfaces + +_USER = "user1" + + +def _initialize_db_and_get_session_factory() -> Callable[[], orm.Session]: + db_engine = database_ops.create_db_engine_and_migrate_db(database_uri="sqlite://") + return lambda: orm.Session(bind=db_engine) + + +def _make_single_container_root_task() -> component_structures.TaskSpec: + container_component = component_structures.ComponentSpec( + name="wedge-test", + implementation=component_structures.ContainerImplementation( + container=component_structures.ContainerSpec(image="python") + ), + ) + container_task = component_structures.TaskSpec( + component_ref=component_structures.ComponentReference(spec=container_component), + ) + pipeline_spec = component_structures.ComponentSpec( + name="wedge-test-pipeline", + implementation=component_structures.GraphImplementation( + graph=component_structures.GraphSpec(tasks={"task": container_task}) + ), + ) + return component_structures.TaskSpec( + component_ref=component_structures.ComponentReference(spec=pipeline_spec), + ) + + +class _FakeLaunchedContainer: + """Always PENDING; optionally reports a transient infra failure reason.""" + + def __init__(self, *, reason: str | None = None): + self._reason = reason + self.terminate_calls = 0 + + @property + def status(self) -> launcher_interfaces.ContainerStatus: + return launcher_interfaces.ContainerStatus.PENDING + + def to_dict(self) -> dict: + return {"fake_launcher_data": True} + + def transient_infra_failure_reason(self) -> str | None: + return self._reason + + def terminate(self) -> None: + self.terminate_calls += 1 + + def upload_log(self) -> None: + pass + + +class _FakeLauncher: + """Launches PENDING containers; refresh reports `self.reason` (if set).""" + + def __init__(self): + self.reason: str | None = None + self.launch_count = 0 + self.refreshed_containers: list[_FakeLaunchedContainer] = [] + + def launch_container_task(self, **kwargs) -> _FakeLaunchedContainer: + self.launch_count += 1 + return _FakeLaunchedContainer() + + def deserialize_launched_container_from_dict(self, d) -> _FakeLaunchedContainer: + return _FakeLaunchedContainer() + + def get_refreshed_launched_container_from_dict(self, d) -> _FakeLaunchedContainer: + container = _FakeLaunchedContainer(reason=self.reason) + self.refreshed_containers.append(container) + return container + + +def _make_orchestrator(session_factory, launcher): + return orchestrator_sql.OrchestratorService_Sql( + session_factory=session_factory, + launcher=launcher, + storage_provider=mock.MagicMock(), + data_root_uri="file:///tmp/artifacts", + logs_root_uri="file:///tmp/logs", + ) + + +def _create_run(session_factory): + api_server_sql.PipelineRunsApiService_Sql().create( + session=session_factory(), + root_task=_make_single_container_root_task(), + created_by=_USER, + ) + + +def _all_container_executions(session_factory) -> list[bts.ContainerExecution]: + with session_factory() as session: + return list(session.execute(sql.select(bts.ContainerExecution)).scalars().all()) + + +def _the_execution_node(session_factory) -> bts.ExecutionNode: + with session_factory() as session: + nodes = list(session.execute(sql.select(bts.ExecutionNode)).scalars().all()) + # A single-container run has exactly one leaf ExecutionNode. + leaf_nodes = [n for n in nodes if n.container_execution_cache_key is not None] + assert len(leaf_nodes) == 1, [n.id for n in nodes] + return leaf_nodes[0] + + +def test_transient_infra_failure_requeues_node(): + session_factory = _initialize_db_and_get_session_factory() + launcher = _FakeLauncher() + launcher.reason = "gke-gcsfuse-sidecar bucket-access-check timeout" + orchestrator = _make_orchestrator(session_factory, launcher) + + _create_run(session_factory) + + # One sweep: queued launches the task (PENDING), running detects the wedge + # and re-queues the node. + orchestrator.process_each_queue_once() + + # The wedged attempt is terminal SYSTEM_ERROR (excluded from cache reuse). + container_executions = _all_container_executions(session_factory) + assert len(container_executions) == 1 + assert container_executions[0].status == bts.ContainerExecutionStatus.SYSTEM_ERROR + assert container_executions[0].ended_at is not None + + # The pod was terminated. + assert launcher.refreshed_containers[-1].terminate_calls == 1 + + # The node is re-queued with retry count 1 and the reason recorded. + node = _the_execution_node(session_factory) + assert node.container_execution_status == bts.ContainerExecutionStatus.QUEUED + assert node.extra_data[orchestrator_sql._TRANSIENT_INFRA_RETRY_COUNT_KEY] == 1 + assert ( + launcher.reason + in node.extra_data[ + bts.EXECUTION_NODE_EXTRA_DATA_ORCHESTRATION_ERROR_MESSAGE_KEY + ] + ) + + +def test_transient_infra_failure_relaunches_until_cap(): + session_factory = _initialize_db_and_get_session_factory() + launcher = _FakeLauncher() + launcher.reason = "gke-gcsfuse-sidecar bucket-access-check timeout" + orchestrator = _make_orchestrator(session_factory, launcher) + + _create_run(session_factory) + + # Each sweep relaunches a fresh pod that wedges again. After + # _MAX_TRANSIENT_INFRA_RETRIES re-queues, the next wedge fails the node. + max_retries = orchestrator_sql._MAX_TRANSIENT_INFRA_RETRIES + for _ in range(max_retries + 1): + orchestrator.process_each_queue_once() + + # A fresh pod was launched for the initial attempt plus each retry. + assert launcher.launch_count == max_retries + 1 + + # Every wedged attempt is its own terminal SYSTEM_ERROR ContainerExecution. + container_executions = _all_container_executions(session_factory) + assert len(container_executions) == max_retries + 1 + assert all( + ce.status == bts.ContainerExecutionStatus.SYSTEM_ERROR + for ce in container_executions + ) + + # The node itself is now failed (no further re-queue). + node = _the_execution_node(session_factory) + assert node.container_execution_status == bts.ContainerExecutionStatus.SYSTEM_ERROR + assert ( + node.extra_data[orchestrator_sql._TRANSIENT_INFRA_RETRY_COUNT_KEY] + == max_retries + )