From 07564169cc90bfd031b5954e836a7b7d71ca88d5 Mon Sep 17 00:00:00 2001 From: "Eric W. Tramel" <1223539+eric-tramel@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:26:31 +0000 Subject: [PATCH] fix: prevent request admission timeout row drops - Classify local request-admission queue timeouts separately from provider timeouts - Preserve request-admission timeouts through async salvage like rate limits - Bound model task admission by provider/model request capacity - Add regression coverage for Issue #725 Fixes #725 Signed-off-by: Eric W. Tramel <1223539+eric-tramel@users.noreply.github.com> --- architecture/dataset-builders.md | 2 +- .../dataset_builders/async_scheduler.py | 150 ++++---- .../dataset_builders/scheduling/resolver.py | 23 +- .../dataset_builders/scheduling/resources.py | 11 +- .../engine/models/clients/errors.py | 1 + .../models/clients/model_request_executor.py | 29 +- .../src/data_designer/engine/models/errors.py | 8 + .../scheduling/test_resolver.py | 8 +- .../scheduling/test_resources.py | 12 +- .../dataset_builders/test_async_scheduler.py | 325 +++++++++++++----- .../clients/test_model_request_executor.py | 59 +++- .../tests/engine/models/test_model_errors.py | 12 + 12 files changed, 466 insertions(+), 174 deletions(-) diff --git a/architecture/dataset-builders.md b/architecture/dataset-builders.md index 06c83b49c..70b6afed4 100644 --- a/architecture/dataset-builders.md +++ b/architecture/dataset-builders.md @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`): 4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler` 5. Hooks `ProcessorRunner` for pre-batch and post-batch stages -`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. Salvage-exhausted tasks are dropped except for rate-limit failures, which stay deferred and retry after cooldown/backoff so 429s delay records rather than discard them. +`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. Salvage-exhausted tasks are dropped except for preserved retryable failures: provider rate limits and local request-admission queue timeouts stay deferred and retry after cooldown/backoff so scheduler-local pressure delays records rather than discarding them. ### Execution Graph diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 53b83b220..c3c5b7ee7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -31,6 +31,7 @@ from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.scheduling.queue import ( FairTaskQueue, + QueueView, ) from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver from data_designer.engine.dataset_builders.scheduling.resources import ( @@ -63,6 +64,7 @@ RETRYABLE_MODEL_ERRORS, GenerationValidationFailureError, ModelRateLimitError, + ModelRequestAdmissionTimeoutError, ) from data_designer.engine.models.request_admission.config import RequestAdmissionConfig from data_designer.engine.models.request_admission.resources import RequestResourceKey @@ -83,8 +85,12 @@ logger = logging.getLogger(__name__) MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 -RATE_LIMIT_RESALVAGE_BACKOFF_S: float = 0.05 -RATE_LIMIT_PRESERVATION_WARNING_INTERVAL: int = 10 +RETRYABLE_RESALVAGE_BACKOFF_S: float = 0.05 +PRESERVED_RETRYABLE_WARNING_INTERVAL: int = 10 +PRESERVED_RETRYABLE_ERRORS: tuple[type[BaseException], ...] = ( + ModelRateLimitError, + ModelRequestAdmissionTimeoutError, +) # Degraded-provider WARN: emit at most one warning per interval when the # rolling fraction of retryable errors exceeds the threshold. Distinct from @@ -190,9 +196,13 @@ def __init__( model_group_limit_multiplier=MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER, model_group_limit_cap=max_model_task_admission, ) + task_resource_limits = { + "llm_wait": max_model_task_admission, + **self._task_scheduling.request_resource_limits, + } admission_config = task_admission_config or TaskAdmissionConfig( submission_capacity=max_in_flight_tasks, - resource_limits={"llm_wait": max_model_task_admission}, + resource_limits=task_resource_limits, bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(), ) self._task_admission = TaskAdmissionController(admission_config) @@ -252,8 +262,8 @@ def __init__( # Deferred retryable failures (retried in salvage rounds) self._deferred: list[Task] = [] self._deferred_errors: dict[Task, Exception] = {} - self._rate_limit_preservation_counts: Counter[Task] = Counter() - self._rate_limit_preservation_log_state: dict[int, tuple[int, int]] = {} + self._preserved_retryable_counts: Counter[Task] = Counter() + self._preserved_retryable_log_state: dict[int, tuple[int, int]] = {} # Tracing self._trace = trace @@ -497,6 +507,8 @@ def _record_observed_task_state(self) -> None: ) def _emit_scheduler_health_snapshot(self, reason: str) -> None: + if self._scheduler_event_sink is None: + return self._emit_scheduler_event( "scheduler_health_snapshot", diagnostics=self._scheduler_health_diagnostics(reason=reason), @@ -606,7 +618,11 @@ def _request_pressure_diagnostics(self) -> dict[str, object]: }, } - def _request_pressure_item_diagnostics(self, item: SchedulableTask) -> dict[str, object]: + def _request_pressure_item_diagnostics( + self, + item: SchedulableTask, + pressure_reason: str | None = None, + ) -> dict[str, object]: if item.request_resource_key is None or self._request_pressure_provider is None: return {"request_resource": None} snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) @@ -616,7 +632,7 @@ def _request_pressure_item_diagnostics(self, item: SchedulableTask) -> dict[str, ) diagnostics: dict[str, object] = { "request_resource": _request_resource_label(item.request_resource_key), - "pressure_reason": self._request_pressure_reason(item), + "pressure_reason": pressure_reason if pressure_reason is not None else self._request_pressure_reason(item), "resource_snapshot": None, "provider_model_snapshot": None, } @@ -711,12 +727,12 @@ def _dispatch_queued_tasks(self) -> _DispatchOutcome: selection = self._fair_queue.select_next(self._is_dispatch_eligible) if selection is None: summary = self._task_admission.explain_blocked(self._fair_queue.view()) - if "group_cap" in summary.dominant_denial_reasons: + if summary.queued_count == 0: + event_kind = "queue_empty" + elif "group_cap" in summary.dominant_denial_reasons: event_kind = "group_capped" - elif summary.dominant_denial_reasons: - event_kind = "admission_blocked" else: - event_kind = "queue_empty" + event_kind = "admission_blocked" self._emit_scheduler_event( event_kind, diagnostics={ @@ -759,7 +775,11 @@ def _dispatch_queued_tasks(self) -> _DispatchOutcome: self._emit_scheduler_health_snapshot("queue_drained") return _DispatchOutcome(dispatched=dispatched) - def _is_dispatch_eligible(self, item: SchedulableTask, view: Any) -> bool: + def _is_dispatch_eligible( + self, + item: SchedulableTask, + view: QueueView, + ) -> bool: if not self._task_admission.is_eligible(item, view): return False if not self._request_pressure_advisory: @@ -786,7 +806,10 @@ def _is_dispatch_eligible(self, item: SchedulableTask, view: Any) -> bool: def _is_request_pressure_limited(self, item: SchedulableTask) -> bool: return self._request_pressure_reason(item) is not None - def _request_pressure_reason(self, item: SchedulableTask) -> str | None: + def _request_pressure_reason( + self, + item: SchedulableTask, + ) -> str | None: if item.request_resource_key is None or self._request_pressure_provider is None: return None snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) @@ -794,11 +817,12 @@ def _request_pressure_reason(self, item: SchedulableTask) -> str | None: item.request_resource_key.provider_name, item.request_resource_key.model_id, ) - if ( + provider_model_limited = ( global_snapshot is not None and global_snapshot.static_cap > 0 and global_snapshot.aggregate_in_flight >= global_snapshot.static_cap - ): + ) + if provider_model_limited: return "provider_model_aggregate_cap" if snapshot is None: return None @@ -810,10 +834,11 @@ def _request_pressure_reason(self, item: SchedulableTask) -> str | None: return "resource_limit" return None - def _has_request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> bool: - return self._request_pressure_open_peer(item, view) is not None - - def _request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> SchedulableTask | None: + def _request_pressure_open_peer( + self, + item: SchedulableTask, + view: QueueView, + ) -> SchedulableTask | None: for peer in view.first_candidate_tasks_by_group.values(): if peer.task_id == item.task_id: continue @@ -1091,8 +1116,8 @@ async def _main_dispatch_loop( if self._deferred: await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) - if self._has_rate_limited_deferred_tasks(): - await self._wait_before_rate_limit_resalvage() + if self._has_preserved_retryable_deferred_tasks(): + await self._wait_before_retryable_resalvage() continue break @@ -1102,6 +1127,8 @@ async def _main_dispatch_loop( self._run_seeds_complete_check(seed_cols) dispatch_outcome = self._dispatch_queued_tasks() + if dispatch_outcome.dispatched: + await asyncio.sleep(0) self._checkpoint_completed_row_groups(all_columns) self._maybe_update_adaptive_row_group_target() @@ -1113,7 +1140,7 @@ async def _main_dispatch_loop( await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) self._maybe_update_adaptive_row_group_target() if self._deferred and not self._in_flight: - await self._wait_before_rate_limit_resalvage() + await self._wait_before_retryable_resalvage() continue # Are we done? @@ -1185,6 +1212,8 @@ async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) if has_pre_batch: self._run_seeds_complete_check(seed_cols) dispatch_outcome = self._dispatch_queued_tasks() + if dispatch_outcome.dispatched: + await asyncio.sleep(0) has_queued = self._fair_queue.has_queued_tasks if not has_queued and not self._in_flight: break @@ -1222,7 +1251,7 @@ async def _salvage_stalled_row_groups( width = len(str(num_rgs)) for rg_id in sorted(stalled_rgs): rg_deferred = [t for t in self._deferred if t.row_group == rg_id] - if not all(self._is_preserved_rate_limit_task(task) for task in rg_deferred): + if not all(self._is_preserved_retryable_task(task) for task in rg_deferred): logger.info(f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Salvaging {len(rg_deferred)} deferred task(s)") # Partition deferred into stalled (retry now) and other (keep for later). @@ -1235,7 +1264,7 @@ async def _salvage_stalled_row_groups( exhausted = [t for t in self._deferred if t.row_group in stalled_rgs] newly_deferred = [t for t in self._deferred if t.row_group not in stalled_rgs] for task in exhausted: - if self._is_rate_limit_error(self._deferred_errors.get(task)): + if self._is_preserved_retryable_error(self._deferred_errors.get(task)): continue # If the row was already dropped by an earlier task in this loop, # the skip was already counted; don't also record a failure. @@ -1248,29 +1277,29 @@ async def _salvage_stalled_row_groups( rg_size = self._get_rg_size(task.row_group) self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column}) self._deferred_errors.pop(task, None) - self._rate_limit_preservation_counts.pop(task, None) - rate_limited_exhausted = [ - task for task in exhausted if self._is_rate_limit_error(self._deferred_errors.get(task)) + self._preserved_retryable_counts.pop(task, None) + preserved_exhausted = [ + task for task in exhausted if self._is_preserved_retryable_error(self._deferred_errors.get(task)) ] - if rate_limited_exhausted: - self._record_rate_limit_preservations(rate_limited_exhausted, num_rgs, width) - self._deferred = other_deferred + newly_deferred + rate_limited_exhausted + if preserved_exhausted: + self._record_preserved_retryables(preserved_exhausted, num_rgs, width) + self._deferred = other_deferred + newly_deferred + preserved_exhausted self._checkpoint_completed_row_groups(all_columns) - async def _wait_before_rate_limit_resalvage(self) -> None: - """Pace repeated 429-only salvage cycles to avoid a hot loop.""" + async def _wait_before_retryable_resalvage(self) -> None: + """Pace repeated preserved retryable salvage cycles to avoid a hot loop.""" self._wake_event.clear() with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(self._wake_event.wait(), timeout=self._rate_limit_resalvage_delay_seconds()) + await asyncio.wait_for(self._wake_event.wait(), timeout=self._retryable_resalvage_delay_seconds()) self._raise_if_fatal_worker_error() - def _rate_limit_resalvage_delay_seconds(self) -> float: + def _retryable_resalvage_delay_seconds(self) -> float: if self._request_pressure_provider is None: - return RATE_LIMIT_RESALVAGE_BACKOFF_S + return RETRYABLE_RESALVAGE_BACKOFF_S cooldowns = [] for task in self._deferred: - if not self._is_rate_limit_error(self._deferred_errors.get(task)): + if not self._is_preserved_retryable_error(self._deferred_errors.get(task)): continue resource = self._schedulable_task(task).request_resource_key if resource is None: @@ -1279,44 +1308,41 @@ def _rate_limit_resalvage_delay_seconds(self) -> float: if snapshot is not None and snapshot.cooldown_remaining_seconds > 0.0: cooldowns.append(snapshot.cooldown_remaining_seconds) if not cooldowns: - return RATE_LIMIT_RESALVAGE_BACKOFF_S - return max(RATE_LIMIT_RESALVAGE_BACKOFF_S, min(cooldowns)) + return RETRYABLE_RESALVAGE_BACKOFF_S + return max(RETRYABLE_RESALVAGE_BACKOFF_S, min(cooldowns)) - def _record_rate_limit_preservations(self, tasks: list[Task], num_rgs: int, width: int) -> None: + def _record_preserved_retryables(self, tasks: list[Task], num_rgs: int, width: int) -> None: by_rg: defaultdict[int, list[Task]] = defaultdict(list) for task in tasks: - self._rate_limit_preservation_counts[task] += 1 + self._preserved_retryable_counts[task] += 1 by_rg[task.row_group].append(task) for rg_id, rg_tasks in sorted(by_rg.items()): count = len(rg_tasks) - max_preservations = max(self._rate_limit_preservation_counts[task] for task in rg_tasks) - warning_bucket = max_preservations // RATE_LIMIT_PRESERVATION_WARNING_INTERVAL - last_count, last_warning_bucket = self._rate_limit_preservation_log_state.get(rg_id, (-1, -1)) + max_preservations = max(self._preserved_retryable_counts[task] for task in rg_tasks) + warning_bucket = max_preservations // PRESERVED_RETRYABLE_WARNING_INTERVAL + last_count, last_warning_bucket = self._preserved_retryable_log_state.get(rg_id, (-1, -1)) - if ( - max_preservations % RATE_LIMIT_PRESERVATION_WARNING_INTERVAL == 0 - and warning_bucket > last_warning_bucket - ): + if max_preservations % PRESERVED_RETRYABLE_WARNING_INTERVAL == 0 and warning_bucket > last_warning_bucket: logger.warning( - f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Preserving {count} rate-limited task(s) after " + f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Preserving {count} retryable task(s) after " f"{max_preservations} deferred salvage cycle(s); records will keep retrying." ) - self._rate_limit_preservation_log_state[rg_id] = (count, warning_bucket) + self._preserved_retryable_log_state[rg_id] = (count, warning_bucket) elif count != last_count: logger.info( - f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Preserving {count} rate-limited task(s); " + f"🔄 ({rg_id + 1:0{width}d}/{num_rgs}) Preserving {count} retryable task(s); " "records will keep retrying." ) - self._rate_limit_preservation_log_state[rg_id] = (count, last_warning_bucket) + self._preserved_retryable_log_state[rg_id] = (count, last_warning_bucket) - def _is_preserved_rate_limit_task(self, task: Task) -> bool: - return self._rate_limit_preservation_counts.get(task, 0) > 0 and self._is_rate_limit_error( + def _is_preserved_retryable_task(self, task: Task) -> bool: + return self._preserved_retryable_counts.get(task, 0) > 0 and self._is_preserved_retryable_error( self._deferred_errors.get(task) ) - def _has_rate_limited_deferred_tasks(self) -> bool: - return any(self._is_rate_limit_error(self._deferred_errors.get(task)) for task in self._deferred) + def _has_preserved_retryable_deferred_tasks(self) -> bool: + return any(self._is_preserved_retryable_error(self._deferred_errors.get(task)) for task in self._deferred) def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: """Checkpoint any row groups that reached completion.""" @@ -1380,16 +1406,16 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: self._deferred_errors = { task: exc for task, exc in self._deferred_errors.items() if task.row_group not in checkpointed } - self._rate_limit_preservation_counts = Counter( + self._preserved_retryable_counts = Counter( { task: count - for task, count in self._rate_limit_preservation_counts.items() + for task, count in self._preserved_retryable_counts.items() if task.row_group not in checkpointed } ) - self._rate_limit_preservation_log_state = { + self._preserved_retryable_log_state = { row_group: state - for row_group, state in self._rate_limit_preservation_log_state.items() + for row_group, state in self._preserved_retryable_log_state.items() if row_group not in checkpointed } for rg_id in checkpointed: @@ -1736,7 +1762,7 @@ async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, self._dispatched.discard(task) if not retryable: self._deferred_errors.pop(task, None) - self._rate_limit_preservation_counts.pop(task, None) + self._preserved_retryable_counts.pop(task, None) if stateful_lock_acquired: self._stateful_locks[id(generator)].release() release_result = self._task_admission.release(lease) @@ -2108,8 +2134,8 @@ def _is_retryable(exc: BaseException) -> bool: return isinstance(exc, RETRYABLE_MODEL_ERRORS) @staticmethod - def _is_rate_limit_error(exc: BaseException | None) -> bool: - return isinstance(exc, ModelRateLimitError) + def _is_preserved_retryable_error(exc: BaseException | None) -> bool: + return isinstance(exc, PRESERVED_RETRYABLE_ERRORS) @staticmethod def _is_expected_non_retryable(exc: BaseException) -> bool: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py index c2f61e1e1..851ec152f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py @@ -10,9 +10,11 @@ from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError from data_designer.engine.dataset_builders.scheduling.resources import ( SchedulableTask, + SchedulerResourceKey, SchedulerResourceRequest, TaskGroupKey, TaskGroupSpec, + request_scheduler_resource_key, stable_task_id, ) from data_designer.engine.dataset_builders.scheduling.task_model import Task @@ -48,11 +50,16 @@ def __init__( self._diagnostics: list[dict[str, object]] = [] for generator in dict.fromkeys(generators.values()): self._metadata_by_generator_id[id(generator)] = self._resolve_metadata(generator) + self._request_resource_limits = self._build_request_resource_limits() @property def diagnostics(self) -> tuple[dict[str, object], ...]: return tuple(self._diagnostics) + @property + def request_resource_limits(self) -> Mapping[SchedulerResourceKey, int]: + return dict(self._request_resource_limits) + def scheduling_for_task(self, task: Task, flow_identity: tuple[str, ...]) -> ResolvedTaskScheduling: generator = self._generators[task.column] metadata = self._metadata_by_generator_id[id(generator)] @@ -100,16 +107,30 @@ def _resolved_from_metadata( identity = (*metadata.identity, *flow_identity) admitted_limit = max(1, min(self._model_group_limit_cap, self._model_group_limit_multiplier * weight)) request_resource_key = _request_resource_key(metadata) + resource_request = {"submission": 1, "llm_wait": 1} + if request_resource_key is not None: + resource_request[request_scheduler_resource_key(request_resource_key)] = 1 return ResolvedTaskScheduling( group=TaskGroupSpec( key=TaskGroupKey(kind=metadata.kind, identity=identity), weight=float(weight), admitted_limit=admitted_limit, ), - resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}), + resource_request=SchedulerResourceRequest(resource_request), request_resource_key=request_resource_key, ) + def _build_request_resource_limits(self) -> dict[SchedulerResourceKey, int]: + limits: dict[SchedulerResourceKey, int] = {} + for metadata in self._metadata_by_generator_id.values(): + resource = _request_resource_key(metadata) + if resource is None: + continue + key = request_scheduler_resource_key(resource) + cap = max(1, metadata.weight) + limits[key] = min(limits.get(key, cap), cap) + return limits + def _request_resource_key(metadata: SchedulingMetadata) -> RequestResourceKey | None: if metadata.kind != "model": diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py index 35a0ec18f..8d13fadd2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py @@ -11,7 +11,7 @@ from data_designer.engine.dataset_builders.scheduling.task_model import Task from data_designer.engine.models.request_admission.resources import RequestResourceKey -SchedulerResourceKey = Literal["submission", "llm_wait", "local"] +SchedulerResourceKey = str @dataclass(frozen=True, order=True) @@ -39,8 +39,8 @@ class SchedulerResourceRequest: def __post_init__(self) -> None: for resource, amount in self.amounts.items(): - if resource not in {"submission", "llm_wait", "local"}: - raise ValueError(f"Unknown scheduler resource key: {resource!r}") + if not isinstance(resource, str) or not resource: + raise ValueError(f"Scheduler resource key must be a non-empty string, got {resource!r}.") if not isinstance(amount, int) or amount <= 0: raise ValueError(f"Scheduler resource amount for {resource!r} must be a positive integer.") @@ -61,3 +61,8 @@ def stable_task_id(task: Task) -> str: raw = f"{task.column}\0{task.row_group}\0{task.row_index}\0{task.task_type}".encode() digest = hashlib.sha1(raw).hexdigest()[:16] return f"task-{digest}" + + +def request_scheduler_resource_key(resource: RequestResourceKey) -> SchedulerResourceKey: + """Return the scheduler task-stage resource for a provider/model request pool.""" + return f"request:{resource.provider_name}/{resource.model_id}" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 8355ce85a..f0b79c28b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -24,6 +24,7 @@ class ProviderErrorKind(str, Enum): NOT_FOUND = "not_found" PERMISSION_DENIED = "permission_denied" RATE_LIMIT = "rate_limit" + REQUEST_ADMISSION_TIMEOUT = "request_admission_timeout" TIMEOUT = "timeout" UNPROCESSABLE_ENTITY = "unprocessable_entity" UNSUPPORTED_CAPABILITY = "unsupported_capability" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py index 721afa41a..13bc1e71c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py @@ -121,12 +121,7 @@ def _execute_sync_attempt(self, domain: RequestDomain, call: Callable[[], _T]) - try: lease = self._request_admission.acquire_sync(item) except RequestAdmissionError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc + raise self._provider_error_from_request_admission(exc) from exc try: self._emit_model_event("model_request_started", item=item, lease=lease) result = call() @@ -169,12 +164,7 @@ async def _execute_async_attempt(self, domain: RequestDomain, call: Callable[[], try: lease = await self._request_admission.acquire_async(item) except RequestAdmissionError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc + raise self._provider_error_from_request_admission(exc) from exc except asyncio.CancelledError: raise try: @@ -216,7 +206,7 @@ def _max_attempts(self) -> int: def _should_retry(self, exc: ProviderError, attempt: int) -> bool: if attempt >= self._max_attempts() - 1: return False - if isinstance(exc.__cause__, RequestAdmissionError): + if exc.kind == ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT: return False if exc.kind == ProviderErrorKind.RATE_LIMIT: return False @@ -249,6 +239,19 @@ def _release_provider_error(self, lease: RequestAdmissionLease, exc: ProviderErr outcome = RequestReleaseOutcome(kind="provider_failure") self._request_admission.release(lease, outcome) + def _provider_error_from_request_admission(self, exc: RequestAdmissionError) -> ProviderError: + kind = ( + ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT + if exc.decision.reason == "queue_timeout" + else ProviderErrorKind.TIMEOUT + ) + return ProviderError( + kind=kind, + message=str(exc), + provider_name=self._provider_name, + model_name=self._model_id, + ) + def _item(self, domain: RequestDomain) -> RequestAdmissionItem: resolved = self._resource_resolver.resolve( provider_name=self._provider_name, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index dc054cff2..c289d5c62 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -73,6 +73,9 @@ class ModelQuotaExceededError(DataDesignerError): ... class ModelTimeoutError(DataDesignerError): ... +class ModelRequestAdmissionTimeoutError(ModelTimeoutError): ... + + class ModelContextWindowExceededError(DataDesignerError): ... @@ -303,6 +306,7 @@ def _raise_from_provider_error( _KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = { ProviderErrorKind.RATE_LIMIT: ModelRateLimitError, ProviderErrorKind.QUOTA_EXCEEDED: ModelQuotaExceededError, + ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT: ModelRequestAdmissionTimeoutError, ProviderErrorKind.TIMEOUT: ModelTimeoutError, ProviderErrorKind.NOT_FOUND: ModelNotFoundError, ProviderErrorKind.PERMISSION_DENIED: ModelPermissionDeniedError, @@ -321,6 +325,10 @@ def _raise_from_provider_error( f"The request to model {model_name!r} timed out while {purpose}.", "Check your connection and try again. You may need to increase the timeout setting for the model.", ), + ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT: ( + f"Local request admission for model {model_name!r} timed out while {purpose}; the provider request was not sent.", + "Reduce request concurrency or tune the model's max_parallel_requests to match the endpoint's real capacity. For async dataset generation, also consider lowering RunConfig.max_in_flight_tasks.", + ), ProviderErrorKind.NOT_FOUND: ( f"The specified model {model_name!r} could not be found while {purpose}.", f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.", diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py index d6dfcdbab..9da5dfc1b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py @@ -69,8 +69,13 @@ def test_task_scheduling_resolver_maps_model_metadata_to_model_resource() -> Non assert schedulable.group.key.kind == "model" assert schedulable.group.weight == 3.0 assert schedulable.group.admitted_limit == 6 - assert schedulable.resource_request.amounts == {"submission": 1, "llm_wait": 1} + assert schedulable.resource_request.amounts == { + "submission": 1, + "llm_wait": 1, + "request:nvidia/nemotron": 1, + } assert schedulable.request_resource_key == RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + assert resolver.request_resource_limits == {"request:nvidia/nemotron": 3} def test_task_scheduling_resolver_records_safe_fallback_diagnostics() -> None: @@ -127,6 +132,7 @@ def generate(self, data: object) -> object: resolver = TaskSchedulingResolver({"answer": generator}) # type: ignore[arg-type] schedulable = resolver.schedulable_task(_task(), ("answer",)) assert schedulable.request_resource_key == RequestResourceKey("nvidia", "endpoint", RequestDomain.CHAT) + assert resolver.request_resource_limits == {"request:nvidia/endpoint": 2} def test_model_registry_generator_metadata_uses_custom_model_for_multi_endpoint_aliases() -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py index 935f2c074..5844a4aa1 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py @@ -21,9 +21,15 @@ def test_scheduler_resource_request_defaults_to_submission() -> None: assert request.amounts == {"submission": 1} -def test_scheduler_resource_request_rejects_unknown_resource() -> None: - with pytest.raises(ValueError, match="Unknown scheduler resource key"): - SchedulerResourceRequest({"gpu": 1}) # type: ignore[arg-type] +def test_scheduler_resource_request_accepts_dynamic_resource_keys() -> None: + request = SchedulerResourceRequest({"request:nvidia/nemotron": 1}) + + assert request.amounts == {"request:nvidia/nemotron": 1} + + +def test_scheduler_resource_request_rejects_empty_resource_key() -> None: + with pytest.raises(ValueError, match="non-empty string"): + SchedulerResourceRequest({"": 1}) def test_scheduler_resource_request_rejects_non_positive_amounts() -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index d705f674b..4bb6d5a7b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -49,6 +49,7 @@ RETRYABLE_MODEL_ERRORS, ModelInternalServerError, ModelRateLimitError, + ModelRequestAdmissionTimeoutError, ModelTimeoutError, ) from data_designer.engine.models.request_admission.config import RequestAdmissionConfig @@ -1500,7 +1501,7 @@ async def test_rate_limit_errors_do_not_trigger_early_shutdown() -> None: @pytest.mark.asyncio(loop_scope="session") async def test_preserved_429_retries_after_unrelated_early_shutdown(monkeypatch: pytest.MonkeyPatch) -> None: """Early shutdown must not turn rate-limited deferred work into dropped rows.""" - monkeypatch.setattr(async_scheduler_module, "RATE_LIMIT_RESALVAGE_BACKOFF_S", 0) + monkeypatch.setattr(async_scheduler_module, "RETRYABLE_RESALVAGE_BACKOFF_S", 0) cell = MockRateLimitThenNonRetryableGenerator( config=_expr_config("cell_out"), resource_provider=_mock_provider(), @@ -2671,41 +2672,71 @@ async def test_scheduler_429_beyond_salvage_cap_is_delayed_not_dropped() -> None assert not any(tracker.is_dropped(0, row_index) for row_index in range(num_records)) assert scheduler._deferred == [] assert scheduler._deferred_errors == {} - assert scheduler._rate_limit_preservation_counts == {} - assert scheduler._rate_limit_preservation_log_state == {} + assert scheduler._preserved_retryable_counts == {} + assert scheduler._preserved_retryable_log_state == {} @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_paces_sustained_429_resalvage( +async def test_scheduler_request_admission_timeout_beyond_salvage_cap_is_delayed_not_dropped( monkeypatch: pytest.MonkeyPatch, - caplog: pytest.LogCaptureFixture, ) -> None: - """Pure 429 loops should wait between salvage cycles instead of spinning CPU-bound.""" + """Local request-admission timeouts may outlast salvage rounds without becoming row drops.""" provider = _mock_provider() - configs = [ - SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), - LLMTextColumnConfig(name="llm_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), - ] - strategies = { - "seed": GenerationStrategy.FULL_COLUMN, - "llm_col": GenerationStrategy.CELL_BY_CELL, - } - rate_limited = MockLLMBoundRateLimitGenerator( - config=_expr_config("llm_col"), + monkeypatch.setattr(async_scheduler_module, "RETRYABLE_RESALVAGE_BACKOFF_S", 0) + config = ExpressionColumnConfig(name="llm_col", expr="'x'", dtype="str") + generator = MockRetryableErrorGenerator( + config=config, resource_provider=provider, - rate_limit_failures=10_000, + error_factory=lambda: ModelRequestAdmissionTimeoutError("request admission queue timeout"), + retryable_failures=3, ) - generators: dict[str, ColumnGenerator] = { - "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), - "llm_col": rate_limited, - } - graph = ExecutionGraph.create(configs, strategies) + graph = ExecutionGraph.create([config], {"llm_col": GenerationStrategy.CELL_BY_CELL}) row_groups = [(0, 1)] tracker = CompletionTracker.with_graph(graph, row_groups) - storage = MagicMock() - storage.dataset_name = "test" - storage.get_file_paths.return_value = {} - buffer_mgr = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators={"llm_col": generator}, + graph=graph, + tracker=tracker, + row_groups=row_groups, + salvage_max_rounds=1, + ) + await asyncio.wait_for(scheduler.run(), timeout=5.0) + + assert tracker.is_row_group_complete(0, 1, ["llm_col"]) + assert not tracker.is_dropped(0, 0) + assert generator._calls == 4 + assert scheduler._deferred == [] + assert scheduler._deferred_errors == {} + + +@pytest.mark.parametrize("retryable_kind", ["rate_limit", "request_admission_timeout"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_paces_sustained_preserved_retryable_resalvage( + retryable_kind: str, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """Preserved retryable loops should wait between salvage cycles instead of spinning CPU-bound.""" + provider = _mock_provider() + retrying_generator = ( + MockLLMBoundRateLimitGenerator( + config=_expr_config("cell_out"), + resource_provider=provider, + rate_limit_failures=10_000, + ) + if retryable_kind == "rate_limit" + else MockRetryableErrorGenerator( + config=_expr_config("cell_out"), + resource_provider=provider, + error_factory=lambda: ModelRequestAdmissionTimeoutError("request admission queue timeout"), + retryable_failures=10_000, + ) + ) + generators, graph, row_groups, tracker, buffer_mgr, _storage = _seed_plus_cell_setup( + retrying_generator, + num_records=1, + ) scheduler = AsyncTaskScheduler( generators=generators, graph=graph, @@ -2731,24 +2762,24 @@ async def controlled_resalvage_wait() -> None: second_wait_started.set() await block_second_wait.wait() - monkeypatch.setattr(scheduler, "_wait_before_rate_limit_resalvage", controlled_resalvage_wait) + monkeypatch.setattr(scheduler, "_wait_before_retryable_resalvage", controlled_resalvage_wait) with caplog.at_level(logging.INFO, logger=async_scheduler_module.__name__): run_task = asyncio.create_task(scheduler.run()) try: await asyncio.wait_for(first_wait_started.wait(), timeout=5.0) - calls_after_preserve = rate_limited._calls + calls_after_preserve = retrying_generator._calls for _ in range(5): await asyncio.sleep(0) - assert rate_limited._calls == calls_after_preserve + assert retrying_generator._calls == calls_after_preserve assert not run_task.done() release_first_wait.set() await asyncio.wait_for(second_wait_started.wait(), timeout=5.0) assert wait_calls == 2 - assert rate_limited._calls > calls_after_preserve + assert retrying_generator._calls > calls_after_preserve finally: run_task.cancel() with pytest.raises(asyncio.CancelledError): @@ -2765,38 +2796,21 @@ async def controlled_resalvage_wait() -> None: @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_only_preserves_429_when_retryable_errors_exhaust( +async def test_scheduler_drops_non_preserved_retryable_errors_when_salvage_exhausts( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Exhausted non-429 retryable errors keep the existing drop behavior.""" + """Exhausted retryable errors are dropped unless they are explicitly preserved.""" provider = _mock_provider() - monkeypatch.setattr(async_scheduler_module, "RATE_LIMIT_RESALVAGE_BACKOFF_S", 0) - configs = [ - SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), - LLMTextColumnConfig(name="llm_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), - ] - strategies = { - "seed": GenerationStrategy.FULL_COLUMN, - "llm_col": GenerationStrategy.CELL_BY_CELL, - } + monkeypatch.setattr(async_scheduler_module, "RETRYABLE_RESALVAGE_BACKOFF_S", 0) mixed = MockMixedRetryableGenerator( - config=_expr_config("llm_col"), + config=_expr_config("cell_out"), resource_provider=provider, rate_limit_failures=3, ) - generators: dict[str, ColumnGenerator] = { - "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), - "llm_col": mixed, - } - graph = ExecutionGraph.create(configs, strategies) - row_groups = [(0, 2)] - tracker = CompletionTracker.with_graph(graph, row_groups) - storage = MagicMock() - storage.dataset_name = "test" - storage.get_file_paths.return_value = {} - storage.write_batch_to_parquet_file.return_value = "/fake.parquet" - storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" - buffer_mgr = RowGroupBufferManager(storage) + generators, graph, row_groups, tracker, buffer_mgr, _storage = _seed_plus_cell_setup( + mixed, + num_records=2, + ) scheduler = AsyncTaskScheduler( generators=generators, graph=graph, @@ -2808,14 +2822,14 @@ async def test_scheduler_only_preserves_429_when_retryable_errors_exhaust( await scheduler.run() - assert tracker.is_row_group_complete(0, 2, ["seed", "llm_col"]) + assert tracker.is_row_group_complete(0, 2, ["seed", "cell_out"]) assert not tracker.is_dropped(0, 0) assert tracker.is_dropped(0, 1) - assert buffer_mgr.get_row(0, 0)["llm_col"] == "mixed_ok_0" + assert buffer_mgr.get_row(0, 0)["cell_out"] == "mixed_ok_0" assert scheduler._deferred == [] assert scheduler._deferred_errors == {} - assert scheduler._rate_limit_preservation_counts == {} - assert scheduler._rate_limit_preservation_log_state == {} + assert scheduler._preserved_retryable_counts == {} + assert scheduler._preserved_retryable_log_state == {} def test_scheduler_rejects_zero_salvage_rounds() -> None: @@ -2833,7 +2847,14 @@ def test_scheduler_rejects_zero_salvage_rounds() -> None: ) -def test_rate_limit_resalvage_delay_uses_request_cooldown() -> None: +@pytest.mark.parametrize( + "error", + [ + ModelRateLimitError("429 Too Many Requests"), + ModelRequestAdmissionTimeoutError("request admission queue timeout"), + ], +) +def test_retryable_resalvage_delay_uses_request_cooldown(error: Exception) -> None: """Scheduler-level pacing should respect request-admission cooldown when available.""" provider = _mock_provider() config = ExpressionColumnConfig(name="llm_col", expr="'x'", dtype="str") @@ -2884,9 +2905,9 @@ def global_snapshots(self) -> dict[ProviderModelKey, object]: request_pressure_provider=PressureProvider(), ) scheduler._deferred = [task] - scheduler._deferred_errors[task] = ModelRateLimitError("429 Too Many Requests") + scheduler._deferred_errors[task] = error - assert scheduler._rate_limit_resalvage_delay_seconds() == pytest.approx(0.25) + assert scheduler._retryable_resalvage_delay_seconds() == pytest.approx(0.25) @pytest.mark.asyncio(loop_scope="session") @@ -3085,23 +3106,67 @@ def __init__( *args: Any, provider_name: str = "provider", model_id: str = "model", + request_weight: int = 1, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self._provider_name = provider_name self._model_id = model_id + self._request_weight = request_weight def get_scheduling_metadata(self) -> SchedulingMetadata: return SchedulingMetadata.model( self._provider_name, self._model_id, "chat", - weight=1, + weight=self._request_weight, ) +class GatedRequestAdmissionCellGenerator(SlowModelBoundCellGenerator): + """Model-bound generator that holds initial request-admission leases.""" + + def __init__( + self, + *args: Any, + request_admission: AdaptiveRequestAdmissionController, + hold_until_active: int, + provider_name: str = "provider", + model_id: str = "model", + **kwargs: Any, + ) -> None: + super().__init__(*args, provider_name=provider_name, model_id=model_id, **kwargs) + self._request_admission = request_admission + self._hold_until_active = hold_until_active + self._resource = RequestResourceKey(provider_name, model_id, RequestDomain.CHAT) + self._started_count = 0 + self._active_leases = 0 + self.initial_leases_acquired: asyncio.Event = asyncio.Event() + self.release_held_leases: asyncio.Event = asyncio.Event() + + async def agenerate(self, data: dict) -> dict: + item = RequestAdmissionItem(self._resource, RequestGroupSpec(self._resource)) + lease = await self._request_admission.acquire_async(item) + self._started_count += 1 + started_index = self._started_count + self._active_leases += 1 + if self._active_leases >= self._hold_until_active: + self.initial_leases_acquired.set() + try: + if started_index <= self._hold_until_active: + await self.release_held_leases.wait() + data[self.config.name] = f"gated_{started_index}" + return data + finally: + self._active_leases = max(0, self._active_leases - 1) + self._request_admission.release(lease, RequestReleaseOutcome(kind="success")) + + class _StaticRequestPressureProvider: - def __init__(self, snapshots: dict[RequestResourceKey, RequestPressureSnapshot]) -> None: + def __init__( + self, + snapshots: dict[RequestResourceKey, RequestPressureSnapshot], + ) -> None: self._snapshots = snapshots @property @@ -3114,7 +3179,7 @@ def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | No def snapshots(self) -> dict[RequestResourceKey, RequestPressureSnapshot]: return dict(self._snapshots) - def global_snapshot(self, provider: str, model: str) -> None: + def global_snapshot(self, _provider: str, _model: str) -> None: return None def global_snapshots(self) -> dict[ProviderModelKey, object]: @@ -3147,6 +3212,44 @@ def _pressure_snapshot( ) +def _build_queued_model_pressure_scheduler( + *, + column: str = "pressured", + provider_name: str = "provider-a", + model_id: str = "model-a", + queued_rows: int = 5, + request_pressure_provider: Any, + scheduler_event_sink: InMemoryAdmissionEventSink | None = None, +) -> AsyncTaskScheduler: + provider = _mock_provider() + config = LLMTextColumnConfig(name=column, prompt="A", model_alias=MODEL_ALIAS) + graph = ExecutionGraph.create([config], {column: GenerationStrategy.CELL_BY_CELL}) + row_groups = [(0, queued_rows)] + generator = SlowModelBoundCellGenerator( + config=_expr_config(column), + resource_provider=provider, + provider_name=provider_name, + model_id=model_id, + delay=30.0, + ) + scheduler = AsyncTaskScheduler( + generators={column: generator}, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + request_pressure_provider=request_pressure_provider, + request_pressure_advisory=True, + scheduler_event_sink=scheduler_event_sink, + ) + scheduler._rg_states[0] = SimpleNamespace(size=queued_rows, pre_batch_done=True, in_flight_count=0) + tasks = tuple( + scheduler._schedulable_task(Task(column=column, row_group=0, row_index=row_index, task_type="cell")) + for row_index in range(queued_rows) + ) + scheduler._fair_queue.enqueue(tasks) + return scheduler + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_fair_admission_across_ready_columns() -> None: """A large ready frontier is admitted across columns instead of one column at a time.""" @@ -3768,7 +3871,7 @@ def test_scheduler_request_pressure_advisory_prefers_pressure_open_peer() -> Non request_pressure_advisory=True, scheduler_event_sink=(sink := InMemoryAdmissionEventSink()), ) - scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) + scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True, in_flight_count=0) pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) open_task = scheduler._schedulable_task(Task(column="open", row_group=0, row_index=0, task_type="cell")) scheduler._fair_queue.enqueue((pressured, open_task)) @@ -3785,34 +3888,14 @@ def test_scheduler_request_pressure_advisory_prefers_pressure_open_peer() -> Non def test_scheduler_request_pressure_advisory_preserves_liveness_when_all_candidates_pressured() -> None: - provider = _mock_provider() - configs = [LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS)] - strategies = {"pressured": GenerationStrategy.CELL_BY_CELL} - generators: dict[str, ColumnGenerator] = { - "pressured": SlowModelBoundCellGenerator( - config=_expr_config("pressured"), - resource_provider=provider, - provider_name="provider-a", - model_id="model-a", - ), - } - graph = ExecutionGraph.create(configs, strategies) - tracker = CompletionTracker.with_graph(graph, [(0, 1)]) pressured_key = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) pressure = _StaticRequestPressureProvider( {pressured_key: _pressure_snapshot(pressured_key, current_limit=1, in_flight=1, waiters=1)} ) - scheduler = AsyncTaskScheduler( - generators=generators, - graph=graph, - tracker=tracker, - row_groups=[(0, 1)], + scheduler = _build_queued_model_pressure_scheduler( + queued_rows=1, request_pressure_provider=pressure, - request_pressure_advisory=True, ) - scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) - pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) - scheduler._fair_queue.enqueue((pressured,)) selection = scheduler._fair_queue.select_next(scheduler._is_dispatch_eligible) @@ -3820,6 +3903,72 @@ def test_scheduler_request_pressure_advisory_preserves_liveness_when_all_candida assert selection.item.payload.column == "pressured" +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_request_resource_admission_avoids_creating_waiters() -> None: + provider = _mock_provider() + sink = InMemoryAdmissionEventSink() + resource = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController( + RequestAdmissionConfig( + initial_limits={resource: 4}, + default_queue_wait_timeout_seconds=0.02, + ), + event_sink=sink, + ) + request_admission.register( + provider_name="provider-a", + model_id="model-a", + alias="primary", + max_parallel_requests=4, + ) + config = LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS) + generator = GatedRequestAdmissionCellGenerator( + config=_expr_config("pressured"), + resource_provider=provider, + request_admission=request_admission, + hold_until_active=3, + provider_name="provider-a", + model_id="model-a", + request_weight=4, + ) + graph = ExecutionGraph.create([config], {"pressured": GenerationStrategy.CELL_BY_CELL}) + row_groups = [(0, 6)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={"pressured": generator}, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_in_flight_tasks=16, + max_model_task_admission=16, + request_pressure_provider=request_admission, + request_pressure_advisory=True, + scheduler_event_sink=sink, + ) + + run_task = asyncio.create_task(scheduler.run()) + try: + await asyncio.wait_for(generator.initial_leases_acquired.wait(), timeout=5.0) + for _ in range(5): + await asyncio.sleep(0) + + waiters = sum(snapshot.waiters for snapshot in request_admission.pressure.snapshots().values()) + assert waiters == 0 + assert len(scheduler._in_flight) == 3 + + generator.release_held_leases.set() + await asyncio.wait_for(run_task, timeout=5.0) + finally: + if not run_task.done(): + run_task.cancel() + with pytest.raises(asyncio.CancelledError): + await run_task + + assert tracker.is_row_group_complete(0, 6, ["pressured"]) + assert not any(tracker.is_dropped(0, row_index) for row_index in range(6)) + assert not any(event.event_kind == "request_wait_timeout" for event in sink.request_events) + + # -- Skip / conditional generation tests (async engine) ----------------------- diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py index 2806ae569..4dbae62e2 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py @@ -21,8 +21,13 @@ ImageGenerationResponse, ImagePayload, ) -from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController -from data_designer.engine.models.request_admission.resources import RequestDomain +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.controller import ( + AdaptiveRequestAdmissionController, + RequestAdmissionDenied, + RequestAdmissionError, +) +from data_designer.engine.models.request_admission.resources import RequestAdmissionItem, RequestDomain from data_designer.engine.observability import InMemoryAdmissionEventSink @@ -135,6 +140,16 @@ async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionRes return ChatCompletionResponse(AssistantMessage(content="ok")) +class _DenyingAdmissionController: + def __init__(self, reason: str) -> None: + self.reason = reason + self.acquire_sync_calls = 0 + + def acquire_sync(self, item: RequestAdmissionItem) -> object: + self.acquire_sync_calls += 1 + raise RequestAdmissionError(RequestAdmissionDenied(item=item, reason=self.reason)) + + def _executor() -> tuple[ModelRequestExecutor, AdaptiveRequestAdmissionController, _Client]: controller = AdaptiveRequestAdmissionController() controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) @@ -217,6 +232,34 @@ def test_model_request_executor_does_not_retry_provider_timeout_without_status() assert client.calls == 1 +def test_model_request_executor_classifies_sync_request_admission_queue_timeout() -> None: + controller = AdaptiveRequestAdmissionController(RequestAdmissionConfig(default_queue_wait_timeout_seconds=0.0)) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron") + + with pytest.raises(ProviderError) as exc_info: + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert exc_info.value.kind == ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT + + +def test_model_request_executor_does_not_retry_non_queue_request_admission_errors() -> None: + controller = _DenyingAdmissionController("hard_policy_denial") + executor = ModelRequestExecutor( + _Client(), + controller, + "nvidia", + "nemotron", + retry_config=RetryConfig(max_retries=2, backoff_factor=0.0), + ) + + with pytest.raises(ProviderError) as exc_info: + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert exc_info.value.kind == ProviderErrorKind.TIMEOUT + assert controller.acquire_sync_calls == 1 + + @pytest.mark.asyncio(loop_scope="session") async def test_model_request_executor_retries_async_provider_503_with_fresh_leases() -> None: sink = InMemoryAdmissionEventSink() @@ -243,6 +286,18 @@ async def test_model_request_executor_retries_async_provider_503_with_fresh_leas assert {event.request_lease_id for event in acquired} == {event.request_lease_id for event in released} +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_classifies_async_request_admission_queue_timeout() -> None: + controller = AdaptiveRequestAdmissionController(RequestAdmissionConfig(default_queue_wait_timeout_seconds=0.0)) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron") + + with pytest.raises(ProviderError) as exc_info: + await executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert exc_info.value.kind == ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT + + @pytest.mark.asyncio(loop_scope="session") async def test_model_request_executor_releases_async_cancellation() -> None: class _SlowClient(_Client): diff --git a/packages/data-designer-engine/tests/engine/models/test_model_errors.py b/packages/data-designer-engine/tests/engine/models/test_model_errors.py index b72b26003..8873aceba 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_errors.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_errors.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import re from typing import Any from unittest.mock import MagicMock @@ -22,6 +24,7 @@ ModelPermissionDeniedError, ModelQuotaExceededError, ModelRateLimitError, + ModelRequestAdmissionTimeoutError, ModelTimeoutError, ModelUnprocessableEntityError, ModelUnsupportedCapabilityError, @@ -128,6 +131,14 @@ ModelTimeoutError, f"Cause: The request to model '{stub_model_name}' timed out while {stub_purpose}.", ), + ( + ProviderError( + kind=ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT, + message="Request admission failed", + ), + ModelRequestAdmissionTimeoutError, + f"Cause: Local request admission for model '{stub_model_name}' timed out while {stub_purpose}; the provider request was not sent.", + ), ( ProviderError( kind=ProviderErrorKind.NOT_FOUND, @@ -202,6 +213,7 @@ "authentication", "api_connection", "timeout", + "request_admission_timeout", "not_found", "internal_server", "unprocessable_entity",