Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,27 @@ def ti_heartbeat(
bind_contextvars(ti_id=str(task_instance_id))
log.debug("Processing heartbeat", hostname=ti_payload.hostname, pid=ti_payload.pid)

# Hot path: since heartbeating a task is a very common operation, we try to do minimize the number of queries
# and DB round trips as much as possible.
# Hot path: in the common case the TI is still running on the same host and pid,
# so we can update last_heartbeat_at directly without first taking a row lock.
fast_path_result = cast(
"CursorResult[Any]",
session.execute(
update(TI)
.where(
TI.id == task_instance_id,
TI.state == TaskInstanceState.RUNNING,
TI.hostname == ti_payload.hostname,
TI.pid == ti_payload.pid,
)
.values(last_heartbeat_at=timezone.utcnow())
.execution_options(synchronize_session=False)
),
)
if fast_path_result.rowcount is not None and fast_path_result.rowcount > 0:
log.debug("Heartbeat updated via fast path")
return

log.debug("Heartbeat fast path missed; falling back to diagnostic checks")

old = select(TI.state, TI.hostname, TI.pid).where(TI.id == task_instance_id).with_for_update()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,30 @@
DEFAULT_RENDERED_MAP_INDEX = "test rendered map index"


def _where_column_keys(statement) -> set[str]:
whereclause = getattr(statement, "whereclause", None)
if whereclause is None:
return set()

keys: set[str] = set()
stack = [whereclause]
while stack:
clause = stack.pop()
left = getattr(clause, "left", None)
if left is not None and getattr(left, "key", None) is not None:
keys.add(left.key)
stack.extend(clause.get_children())
return keys


def _is_task_instance_update(statement) -> bool:
return getattr(statement, "is_update", False) and statement.table.name == TaskInstance.__table__.name


def _is_select_for_update(statement) -> bool:
return getattr(statement, "is_select", False) and "FOR UPDATE" in str(statement.compile()).upper()


def _create_asset_aliases(session, num: int = 2) -> None:
asset_aliases = [
AssetAliasModel(
Expand Down Expand Up @@ -1929,6 +1953,144 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)

def test_ti_heartbeat_fast_path_skips_fallback(
self, client, session, create_task_instance, monkeypatch, time_machine
):
"""When the fast-path UPDATE succeeds, the fallback path does not run."""
time_now = timezone.parse("2024-10-31T12:00:00Z")
time_machine.move_to(time_now, tick=False)

ti = create_task_instance(
task_id="test_ti_heartbeat_fast_path_skips_fallback",
state=State.RUNNING,
hostname="random-hostname",
pid=1547,
last_heartbeat_at=time_now,
session=session,
)
session.commit()

new_time = time_now.add(minutes=10)
time_machine.move_to(new_time, tick=False)

original_execute = Session.execute
task_instance_updates = []
for_update_selects = []

def counting_execute(session_obj, statement, *args, **kwargs):
if _is_task_instance_update(statement):
task_instance_updates.append(statement)
if _is_select_for_update(statement):
for_update_selects.append(statement)
return original_execute(session_obj, statement, *args, **kwargs)

monkeypatch.setattr(Session, "execute", counting_execute)

response = client.put(
f"/execution/task-instances/{ti.id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

assert response.status_code == 204
assert len(task_instance_updates) == 1
assert _where_column_keys(task_instance_updates[0]) == {"id", "state", "hostname", "pid"}
assert len(for_update_selects) == 0
session.refresh(ti)
assert ti.last_heartbeat_at == new_time

def test_ti_heartbeat_fallback_updates_on_fast_path_miss(
self, client, session, create_task_instance, monkeypatch, time_machine
):
"""When the fast-path UPDATE returns rowcount=0 the fallback path should
still update last_heartbeat_at."""
time_now = timezone.parse("2024-10-31T12:00:00Z")
time_machine.move_to(time_now, tick=False)

ti = create_task_instance(
task_id="test_ti_heartbeat_fallback_updates_on_fast_path_miss",
state=State.RUNNING,
hostname="random-hostname",
pid=1547,
last_heartbeat_at=time_now,
session=session,
)
session.commit()

new_time = time_now.add(minutes=10)
time_machine.move_to(new_time, tick=False)

original_execute = Session.execute
fast_path_intercepted = False

def execute_with_fast_path_miss(session_obj, statement, *args, **kwargs):
nonlocal fast_path_intercepted
if (
not fast_path_intercepted
and getattr(statement, "is_update", False)
and statement.table.name == TaskInstance.__table__.name
):
fast_path_intercepted = True
return mock.MagicMock(rowcount=0)
return original_execute(session_obj, statement, *args, **kwargs)

monkeypatch.setattr(Session, "execute", execute_with_fast_path_miss)

response = client.put(
f"/execution/task-instances/{ti.id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

assert response.status_code == 204
assert fast_path_intercepted
session.refresh(ti)
assert ti.last_heartbeat_at == new_time

def test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount(
self, client, session, create_task_instance, monkeypatch, time_machine
):
"""A truthy-but-unknown rowcount must not be treated as fast-path success."""
time_now = timezone.parse("2024-10-31T12:00:00Z")
time_machine.move_to(time_now, tick=False)

ti = create_task_instance(
task_id="test_ti_heartbeat_fallback_updates_on_unknown_fast_path_rowcount",
state=State.RUNNING,
hostname="random-hostname",
pid=1547,
last_heartbeat_at=time_now,
session=session,
)
session.commit()

new_time = time_now.add(minutes=10)
time_machine.move_to(new_time, tick=False)

original_execute = Session.execute
fast_path_intercepted = False

def execute_with_unknown_fast_path_rowcount(session_obj, statement, *args, **kwargs):
nonlocal fast_path_intercepted
if (
not fast_path_intercepted
and getattr(statement, "is_update", False)
and statement.table.name == TaskInstance.__table__.name
):
fast_path_intercepted = True
return mock.MagicMock(rowcount=-1)
return original_execute(session_obj, statement, *args, **kwargs)

monkeypatch.setattr(Session, "execute", execute_with_unknown_fast_path_rowcount)

response = client.put(
f"/execution/task-instances/{ti.id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

assert response.status_code == 204
assert fast_path_intercepted
session.refresh(ti)
assert ti.last_heartbeat_at == new_time


class TestTIPutRTIF:
def setup_method(self):
Expand Down
Loading