Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from data_designer.engine.context import current_row_group, current_row_group_start_offset
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.row_group_plan import (
RowGroupInput,
RowGroupPlanLike,
normalize_row_group_plan,
)
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
from data_designer.engine.dataset_builders.scheduling.queue import (
FairTaskQueue,
Expand Down Expand Up @@ -146,7 +151,7 @@ def __init__(
generators: dict[str, ColumnGenerator],
graph: ExecutionGraph,
tracker: CompletionTracker,
row_groups: list[tuple[int, int]],
row_groups: RowGroupInput,
buffer_manager: RowGroupBufferManager | None = None,
*,
max_concurrent_row_groups: int = 3,
Expand All @@ -170,7 +175,6 @@ def __init__(
progress_bar: bool = False,
scheduler_event_sink: SchedulerAdmissionEventSink | None = None,
run_id: str | None = None,
row_group_start_offsets: dict[int, int] | None = None,
initial_completed_records: int = 0,
adaptive_row_group_admission: bool = False,
adaptive_row_group_initial_target: int = 1,
Expand All @@ -180,7 +184,7 @@ def __init__(
self._generators = generators
self._graph = graph
self._tracker = tracker
self._row_groups = row_groups
self._row_groups: RowGroupPlanLike = normalize_row_group_plan(row_groups)
self._buffer_manager = buffer_manager

self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups)
Expand Down Expand Up @@ -288,17 +292,12 @@ def __init__(
self._first_non_retryable_error: Exception | None = None
self._fatal_worker_error: BaseException | None = None

# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)
self._rg_start_offset_map: dict[int, int] = row_group_start_offsets or self._build_row_group_start_offsets(
row_groups
)
self._max_concurrent_row_groups = max_concurrent_row_groups
self._max_in_flight_tasks = max_in_flight_tasks
self._max_model_task_admission = max_model_task_admission
self._num_records = num_records
self._buffer_size = buffer_size
self._scheduled_records = sum(size for _, size in row_groups)
self._scheduled_records = self._row_groups.scheduled_total_rows
self._initial_completed_records = initial_completed_records
self._observed_max_row_groups_in_flight = 0
self._observed_max_task_leases_by_resource: dict[str, int] = {}
Expand Down Expand Up @@ -331,15 +330,6 @@ def __init__(
self._progress_bar = StickyProgressBar() if progress_bar else None
self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval)

@staticmethod
def _build_row_group_start_offsets(row_groups: list[tuple[int, int]]) -> dict[int, int]:
offsets: dict[int, int] = {}
next_offset = 0
for rg_id, rg_size in row_groups:
offsets[rg_id] = next_offset
next_offset += rg_size
return offsets

def _setup_async_progress_reporter(
self,
num_records: int,
Expand Down Expand Up @@ -531,17 +521,16 @@ def _scheduler_health_diagnostics(self, *, reason: str) -> dict[str, object]:
}

def _scheduler_job_diagnostics(self) -> dict[str, object]:
row_group_sizes = [size for _rg_id, size in self._row_groups]
strategies = {column: self._graph.get_strategy(column).value for column in self._graph.columns}
task_count_by_strategy = Counter(strategies.values())
return {
"run_id": self._run_id,
"num_records": self._num_records,
"buffer_size": self._buffer_size,
"row_group_count": len(self._row_groups),
"row_group_total_rows": sum(row_group_sizes),
"row_group_min_size": min(row_group_sizes, default=0),
"row_group_max_size": max(row_group_sizes, default=0),
"row_group_total_rows": self._row_groups.scheduled_total_rows,
"row_group_min_size": self._row_groups.row_group_min_size,
"row_group_max_size": self._row_groups.row_group_max_size,
"graph_column_count": len(self._graph.columns),
"graph_root_columns": tuple(self._graph.get_root_columns()),
"graph_depth": len(self._graph.get_longest_dependency_chain()),
Expand Down Expand Up @@ -859,8 +848,7 @@ def _task_flow_identity(self, task: Task) -> tuple[str, ...]:
def _max_admitted_rows_guardrail(self) -> int:
if self._num_records > 0 and self._buffer_size > 0:
return min(self._num_records, max(3 * self._buffer_size, 8192))
total_rows = sum(size for _rg_id, size in self._row_groups)
return max(1, total_rows)
return max(1, self._row_groups.scheduled_total_rows)

async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> None:
while True:
Expand Down Expand Up @@ -1543,7 +1531,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
seed_cols = self._seed_cols
if not seed_cols:
return
num_rgs = len(self._rg_size_map)
num_rgs = len(self._row_groups)
width = len(str(num_rgs))
logger.info(f"πŸš€ ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records")
seen_instances: set[int] = set()
Expand All @@ -1567,7 +1555,7 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_
"""Core task execution logic."""
num_rgs = len(self._row_groups)
token = current_row_group.set((task.row_group, num_rgs))
start_offset_token = current_row_group_start_offset.set(self._rg_start_offset_map.get(task.row_group))
start_offset_token = current_row_group_start_offset.set(self._get_rg_start_offset(task.row_group))
group = lease.item.group
identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16]
correlation_token = runtime_correlation_provider.set(
Expand Down Expand Up @@ -1961,10 +1949,16 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any:

def _get_rg_size(self, row_group: int) -> int:
try:
return self._rg_size_map[row_group]
return self._row_groups.row_group_size(row_group)
except KeyError:
raise ValueError(f"Unknown row group: {row_group}") from None

def _get_rg_start_offset(self, row_group: int) -> int | None:
try:
return self._row_groups.row_group_start_offset(row_group)
except KeyError:
return None

def task_admission_snapshot(self) -> object:
"""Return the current scheduler task-admission snapshot for diagnostics."""
return self._task_admission.view()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
from data_designer.engine.context import current_row_group, current_row_group_start_offset
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig
from data_designer.engine.dataset_builders.row_group_plan import (
CompactRowGroupPlan,
RowGroupInput,
RowGroupPlanLike,
normalize_row_group_plan,
)
from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor
from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs
from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager
Expand Down Expand Up @@ -125,16 +131,11 @@ class RowGroupResumePlan:

Attributes:
total_row_groups: Total row group count for the full target (original + extension).
remaining_row_groups: ``(rg_id, rg_size)`` for groups not yet on disk, in id order.
row_group_start_offsets: ``rg_id -> planned start offset`` for the remaining
groups, computed from the original plan so completed-group sizes are
preserved (offsets are not recomputed from the remaining list, which
would shift them when there are holes).
remaining_row_groups: lazy plan of ``(rg_id, rg_size)`` for groups not yet on disk, in id order.
"""

total_row_groups: int
remaining_row_groups: list[tuple[int, int]]
row_group_start_offsets: dict[int, int]
remaining_row_groups: CompactRowGroupPlan


def build_row_group_resume_plan(
Expand All @@ -160,32 +161,20 @@ def build_row_group_resume_plan(
completed_ids: Row-group IDs already persisted on disk.

Returns:
A ``RowGroupResumePlan`` whose ``row_group_start_offsets`` are taken from
the full original plan, so the offset for ``rg_id`` is the same whether
or not earlier groups have completed. This is what lets ordered seed
generators seek to the correct row when resuming with holes.
A ``RowGroupResumePlan`` whose remaining row-group plan preserves full
original offsets, so the offset for ``rg_id`` is the same whether or not
earlier groups have completed. This is what lets ordered seed generators
seek to the correct row when resuming with holes.
"""
num_original_groups = -(-original_target // buffer_size)
extension_records = num_records - original_target
total_row_groups = num_original_groups + -(-extension_records // buffer_size)

def _rg_size(rg_id: int) -> int:
if rg_id < num_original_groups:
return min(buffer_size, original_target - rg_id * buffer_size)
ext_group_idx = rg_id - num_original_groups
return min(buffer_size, extension_records - ext_group_idx * buffer_size)

all_start_offsets: dict[int, int] = {}
next_offset = 0
for rg_id in range(total_row_groups):
all_start_offsets[rg_id] = next_offset
next_offset += _rg_size(rg_id)

remaining_row_groups = [(rg_id, _rg_size(rg_id)) for rg_id in range(total_row_groups) if rg_id not in completed_ids]
remaining_row_groups = CompactRowGroupPlan.resume(
original_target=original_target,
num_records=num_records,
buffer_size=buffer_size,
completed_ids=completed_ids,
)
return RowGroupResumePlan(
total_row_groups=total_row_groups,
total_row_groups=remaining_row_groups.total_row_groups,
remaining_row_groups=remaining_row_groups,
row_group_start_offsets={rg_id: all_start_offsets[rg_id] for rg_id, _ in remaining_row_groups},
)


Expand Down Expand Up @@ -571,6 +560,13 @@ def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState
"(you may extend the dataset beyond the original target). "
"Use resume=ResumeMode.NEVER to start a new run."
)
original_target_num_records = metadata.get("original_target_num_records", target_num_records)
if original_target_num_records > target_num_records:
raise DatasetGenerationError(
"πŸ›‘ Cannot resume: metadata.json has original_target_num_records="
f"{original_target_num_records}, which is greater than target_num_records={target_num_records}. "
"Start a fresh run with resume=ResumeMode.NEVER, or restore a valid metadata.json."
)

meta_buffer_size = metadata.get("buffer_size")
if meta_buffer_size != buffer_size:
Expand All @@ -593,7 +589,7 @@ def _load_resume_state(self, num_records: int, buffer_size: int) -> _ResumeState
actual_num_records=actual_num_records,
buffer_size=buffer_size,
target_num_records=target_num_records,
original_target_num_records=metadata.get("original_target_num_records", target_num_records),
original_target_num_records=original_target_num_records,
completed_row_groups=completed_row_groups,
)

Expand Down Expand Up @@ -916,8 +912,7 @@ def _build_async(
settings = self._resource_provider.run_config
trace_enabled = _is_async_trace_enabled(settings)

precomputed_row_groups: list[tuple[int, int]] | None = None
row_group_start_offsets: dict[int, int] | None = None
precomputed_row_groups: RowGroupInput | None = None
initial_actual_num_records = 0
initial_total_num_batches = 0
original_target = num_records # immutable original target; overridden on resume
Expand All @@ -944,20 +939,21 @@ def _build_async(
buffer_size=buffer_size,
completed_ids=completed_ids,
)
if len(completed_ids) >= resume_plan.total_row_groups:
remaining_row_group_count = len(resume_plan.remaining_row_groups)
completed_row_group_count = resume_plan.total_row_groups - remaining_row_group_count
if remaining_row_group_count == 0:
logger.warning(
"⚠️ Dataset is already complete β€” all row groups were found in the existing artifact "
"directory. Nothing to resume. Use resume=ResumeMode.NEVER if you want to generate a new dataset."
)
return False

logger.info(
f"▢️ Resuming async run: {len(completed_ids)} of {resume_plan.total_row_groups} row group(s) already "
f"complete ({initial_actual_num_records} records), skipping them."
f"▢️ Resuming async run: {completed_row_group_count} of {resume_plan.total_row_groups} row group(s) "
f"already complete ({initial_actual_num_records} records), skipping them."
)

precomputed_row_groups = resume_plan.remaining_row_groups
row_group_start_offsets = resume_plan.row_group_start_offsets

def finalize_row_group(rg_id: int) -> None:
def on_complete(final_path: Path | str | None) -> None:
Expand All @@ -982,7 +978,6 @@ def on_complete(final_path: Path | str | None) -> None:
disable_early_shutdown=settings.disable_early_shutdown,
trace=trace_enabled,
precomputed_row_groups=precomputed_row_groups,
row_group_start_offsets=row_group_start_offsets,
initial_actual_num_records=initial_actual_num_records,
initial_total_num_batches=initial_total_num_batches,
)
Expand Down Expand Up @@ -1051,8 +1046,7 @@ def _prepare_async_run(
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
trace: bool = False,
precomputed_row_groups: list[tuple[int, int]] | None = None,
row_group_start_offsets: dict[int, int] | None = None,
precomputed_row_groups: RowGroupInput | None = None,
initial_actual_num_records: int = 0,
initial_total_num_batches: int = 0,
) -> tuple[AsyncTaskScheduler, RowGroupBufferManager]:
Expand All @@ -1078,16 +1072,9 @@ def _prepare_async_run(
gen.log_pre_generation()

if precomputed_row_groups is not None:
row_groups = precomputed_row_groups
row_groups: RowGroupPlanLike = normalize_row_group_plan(precomputed_row_groups)
else:
row_groups = []
remaining = num_records
rg_id = 0
while remaining > 0:
size = min(buffer_size, remaining)
row_groups.append((rg_id, size))
remaining -= size
rg_id += 1
row_groups = CompactRowGroupPlan.fresh(num_records=num_records, buffer_size=buffer_size)

tracker = CompletionTracker.with_graph(graph, row_groups)
buffer_manager = RowGroupBufferManager(
Expand Down Expand Up @@ -1143,7 +1130,6 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
trace=trace,
num_records=num_records,
buffer_size=buffer_size,
row_group_start_offsets=row_group_start_offsets,
initial_completed_records=initial_actual_num_records,
progress_interval=self._resource_provider.run_config.progress_interval,
progress_bar=self._resource_provider.run_config.progress_bar,
Expand Down
Loading
Loading