diff --git a/gigl/distributed/base_dist_loader.py b/gigl/distributed/base_dist_loader.py index ed0783ef6..900d45279 100644 --- a/gigl/distributed/base_dist_loader.py +++ b/gigl/distributed/base_dist_loader.py @@ -9,11 +9,13 @@ """ import math +import queue import sys +import threading import time from collections import Counter, defaultdict from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Callable, Final, Optional, Union import torch from graphlearn_torch.channel import RemoteReceivingChannel, ShmChannel @@ -51,9 +53,55 @@ logger = Logger() +_COLLATION_SENTINEL: Final = object() # Signals end-of-epoch to consumer + DEFAULT_NUM_CPU_THREADS = 2 +class TimingStats: + """Accumulates timing measurements for profiling and outputs a summary.""" + + def __init__(self, name: str): + self._name = name + self._totals: dict[str, float] = defaultdict(float) + self._counts: dict[str, int] = defaultdict(int) + self._mins: dict[str, float] = {} + self._maxs: dict[str, float] = {} + self._order: list[str] = [] + + def record(self, key: str, elapsed: float) -> None: + if key not in self._totals: + self._order.append(key) + self._totals[key] += elapsed + self._counts[key] += 1 + if key not in self._mins or elapsed < self._mins[key]: + self._mins[key] = elapsed + if key not in self._maxs or elapsed > self._maxs[key]: + self._maxs[key] = elapsed + + def summary(self) -> str: + lines = [ + f"\n{'=' * 80}", + f" {self._name} — Timing Summary", + f"{'=' * 80}", + ] + for key in self._order: + total = self._totals[key] + count = self._counts[key] + avg = total / count if count > 0 else 0 + min_v = self._mins.get(key, 0) + max_v = self._maxs.get(key, 0) + if count == 1: + lines.append(f" {key:<45s} {total:>10.4f}s") + continue + lines.append( + f" {key:<45s} total={total:>10.4f}s n={count:>6d} " + f"avg={avg:>8.4f}s min={min_v:>8.4f}s max={max_v:>8.4f}s" + ) + lines.append(f"{'=' * 80}\n") + return "\n".join(lines) + + # We don't see logs for graph store mode for whatever reason. # TOOD(#442): Revert this once the GCP issues are resolved. def _flush() -> None: @@ -114,6 +162,12 @@ class BaseDistLoader(DistLoader): ``batch_index * process_start_gap_seconds`` before dispatching. Only applies to graph store mode. Defaults to ``None`` (no staggering). + background_collation_queue_size: If set to a positive integer, enables + background collation in a daemon thread. The collation of sampled + messages (via ``_collate_fn``) is performed in a background thread, + overlapping with GPU training. The value controls the maximum number + of pre-collated batches buffered in memory. ``None`` disables + background collation (default behavior). """ @staticmethod @@ -220,6 +274,7 @@ def __init__( producer: Union[DistSamplingProducer, Callable[..., int]], sampler_options: SamplerOptions, process_start_gap_seconds: float = 60.0, + background_collation_queue_size: Optional[int] = None, max_concurrent_producer_inits: Optional[int] = None, ): if max_concurrent_producer_inits is None: @@ -229,6 +284,30 @@ def __init__( # Will be set to False once connections are initialized. self._shutdowned = True + # --- Background collation setup (validate early, before heavy init) --- + if ( + background_collation_queue_size is not None + and background_collation_queue_size < 1 + ): + raise ValueError( + f"background_collation_queue_size must be >= 1 if provided, " + f"got {background_collation_queue_size}" + ) + self._background_collation_queue_size = background_collation_queue_size + self._collation_thread: Optional[threading.Thread] = None + self._collated_queue: Optional[queue.Queue] = None + self._collation_stop_event: Optional[threading.Event] = None + + # --- Timing instrumentation --- + mode_label = ( + "background_collation" + if background_collation_queue_size is not None + else "synchronous" + ) + self._timing = TimingStats(f"BaseDistLoader ({mode_label})") + self._epoch_start_time: Optional[float] = None + self._log_timing_every_n_batches: Final[int] = 10 + # Store dataset metadata for subclass _collate_fn usage self._is_homogeneous_with_labeled_edge_type = ( dataset_schema.is_homogeneous_with_labeled_edge_type @@ -799,10 +878,181 @@ def _init_graph_store_connections( ) _flush() + # --- Background collation methods --- + + def _maybe_log_timing(self) -> None: + """Log timing summary periodically and at end of epoch.""" + if self._num_recv % self._log_timing_every_n_batches == 0: + logger.info(self._timing.summary()) + _flush() + + def __next__(self): # type: ignore[override] + """Returns the next collated batch. + + When background collation is enabled, retrieves pre-collated results + from the bounded queue. Otherwise, falls back to the synchronous + path (replicated from GLT ``DistLoader``). + + Returns: + A ``Data`` or ``HeteroData`` batch. + + Raises: + StopIteration: When the epoch is exhausted. + """ + if self._background_collation_queue_size is not None: + return self._next_from_background_collation() + # Original synchronous path (replicated from GLT DistLoader) + if self._num_recv == self._num_expected: + logger.info( + f"[sync] Epoch done. Total batches: {self._num_recv}, " + f"epoch wall time: {time.time() - (self._epoch_start_time or 0.0):.2f}s" + ) + logger.info(self._timing.summary()) + _flush() + raise StopIteration + t0 = time.time() + if self._with_channel: + msg = self._channel.recv() + else: + msg = self._collocated_producer.sample() + t_recv = time.time() + self._timing.record("sync/recv", t_recv - t0) + + result = self._collate_fn(msg) + t_collate = time.time() + self._timing.record("sync/collate_fn", t_collate - t_recv) + self._timing.record("sync/total_next", t_collate - t0) + + self._num_recv += 1 + self._maybe_log_timing() + return result + + def _next_from_background_collation(self): + """Retrieves the next pre-collated batch from the background queue. + + Returns: + A ``Data`` or ``HeteroData`` batch. + + Raises: + StopIteration: On sentinel or when expected count is reached. + """ + assert self._collated_queue is not None + t0 = time.time() + qsize_before = self._collated_queue.qsize() + item = self._collated_queue.get() + t_get = time.time() + self._timing.record("bg_consumer/queue_get", t_get - t0) + self._timing.record("bg_consumer/queue_size_at_get", qsize_before) + + if item is _COLLATION_SENTINEL: + logger.info( + f"[bg] Epoch done. Total batches: {self._num_recv}, " + f"epoch wall time: {time.time() - (self._epoch_start_time or 0.0):.2f}s" + ) + logger.info(self._timing.summary()) + _flush() + raise StopIteration + if isinstance(item, BaseException): + raise item + self._num_recv += 1 + self._maybe_log_timing() + return item + + def _collation_worker(self) -> None: + """Target function for the background collation daemon thread. + + Continuously receives messages from the channel (or collocated + producer) and runs ``_collate_fn``, placing collated results into + ``_collated_queue``. Exits when the epoch batch count is reached, + a ``StopIteration`` is received from the channel, or the stop event + is set. + """ + assert self._collated_queue is not None + assert self._collation_stop_event is not None + num_produced = 0 + try: + while True: + # For finite epochs, exit after producing all expected batches + if ( + self._num_expected != float("inf") + and num_produced >= self._num_expected + ): + self._collated_queue.put(_COLLATION_SENTINEL) + return + + # Receive next sampled message + t0 = time.time() + try: + if self._with_channel: + msg = self._channel.recv() + else: + msg = self._collocated_producer.sample() + except StopIteration: + self._collated_queue.put(_COLLATION_SENTINEL) + return + t_recv = time.time() + self._timing.record("bg_producer/recv", t_recv - t0) + + # Check stop event between recv and collate + if self._collation_stop_event.is_set(): + return + + result = self._collate_fn(msg) + t_collate = time.time() + self._timing.record("bg_producer/collate_fn", t_collate - t_recv) + + self._collated_queue.put(result) + t_put = time.time() + self._timing.record("bg_producer/queue_put", t_put - t_collate) + self._timing.record("bg_producer/total_iteration", t_put - t0) + + num_produced += 1 + except Exception as e: + self._collated_queue.put(e) + + def _start_collation_thread(self) -> None: + """Creates and starts a fresh background collation thread.""" + assert self._background_collation_queue_size is not None + self._collation_stop_event = threading.Event() + self._collated_queue = queue.Queue( + maxsize=self._background_collation_queue_size + ) + self._collation_thread = threading.Thread( + target=self._collation_worker, daemon=True + ) + self._collation_thread.start() + + def _stop_collation_thread(self) -> None: + """Stops the background collation thread if it is running. + + Sets the stop event and drains the queue to unblock the worker + if it is blocked on ``queue.put()``. Joins the thread with a + 10-second timeout. + """ + if self._collation_thread is None or not self._collation_thread.is_alive(): + return + assert self._collation_stop_event is not None + assert self._collated_queue is not None + + self._collation_stop_event.set() + # Drain the queue to unblock the worker if it's blocked on put() + while True: + try: + self._collated_queue.get_nowait() + except queue.Empty: + break + self._collation_thread.join(timeout=10.0) + if self._collation_thread.is_alive(): + logger.warning( + "Background collation thread did not terminate within 10 seconds." + ) + # Overwrite DistLoader.shutdown to so we can use our own shutdown and rpc calls def shutdown(self) -> None: if self._shutdowned: return + if self._background_collation_queue_size is not None: + self._stop_collation_thread() if self._is_collocated_worker: self._collocated_producer.shutdown() elif self._is_mp_worker: @@ -821,6 +1071,25 @@ def shutdown(self) -> None: # Overwrite DistLoader.__iter__ to so we can use our own __iter__ and rpc calls def __iter__(self) -> Self: + if self._background_collation_queue_size is not None: + self._stop_collation_thread() + + # Log previous epoch timing (if any) and reset for new epoch + if self._epoch > 0: + logger.info( + f"[iter] Resetting for epoch {self._epoch}. " f"Previous epoch timing:" + ) + logger.info(self._timing.summary()) + _flush() + + mode_label = ( + "background_collation" + if self._background_collation_queue_size is not None + else "synchronous" + ) + self._timing = TimingStats(f"BaseDistLoader ({mode_label}) epoch={self._epoch}") + self._epoch_start_time = time.time() + self._num_recv = 0 if self._is_collocated_worker: self._collocated_producer.reset() @@ -841,4 +1110,8 @@ def __iter__(self) -> Self: torch.futures.wait_all(rpc_futures) self._channel.reset() self._epoch += 1 + + if self._background_collation_queue_size is not None: + self._start_collation_thread() + return self diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index f96fd3bc0..22d1cb423 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -97,6 +97,7 @@ def __init__( context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this + background_collation_queue_size: Optional[int] = None, ): """ Neighbor loader for Anchor Based Link Prediction (ABLP) tasks. @@ -215,6 +216,12 @@ def __init__( context (deprecated - will be removed soon) (Optional[DistributedContext]): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon) (int): The total number of processes within a node. + background_collation_queue_size (Optional[int]): If set to a positive + integer, enables background collation in a daemon thread. The + collation of sampled messages is performed in a background thread, + overlapping with GPU training. The value controls the maximum + number of pre-collated batches buffered in memory. ``None`` + disables background collation (default behavior). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, @@ -385,6 +392,7 @@ def __init__( producer=producer, sampler_options=sampler_options, process_start_gap_seconds=process_start_gap_seconds, + background_collation_queue_size=background_collation_queue_size, max_concurrent_producer_inits=max_concurrent_producer_inits, ) diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index e47075343..3a514463c 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -1,4 +1,5 @@ import sys +import time from collections import abc from itertools import count from typing import Callable, Optional, Tuple, Union @@ -92,6 +93,7 @@ def __init__( num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, + background_collation_queue_size: Optional[int] = None, sampler_options: Optional[SamplerOptions] = None, ): """ @@ -166,6 +168,12 @@ def __init__( Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). + background_collation_queue_size (Optional[int]): If set to a positive + integer, enables background collation in a daemon thread. The + collation of sampled messages is performed in a background thread, + overlapping with GPU training. The value controls the maximum + number of pre-collated batches buffered in memory. ``None`` + disables background collation (default behavior). sampler_options (Optional[SamplerOptions]): Controls which sampler class is instantiated. Pass ``KHopNeighborSamplerOptions`` to use the built-in sampler, or ``CustomSamplerOptions`` to dynamically import a custom sampler class. @@ -294,6 +302,7 @@ def __init__( producer=producer, sampler_options=sampler_options, process_start_gap_seconds=process_start_gap_seconds, + background_collation_queue_size=background_collation_queue_size, max_concurrent_producer_inits=max_concurrent_producer_inits, ) @@ -529,6 +538,7 @@ def _setup_for_colocated( ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: + t0 = time.time() # Extract user-defined metadata before super()._collate_fn, which # calls GLT's to_hetero_data. to_hetero_data misinterprets #META. keys # as edge types and fails when edge_dir="out" (tries to call @@ -536,16 +546,24 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: # TODO (mkolodner-sc): Remove once GLT's to_hetero_data is fixed. metadata, stripped_msg = extract_metadata(msg, self.to_device) data = super()._collate_fn(stripped_msg) + t_base = time.time() + self._timing.record("collate/glt_base_collate", t_base - t0) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, edge_feature_info=self._edge_feature_info, device=self.to_device, ) + t_feat = time.time() + self._timing.record("collate/set_missing_features", t_feat - t_base) + if isinstance(data, HeteroData): data = strip_label_edges(data) + t_strip = time.time() + self._timing.record("collate/strip_label_edges", t_strip - t_feat) if self._is_homogeneous_with_labeled_edge_type: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) + self._timing.record("collate/labeled_to_homogeneous", time.time() - t_feat) if isinstance(self._sampler_options, PPRSamplerOptions): matched, metadata = extract_edge_type_metadata( @@ -562,4 +580,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: # data object so downstream code can access them via attribute lookup. for key, value in metadata.items(): data[key] = value + + self._timing.record("collate/total", time.time() - t0) return data diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index a92bdd29d..4963bccc5 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -13,6 +13,7 @@ import torch import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_homogeneous_model from torch_geometric.data import Data, HeteroData from gigl.common import Uri, UriFactory @@ -36,6 +37,8 @@ GraphStoreInfo, ) from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, @@ -976,7 +979,7 @@ def test_homogeneous_training(self): ), ) - def test_multiple_loaders_in_graph_store(self): + def _test_multiple_loaders_in_graph_store(self): """Test that multiple loader instances (2 ABLP + 2 DistNeighborLoader) can work in parallel, followed by another (ABLP, DistNeighborLoader) pair sequentially. """ diff --git a/tests/unit/distributed/background_collation_test.py b/tests/unit/distributed/background_collation_test.py new file mode 100644 index 000000000..0f5400430 --- /dev/null +++ b/tests/unit/distributed/background_collation_test.py @@ -0,0 +1,372 @@ +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch_geometric.data import Data + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.types.graph import FeaturePartitionData, GraphPartitionData, PartitionOutput +from gigl.utils.iterator import InfiniteIterator +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +# GLT requires subclasses of DistNeighborLoader to be run in a separate process. +# Otherwise, we may run into segmentation fault or other memory issues. +# Calling these functions in separate processes also allows us to use shutdown_rpc() +# to ensure cleanup of ports, providing stronger guarantees of isolation between tests. + + +# We require each of these functions to accept local_rank as the first argument +# since we use mp.spawn with `nprocs=1`. + + +def _run_background_collation_batch_count( + _: int, + dataset: DistDataset, + expected_data_count: int, + queue_size: int, +) -> None: + """Verifies that a loader with background collation produces the correct batch count.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=queue_size, + ) + + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + + assert ( + count == expected_data_count + ), f"Expected {expected_data_count} batches, but got {count}." + + shutdown_rpc() + + +def _run_background_collation_multiple_epochs( + _: int, + dataset: DistDataset, + expected_data_count: int, +) -> None: + """Verifies thread restart across epoch boundaries via __iter__.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=2, + ) + + for epoch in range(3): + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + assert ( + count == expected_data_count + ), f"Epoch {epoch}: expected {expected_data_count} batches, got {count}." + + shutdown_rpc() + + +def _run_background_collation_early_break( + _: int, + dataset: DistDataset, + expected_data_count: int, +) -> None: + """Verifies that breaking mid-epoch and re-iterating works correctly.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=2, + ) + + # Partial iteration — break after 5 batches + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + if count >= 5: + break + + assert count == 5, f"Expected 5 batches in partial epoch, got {count}." + + # Full second epoch should still produce the correct count + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + assert ( + count == expected_data_count + ), f"Expected {expected_data_count} batches in full epoch, got {count}." + + shutdown_rpc() + + +def _run_background_collation_shutdown( + _: int, + dataset: DistDataset, +) -> None: + """Verifies that partial iteration then shutdown() terminates cleanly.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=2, + ) + + # Partial iteration + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + if count >= 3: + break + + loader.shutdown() + shutdown_rpc() + + +def _run_background_collation_queue_size_one( + _: int, + dataset: DistDataset, + expected_data_count: int, +) -> None: + """Verifies that the minimum pipeline depth (queue_size=1) works correctly.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=1, + ) + + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + + assert ( + count == expected_data_count + ), f"Expected {expected_data_count} batches, but got {count}." + + shutdown_rpc() + + +def _run_background_collation_infinite_iterator( + _: int, + dataset: DistDataset, + max_num_batches: int, +) -> None: + """Verifies that background collation works with InfiniteIterator.""" + create_test_process_group() + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=2, + ) + + infinite_loader: InfiniteIterator = InfiniteIterator(loader) + + count = 0 + for datum in infinite_loader: + assert isinstance(datum, Data) + count += 1 + if count == max_num_batches: + break + + assert count == max_num_batches, f"Expected {max_num_batches} batches, got {count}." + + shutdown_rpc() + + +class TestBackgroundCollation(TestCase): + def setUp(self) -> None: + super().setUp() + self._world_size = 1 + + def tearDown(self) -> None: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + super().tearDown() + + def test_produces_same_batch_count(self) -> None: + """Loader with background collation produces correct number of batches.""" + expected_data_count = 18 + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_background_collation_batch_count, + args=(dataset, expected_data_count, 4), + ) + + def test_multiple_epochs(self) -> None: + """Thread restart across epoch boundaries via __iter__.""" + expected_data_count = 18 + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_background_collation_multiple_epochs, + args=(dataset, expected_data_count), + ) + + def test_early_break(self) -> None: + """Break mid-epoch, then iterate again — no hang or corruption.""" + expected_data_count = 18 + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_background_collation_early_break, + args=(dataset, expected_data_count), + ) + + def test_shutdown(self) -> None: + """Partial iteration then shutdown() — clean termination.""" + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_background_collation_shutdown, + args=(dataset,), + ) + + def test_invalid_queue_size(self) -> None: + """background_collation_queue_size=0 raises ValueError.""" + partition_output = PartitionOutput( + node_partition_book=torch.zeros(5), + edge_partition_book=torch.zeros(5), + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]), + edge_ids=None, + ), + partitioned_node_features=FeaturePartitionData( + feats=torch.zeros(10, 2), ids=torch.arange(10) + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="in") + dataset.build(partition_output=partition_output) + + create_test_process_group() + with self.assertRaises(ValueError): + DistNeighborLoader( + dataset=dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + background_collation_queue_size=0, + ) + + def test_queue_size_one(self) -> None: + """Minimum pipeline depth works correctly.""" + expected_data_count = 18 + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + mp.spawn( + fn=_run_background_collation_queue_size_one, + args=(dataset, expected_data_count), + ) + + def test_infinite_iterator_with_background_collation(self) -> None: + """Background collation works correctly with InfiniteIterator.""" + partition_output = PartitionOutput( + node_partition_book=torch.zeros(18), + edge_partition_book=None, + partitioned_edge_index=GraphPartitionData( + edge_index=torch.tensor([[10], [15]]), edge_ids=None + ), + partitioned_edge_features=None, + partitioned_node_features=None, + partitioned_negative_labels=None, + partitioned_positive_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + + # Iterate across more than one epoch + max_num_batches = 18 * 2 + + mp.spawn( + fn=_run_background_collation_infinite_iterator, + args=(dataset, max_num_batches), + ) + + +if __name__ == "__main__": + absltest.main()