diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index e1687206d5547..674186a7c9bdc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -695,8 +695,27 @@ def ti_heartbeat( bind_contextvars(ti_id=str(task_instance_id)) log.debug("Processing heartbeat", hostname=ti_payload.hostname, pid=ti_payload.pid) - # Hot path: since heartbeating a task is a very common operation, we try to do minimize the number of queries - # and DB round trips as much as possible. + # Hot path: in the common case the TI is still running on the same host and pid, + # so we can update last_heartbeat_at directly without first taking a row lock. + fast_path_result = cast( + "CursorResult[Any]", + session.execute( + update(TI) + .where( + TI.id == task_instance_id, + TI.state == TaskInstanceState.RUNNING, + TI.hostname == ti_payload.hostname, + TI.pid == ti_payload.pid, + ) + .values(last_heartbeat_at=timezone.utcnow()) + .execution_options(synchronize_session=False) + ), + ) + if fast_path_result.rowcount is not None and fast_path_result.rowcount > 0: + log.debug("Heartbeat updated via fast path") + return + + log.debug("Heartbeat fast path missed; falling back to diagnostic checks") old = select(TI.state, TI.hostname, TI.pid).where(TI.id == task_instance_id).with_for_update() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index c6135711be9bf..f7e50b2d8c9ef 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -72,6 +72,30 @@ DEFAULT_RENDERED_MAP_INDEX = "test rendered map index" +def _where_column_keys(statement) -> set[str]: + whereclause = getattr(statement, "whereclause", None) + if whereclause is None: + return set() + + keys: set[str] = set() + stack = [whereclause] + while stack: + clause = stack.pop() + left = getattr(clause, "left", None) + if left is not None and getattr(left, "key", None) is not None: + keys.add(left.key) + stack.extend(clause.get_children()) + return keys + + +def _is_task_instance_update(statement) -> bool: + return getattr(statement, "is_update", False) and statement.table.name == TaskInstance.__table__.name + + +def _is_select_for_update(statement) -> bool: + return getattr(statement, "is_select", False) and "FOR UPDATE" in str(statement.compile()).upper() + + def _create_asset_aliases(session, num: int = 2) -> None: asset_aliases = [ AssetAliasModel( @@ -1929,6 +1953,144 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_ti_heartbeat_fast_path_skips_fallback( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """When the fast-path UPDATE succeeds, the fallback path does not run.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fast_path_skips_fallback", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + task_instance_updates = [] + for_update_selects = [] + + def counting_execute(session_obj, statement, *args, **kwargs): + if _is_task_instance_update(statement): + task_instance_updates.append(statement) + if _is_select_for_update(statement): + for_update_selects.append(statement) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", counting_execute) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert len(task_instance_updates) == 1 + assert _where_column_keys(task_instance_updates[0]) == {"id", "state", "hostname", "pid"} + assert len(for_update_selects) == 0 + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + + def test_ti_heartbeat_fallback_updates_on_fast_path_miss( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """When the fast-path UPDATE returns rowcount=0 the fallback path should + still update last_heartbeat_at.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fallback_updates_on_fast_path_miss", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + fast_path_intercepted = False + + def execute_with_fast_path_miss(session_obj, statement, *args, **kwargs): + nonlocal fast_path_intercepted + if ( + not fast_path_intercepted + and getattr(statement, "is_update", False) + and statement.table.name == TaskInstance.__table__.name + ): + fast_path_intercepted = True + return mock.MagicMock(rowcount=0) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", execute_with_fast_path_miss) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert fast_path_intercepted + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + + def test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount( + self, client, session, create_task_instance, monkeypatch, time_machine + ): + """A truthy-but-unknown rowcount must not be treated as fast-path success.""" + time_now = timezone.parse("2024-10-31T12:00:00Z") + time_machine.move_to(time_now, tick=False) + + ti = create_task_instance( + task_id="test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=time_now, + session=session, + ) + session.commit() + + new_time = time_now.add(minutes=10) + time_machine.move_to(new_time, tick=False) + + original_execute = Session.execute + fast_path_intercepted = False + + def execute_with_unknown_fast_path_rowcount(session_obj, statement, *args, **kwargs): + nonlocal fast_path_intercepted + if ( + not fast_path_intercepted + and getattr(statement, "is_update", False) + and statement.table.name == TaskInstance.__table__.name + ): + fast_path_intercepted = True + return mock.MagicMock(rowcount=-1) + return original_execute(session_obj, statement, *args, **kwargs) + + monkeypatch.setattr(Session, "execute", execute_with_unknown_fast_path_rowcount) + + response = client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + + assert response.status_code == 204 + assert fast_path_intercepted + session.refresh(ti) + assert ti.last_heartbeat_at == new_time + class TestTIPutRTIF: def setup_method(self):