Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion architecture/dataset-builders.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`):
4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler`
5. Hooks `ProcessorRunner` for pre-batch and post-batch stages

`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. Salvage-exhausted tasks are dropped except for rate-limit failures, which stay deferred and retry after cooldown/backoff so 429s delay records rather than discard them.
`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. Salvage-exhausted tasks are dropped except for preserved retryable failures: provider rate limits and local request-admission queue timeouts stay deferred and retry after cooldown/backoff so scheduler-local pressure delays records rather than discarding them.

### Execution Graph

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError
from data_designer.engine.dataset_builders.scheduling.resources import (
SchedulableTask,
SchedulerResourceKey,
SchedulerResourceRequest,
TaskGroupKey,
TaskGroupSpec,
request_scheduler_resource_key,
stable_task_id,
)
from data_designer.engine.dataset_builders.scheduling.task_model import Task
Expand Down Expand Up @@ -48,11 +50,16 @@ def __init__(
self._diagnostics: list[dict[str, object]] = []
for generator in dict.fromkeys(generators.values()):
self._metadata_by_generator_id[id(generator)] = self._resolve_metadata(generator)
self._request_resource_limits = self._build_request_resource_limits()

@property
def diagnostics(self) -> tuple[dict[str, object], ...]:
return tuple(self._diagnostics)

@property
def request_resource_limits(self) -> Mapping[SchedulerResourceKey, int]:
return dict(self._request_resource_limits)

def scheduling_for_task(self, task: Task, flow_identity: tuple[str, ...]) -> ResolvedTaskScheduling:
generator = self._generators[task.column]
metadata = self._metadata_by_generator_id[id(generator)]
Expand Down Expand Up @@ -100,16 +107,30 @@ def _resolved_from_metadata(
identity = (*metadata.identity, *flow_identity)
admitted_limit = max(1, min(self._model_group_limit_cap, self._model_group_limit_multiplier * weight))
request_resource_key = _request_resource_key(metadata)
resource_request = {"submission": 1, "llm_wait": 1}
if request_resource_key is not None:
resource_request[request_scheduler_resource_key(request_resource_key)] = 1
return ResolvedTaskScheduling(
group=TaskGroupSpec(
key=TaskGroupKey(kind=metadata.kind, identity=identity),
weight=float(weight),
admitted_limit=admitted_limit,
),
resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}),
resource_request=SchedulerResourceRequest(resource_request),
request_resource_key=request_resource_key,
)

def _build_request_resource_limits(self) -> dict[SchedulerResourceKey, int]:
limits: dict[SchedulerResourceKey, int] = {}
for metadata in self._metadata_by_generator_id.values():
resource = _request_resource_key(metadata)
if resource is None:
continue
key = request_scheduler_resource_key(resource)
cap = max(1, metadata.weight)
limits[key] = min(limits.get(key, cap), cap)
return limits


def _request_resource_key(metadata: SchedulingMetadata) -> RequestResourceKey | None:
if metadata.kind != "model":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from data_designer.engine.dataset_builders.scheduling.task_model import Task
from data_designer.engine.models.request_admission.resources import RequestResourceKey

SchedulerResourceKey = Literal["submission", "llm_wait", "local"]
SchedulerResourceKey = str


@dataclass(frozen=True, order=True)
Expand Down Expand Up @@ -39,8 +39,8 @@ class SchedulerResourceRequest:

def __post_init__(self) -> None:
for resource, amount in self.amounts.items():
if resource not in {"submission", "llm_wait", "local"}:
raise ValueError(f"Unknown scheduler resource key: {resource!r}")
if not isinstance(resource, str) or not resource:
raise ValueError(f"Scheduler resource key must be a non-empty string, got {resource!r}.")
if not isinstance(amount, int) or amount <= 0:
raise ValueError(f"Scheduler resource amount for {resource!r} must be a positive integer.")

Expand All @@ -61,3 +61,8 @@ def stable_task_id(task: Task) -> str:
raw = f"{task.column}\0{task.row_group}\0{task.row_index}\0{task.task_type}".encode()
digest = hashlib.sha1(raw).hexdigest()[:16]
return f"task-{digest}"


def request_scheduler_resource_key(resource: RequestResourceKey) -> SchedulerResourceKey:
"""Return the scheduler task-stage resource for a provider/model request pool."""
return f"request:{resource.provider_name}/{resource.model_id}"
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ProviderErrorKind(str, Enum):
NOT_FOUND = "not_found"
PERMISSION_DENIED = "permission_denied"
RATE_LIMIT = "rate_limit"
REQUEST_ADMISSION_TIMEOUT = "request_admission_timeout"
TIMEOUT = "timeout"
UNPROCESSABLE_ENTITY = "unprocessable_entity"
UNSUPPORTED_CAPABILITY = "unsupported_capability"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,7 @@ def _execute_sync_attempt(self, domain: RequestDomain, call: Callable[[], _T]) -
try:
lease = self._request_admission.acquire_sync(item)
except RequestAdmissionError as exc:
raise ProviderError(
kind=ProviderErrorKind.TIMEOUT,
message=str(exc),
provider_name=self._provider_name,
model_name=self._model_id,
) from exc
raise self._provider_error_from_request_admission(exc) from exc
try:
self._emit_model_event("model_request_started", item=item, lease=lease)
result = call()
Expand Down Expand Up @@ -169,12 +164,7 @@ async def _execute_async_attempt(self, domain: RequestDomain, call: Callable[[],
try:
lease = await self._request_admission.acquire_async(item)
except RequestAdmissionError as exc:
raise ProviderError(
kind=ProviderErrorKind.TIMEOUT,
message=str(exc),
provider_name=self._provider_name,
model_name=self._model_id,
) from exc
raise self._provider_error_from_request_admission(exc) from exc
except asyncio.CancelledError:
raise
try:
Expand Down Expand Up @@ -216,7 +206,7 @@ def _max_attempts(self) -> int:
def _should_retry(self, exc: ProviderError, attempt: int) -> bool:
if attempt >= self._max_attempts() - 1:
return False
if isinstance(exc.__cause__, RequestAdmissionError):
if exc.kind == ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT:
return False
if exc.kind == ProviderErrorKind.RATE_LIMIT:
return False
Expand Down Expand Up @@ -249,6 +239,19 @@ def _release_provider_error(self, lease: RequestAdmissionLease, exc: ProviderErr
outcome = RequestReleaseOutcome(kind="provider_failure")
self._request_admission.release(lease, outcome)

def _provider_error_from_request_admission(self, exc: RequestAdmissionError) -> ProviderError:
kind = (
ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT
if exc.decision.reason == "queue_timeout"
else ProviderErrorKind.TIMEOUT
)
return ProviderError(
kind=kind,
message=str(exc),
provider_name=self._provider_name,
model_name=self._model_id,
)

def _item(self, domain: RequestDomain) -> RequestAdmissionItem:
resolved = self._resource_resolver.resolve(
provider_name=self._provider_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class ModelQuotaExceededError(DataDesignerError): ...
class ModelTimeoutError(DataDesignerError): ...


class ModelRequestAdmissionTimeoutError(ModelTimeoutError): ...


class ModelContextWindowExceededError(DataDesignerError): ...


Expand Down Expand Up @@ -303,6 +306,7 @@ def _raise_from_provider_error(
_KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = {
ProviderErrorKind.RATE_LIMIT: ModelRateLimitError,
ProviderErrorKind.QUOTA_EXCEEDED: ModelQuotaExceededError,
ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT: ModelRequestAdmissionTimeoutError,
ProviderErrorKind.TIMEOUT: ModelTimeoutError,
ProviderErrorKind.NOT_FOUND: ModelNotFoundError,
ProviderErrorKind.PERMISSION_DENIED: ModelPermissionDeniedError,
Expand All @@ -321,6 +325,10 @@ def _raise_from_provider_error(
f"The request to model {model_name!r} timed out while {purpose}.",
"Check your connection and try again. You may need to increase the timeout setting for the model.",
),
ProviderErrorKind.REQUEST_ADMISSION_TIMEOUT: (
f"Local request admission for model {model_name!r} timed out while {purpose}; the provider request was not sent.",
"Reduce request concurrency or tune the model's max_parallel_requests to match the endpoint's real capacity. For async dataset generation, also consider lowering RunConfig.max_in_flight_tasks.",
),
ProviderErrorKind.NOT_FOUND: (
f"The specified model {model_name!r} could not be found while {purpose}.",
f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ def test_task_scheduling_resolver_maps_model_metadata_to_model_resource() -> Non
assert schedulable.group.key.kind == "model"
assert schedulable.group.weight == 3.0
assert schedulable.group.admitted_limit == 6
assert schedulable.resource_request.amounts == {"submission": 1, "llm_wait": 1}
assert schedulable.resource_request.amounts == {
"submission": 1,
"llm_wait": 1,
"request:nvidia/nemotron": 1,
}
assert schedulable.request_resource_key == RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT)
assert resolver.request_resource_limits == {"request:nvidia/nemotron": 3}


def test_task_scheduling_resolver_records_safe_fallback_diagnostics() -> None:
Expand Down Expand Up @@ -127,6 +132,7 @@ def generate(self, data: object) -> object:
resolver = TaskSchedulingResolver({"answer": generator}) # type: ignore[arg-type]
schedulable = resolver.schedulable_task(_task(), ("answer",))
assert schedulable.request_resource_key == RequestResourceKey("nvidia", "endpoint", RequestDomain.CHAT)
assert resolver.request_resource_limits == {"request:nvidia/endpoint": 2}


def test_model_registry_generator_metadata_uses_custom_model_for_multi_endpoint_aliases() -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@ def test_scheduler_resource_request_defaults_to_submission() -> None:
assert request.amounts == {"submission": 1}


def test_scheduler_resource_request_rejects_unknown_resource() -> None:
with pytest.raises(ValueError, match="Unknown scheduler resource key"):
SchedulerResourceRequest({"gpu": 1}) # type: ignore[arg-type]
def test_scheduler_resource_request_accepts_dynamic_resource_keys() -> None:
request = SchedulerResourceRequest({"request:nvidia/nemotron": 1})

assert request.amounts == {"request:nvidia/nemotron": 1}


def test_scheduler_resource_request_rejects_empty_resource_key() -> None:
with pytest.raises(ValueError, match="non-empty string"):
SchedulerResourceRequest({"": 1})


def test_scheduler_resource_request_rejects_non_positive_amounts() -> None:
Expand Down
Loading
Loading