Skip to content

optimize streaming#9425

Open
Jintao-Huang wants to merge 2 commits into
modelscope:mainfrom
Jintao-Huang:update_streaming
Open

optimize streaming#9425
Jintao-Huang wants to merge 2 commits into
modelscope:mainfrom
Jintao-Huang:update_streaming

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the DataLoaderDispatcher by separating tensors from non-tensor structures (schemas) during scatter operations. Lightweight schemas are scattered via pickle, while tensors are transferred efficiently using asynchronous P2P isend and irecv operations. The review feedback highlights several critical improvements: keeping strong references to temporary tensors during asynchronous sends to prevent premature garbage collection, handling namedtuple instances and preserving custom dictionary subclasses during flattening and unflattening, and adding support for the hccl backend to enable efficient transfers on Huawei NPU devices.

Comment thread swift/dataloader/dispatcher.py
Comment on lines +19 to +34
def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.

Tensors are appended to `tensors` and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
else:
return obj
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If the input data contains namedtuple instances (which are common in PyTorch/HuggingFace ecosystems), calling type(obj)(...) with a generator will fail because namedtuple constructors expect individual field arguments rather than an iterable. Additionally, we should preserve the original dictionary subclass (like BatchEncoding or OrderedDict) by using type(obj)(...) instead of returning a plain dict.

Suggested change
def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.
Tensors are appended to `tensors` and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
else:
return obj
def _flatten_for_scatter(obj, tensors):
"""Recursively separate tensors from a nested structure.
Tensors are appended to tensors and replaced by _TensorMeta sentinels.
The returned schema is lightweight and can be pickled efficiently.
"""
if torch.is_tensor(obj):
idx = len(tensors)
tensors.append(obj)
return _TensorMeta(idx, tuple(obj.shape), obj.dtype)
elif isinstance(obj, dict):
return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()})
elif isinstance(obj, (tuple, list)):
if hasattr(obj, '_fields'):
return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj))
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
else:
return obj

Comment on lines +37 to +46
def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
else:
return schema
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly to _flatten_for_scatter, we should handle namedtuple instances and preserve dictionary subclasses when reconstructing the original nested structure.

Suggested change
def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
else:
return schema
def _unflatten_from_scatter(schema, tensors):
"""Reconstruct the original nested structure from schema and flat tensors list."""
if isinstance(schema, _TensorMeta):
return tensors[schema.idx]
elif isinstance(schema, dict):
return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()})
elif isinstance(schema, (tuple, list)):
if hasattr(schema, '_fields'):
return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema))
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
else:
return schema

Comment thread swift/dataloader/dispatcher.py
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the DataLoaderDispatcher._scatter_object_list method by separating tensors from non-tensor structures (schemas). The lightweight schemas are scattered via pickle, while the tensors are transferred efficiently using point-to-point communication (dist.isend and dist.irecv). The review feedback identifies critical issues: first, the flattening and unflattening processes should preserve original dictionary subclasses (such as BatchEncoding) and support namedtuple to prevent breaking downstream code and avoiding TypeError exceptions; second, point-to-point communication APIs (dist.isend/dist.irecv) must use group-relative ranks instead of global ranks when a process group is specified to avoid out-of-bounds errors or incorrect communication in sub-group configurations.

Comment on lines +29 to +32
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preserve the original dictionary subclass (such as Hugging Face's BatchEncoding) and support namedtuple during flattening. Returning a plain dict for BatchEncoding will break downstream code that accesses keys as attributes (e.g., batch.input_ids). Additionally, instantiating a namedtuple with a generator directly (i.e., type(obj)(generator)) raises a TypeError because its __new__ expects separate positional arguments.

Suggested change
elif isinstance(obj, dict):
return {k: _flatten_for_scatter(v, tensors) for k, v in obj.items()}
elif isinstance(obj, (tuple, list)):
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)
elif isinstance(obj, dict):
return type(obj)({k: _flatten_for_scatter(v, tensors) for k, v in obj.items()})
elif isinstance(obj, (tuple, list)):
if hasattr(obj, '_fields'): # namedtuple
return type(obj)(*(_flatten_for_scatter(v, tensors) for v in obj))
return type(obj)(_flatten_for_scatter(v, tensors) for v in obj)

Comment on lines +41 to +44
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Preserve the original dictionary subclass and support namedtuple during unflattening to match the types of the original batch containers.

Suggested change
elif isinstance(schema, dict):
return {k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()}
elif isinstance(schema, (tuple, list)):
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)
elif isinstance(schema, dict):
return type(schema)({k: _unflatten_from_scatter(v, tensors) for k, v in schema.items()})
elif isinstance(schema, (tuple, list)):
if hasattr(schema, '_fields'): # namedtuple
return type(schema)(*(_unflatten_from_scatter(v, tensors) for v in schema))
return type(schema)(_unflatten_from_scatter(v, tensors) for v in schema)

Comment on lines +126 to +133
for r in range(1, self.world_size):
dst_rank = dist.get_global_rank(self.group, r)
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=dst_rank, group=self.group))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In PyTorch distributed, when a process group is specified in point-to-point communication APIs like dist.isend and dist.irecv, the destination/source rank must be the group-relative rank (local rank within the group), not the global rank. Using the global rank dst_rank with group=self.group will cause a ValueError (out of bounds) or incorrect/corrupted communication if self.group is a sub-group (e.g., in pipeline or tensor parallel setups).

Suggested change
for r in range(1, self.world_size):
dst_rank = dist.get_global_rank(self.group, r)
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=dst_rank, group=self.group))
for r in range(1, self.world_size):
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=r, group=self.group))

Comment on lines +160 to +163
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group))
my_tensors.append(recv_buf)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, for dist.irecv, when group=self.group is specified, the source rank must be the group-relative rank (which is 0 for the sender rank 0), not the global rank global_src_rank.

Suggested change
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group))
my_tensors.append(recv_buf)
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=0, group=self.group))
my_tensors.append(recv_buf)

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes the DataLoaderDispatcher._scatter_object_list method by separating tensors from non-tensor structures (schemas), allowing lightweight schemas to be scattered via pickle and tensors to be transferred efficiently using point-to-point async communication (isend/irecv). The reviewer identified critical issues regarding the handling of namedtuple objects, which will raise a TypeError when instantiated with a generator expression, and the incorrect use of global ranks instead of group-relative ranks for distributed operations when a custom process group is specified.

Comment thread swift/dataloader/dispatcher.py
Comment thread swift/dataloader/dispatcher.py
Comment on lines 101 to +163
global_src_rank = dist.get_global_rank(self.group, 0)
dist.scatter_object_list(outputs, inputs, global_src_rank, group=self.group)
return outputs[0]
scatter_device = self._scatter_device

if self.rank == 0:
# Flatten each rank's data: separate tensors from schema
schemas = []
per_rank_tensors = []
for item in inputs:
if item is None:
schemas.append(None)
per_rank_tensors.append([])
else:
tensors = []
schema = _flatten_for_scatter(item, tensors)
schemas.append(schema)
per_rank_tensors.append(tensors)

# Scatter lightweight schemas (no tensor payload, fast pickle)
schema_out = [None]
dist.scatter_object_list(schema_out, schemas, global_src_rank, group=self.group)
my_schema = schema_out[0]

# Send tensors to other ranks via async P2P
handles = []
send_bufs = [] # keep tensors alive until sends complete
for r in range(1, self.world_size):
dst_rank = dist.get_global_rank(self.group, r)
for t in per_rank_tensors[r]:
tensor = t.contiguous()
if scatter_device is not None:
tensor = tensor.to(scatter_device)
send_bufs.append(tensor)
handles.append(dist.isend(tensor, dst=dst_rank, group=self.group))

# Rank 0 keeps its own tensors (move to device if needed)
my_tensors = per_rank_tensors[0]
if scatter_device is not None:
my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]

# Wait for all sends to complete
for h in handles:
h.wait()
del send_bufs # safe to release after all sends finished
else:
# Receive schema (lightweight)
schema_out = [None]
dist.scatter_object_list(schema_out, None, global_src_rank, group=self.group)
my_schema = schema_out[0]

if my_schema is None:
return None

# Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
metas = []
_collect_tensor_metas(my_schema, metas)
metas.sort(key=lambda m: m.idx)
device = scatter_device if scatter_device is not None else 'cpu'
my_tensors = []
handles = []
for meta in metas:
recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
handles.append(dist.irecv(recv_buf, src=global_src_rank, group=self.group))
my_tensors.append(recv_buf)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In PyTorch Distributed, when a custom group is specified in collective or P2P operations (such as dist.scatter_object_list, dist.isend, and dist.irecv), the rank arguments (like src and dst) must be relative to the group (i.e., group-relative ranks), not global ranks.

Using global ranks with a non-WORLD group will cause runtime errors or incorrect communication. We should use 0 as the source rank and r as the destination rank directly, which also simplifies the code by removing the need for dist.get_global_rank calls.

        scatter_device = self._scatter_device

        if self.rank == 0:
            # Flatten each rank's data: separate tensors from schema
            schemas = []
            per_rank_tensors = []
            for item in inputs:
                if item is None:
                    schemas.append(None)
                    per_rank_tensors.append([])
                else:
                    tensors = []
                    schema = _flatten_for_scatter(item, tensors)
                    schemas.append(schema)
                    per_rank_tensors.append(tensors)

            # Scatter lightweight schemas (no tensor payload, fast pickle)
            schema_out = [None]
            dist.scatter_object_list(schema_out, schemas, src=0, group=self.group)
            my_schema = schema_out[0]

            # Send tensors to other ranks via async P2P
            handles = []
            send_bufs = []  # keep tensors alive until sends complete
            for r in range(1, self.world_size):
                for t in per_rank_tensors[r]:
                    tensor = t.contiguous()
                    if scatter_device is not None:
                        tensor = tensor.to(scatter_device)
                    send_bufs.append(tensor)
                    handles.append(dist.isend(tensor, dst=r, group=self.group))

            # Rank 0 keeps its own tensors (move to device if needed)
            my_tensors = per_rank_tensors[0]
            if scatter_device is not None:
                my_tensors = [t.contiguous().to(scatter_device) for t in my_tensors]

            # Wait for all sends to complete
            for h in handles:
                h.wait()
            del send_bufs  # safe to release after all sends finished
        else:
            # Receive schema (lightweight)
            schema_out = [None]
            dist.scatter_object_list(schema_out, None, src=0, group=self.group)
            my_schema = schema_out[0]

            if my_schema is None:
                return None

            # Receive tensors via async P2P (shape/dtype from _TensorMeta in schema)
            metas = []
            _collect_tensor_metas(my_schema, metas)
            metas.sort(key=lambda m: m.idx)
            device = scatter_device if scatter_device is not None else 'cpu'
            my_tensors = []
            handles = []
            for meta in metas:
                recv_buf = torch.empty(meta.shape, dtype=meta.dtype, device=device)
                handles.append(dist.irecv(recv_buf, src=0, group=self.group))
                my_tensors.append(recv_buf)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants