Skip to content

Commit 760cc7d

Browse files
zpcorefmassa
andauthored
Support of device ordering (#95)
* Trying out new version of redistribute_local_tensor Taken from pytorch/pytorch#160266, but I'm hitting an assertion for now * update ordered sharding * relocate redistribute tensor function (jax way map tensor dim to mesh dim) * fix loss curve mismatch * imporve ordering logic and bring back _optimize_same_nd_sharding_as_1d * lint * fix small bug * adress review feedback * fix CI --------- Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
1 parent 939635a commit 760cc7d

File tree

5 files changed

+942
-65
lines changed

5 files changed

+942
-65
lines changed

.github/workflows/test_torchtitan.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,4 @@ jobs:
4747
--model.name llama3_auto_parallel \
4848
--parallelism.tensor_parallel_degree 4 \
4949
--training.dataset c4 \
50-
--training.compile
50+
--compile.enable

autoparallel/apply_sharding.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,28 @@ def run_node(self, n):
5151
self._curr_node = n
5252
return super().run_node(n)
5353

54-
def redistribute_tensor(self, arg, curr_spec, tgt_spec, src_tgt_nodes=None):
54+
def redistribute_tensor(self, arg, curr_spec, tgt_spec, node=None):
5555
tgt_placements = tuple(
5656
p if not p.is_partial() else Replicate() for p in tgt_spec.placements
5757
)
5858
x = arg
59+
if node in self.param_placement_order and self.param_placement_order[node][1]:
60+
assert curr_spec.placements != tgt_spec.placements
5961
if curr_spec.placements != tgt_spec.placements:
6062
tgt_spec_c = DTensorSpec(
6163
tgt_spec.mesh, tgt_placements, tensor_meta=tgt_spec.tensor_meta
6264
)
63-
placement_order = None
64-
if (
65-
src_tgt_nodes is not None
66-
and src_tgt_nodes in self.param_placement_order
67-
):
68-
placement_order = self.param_placement_order[src_tgt_nodes]
65+
origin_order = None
66+
tgt_order = None
67+
if node in self.param_placement_order:
68+
tgt_order, do_reorder = self.param_placement_order[node]
69+
origin_order = tgt_order[::-1] if do_reorder else tgt_order
6970
x = ordered_redistribute_local_tensor(
70-
arg, curr_spec, tgt_spec_c, placement_order
71+
arg,
72+
curr_spec,
73+
tgt_spec_c,
74+
src_placement_order=origin_order,
75+
tgt_placement_order=tgt_order,
7176
)
7277
return x
7378

@@ -109,7 +114,7 @@ def redistribute_args(self, args):
109114
for n, arg, curr_spec, tgt_spec in zip(
110115
all_input_nodes, flat_args_t, curr_specs, tgt_specs
111116
):
112-
x = self.redistribute_tensor(arg, curr_spec, tgt_spec, (node, n))
117+
x = self.redistribute_tensor(arg, curr_spec, tgt_spec, node)
113118
new_flat_args_t.append(x)
114119
self.tgt_spec = tgt_spec
115120
new_flat_args = []

0 commit comments

Comments
 (0)