File tree Expand file tree Collapse file tree 2 files changed +21
-4
lines changed Expand file tree Collapse file tree 2 files changed +21
-4
lines changed Original file line number Diff line number Diff line change @@ -281,10 +281,8 @@ def build_model_graph(self):
281281 # we basically want to remove noops in here
282282 prev = torch ._inductor .config .pattern_matcher
283283 torch ._inductor .config .pattern_matcher = False
284- try :
285- gm = joint_graph_passes (gm )
286- finally :
287- torch ._inductor .config .pattern_matcher = prev
284+ gm = joint_graph_passes (gm )
285+ torch ._inductor .config .pattern_matcher = prev
288286 remove_assert_ops (gm .graph )
289287 gm .graph .eliminate_dead_code ()
290288 gm .recompile ()
Original file line number Diff line number Diff line change @@ -851,3 +851,22 @@ def input_fn():
851851autop .add_output_constraints ([x_sharding ])
852852
853853sharding_placement = autop .optimize_placement ()
854+ parallel_mod = autop .apply_placement (sharding_placement )
855+
856+ # run weight init on our sharded DTensor params
857+ parallel_mod .to_empty (device = "cuda" )
858+ parallel_mod .init_weights (init_std = 0.02 , buffer_device = "cuda" ) # maybe not correct value
859+
860+ # # now let's run it
861+ x = (
862+ torch .randn (
863+ # 0,
864+ # args.vocab_size,
865+ (bs // mesh .shape [0 ], seqlen , dim ),
866+ device = torch .device ("cuda" ),
867+ dtype = torch .bfloat16
868+ ),
869+ )
870+ out = parallel_mod (* x )
871+ out .backward (torch .randn_like (out ))
872+ print ("All good!" )
You can’t perform that action at this time.
0 commit comments