Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cloud_pipelines_backend/launchers/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def get_refreshed(
) -> typing_extensions.Self:
raise NotImplementedError()

def transient_infra_failure_reason(self) -> str | None:
"""Short human-readable reason if this container is wedged by a transient
infrastructure fault that warrants relaunching the task, else None.

Default: None (no transient fault). Launchers override this to detect
launcher-specific transient faults (e.g. a wedged CSI sidecar). The
orchestrator relaunches the task in place when a reason is returned.
"""
return None

def get_log(self) -> str:
raise NotImplementedError()

Expand Down
70 changes: 70 additions & 0 deletions cloud_pipelines_backend/launchers/kubernetes_launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
_MAX_INPUT_VALUE_SIZE = 10000
_MAIN_CONTAINER_NAME = "main"

# Detection of pods wedged by a transient GKE gcsfuse-sidecar failure.
# The GKE-injected gcsfuse sidecar runs a bucket-access-check pre-flight that
# resolves Workload Identity via the metadata server. When that endpoint is
# degraded the call times out, the sidecar fatals (exit 255), the GCS volume
# never mounts, and the main container wedges in CreateContainerConfigError
# until the run-level timeout cancels it. We surface that signature so the
# orchestrator can relaunch the task in place (usually on a healthier node).
_GCSFUSE_SIDECAR_CONTAINER_NAME = "gke-gcsfuse-sidecar"
_GCSFUSE_SIDECAR_FATAL_EXIT_CODE = 255
_GCSFUSE_WEDGE_MESSAGE_SUBSTRING = "bucket access check"
_GCSFUSE_WEDGE_REASONS = frozenset({"ContainerStatusUnknown", "Error", "StartError"})


# Kubernetes annotation keys. (Has strict naming policy. Single slash only etc.)
_CLOUD_PIPELINES_KUBERNETES_ANNOTATION_KEY = "cloud-pipelines.net"
Expand Down Expand Up @@ -919,6 +931,64 @@ def get_refreshed(self) -> "LaunchedKubernetesContainer":
new_launched_container._debug_pod = pod
return new_launched_container

def _detect_wedged_gcsfuse_sidecar(self) -> bool:
"""Return True if the pod is wedged by a fatal gcsfuse-sidecar pre-flight.

Signature: the gke-gcsfuse-sidecar init container terminated with exit
255, a transient reason, and a bucket-access-check message, while the
main container has not started. See the module-level constants.
"""
pod_status: k8s_client_lib.V1PodStatus | None = self._debug_pod.status
if pod_status is None:
return False
# Never relaunch a pod whose main container has already started or ended;
# the wedge only matters while the main container is still blocked.
main_container_state = self._get_main_container_state()
if main_container_state is not None and (
main_container_state.running is not None
or main_container_state.terminated is not None
):
return False

init_container_statuses: list[k8s_client_lib.V1ContainerStatus] = (
pod_status.init_container_statuses or []
)
for container_status in init_container_statuses:
if container_status.name != _GCSFUSE_SIDECAR_CONTAINER_NAME:
continue
state: k8s_client_lib.V1ContainerState | None = container_status.state
last_state: k8s_client_lib.V1ContainerState | None = (
container_status.last_state
)
terminated_states = [
s.terminated
for s in (state, last_state)
if s is not None and s.terminated is not None
]
for terminated in terminated_states:
message = terminated.message or ""
if (
terminated.exit_code == _GCSFUSE_SIDECAR_FATAL_EXIT_CODE
and terminated.reason in _GCSFUSE_WEDGE_REASONS
and _GCSFUSE_WEDGE_MESSAGE_SUBSTRING in message.lower()
):
return True
return False

def transient_infra_failure_reason(self) -> str | None:
"""Reason string when the pod is wedged by a transient gcsfuse-sidecar
bucket-access-check failure, else None.

The orchestrator relaunches the task in place (same run/node) when this
returns a reason, so we do not mutate the pod here.
"""
if self._detect_wedged_gcsfuse_sidecar():
return (
"gke-gcsfuse-sidecar bucket-access-check timeout "
"(GKE metadata server unreachable)"
)
return None

def get_log(self) -> str:
launcher = self._get_launcher()
core_api_client = k8s_client_lib.CoreV1Api(api_client=launcher._api_client)
Expand Down
105 changes: 105 additions & 0 deletions cloud_pipelines_backend/orchestrator_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
DYNAMIC_DATA_SECRET_KEY = "secret"
DYNAMIC_DATA_SECRET_NAME_KEY = "name"

# Per-node retry budget for transient infrastructure failures (e.g. a wedged
# gcsfuse sidecar). Each attempt is one fresh pod; the count is tracked on the
# ExecutionNode's extra_data so it survives across orchestrator polls/relaunches.
_MAX_TRANSIENT_INFRA_RETRIES = 2
_TRANSIENT_INFRA_RETRY_COUNT_KEY = "transient_infra_retry_count"


class OrchestratorError(RuntimeError):
pass
Expand Down Expand Up @@ -753,6 +759,20 @@ def internal_process_one_running_execution(
reloaded_launched_container: launcher_interfaces.LaunchedContainer = (
self._launcher.get_refreshed_launched_container_from_dict(launcher_data)
)
# Self-heal transient infra failures (e.g. a wedged gcsfuse sidecar) by
# relaunching the task in place: error out this attempt and re-queue the
# execution node so the normal launch path builds a fresh pod.
transient_infra_failure_reason = (
reloaded_launched_container.transient_infra_failure_reason()
)
if transient_infra_failure_reason is not None:
self._handle_transient_infra_failure(
session=session,
container_execution=container_execution,
launched_container=reloaded_launched_container,
reason=transient_infra_failure_reason,
)
return
current_time = _get_current_time()
# Saving the updated launcher data
reloaded_launcher_data = reloaded_launched_container.to_dict()
Expand Down Expand Up @@ -968,6 +988,91 @@ def _maybe_preload_value(
)
session.commit()

def _handle_transient_infra_failure(
self,
*,
session: orm.Session,
container_execution: bts.ContainerExecution,
launched_container: launcher_interfaces.LaunchedContainer,
reason: str,
):
"""Relaunch a task wedged by a transient infrastructure failure.

The current pod attempt is terminated and its ContainerExecution is
marked SYSTEM_ERROR (which both records the failed attempt and excludes
it from cache reuse). Each linked ExecutionNode is then re-queued so the
normal launch path builds a fresh pod, up to `_MAX_TRANSIENT_INFRA_RETRIES`
attempts; beyond that the node is failed (SYSTEM_ERROR) and its downstream
skipped so the run fails fast instead of hanging until its timeout.
"""
# Best-effort delete of the wedged pod; it may already be gone.
try:
launched_container.terminate()
except Exception:
_logger.exception(
f"Failed to terminate wedged container execution {container_execution.id}; "
"continuing with re-queue."
)

session.rollback()
with session.begin():
container_execution.status = bts.ContainerExecutionStatus.SYSTEM_ERROR
container_execution.ended_at = _get_current_time()

execution_nodes = container_execution.execution_nodes
for execution_node in execution_nodes:
if execution_node.extra_data is None:
execution_node.extra_data = {}
retry_count = int(
execution_node.extra_data.get(_TRANSIENT_INFRA_RETRY_COUNT_KEY, 0)
)

# Record the reason on the node either way, for observability.
_record_orchestration_error_message(
container_execution=container_execution,
execution_nodes=[execution_node],
message=(
f"Transient infrastructure failure: {reason}. "
f"Attempt {retry_count + 1}/{_MAX_TRANSIENT_INFRA_RETRIES + 1}."
),
)

if retry_count < _MAX_TRANSIENT_INFRA_RETRIES:
execution_node.extra_data[_TRANSIENT_INFRA_RETRY_COUNT_KEY] = (
retry_count + 1
)
# Re-queue: the queued processor will launch a fresh pod.
# The wedged (now SYSTEM_ERROR) execution is not a cache
# reuse candidate, so a new ContainerExecution is created.
execution_node.container_execution_status = (
bts.ContainerExecutionStatus.QUEUED
)
_logger.info(
f"Re-queuing execution {execution_node.id} after transient "
f"infra failure ({reason}); retry "
f"{retry_count + 1}/{_MAX_TRANSIENT_INFRA_RETRIES}."
)
else:
execution_node.container_execution_status = (
bts.ContainerExecutionStatus.SYSTEM_ERROR
)
record_system_error_exception(
execution=execution_node,
exception=OrchestratorError(
f"Transient infrastructure failure persisted after "
f"{_MAX_TRANSIENT_INFRA_RETRIES} retries: {reason}"
),
)
_mark_all_downstream_executions_as_skipped(
session=session, execution=execution_node
)
_logger.warning(
f"Execution {execution_node.id} still hitting transient infra "
f"failure ({reason}) after {retry_count} retries "
f"(max {_MAX_TRANSIENT_INFRA_RETRIES}); failing it and skipping "
"downstream."
)


def _get_direct_downstream_executions(
session: orm.Session, execution: bts.ExecutionNode
Expand Down
165 changes: 165 additions & 0 deletions tests/test_kubernetes_launchers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Tests for cloud_pipelines_backend.launchers.kubernetes_launchers.

Focused on `LaunchedKubernetesContainer.transient_infra_failure_reason`, which
detects pods wedged by a transient GKE gcsfuse-sidecar bucket-access-check
failure so the orchestrator can relaunch the task in place. Detection is pure
(reads the cached pod status), so these tests run offline with no API client.
"""

from __future__ import annotations

from kubernetes import client as k8s_client_lib

from cloud_pipelines_backend.launchers import kubernetes_launchers
from cloud_pipelines_backend.launchers.kubernetes_launchers import (
LaunchedKubernetesContainer,
)


def _make_main_container_status(
*, running: bool = False, terminated: bool = False
) -> k8s_client_lib.V1ContainerStatus:
state = k8s_client_lib.V1ContainerState()
if running:
state.running = k8s_client_lib.V1ContainerStateRunning()
elif terminated:
state.terminated = k8s_client_lib.V1ContainerStateTerminated(exit_code=0)
else:
state.waiting = k8s_client_lib.V1ContainerStateWaiting(
reason="CreateContainerConfigError"
)
return k8s_client_lib.V1ContainerStatus(
name="main", ready=False, restart_count=0, image="img", image_id="", state=state
)


def _make_sidecar_status(
*,
exit_code: int = 255,
reason: str = "Error",
message: str = "Bucket access check failed for stg-oasis-tmp",
in_last_state: bool = True,
) -> k8s_client_lib.V1ContainerStatus:
terminated = k8s_client_lib.V1ContainerStateTerminated(
exit_code=exit_code, reason=reason, message=message
)
state = k8s_client_lib.V1ContainerState()
last_state = k8s_client_lib.V1ContainerState()
if in_last_state:
last_state.terminated = terminated
else:
state.terminated = terminated
return k8s_client_lib.V1ContainerStatus(
name=kubernetes_launchers._GCSFUSE_SIDECAR_CONTAINER_NAME,
ready=False,
restart_count=1,
image="gcsfuse",
image_id="",
state=state,
last_state=last_state,
)


def _make_pod(
*,
sidecar_status: k8s_client_lib.V1ContainerStatus | None,
main_status: k8s_client_lib.V1ContainerStatus | None = None,
) -> k8s_client_lib.V1Pod:
if main_status is None:
main_status = _make_main_container_status()
init_statuses = [sidecar_status] if sidecar_status is not None else []
return k8s_client_lib.V1Pod(
api_version="v1",
kind="Pod",
metadata=k8s_client_lib.V1ObjectMeta(
name="task-abc-orig",
generate_name="task-abc-",
namespace="kueue-jobs-staging",
annotations={"gke-gcsfuse/volumes": "true"},
),
spec=k8s_client_lib.V1PodSpec(
restart_policy="Never",
containers=[k8s_client_lib.V1Container(name="main")],
init_containers=[
k8s_client_lib.V1Container(
name=kubernetes_launchers._GCSFUSE_SIDECAR_CONTAINER_NAME
)
],
),
status=k8s_client_lib.V1PodStatus(
phase="Pending",
container_statuses=[main_status],
init_container_statuses=init_statuses,
),
)


def _make_container(pod: k8s_client_lib.V1Pod) -> LaunchedKubernetesContainer:
return LaunchedKubernetesContainer(
pod_name=pod.metadata.name,
namespace=pod.metadata.namespace,
output_uris={"out": "gs://bucket/out"},
log_uri="gs://bucket/log",
debug_pod=pod,
launcher=None,
)


def test_wedged_sidecar_returns_reason():
pod = _make_pod(sidecar_status=_make_sidecar_status())
container = _make_container(pod)

reason = container.transient_infra_failure_reason()

assert reason is not None
assert "gke-gcsfuse-sidecar" in reason


def test_wedge_in_current_state_is_detected():
pod = _make_pod(sidecar_status=_make_sidecar_status(in_last_state=False))
container = _make_container(pod)
assert container.transient_infra_failure_reason() is not None


def test_running_main_container_is_not_a_wedge():
pod = _make_pod(
sidecar_status=_make_sidecar_status(),
main_status=_make_main_container_status(running=True),
)
container = _make_container(pod)
assert container.transient_infra_failure_reason() is None


def test_terminated_main_container_is_not_a_wedge():
pod = _make_pod(
sidecar_status=_make_sidecar_status(),
main_status=_make_main_container_status(terminated=True),
)
container = _make_container(pod)
assert container.transient_infra_failure_reason() is None


def test_sidecar_clean_exit_is_not_a_wedge():
pod = _make_pod(
sidecar_status=_make_sidecar_status(exit_code=0, reason="Completed")
)
container = _make_container(pod)
assert container.transient_infra_failure_reason() is None


def test_unrelated_failure_message_is_not_a_wedge():
pod = _make_pod(sidecar_status=_make_sidecar_status(message="some other failure"))
container = _make_container(pod)
assert container.transient_infra_failure_reason() is None


def test_no_sidecar_is_not_a_wedge():
pod = _make_pod(sidecar_status=None)
container = _make_container(pod)
assert container.transient_infra_failure_reason() is None


def test_default_interface_reason_is_none():
# The base interface default reports no transient failure.
sentinel = kubernetes_launchers.interfaces.LaunchedContainer()
assert sentinel.transient_infra_failure_reason() is None
Loading
Loading