Skip to content

Commit 1ca0f1d

Browse files
bdhirshwconstab
andauthored
use passed in model's init_weights fn instead of random weight init (#20)
* use passed in model's init_weights fn instead of random weight init * remove breakpoint * lint * lint fix * update tests, make init_weights an explicit method for user to call * plumb args/kwargs to init_weights * ensure we parametrize with real tensors * ensure init_weights() works when user calls to_empty * ensure init_weights works with to_empty, put params on meta device * minor fixes --------- Co-authored-by: Will Constable <whc@meta.com>
1 parent 5726d7c commit 1ca0f1d

File tree

5 files changed

+85
-16
lines changed

5 files changed

+85
-16
lines changed

autoparallel/api.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch._logging import trace_structured
1717
from torch._subclasses import FakeTensorMode
1818
from torch.distributed.tensor import DeviceMesh
19+
from torch.nn.utils import stateless
1920

2021
from .apply_sharding import apply_sharding_to_model
2122
from .export_module import aot_export_module, apply_node_renaming
@@ -380,21 +381,56 @@ def apply_placement(self, sharding_placement=None):
380381
)
381382
self.parallel_gm = parallel_gm
382383

383-
param_names = [k.replace(".", "/") for k, _ in self.model.named_parameters()]
384-
buffer_names = [k.replace(".", "/") for k, _ in self.model.named_buffers()]
384+
param_names = [k for k, _ in self.model.named_parameters()]
385+
buffer_names = [k for k, _ in self.model.named_buffers()]
386+
param_names_no_fqns = [k.replace(".", "/") for k in param_names]
387+
buffer_names_no_fqns = [k.replace(".", "/") for k in buffer_names]
385388
assert len(param_names) == len(sharded_weights)
386389
assert len(buffer_names) == len(sharded_buffers)
387-
sharded_weights = {k: v for k, v in zip(param_names, sharded_weights)}
388-
sharded_buffers = {k: v for k, v in zip(buffer_names, sharded_buffers)}
389390

390-
self.sharded_weights = sharded_weights
391-
self.sharded_buffers = sharded_buffers
391+
sharded_weights_no_fqns = {
392+
k: v for k, v in zip(param_names_no_fqns, sharded_weights)
393+
}
394+
sharded_buffers_no_fqns = {
395+
k: v for k, v in zip(buffer_names_no_fqns, sharded_buffers)
396+
}
397+
398+
# TODO: preserve state dict properly in the generated nn.module
399+
self.sharded_weights = sharded_weights_no_fqns
400+
self.sharded_buffers = sharded_buffers_no_fqns
392401
self.parallel_model_fn, self.fwd_gm, self.bwd_gm = prepare_module(
393402
parallel_gm, self.spec, self.metadata.num_outputs
394403
)
395404

396-
sharded_weights = try_convert_fake_to_real(sharded_weights)
397-
sharded_buffers = try_convert_fake_to_real(sharded_buffers)
398-
self.parallel_model = self.parallel_model_fn(sharded_weights, sharded_buffers)
405+
self.parallel_model = self.parallel_model_fn(
406+
sharded_weights_no_fqns, sharded_buffers_no_fqns
407+
)
408+
409+
# Right now we require a convention that the user model provides an init_weights method,
410+
# although we could snoop for other methods too.
411+
if hasattr(self.model, "init_weights"):
412+
413+
def init_weights(*args, **kwargs):
414+
# TODO: once we have proper FQN support we should remove this
415+
# Replace 'params.tok_embeddings/weight' -> 'tok_embeddings.weight'
416+
# Replace 'buffers_.freqs_cis' -> 'freqs_cis'
417+
sharded_params_buffers = {
418+
k.replace("params.", "")
419+
.replace("buffers_.", "")
420+
.replace("/", "."): v
421+
for k, v in self.parallel_model.state_dict().items()
422+
}
423+
with stateless._reparametrize_module(
424+
self.model, sharded_params_buffers
425+
):
426+
self.model.init_weights(*args, **kwargs)
427+
428+
else:
429+
init_weights = None
430+
431+
# assign an init_weights method onto the output mod.
432+
# all it does is sneakily run the original user mod's init_weights method,
433+
# but with our new DTensor sharded params attached to the user module.
434+
self.parallel_model.init_weights = init_weights
399435

400436
return self.parallel_model

autoparallel/apply_sharding.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import contextlib
67
import operator
78

89
import torch
10+
from torch._subclasses.fake_tensor import FakeTensor, unset_fake_temporarily
911
from torch.distributed.tensor import DTensor
1012
from torch.distributed.tensor._dtensor_spec import DTensorSpec
1113
from torch.distributed.tensor._redistribute import redistribute_local_tensor
@@ -155,7 +157,7 @@ def call_function(self, target, args, kwargs):
155157
return out
156158

157159

158-
def shard_nodes_given_placements(gm, sharding_placement, node_prefix):
160+
def shard_nodes_given_placements(gm, sharding_placement, node_prefix, *, meta=False):
159161
# NOTE: this relies my customized export_module
160162
nodes = [
161163
x for x in gm.graph.find_nodes(op="placeholder") if node_prefix in x.target
@@ -167,9 +169,21 @@ def shard_nodes_given_placements(gm, sharding_placement, node_prefix):
167169
# all tensors start as replicated
168170
curr_placement = (Replicate(),) * mesh.ndim
169171
tensor = node.meta["val"]
170-
sharded_tensor = DTensor.from_local(tensor, mesh, curr_placement).redistribute(
171-
mesh, tgt_spec.placements
172-
)
172+
173+
if meta:
174+
assert isinstance(
175+
tensor, FakeTensor
176+
), f"only FakeTensor params supported for now, got {type(tensor)}"
177+
ctx = unset_fake_temporarily
178+
with ctx():
179+
tensor = torch.randn(tensor.shape, dtype=tensor.dtype, device="meta")
180+
else:
181+
ctx = contextlib.nullcontext
182+
183+
with ctx():
184+
sharded_tensor = DTensor.from_local(
185+
tensor, mesh, curr_placement
186+
).redistribute(mesh, tgt_spec.placements)
173187
sharded_tensors.append(sharded_tensor)
174188
return sharded_tensors
175189

@@ -189,4 +203,13 @@ def apply_sharding_to_model(gm, sharding_placement):
189203
args = [x.to_local() for x in args]
190204
parallel_gm = make_fx(interp.run)(*args)
191205

192-
return parallel_gm, sharded_params, sharded_buffers
206+
# We put DTensor(meta_tensor) tensors in the state dict, as the user expects to be
207+
# able to call parallel_mod.to_empty(device='cuda'). This does not work with FakeTensors.
208+
sharded_meta_params = shard_nodes_given_placements(
209+
gm, sharding_placement, "param", meta=True
210+
)
211+
sharded_meta_buffers = shard_nodes_given_placements(
212+
gm, sharding_placement, "buffer", meta=True
213+
)
214+
215+
return parallel_gm, sharded_meta_params, sharded_meta_buffers

autoparallel/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def fill_missing_redistribute_cost(op, specs, out_strat):
6767
for strat in out_strat.strategies:
6868
# TODO: check me
6969
if strat.redistribute_cost is None:
70-
7170
# TODO: the torch.ops.aten.slice.Tensor is wrong here and in the input_spec!!!!!
7271
handled_ops = {
7372
torch.ops.aten.ones_like.default,

examples/example_autoparallel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ def __init__(self, nheads, dim1, dim2):
2323
self.w1 = nn.Linear(dim1, dim2, bias=bias)
2424
self.w2 = nn.Linear(dim2, dim1, bias=bias)
2525

26+
def init_weights(self):
27+
for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]:
28+
torch.nn.init.normal_(lin.weight)
29+
if lin.bias is not None:
30+
torch.nn.init.normal_(lin.bias)
31+
2632
def forward(self, x):
2733
q = self.wq(x)
2834
k = self.wk(x)
@@ -94,6 +100,9 @@ def input_fn():
94100
sharding_placement = autop.optimize_placement()
95101
parallel_mod = autop.apply_placement(sharding_placement)
96102

103+
# run weight init on our sharded DTensor params
104+
parallel_mod.init_weights()
105+
97106
# now let's run it
98107
x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),)
99108
out = parallel_mod(*x)

examples/example_llama3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ def __init__(self, model_args: TransformerModelArgs):
475475
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
476476
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
477477
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
478-
self.init_weights()
479478

480479
def init_weights(
481480
self,
@@ -628,6 +627,9 @@ def input_fn():
628627
print(f"Took {time.time() - t:.2f} s")
629628
parallel_mod = autop.apply_placement(sharding_placement)
630629

630+
# run weight init on our sharded DTensor params
631+
parallel_mod.init_weights()
632+
631633
# now let's run it
632634
x = (
633635
torch.randint(

0 commit comments

Comments
 (0)