optimize streaming#9425
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the DataLoaderDispatcher by separating tensors from non-tensor structures (schemas) during scatter operations. Lightweight schemas are scattered via pickle, while tensors are transferred efficiently using asynchronous P2P isend and irecv operations. The review feedback highlights several critical improvements: keeping strong references to temporary tensors during asynchronous sends to prevent premature garbage collection, handling namedtuple instances and preserving custom dictionary subclasses during flattening and unflattening, and adding support for the hccl backend to enable efficient transfers on Huawei NPU devices.
| def _flatten_for_scatter(obj, tensors): | ||
| """Recursively separate tensors from a nested structure. | ||
|
|
||
| Tensors are appended to `tensors` and replaced by _TensorMeta sentinels. | ||
| The returned schema is lightweight and can be pickled efficiently. | ||
| """ | ||
| if torch.is_tensor(obj): | ||
| idx = len(tensors) | ||
| tensors.append(obj) | ||
| return _TensorMeta(idx, tuple(obj.shape), obj.dtype) | ||
| elif isinstance(obj, dict): | ||
| return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()} | ||
| elif isinstance(obj, (tuple, list)): | ||
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) | ||
| else: | ||
| return obj |
There was a problem hiding this comment.
If the input data contains namedtuple instances (which are common in PyTorch/HuggingFace ecosystems), calling type(obj)(...) with a generator will fail because namedtuple constructors expect individual field arguments rather than an iterable. Additionally, we should preserve the original dictionary subclass (like BatchEncoding or OrderedDict) by using type(obj)(...) instead of returning a plain dict.
| def _flatten_for_scatter(obj, tensors): | |
| """Recursively separate tensors from a nested structure. | |
| Tensors are appended to `tensors` and replaced by _TensorMeta sentinels. | |
| The returned schema is lightweight and can be pickled efficiently. | |
| """ | |
| if torch.is_tensor(obj): | |
| idx = len(tensors) | |
| tensors.append(obj) | |
| return _TensorMeta(idx, tuple(obj.shape), obj.dtype) | |
| elif isinstance(obj, dict): | |
| return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()} | |
| elif isinstance(obj, (tuple, list)): | |
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) | |
| else: | |
| return obj | |
| def _flatten_for_scatter(obj, tensors): | |
| """Recursively separate tensors from a nested structure. | |
| Tensors are appended to tensors and replaced by _TensorMeta sentinels. | |
| The returned schema is lightweight and can be pickled efficiently. | |
| """ | |
| if torch.is_tensor(obj): | |
| idx = len(tensors) | |
| tensors.append(obj) | |
| return _TensorMeta(idx, tuple(obj.shape), obj.dtype) | |
| elif isinstance(obj, dict): | |
| return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}) | |
| elif isinstance(obj, (tuple, list)): | |
| if hasattr(obj, '_fields'): | |
| return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj)) | |
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) | |
| else: | |
| return obj |
| def _unflatten_from_scatter(schema, tensors): | ||
| """Reconstruct the original nested structure from schema and flat tensors list.""" | ||
| if isinstance(schema, _TensorMeta): | ||
| return tensors[schema.idx] | ||
| elif isinstance(schema, dict): | ||
| return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()} | ||
| elif isinstance(schema, (tuple, list)): | ||
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) | ||
| else: | ||
| return schema |
There was a problem hiding this comment.
Similarly to _flatten_for_scatter, we should handle namedtuple instances and preserve dictionary subclasses when reconstructing the original nested structure.
| def _unflatten_from_scatter(schema, tensors): | |
| """Reconstruct the original nested structure from schema and flat tensors list.""" | |
| if isinstance(schema, _TensorMeta): | |
| return tensors[schema.idx] | |
| elif isinstance(schema, dict): | |
| return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()} | |
| elif isinstance(schema, (tuple, list)): | |
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) | |
| else: | |
| return schema | |
| def _unflatten_from_scatter(schema, tensors): | |
| """Reconstruct the original nested structure from schema and flat tensors list.""" | |
| if isinstance(schema, _TensorMeta): | |
| return tensors[schema.idx] | |
| elif isinstance(schema, dict): | |
| return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}) | |
| elif isinstance(schema, (tuple, list)): | |
| if hasattr(schema, '_fields'): | |
| return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema)) | |
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) | |
| else: | |
| return schema |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request optimizes the DataLoaderDispatcher._scatter_object_list method by separating tensors from non-tensor structures (schemas). The lightweight schemas are scattered via pickle, while the tensors are transferred efficiently using point-to-point communication (dist.isend and dist.irecv). The review feedback identifies critical issues: first, the flattening and unflattening processes should preserve original dictionary subclasses (such as BatchEncoding) and support namedtuple to prevent breaking downstream code and avoiding TypeError exceptions; second, point-to-point communication APIs (dist.isend/dist.irecv) must use group-relative ranks instead of global ranks when a process group is specified to avoid out-of-bounds errors or incorrect communication in sub-group configurations.
| elif isinstance(obj, dict): | ||
| return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()} | ||
| elif isinstance(obj, (tuple, list)): | ||
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) |
There was a problem hiding this comment.
Preserve the original dictionary subclass (such as Hugging Face's BatchEncoding) and support namedtuple during flattening. Returning a plain dict for BatchEncoding will break downstream code that accesses keys as attributes (e.g., batch.input_ids). Additionally, instantiating a namedtuple with a generator directly (i.e., type(obj)(generator)) raises a TypeError because its __new__ expects separate positional arguments.
| elif isinstance(obj, dict): | |
| return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()} | |
| elif isinstance(obj, (tuple, list)): | |
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) | |
| elif isinstance(obj, dict): | |
| return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}) | |
| elif isinstance(obj, (tuple, list)): | |
| if hasattr(obj, '_fields'): # namedtuple | |
| return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj)) | |
| return type(obj)(_flatten_for_scatter(v, tensors) for v in obj) |
| elif isinstance(schema, dict): | ||
| return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()} | ||
| elif isinstance(schema, (tuple, list)): | ||
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) |
There was a problem hiding this comment.
Preserve the original dictionary subclass and support namedtuple during unflattening to match the types of the original batch containers.
| elif isinstance(schema, dict): | |
| return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()} | |
| elif isinstance(schema, (tuple, list)): | |
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) | |
| elif isinstance(schema, dict): | |
| return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}) | |
| elif isinstance(schema, (tuple, list)): | |
| if hasattr(schema, '_fields'): # namedtuple | |
| return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema)) | |
| return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema) |
| for r in range(1, self.world_size): | ||
| dst_rank = dist.get_global_rank(self.group, r) | ||
| for t in per_rank_tensors[r]: | ||
| tensor = t.contiguous() | ||
| if scatter_device is not None: | ||
| tensor = tensor.to(scatter_device) | ||
| send_bufs.append(tensor) | ||
| handles.append(dist.isend(tensor, dst=dst_rank, group=self.group)) |
There was a problem hiding this comment.
In PyTorch distributed, when a process group is specified in point-to-point communication APIs like dist.isend and dist.irecv, the destination/source rank must be the group-relative rank (local rank within the group), not the global rank. Using the global rank dst_rank with group=self.group will cause a ValueError (out of bounds) or incorrect/corrupted communication if self.group is a sub-group (e.g., in pipeline or tensor parallel setups).
| for r in range(1, self.world_size): | |
| dst_rank = dist.get_global_rank(self.group, r) | |
| for t in per_rank_tensors[r]: | |
| tensor = t.contiguous() | |
| if scatter_device is not None: | |
| tensor = tensor.to(scatter_device) | |
| send_bufs.append(tensor) | |
| handles.append(dist.isend(tensor, dst=dst_rank, group=self.group)) | |
| for r in range(1, self.world_size): | |
| for t in per_rank_tensors[r]: | |
| tensor = t.contiguous() | |
| if scatter_device is not None: | |
| tensor = tensor.to(scatter_device) | |
| send_bufs.append(tensor) | |
| handles.append(dist.isend(tensor, dst=r, group=self.group)) |
| for meta in metas: | ||
| recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device) | ||
| handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group)) | ||
| my_tensors.append(recv_buf) |
There was a problem hiding this comment.
Similarly, for dist.irecv, when group=self.group is specified, the source rank must be the group-relative rank (which is 0 for the sender rank 0), not the global rank global_src_rank.
| for meta in metas: | |
| recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device) | |
| handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group)) | |
| my_tensors.append(recv_buf) | |
| for meta in metas: | |
| recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device) | |
| handles.append(dist.irecv(recv_buf, src=0, group=self.group)) | |
| my_tensors.append(recv_buf) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request optimizes the DataLoaderDispatcher._scatter_object_list method by separating tensors from non-tensor structures (schemas), allowing lightweight schemas to be scattered via pickle and tensors to be transferred efficiently using point-to-point async communication (isend/irecv). The reviewer identified critical issues regarding the handling of namedtuple objects, which will raise a TypeError when instantiated with a generator expression, and the incorrect use of global ranks instead of group-relative ranks for distributed operations when a custom process group is specified.
| global_src_rank = dist.get_global_rank(self.group, 0) | ||
| dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group) | ||
| return outputs[0] | ||
| scatter_device = self._scatter_device | ||
|
|
||
| if self.rank == 0: | ||
| # Flatten each rank's data: separate tensors from schema | ||
| schemas = [] | ||
| per_rank_tensors = [] | ||
| for item in inputs: | ||
| if item is None: | ||
| schemas.append(None) | ||
| per_rank_tensors.append([]) | ||
| else: | ||
| tensors = [] | ||
| schema = _flatten_for_scatter(item, tensors) | ||
| schemas.append(schema) | ||
| per_rank_tensors.append(tensors) | ||
|
|
||
| # Scatter lightweight schemas (no tensor payload, fast pickle) | ||
| schema_out = [None] | ||
| dist.scatter_object_list(schema_out, schemas, global_src_rank, group=self.group) | ||
| my_schema = schema_out[0] | ||
|
|
||
| # Send tensors to other ranks via async P2P | ||
| handles = [] | ||
| send_bufs = [] # keep tensors alive until sends complete | ||
| for r in range(1, self.world_size): | ||
| dst_rank = dist.get_global_rank(self.group, r) | ||
| for t in per_rank_tensors[r]: | ||
| tensor = t.contiguous() | ||
| if scatter_device is not None: | ||
| tensor = tensor.to(scatter_device) | ||
| send_bufs.append(tensor) | ||
| handles.append(dist.isend(tensor, dst=dst_rank, group=self.group)) | ||
|
|
||
| # Rank 0 keeps its own tensors (move to device if needed) | ||
| my_tensors = per_rank_tensors[0] | ||
| if scatter_device is not None: | ||
| my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors] | ||
|
|
||
| # Wait for all sends to complete | ||
| for h in handles: | ||
| h.wait() | ||
| del send_bufs # safe to release after all sends finished | ||
| else: | ||
| # Receive schema (lightweight) | ||
| schema_out = [None] | ||
| dist.scatter_object_list(schema_out, None, global_src_rank, group=self.group) | ||
| my_schema = schema_out[0] | ||
|
|
||
| if my_schema is None: | ||
| return None | ||
|
|
||
| # Receive tensors via async P2P (shape/dtype from _TensorMeta in schema) | ||
| metas = [] | ||
| _collect_tensor_metas(my_schema, metas) | ||
| metas.sort(key=lambda m: m.idx) | ||
| device = scatter_device if scatter_device is not None else 'cpu' | ||
| my_tensors = [] | ||
| handles = [] | ||
| for meta in metas: | ||
| recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device) | ||
| handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group)) | ||
| my_tensors.append(recv_buf) |
There was a problem hiding this comment.
In PyTorch Distributed, when a custom group is specified in collective or P2P operations (such as dist.scatter_object_list, dist.isend, and dist.irecv), the rank arguments (like src and dst) must be relative to the group (i.e., group-relative ranks), not global ranks.
Using global ranks with a non-WORLD group will cause runtime errors or incorrect communication. We should use 0 as the source rank and r as the destination rank directly, which also simplifies the code by removing the need for dist.get_global_rank calls.
scatter_device = self._scatter_device
if self.rank == 0:
# Flatten each rank's data: separate tensors from schema
schemas = []
per_rank_tensors = []
for item in inputs:
if item is None:
schemas.append(None)
per_rank_tensors.append([])
else:
tensors = []
schema = _flatten_for_scatter(item, tensors)
schemas.append(schema)
per_rank_tensors.append(tensors)
# Scatter lightweight schemas (no tensor payload, fast pickle)
schema_out = [None]
dist.scatter_object_list(schema_out, schemas, src=0, group=self.group)
my_schema = schema_out[0]
# Send tensors to other ranks via async P2P
handles = []
send_bufs = [] # keep tensors alive until sends complete
for r in range(1, self.world_size):
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=r, group=self.group))
# Rank 0 keeps its own tensors (move to device if needed)
my_tensors = per_rank_tensors[0]
if scatter_device is not None:
my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]
# Wait for all sends to complete
for h in handles:
h.wait()
del send_bufs # safe to release after all sends finished
else:
# Receive schema (lightweight)
schema_out = [None]
dist.scatter_object_list(schema_out, None, src=0, group=self.group)
my_schema = schema_out[0]
if my_schema is None:
return None
# Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
metas = []
_collect_tensor_metas(my_schema, metas)
metas.sort(key=lambda m: m.idx)
device = scatter_device if scatter_device is not None else 'cpu'
my_tensors = []
handles = []
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=0, group=self.group))
my_tensors.append(recv_buf)
No description provided.