diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 53b83b220..ce3b1d11e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -28,6 +28,11 @@ from data_designer.engine.context import current_row_group, current_row_group_start_offset from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.row_group_plan import ( + RowGroupInput, + RowGroupPlanLike, + normalize_row_group_plan, +) from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.scheduling.queue import ( FairTaskQueue, @@ -146,7 +151,7 @@ def __init__( generators: dict[str, ColumnGenerator], graph: ExecutionGraph, tracker: CompletionTracker, - row_groups: list[tuple[int, int]], + row_groups: RowGroupInput, buffer_manager: RowGroupBufferManager | None = None, *, max_concurrent_row_groups: int = 3, @@ -170,7 +175,6 @@ def __init__( progress_bar: bool = False, scheduler_event_sink: SchedulerAdmissionEventSink | None = None, run_id: str | None = None, - row_group_start_offsets: dict[int, int] | None = None, initial_completed_records: int = 0, adaptive_row_group_admission: bool = False, adaptive_row_group_initial_target: int = 1, @@ -180,7 +184,7 @@ def __init__( self._generators = generators self._graph = graph self._tracker = tracker - self._row_groups = row_groups + self._row_groups: RowGroupPlanLike = normalize_row_group_plan(row_groups) self._buffer_manager = buffer_manager self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups) @@ -288,17 +292,12 @@ def __init__( self._first_non_retryable_error: Exception | None = None self._fatal_worker_error: BaseException | None = None - # Pre-compute row-group sizes for O(1) lookup - self._rg_size_map: dict[int, int] = dict(row_groups) - self._rg_start_offset_map: dict[int, int] = row_group_start_offsets or self._build_row_group_start_offsets( - row_groups - ) self._max_concurrent_row_groups = max_concurrent_row_groups self._max_in_flight_tasks = max_in_flight_tasks self._max_model_task_admission = max_model_task_admission self._num_records = num_records self._buffer_size = buffer_size - self._scheduled_records = sum(size for _, size in row_groups) + self._scheduled_records = self._row_groups.scheduled_total_rows self._initial_completed_records = initial_completed_records self._observed_max_row_groups_in_flight = 0 self._observed_max_task_leases_by_resource: dict[str, int] = {} @@ -331,15 +330,6 @@ def __init__( self._progress_bar = StickyProgressBar() if progress_bar else None self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval) - @staticmethod - def _build_row_group_start_offsets(row_groups: list[tuple[int, int]]) -> dict[int, int]: - offsets: dict[int, int] = {} - next_offset = 0 - for rg_id, rg_size in row_groups: - offsets[rg_id] = next_offset - next_offset += rg_size - return offsets - def _setup_async_progress_reporter( self, num_records: int, @@ -531,7 +521,6 @@ def _scheduler_health_diagnostics(self, *, reason: str) -> dict[str, object]: } def _scheduler_job_diagnostics(self) -> dict[str, object]: - row_group_sizes = [size for _rg_id, size in self._row_groups] strategies = {column: self._graph.get_strategy(column).value for column in self._graph.columns} task_count_by_strategy = Counter(strategies.values()) return { @@ -539,9 +528,9 @@ def _scheduler_job_diagnostics(self) -> dict[str, object]: "num_records": self._num_records, "buffer_size": self._buffer_size, "row_group_count": len(self._row_groups), - "row_group_total_rows": sum(row_group_sizes), - "row_group_min_size": min(row_group_sizes, default=0), - "row_group_max_size": max(row_group_sizes, default=0), + "row_group_total_rows": self._row_groups.scheduled_total_rows, + "row_group_min_size": self._row_groups.row_group_min_size, + "row_group_max_size": self._row_groups.row_group_max_size, "graph_column_count": len(self._graph.columns), "graph_root_columns": tuple(self._graph.get_root_columns()), "graph_depth": len(self._graph.get_longest_dependency_chain()), @@ -859,8 +848,7 @@ def _task_flow_identity(self, task: Task) -> tuple[str, ...]: def _max_admitted_rows_guardrail(self) -> int: if self._num_records > 0 and self._buffer_size > 0: return min(self._num_records, max(3 * self._buffer_size, 8192)) - total_rows = sum(size for _rg_id, size in self._row_groups) - return max(1, total_rows) + return max(1, self._row_groups.scheduled_total_rows) async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> None: while True: @@ -1543,7 +1531,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: seed_cols = self._seed_cols if not seed_cols: return - num_rgs = len(self._rg_size_map) + num_rgs = len(self._row_groups) width = len(str(num_rgs)) logger.info(f"🚀 ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records") seen_instances: set[int] = set() @@ -1567,7 +1555,7 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_ """Core task execution logic.""" num_rgs = len(self._row_groups) token = current_row_group.set((task.row_group, num_rgs)) - start_offset_token = current_row_group_start_offset.set(self._rg_start_offset_map.get(task.row_group)) + start_offset_token = current_row_group_start_offset.set(self._get_rg_start_offset(task.row_group)) group = lease.item.group identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16] correlation_token = runtime_correlation_provider.set( @@ -1961,10 +1949,16 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: def _get_rg_size(self, row_group: int) -> int: try: - return self._rg_size_map[row_group] + return self._row_groups.row_group_size(row_group) except KeyError: raise ValueError(f"Unknown row group: {row_group}") from None + def _get_rg_start_offset(self, row_group: int) -> int | None: + try: + return self._row_groups.row_group_start_offset(row_group) + except KeyError: + return None + def task_admission_snapshot(self) -> object: """Return the current scheduler task-admission snapshot for diagnostics.""" return self._task_admission.view() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 43f5d7e71..efa68edf5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -38,6 +38,12 @@ from data_designer.engine.context import current_row_group, current_row_group_start_offset from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.row_group_plan import ( + CompactRowGroupPlan, + RowGroupInput, + RowGroupPlanLike, + normalize_row_group_plan, +) from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager @@ -125,16 +131,11 @@ class RowGroupResumePlan: Attributes: total_row_groups: Total row group count for the full target (original + extension). - remaining_row_groups: ``(rg_id, rg_size)`` for groups not yet on disk, in id order. - row_group_start_offsets: ``rg_id -> planned start offset`` for the remaining - groups, computed from the original plan so completed-group sizes are - preserved (offsets are not recomputed from the remaining list, which - would shift them when there are holes). + remaining_row_groups: lazy plan of ``(rg_id, rg_size)`` for groups not yet on disk, in id order. """ total_row_groups: int - remaining_row_groups: list[tuple[int, int]] - row_group_start_offsets: dict[int, int] + remaining_row_groups: CompactRowGroupPlan def build_row_group_resume_plan( @@ -160,32 +161,20 @@ def build_row_group_resume_plan( completed_ids: Row-group IDs already persisted on disk. Returns: - A ``RowGroupResumePlan`` whose ``row_group_start_offsets`` are taken from - the full original plan, so the offset for ``rg_id`` is the same whether - or not earlier groups have completed. This is what lets ordered seed - generators seek to the correct row when resuming with holes. + A ``RowGroupResumePlan`` whose remaining row-group plan preserves full + original offsets, so the offset for ``rg_id`` is the same whether or not + earlier groups have completed. This is what lets ordered seed generators + seek to the correct row when resuming with holes. """ - num_original_groups = -(-original_target // buffer_size) - extension_records = num_records - original_target - total_row_groups = num_original_groups + -(-extension_records // buffer_size) - - def _rg_size(rg_id: int) -> int: - if rg_id < num_original_groups: - return min(buffer_size, original_target - rg_id * buffer_size) - ext_group_idx = rg_id - num_original_groups - return min(buffer_size, extension_records - ext_group_idx * buffer_size) - - all_start_offsets: dict[int, int] = {} - next_offset = 0 - for rg_id in range(total_row_groups): - all_start_offsets[rg_id] = next_offset - next_offset += _rg_size(rg_id) - - remaining_row_groups = [(rg_id, _rg_size(rg_id)) for rg_id in range(total_row_groups) if rg_id not in completed_ids] + remaining_row_groups = CompactRowGroupPlan.resume( + original_target=original_target, + num_records=num_records, + buffer_size=buffer_size, + completed_ids=completed_ids, + ) return RowGroupResumePlan( - total_row_groups=total_row_groups, + total_row_groups=remaining_row_groups.total_row_groups, remaining_row_groups=remaining_row_groups, - row_group_start_offsets={rg_id: all_start_offsets[rg_id] for rg_id, _ in remaining_row_groups}, ) @@ -571,6 +560,13 @@ def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState "(you may extend the dataset beyond the original target). " "Use resume=ResumeMode.NEVER to start a new run." ) + original_target_num_records = metadata.get("original_target_num_records", target_num_records) + if original_target_num_records > target_num_records: + raise DatasetGenerationError( + "🛑 Cannot resume: metadata.json has original_target_num_records=" + f"{original_target_num_records}, which is greater than target_num_records={target_num_records}. " + "Start a fresh run with resume=ResumeMode.NEVER, or restore a valid metadata.json." + ) meta_buffer_size = metadata.get("buffer_size") if meta_buffer_size != buffer_size: @@ -593,7 +589,7 @@ def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState actual_num_records=actual_num_records, buffer_size=buffer_size, target_num_records=target_num_records, - original_target_num_records=metadata.get("original_target_num_records", target_num_records), + original_target_num_records=original_target_num_records, completed_row_groups=completed_row_groups, ) @@ -916,8 +912,7 @@ def _build_async( settings = self._resource_provider.run_config trace_enabled = _is_async_trace_enabled(settings) - precomputed_row_groups: list[tuple[int, int]] | None = None - row_group_start_offsets: dict[int, int] | None = None + precomputed_row_groups: RowGroupInput | None = None initial_actual_num_records = 0 initial_total_num_batches = 0 original_target = num_records # immutable original target; overridden on resume @@ -944,7 +939,9 @@ def _build_async( buffer_size=buffer_size, completed_ids=completed_ids, ) - if len(completed_ids) >= resume_plan.total_row_groups: + remaining_row_group_count = len(resume_plan.remaining_row_groups) + completed_row_group_count = resume_plan.total_row_groups - remaining_row_group_count + if remaining_row_group_count == 0: logger.warning( "⚠️ Dataset is already complete — all row groups were found in the existing artifact " "directory. Nothing to resume. Use resume=ResumeMode.NEVER if you want to generate a new dataset." @@ -952,12 +949,11 @@ def _build_async( return False logger.info( - f"▶️ Resuming async run: {len(completed_ids)} of {resume_plan.total_row_groups} row group(s) already " - f"complete ({initial_actual_num_records} records), skipping them." + f"▶️ Resuming async run: {completed_row_group_count} of {resume_plan.total_row_groups} row group(s) " + f"already complete ({initial_actual_num_records} records), skipping them." ) precomputed_row_groups = resume_plan.remaining_row_groups - row_group_start_offsets = resume_plan.row_group_start_offsets def finalize_row_group(rg_id: int) -> None: def on_complete(final_path: Path | str | None) -> None: @@ -982,7 +978,6 @@ def on_complete(final_path: Path | str | None) -> None: disable_early_shutdown=settings.disable_early_shutdown, trace=trace_enabled, precomputed_row_groups=precomputed_row_groups, - row_group_start_offsets=row_group_start_offsets, initial_actual_num_records=initial_actual_num_records, initial_total_num_batches=initial_total_num_batches, ) @@ -1051,8 +1046,7 @@ def _prepare_async_run( shutdown_error_window: int = 10, disable_early_shutdown: bool = False, trace: bool = False, - precomputed_row_groups: list[tuple[int, int]] | None = None, - row_group_start_offsets: dict[int, int] | None = None, + precomputed_row_groups: RowGroupInput | None = None, initial_actual_num_records: int = 0, initial_total_num_batches: int = 0, ) -> tuple[AsyncTaskScheduler, RowGroupBufferManager]: @@ -1078,16 +1072,9 @@ def _prepare_async_run( gen.log_pre_generation() if precomputed_row_groups is not None: - row_groups = precomputed_row_groups + row_groups: RowGroupPlanLike = normalize_row_group_plan(precomputed_row_groups) else: - row_groups = [] - remaining = num_records - rg_id = 0 - while remaining > 0: - size = min(buffer_size, remaining) - row_groups.append((rg_id, size)) - remaining -= size - rg_id += 1 + row_groups = CompactRowGroupPlan.fresh(num_records=num_records, buffer_size=buffer_size) tracker = CompletionTracker.with_graph(graph, row_groups) buffer_manager = RowGroupBufferManager( @@ -1143,7 +1130,6 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: trace=trace, num_records=num_records, buffer_size=buffer_size, - row_group_start_offsets=row_group_start_offsets, initial_completed_records=initial_actual_num_records, progress_interval=self._resource_provider.run_config.progress_interval, progress_bar=self._resource_provider.run_config.progress_bar, diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/row_group_plan.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/row_group_plan.py new file mode 100644 index 000000000..1a23fabb6 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/row_group_plan.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import InitVar, dataclass, field +from typing import Protocol + + +class RowGroupPlanLike(Protocol): + """Shared scheduler-facing interface for row-group plans.""" + + @property + def scheduled_total_rows(self) -> int: ... + + @property + def row_group_min_size(self) -> int: ... + + @property + def row_group_max_size(self) -> int: ... + + def __iter__(self) -> Iterator[tuple[int, int]]: ... + + def __len__(self) -> int: ... + + def has_row_group(self, row_group: int) -> bool: ... + + def row_group_size(self, row_group: int) -> int: + """Return the scheduled size for ``row_group``. + + Raises: + KeyError: If ``row_group`` is not part of this plan. + """ + ... + + def row_group_start_offset(self, row_group: int) -> int: + """Return the original dataset start offset for ``row_group``. + + Raises: + KeyError: If ``row_group`` is not part of this plan. + """ + ... + + def describe_known_row_groups(self) -> str: ... + + +def _ceil_div(numerator: int, denominator: int) -> int: + if numerator <= 0: + return 0 + return -(-numerator // denominator) + + +@dataclass(frozen=True, slots=True) +class CompactRowGroupPlan: + """Lazy row-group plan for fresh and resumed async runs.""" + + original_target: int + num_records: int + buffer_size: int + completed_ids: InitVar[set[int] | frozenset[int]] = frozenset() + + _num_original_groups: int = field(init=False, repr=False) + _extension_records: int = field(init=False, repr=False) + _total_row_groups: int = field(init=False, repr=False) + _id_filter: frozenset[int] = field(init=False, repr=False) + _filter_includes_scheduled: bool = field(init=False, repr=False) + _scheduled_ids: tuple[int, ...] | None = field(init=False, repr=False) + _scheduled_count: int = field(init=False, repr=False) + _scheduled_total_rows: int = field(init=False, repr=False) + _scheduled_full_group_count: int = field(init=False, repr=False) + _partial_remaining_sizes: tuple[int, ...] = field(init=False, repr=False) + + def __post_init__(self, completed_ids: set[int] | frozenset[int]) -> None: + if self.original_target < 0: + raise ValueError("original_target must be non-negative.") + if self.num_records < 0: + raise ValueError("num_records must be non-negative.") + if self.num_records < self.original_target: + raise ValueError("num_records must be greater than or equal to original_target.") + if max(self.original_target, self.num_records) > 0 and self.buffer_size <= 0: + raise ValueError("buffer_size must be positive when row groups are present.") + + num_original_groups = _ceil_div(self.original_target, self.buffer_size) if self.buffer_size > 0 else 0 + extension_records = self.num_records - self.original_target + num_extension_groups = _ceil_div(extension_records, self.buffer_size) if self.buffer_size > 0 else 0 + total_row_groups = num_original_groups + num_extension_groups + + valid_completed_count = sum(1 for rg_id in completed_ids if 0 <= rg_id < total_row_groups) + # Keep the retained filter proportional to the smaller side of the resume frontier. + if valid_completed_count > total_row_groups // 2: + scheduled_ids = tuple(rg_id for rg_id in range(total_row_groups) if rg_id not in completed_ids) + id_filter = frozenset(scheduled_ids) + filter_includes_scheduled = True + scheduled_sizes = tuple( + self._row_group_size_for(rg_id, num_original_groups, extension_records) for rg_id in scheduled_ids + ) + scheduled_count = len(scheduled_ids) + scheduled_total_rows = sum(scheduled_sizes) + scheduled_full_group_count = sum(1 for size in scheduled_sizes if size == self.buffer_size) + partial_remaining_sizes = tuple(size for size in scheduled_sizes if size != self.buffer_size) + else: + id_filter = frozenset(rg_id for rg_id in completed_ids if 0 <= rg_id < total_row_groups) + filter_includes_scheduled = False + scheduled_ids = None + completed_rows = sum( + self._row_group_size_for(rg_id, num_original_groups, extension_records) for rg_id in id_filter + ) + scheduled_count = total_row_groups - len(id_filter) + scheduled_total_rows = self.num_records - completed_rows + completed_full_group_count = sum( + 1 + for rg_id in id_filter + if self._row_group_size_for(rg_id, num_original_groups, extension_records) == self.buffer_size + ) + scheduled_full_group_count = self._count_full_groups(extension_records) - completed_full_group_count + partial_remaining_sizes = tuple( + size + for rg_id, size in self._partial_group_sizes(num_original_groups, extension_records) + if rg_id not in id_filter + ) + + object.__setattr__(self, "_num_original_groups", num_original_groups) + object.__setattr__(self, "_extension_records", extension_records) + object.__setattr__(self, "_total_row_groups", total_row_groups) + object.__setattr__(self, "_id_filter", id_filter) + object.__setattr__(self, "_filter_includes_scheduled", filter_includes_scheduled) + object.__setattr__(self, "_scheduled_ids", scheduled_ids) + object.__setattr__(self, "_scheduled_count", scheduled_count) + object.__setattr__(self, "_scheduled_total_rows", scheduled_total_rows) + object.__setattr__(self, "_scheduled_full_group_count", scheduled_full_group_count) + object.__setattr__(self, "_partial_remaining_sizes", partial_remaining_sizes) + + @classmethod + def fresh(cls, *, num_records: int, buffer_size: int) -> CompactRowGroupPlan: + return cls(original_target=num_records, num_records=num_records, buffer_size=buffer_size) + + @classmethod + def resume( + cls, + *, + original_target: int, + num_records: int, + buffer_size: int, + completed_ids: set[int], + ) -> CompactRowGroupPlan: + return cls( + original_target=original_target, + num_records=num_records, + buffer_size=buffer_size, + completed_ids=frozenset(completed_ids), + ) + + def __iter__(self) -> Iterator[tuple[int, int]]: + if self._scheduled_ids is not None: + for rg_id in self._scheduled_ids: + yield rg_id, self.row_group_size(rg_id) + return + for rg_id in range(self._total_row_groups): + if rg_id not in self._id_filter: + yield rg_id, self.row_group_size(rg_id) + + def __len__(self) -> int: + return self._scheduled_count + + @property + def total_row_groups(self) -> int: + return self._total_row_groups + + @property + def scheduled_total_rows(self) -> int: + return self._scheduled_total_rows + + @property + def row_group_min_size(self) -> int: + if self._scheduled_count == 0: + return 0 + candidates = list(self._partial_remaining_sizes) + if self._scheduled_full_group_count > 0: + candidates.append(self.buffer_size) + return min(candidates) + + @property + def row_group_max_size(self) -> int: + if self._scheduled_count == 0: + return 0 + candidates = list(self._partial_remaining_sizes) + if self._scheduled_full_group_count > 0: + candidates.append(self.buffer_size) + return max(candidates) + + def has_row_group(self, row_group: int) -> bool: + if row_group < 0 or row_group >= self._total_row_groups: + return False + if self._filter_includes_scheduled: + return row_group in self._id_filter + return row_group not in self._id_filter + + def row_group_size(self, row_group: int) -> int: + if not self.has_row_group(row_group): + raise KeyError(row_group) + return self._row_group_size_for(row_group, self._num_original_groups, self._extension_records) + + def row_group_start_offset(self, row_group: int) -> int: + if not self.has_row_group(row_group): + raise KeyError(row_group) + if row_group < self._num_original_groups: + return row_group * self.buffer_size + return self.original_target + (row_group - self._num_original_groups) * self.buffer_size + + def describe_known_row_groups(self) -> str: + if self._scheduled_count == self._total_row_groups: + return f"0..{self._total_row_groups - 1}" if self._total_row_groups else "none" + return f"{self._scheduled_count} scheduled of {self._total_row_groups} total row groups" + + def _count_full_groups(self, extension_records: int) -> int: + if self.buffer_size <= 0: + return 0 + original_full = self.original_target // self.buffer_size + extension_full = extension_records // self.buffer_size if extension_records > 0 else 0 + return original_full + extension_full + + def _partial_group_sizes(self, num_original_groups: int, extension_records: int) -> tuple[tuple[int, int], ...]: + partials: list[tuple[int, int]] = [] + if self.original_target > 0 and self.original_target % self.buffer_size != 0: + partials.append((num_original_groups - 1, self.original_target % self.buffer_size)) + if extension_records > 0 and extension_records % self.buffer_size != 0: + partials.append( + ( + num_original_groups + _ceil_div(extension_records, self.buffer_size) - 1, + extension_records % self.buffer_size, + ) + ) + return tuple(partials) + + def _row_group_size_for(self, rg_id: int, num_original_groups: int, extension_records: int) -> int: + if rg_id < num_original_groups: + return min(self.buffer_size, self.original_target - rg_id * self.buffer_size) + ext_group_idx = rg_id - num_original_groups + return min(self.buffer_size, extension_records - ext_group_idx * self.buffer_size) + + +@dataclass(frozen=True, slots=True) +class ExplicitRowGroupPlan: + """Adapter for already-materialized row-group tuples used by tests and small callers.""" + + row_groups: tuple[tuple[int, int], ...] + + _sizes: dict[int, int] = field(init=False, repr=False) + _start_offsets: dict[int, int] = field(init=False, repr=False) + _scheduled_total_rows: int = field(init=False, repr=False) + + def __post_init__(self) -> None: + sizes: dict[int, int] = {} + start_offsets: dict[int, int] = {} + next_offset = 0 + for rg_id, rg_size in self.row_groups: + sizes[rg_id] = rg_size + start_offsets[rg_id] = next_offset + next_offset += rg_size + object.__setattr__(self, "_sizes", sizes) + object.__setattr__(self, "_start_offsets", start_offsets) + object.__setattr__(self, "_scheduled_total_rows", next_offset) + + def __iter__(self) -> Iterator[tuple[int, int]]: + return iter(self.row_groups) + + def __len__(self) -> int: + return len(self.row_groups) + + @property + def scheduled_total_rows(self) -> int: + return self._scheduled_total_rows + + @property + def row_group_min_size(self) -> int: + return min(self._sizes.values(), default=0) + + @property + def row_group_max_size(self) -> int: + return max(self._sizes.values(), default=0) + + def has_row_group(self, row_group: int) -> bool: + return row_group in self._sizes + + def row_group_size(self, row_group: int) -> int: + return self._sizes[row_group] + + def row_group_start_offset(self, row_group: int) -> int: + return self._start_offsets[row_group] + + def describe_known_row_groups(self) -> str: + known = sorted(self._sizes) + return str(known) + + +RowGroupInput = CompactRowGroupPlan | ExplicitRowGroupPlan | Sequence[tuple[int, int]] + + +def normalize_row_group_plan(row_groups: RowGroupInput) -> RowGroupPlanLike: + if isinstance(row_groups, CompactRowGroupPlan | ExplicitRowGroupPlan): + return row_groups + return ExplicitRowGroupPlan(tuple(row_groups)) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index b34ffe69a..855c91642 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -8,6 +8,11 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.dataset_builders.row_group_plan import ( + RowGroupInput, + RowGroupPlanLike, + normalize_row_group_plan, +) from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task @@ -44,16 +49,16 @@ def __init__(self) -> None: self._dropped: dict[int, set[int]] = defaultdict(set) self._graph: ExecutionGraph | None = None - self._row_group_sizes: dict[int, int] = {} + self._row_group_plan: RowGroupPlanLike | None = None self._batch_complete: dict[int, set[str]] = defaultdict(set) self._frontier: set[Task] = set() @classmethod - def with_graph(cls, graph: ExecutionGraph, row_groups: list[tuple[int, int]]) -> CompletionTracker: + def with_graph(cls, graph: ExecutionGraph, row_groups: RowGroupInput) -> CompletionTracker: """Create a frontier-enabled tracker backed by an execution graph.""" tracker = cls() tracker._graph = graph - tracker._row_group_sizes = dict(row_groups) + tracker._row_group_plan = normalize_row_group_plan(row_groups) return tracker def mark_cell_complete(self, column: str, row_group: int, row_index: int) -> FrontierDelta: @@ -106,7 +111,7 @@ def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: """Check if *column* has been fully completed for *row_group_index*.""" if column in self._batch_complete.get(row_group_index, set()): return True - rg_size = self._row_group_sizes.get(row_group_index, 0) + rg_size = self._row_group_size_or_default(row_group_index, default=0) if rg_size == 0: return False completed = self._completed.get(row_group_index, {}).get(column, set()) @@ -190,7 +195,9 @@ def seed_frontier(self) -> None: if self._graph is None: raise RuntimeError("This method requires a graph to be set.") for col in self._graph.get_root_columns(): - for rg_id, rg_size in self._row_group_sizes.items(): + if self._row_group_plan is None: + raise RuntimeError("This method requires row groups to be set.") + for rg_id, rg_size in self._row_group_plan: self.add_root_tasks(rg_id, rg_size, columns=(col,)) def add_root_tasks( @@ -244,7 +251,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) - rg_size = self._row_group_sizes[row_group] + rg_size = self._row_group_size(row_group) for down in sorted(self._graph.get_downstream_columns(column)): batch_ups, cell_ups = self._graph.split_upstream_by_strategy(down) @@ -295,7 +302,7 @@ def _reevaluate_batch_tasks(self, row_group: int) -> list[Task]: rg_completed = self._completed.get(row_group, {}) rg_dropped = self._dropped.get(row_group, set()) rg_batch_complete = self._batch_complete.get(row_group, set()) - rg_size = self._row_group_sizes[row_group] + rg_size = self._row_group_size(row_group) for col in self._graph.get_topological_order(): if self._graph.get_strategy(col) != GenerationStrategy.FULL_COLUMN: @@ -338,8 +345,20 @@ def _validate_row_group(self, row_group: int) -> int | None: """Validate row-group id in graph-enabled mode and return its expected size.""" if self._graph is None: return None - expected = self._row_group_sizes.get(row_group) - if expected is None: - known = sorted(self._row_group_sizes) - raise ValueError(f"Unknown row_group {row_group}. Known row_groups: {known}") - return expected + if self._row_group_plan is None: + raise RuntimeError("This method requires row groups to be set.") + if not self._row_group_plan.has_row_group(row_group): + raise ValueError( + f"Unknown row_group {row_group}. Known row_groups: {self._row_group_plan.describe_known_row_groups()}" + ) + return self._row_group_plan.row_group_size(row_group) + + def _row_group_size(self, row_group: int) -> int: + if self._row_group_plan is None: + raise RuntimeError("This method requires row groups to be set.") + return self._row_group_plan.row_group_size(row_group) + + def _row_group_size_or_default(self, row_group: int, *, default: int) -> int: + if self._row_group_plan is None or not self._row_group_plan.has_row_group(row_group): + return default + return self._row_group_plan.row_group_size(row_group) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 6f2c74c49..66ef898e6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -4,6 +4,7 @@ from __future__ import annotations import math +import tracemalloc import warnings from types import SimpleNamespace from unittest.mock import MagicMock, Mock @@ -26,6 +27,7 @@ ) from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder +from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -228,6 +230,44 @@ def __init__(self, **kwargs: object) -> None: assert captured_kwargs["max_model_task_admission"] == 64 +def test_prepare_async_run_uses_compact_plan_for_large_fresh_runs(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: dict[str, object] = {} + + class _SpyScheduler: + def __init__(self, **kwargs: object) -> None: + captured_kwargs.update(kwargs) + + monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler) + model_registry = MagicMock() + model_registry.request_admission = None + provider = SimpleNamespace( + model_registry=model_registry, + run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False), + ) + processor_runner = MagicMock() + processor_runner.has_processors_for.return_value = False + config = SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}) + builder = SimpleNamespace( + _column_configs=[config], + _processor_runner=processor_runner, + artifact_storage=MagicMock(), + _resource_provider=provider, + ) + generator = MockSeed(config=_expr_config("seed"), resource_provider=provider) + + tracemalloc.start() + try: + DatasetBuilder._prepare_async_run(builder, [generator], num_records=2_000_000, buffer_size=2) + _current, peak_bytes = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() + + row_groups = captured_kwargs["row_groups"] + assert isinstance(row_groups, CompactRowGroupPlan) + assert len(row_groups) == 1_000_000 + assert peak_bytes < 5 * 1024 * 1024 + + # -- Test that existing sync path is unaffected -------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index d705f674b..aec89f4de 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -6,6 +6,7 @@ import asyncio import logging import time +import tracemalloc from collections.abc import Callable from types import SimpleNamespace from typing import Any @@ -37,6 +38,7 @@ from data_designer.engine.context import current_row_group_start_offset from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionConfig, TaskAdmissionLease from data_designer.engine.dataset_builders.scheduling.task_model import Task @@ -425,6 +427,44 @@ def _make_storage() -> MagicMock: return storage +def test_scheduler_preparation_memory_stays_bounded_for_million_row_groups() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + + tracemalloc.start() + try: + row_groups = CompactRowGroupPlan.fresh(num_records=2_000_000, buffer_size=2) + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + num_records=2_000_000, + buffer_size=2, + ) + _current, peak_bytes = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() + + assert len(row_groups) == 1_000_000 + assert row_groups.scheduled_total_rows == 2_000_000 + assert scheduler._scheduled_records == 2_000_000 + assert peak_bytes < 5 * 1024 * 1024 + + def _seed_plus_cell_setup( cell_generator: ColumnGenerator, num_records: int, @@ -589,7 +629,12 @@ async def test_scheduler_sets_row_group_start_offsets_for_generators() -> None: configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] strategies = {"seed": GenerationStrategy.FULL_COLUMN} generators = {"seed": _OffsetSeedGenerator(config=_expr_config("seed"), resource_provider=provider)} - row_groups = [(1, 1), (3, 1)] + row_groups = CompactRowGroupPlan.resume( + original_target=4, + num_records=4, + buffer_size=1, + completed_ids={0, 2}, + ) graph = ExecutionGraph.create(configs, strategies) tracker = CompletionTracker.with_graph(graph, row_groups) @@ -602,7 +647,6 @@ async def test_scheduler_sets_row_group_start_offsets_for_generators() -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, - row_group_start_offsets={1: 1, 3: 3}, ) await scheduler.run() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 5d66b6c48..4c9a5b32b 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -5,6 +5,7 @@ import json import logging +import tracemalloc from pathlib import Path from typing import TYPE_CHECKING from unittest.mock import Mock, patch @@ -31,6 +32,7 @@ from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder, build_row_group_resume_plan from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError +from data_designer.engine.dataset_builders.row_group_plan import CompactRowGroupPlan from data_designer.engine.models.errors import ( FormattedLLMErrorMessage, ModelGenerationValidationFailureError, @@ -1876,6 +1878,26 @@ def test_build_resume_raises_when_num_records_below_original_target( builder.build(num_records=7, resume=ResumeMode.ALWAYS) +def test_build_resume_raises_when_original_target_metadata_exceeds_target( + stub_resource_provider, stub_test_config_builder, tmp_path +): + """resume=ALWAYS rejects corrupt metadata where original_target exceeds target.""" + dataset_dir = tmp_path / "dataset" + _write_metadata( + dataset_dir, + target_num_records=10, + original_target_num_records=20, + buffer_size=2, + num_completed_batches=2, + actual_num_records=4, + ) + _write_parquet_files(dataset_dir / "parquet-files", [0, 1]) + + builder = _make_resume_builder(stub_resource_provider, stub_test_config_builder, tmp_path, buffer_size=2) + with pytest.raises(DatasetGenerationError, match="original_target_num_records=20.*target_num_records=10"): + builder.build(num_records=10, resume=ResumeMode.ALWAYS) + + def test_build_resume_allows_larger_num_records(stub_resource_provider, stub_test_config_builder, tmp_path, caplog): """resume=ALWAYS succeeds when num_records > original target (extending the dataset).""" dataset_dir = tmp_path / "dataset" @@ -2436,8 +2458,58 @@ def test_row_group_resume_plan_keeps_original_offsets_for_remaining_groups() -> ) assert plan.total_row_groups == 4 - assert plan.remaining_row_groups == [(1, 1), (3, 1)] - assert plan.row_group_start_offsets == {1: 1, 3: 3} + assert not isinstance(plan.remaining_row_groups, list) + assert list(plan.remaining_row_groups) == [(1, 1), (3, 1)] + assert plan.remaining_row_groups.row_group_start_offset(1) == 1 + assert plan.remaining_row_groups.row_group_start_offset(3) == 3 + + +def test_compact_row_group_plan_rejects_negative_extension() -> None: + with pytest.raises(ValueError, match="num_records must be greater than or equal to original_target"): + CompactRowGroupPlan.resume( + original_target=10, + num_records=8, + buffer_size=2, + completed_ids=set(), + ) + + +def test_row_group_resume_plan_tracks_completed_ids_at_half_complete_boundary() -> None: + plan = CompactRowGroupPlan.resume( + original_target=12, + num_records=12, + buffer_size=2, + completed_ids={1, 2, 4}, + ) + + assert getattr(plan, "_filter_includes_scheduled") is False + assert getattr(plan, "_scheduled_ids") is None + assert list(plan) == [(0, 2), (3, 2), (5, 2)] + assert plan.scheduled_total_rows == 6 + assert plan.row_group_start_offset(5) == 10 + with pytest.raises(KeyError): + plan.row_group_size(1) + + +def test_row_group_resume_plan_stays_sparse_when_almost_complete() -> None: + completed_ids = set(range(999_998)) + + tracemalloc.start() + try: + plan = CompactRowGroupPlan.resume( + original_target=2_000_000, + num_records=2_000_000, + buffer_size=2, + completed_ids=completed_ids, + ) + remaining = list(plan) + current_bytes, _peak_bytes = tracemalloc.get_traced_memory() + finally: + tracemalloc.stop() + + assert remaining == [(999_998, 2), (999_999, 2)] + assert plan.scheduled_total_rows == 4 + assert current_bytes < 5 * 1024 * 1024 def test_initial_actual_num_records_uses_actual_parquet_rows_for_partial_row_group( @@ -2667,7 +2739,7 @@ def capturing_prepare(*args, **kwargs): builder.build(num_records=6, resume=ResumeMode.ALWAYS) # Only rg_id=1 remains; rg_id=0 and rg_id=2 are already on disk - assert captured["precomputed_row_groups"] == [(1, 2)] + assert list(captured["precomputed_row_groups"]) == [(1, 2)] def test_build_async_resume_extension_non_aligned_row_group_sizes( @@ -2712,7 +2784,7 @@ def capturing_prepare(*args, **kwargs): builder.build(num_records=7, resume=ResumeMode.ALWAYS) # rg_id=3 should have 2 records (7-5=2 extension records, buffer_size=2), not 1 - assert captured["precomputed_row_groups"] == [(3, 2)] + assert list(captured["precomputed_row_groups"]) == [(3, 2)] def test_build_async_resume_not_already_complete_when_extension_fits_in_slack(