Skip to content

Commit b5eb863

Browse files
committed
get DS3 running forward, OOM at backward
1 parent 1dbcdfe commit b5eb863

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

autoparallel/api.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,8 @@ def build_model_graph(self):
281281
# we basically want to remove noops in here
282282
prev = torch._inductor.config.pattern_matcher
283283
torch._inductor.config.pattern_matcher = False
284-
try:
285-
gm = joint_graph_passes(gm)
286-
finally:
287-
torch._inductor.config.pattern_matcher = prev
284+
gm = joint_graph_passes(gm)
285+
torch._inductor.config.pattern_matcher = prev
288286
remove_assert_ops(gm.graph)
289287
gm.graph.eliminate_dead_code()
290288
gm.recompile()

examples/example_ds3.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,22 @@ def input_fn():
851851
autop.add_output_constraints([x_sharding])
852852

853853
sharding_placement = autop.optimize_placement()
854+
parallel_mod = autop.apply_placement(sharding_placement)
855+
856+
# run weight init on our sharded DTensor params
857+
parallel_mod.to_empty(device="cuda")
858+
parallel_mod.init_weights(init_std=0.02, buffer_device="cuda") # maybe not correct value
859+
860+
# # now let's run it
861+
x = (
862+
torch.randn(
863+
# 0,
864+
# args.vocab_size,
865+
(bs // mesh.shape[0], seqlen, dim),
866+
device=torch.device("cuda"),
867+
dtype=torch.bfloat16
868+
),
869+
)
870+
out = parallel_mod(*x)
871+
out.backward(torch.randn_like(out))
872+
print("All good!")

0 commit comments

Comments
 (0)