Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from gigl.utils.data_splitters import (
NodeAnchorLinkSplitter,
NodeSplitter,
<<<<<<< HEAD
validate_max_labels_per_anchor_node,
=======
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230
)
from gigl.utils.share_memory import share_memory

Expand Down Expand Up @@ -149,7 +153,13 @@ def __init__(
self._degree_tensor: Optional[
Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
] = degree_tensor
<<<<<<< HEAD
self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node(
max_labels_per_anchor_node
)
=======
self._max_labels_per_anchor_node = max_labels_per_anchor_node
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230

# TODO (mkolodner-sc): Modify so that we don't need to rely on GLT's base variable naming (i.e. partition_idx, num_partitions) in favor of more clear
# naming (i.e. rank, world_size).
Expand Down Expand Up @@ -344,7 +354,13 @@ def max_labels_per_anchor_node(self) -> Optional[int]:
def max_labels_per_anchor_node(
self, new_max_labels_per_anchor_node: Optional[int]
) -> None:
<<<<<<< HEAD
self._max_labels_per_anchor_node = validate_max_labels_per_anchor_node(
new_max_labels_per_anchor_node
)
=======
self._max_labels_per_anchor_node = new_max_labels_per_anchor_node
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230

@property
def train_node_ids(
Expand Down
16 changes: 16 additions & 0 deletions gigl/distributed/graph_store/storage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from gigl.utils.data_splitters import (
DistNodeAnchorLinkSplitter,
DistNodeSplitter,
<<<<<<< HEAD
get_max_labels_per_anchor_node_from_runtime_args,
=======
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230
)

logger = Logger()
Expand Down Expand Up @@ -77,20 +81,32 @@ def build_storage_dataset(
``0.1`` selects 10 % of edges.
max_labels_per_anchor_node: Optional cap for how many labels to
materialize per anchor node when the storage server serves ABLP
<<<<<<< HEAD
input. If ``None``, this is inferred from the task config's
``trainer_args``.
=======
input.
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230

Returns:
A partitioned :class:`DistDataset` ready to be served.
"""
gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri(
gbml_config_uri=task_config_uri
)
if max_labels_per_anchor_node is None:
max_labels_per_anchor_node = get_max_labels_per_anchor_node_from_runtime_args(
dict(gbml_config_pb_wrapper.trainer_config.trainer_args)
)
serialized_graph_metadata = convert_pb_to_serialized_graph_metadata(
preprocessed_metadata_pb_wrapper=gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper,
graph_metadata_pb_wrapper=gbml_config_pb_wrapper.graph_metadata_pb_wrapper,
tfrecord_uri_pattern=tf_record_uri_pattern,
)
<<<<<<< HEAD
=======
# TODO: Pipe in max_labels_per_anchor_node to build_dataset.
>>>>>>> 62d33243162de9daca9be67b4c0d1f73e7319230
dataset = build_dataset(
serialized_graph_metadata=serialized_graph_metadata,
sample_edge_direction=sample_edge_direction,
Expand Down
Loading