From cf59b7cf178718a280782f5087fad58ccd1fbdbc Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Mon, 13 Apr 2026 14:35:26 +0100 Subject: [PATCH] [v3-2-test] Add fast-path heartbeat UPDATE to avoid row lock in the common case (#65029) * Add fast-path heartbeat UPDATE to avoid row lock in the common case The ti_heartbeat endpoint now attempts a single guarded UPDATE (matching id, state, hostname, and pid) before falling back to the existing SELECT FOR UPDATE path. When the task is still running on the expected host this returns immediately, eliminating the row lock and a round trip for the vast majority of heartbeat calls. * fixup! Add fast-path heartbeat UPDATE to avoid row lock in the common case * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix static checks * fixup! Fix static checks --------- (cherry picked from commit c97d1a510253db727d9e2bc237ce04e537341872) Co-authored-by: Ephraim Anierobi Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../execution_api/routes/task_instances.py | 23 ++- .../versions/head/test_task_instances.py | 162 ++++++++++++++++++ 2 files changed, 183 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index e1687206d5547..674186a7c9bdc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -695,8 +695,27 @@ def ti_heartbeat( bind_contextvars(ti_id=str(task_instance_id)) log.debug("Processing heartbeat", hostname=ti_payload.hostname, pid=ti_payload.pid) - # Hot path: since heartbeating a task is a very common operation, we try to do minimize the number of queries - # and DB round trips as much as possible. + # Hot path: in the common case the TI is still running on the same host and pid, + # so we can update last_heartbeat_at directly without first taking a row lock. + fast_path_result = cast( + "CursorResult[Any]", + session.execute( + update(TI) + .where( + TI.id == task_instance_id, + TI.state == TaskInstanceState.RUNNING, + TI.hostname == ti_payload.hostname, + TI.pid == ti_payload.pid, + ) + .values(last_heartbeat_at=timezone.utcnow()) + .execution_options(synchronize_session=False) + ), + ) + if fast_path_result.rowcount is not None and fast_path_result.rowcount > 0: + log.debug("Heartbeat updated via fast path") + return + + log.debug("Heartbeat fast path missed; falling back to diagnostic checks") old = select(TI.state, TI.hostname, TI.pid).where(TI.id == task_instance_id).with_for_update() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index c6135711be9bf..f7e50b2d8c9ef 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -72,6 +72,30 @@ DEFAULT_RENDERED_MAP_INDEX = "test rendered map index" +def _where_column_keys(statement) -> set[str]: + whereclause = getattr(statement, "whereclause", None) + if whereclause is None: + return set() + + keys: set[str] = set() + stack = [whereclause] + while stack: + clause = stack.pop() + left = getattr(clause, "left", None) + if left is not None and getattr(left, "key", None) is not None: + keys.add(left.key) + stack.extend(clause.get_children()) + return keys + + +def _is_task_instance_update(statement) -> bool: + return getattr(statement, "is_update", False) and statement.table.name == TaskInstance.__table__.name + + +def _is_select_for_update(statement) -> bool: + return getattr(statement, "is_select", False) and "FOR UPDATE" in str(statement.compile()).upper() + + def _create_asset_aliases(session, num: int = 2) -> None: asset_aliases = [ AssetAliasModel( @@ -1929,6 +1953,144 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_ti_heartbeat_fast_path_skips_fallback( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """When the fast-path UPDATE succeeds, the fallback path does not run.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fast_path_skips_fallback", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + task_instance_updates = [] + for_update_selects = [] + + def counting_execute(session_obj, statement, *args, **kwargs): + if _is_task_instance_update(statement): + task_instance_updates.append(statement) + if _is_select_for_update(statement): + for_update_selects.append(statement) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", counting_execute) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert len(task_instance_updates) == 1 + assert _where_column_keys(task_instance_updates[0]) == {"id", "state", "hostname", "pid"} + assert len(for_update_selects) == 0 + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + + def test_ti_heartbeat_fallback_updates_on_fast_path_miss( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """When the fast-path UPDATE returns rowcount=0 the fallback path should + still update last_heartbeat_at.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fallback_updates_on_fast_path_miss", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + fast_path_intercepted = False + + def execute_with_fast_path_miss(session_obj, statement, *args, **kwargs): + nonlocal fast_path_intercepted + if ( + not fast_path_intercepted + and getattr(statement, "is_update", False) + and statement.table.name == TaskInstance.__table__.name + ): + fast_path_intercepted = True + return mock.MagicMock(rowcount=0) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", execute_with_fast_path_miss) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert fast_path_intercepted + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + + def test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """A truthy-but-unknown rowcount must not be treated as fast-path success.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + fast_path_intercepted = False + + def execute_with_unknown_fast_path_rowcount(session_obj, statement, *args, **kwargs): + nonlocal fast_path_intercepted + if ( + not fast_path_intercepted + and getattr(statement, "is_update", False) + and statement.table.name == TaskInstance.__table__.name + ): + fast_path_intercepted = True + return mock.MagicMock(rowcount=-1) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", execute_with_unknown_fast_path_rowcount) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert fast_path_intercepted + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + class TestTIPutRTIF: def setup_method(self):