Skip to content

Commit d5edb93

Browse files
authored
Make some progress on OmniFMv2 autoparallel (#36)
1 parent 0530113 commit d5edb93

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

autoparallel/export_module.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,8 @@ def flattened_joint(*args):
134134
output_gradients = []
135135
for a, grad in zip(args, gradients):
136136
if isinstance(a, torch.Tensor) and a.requires_grad:
137-
assert (
138-
grad is not None
139-
), """\
140-
Found a parameter that did not receive a gradient.
141-
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
142-
https://github.com/pytorch/pytorch/issues/101192
143-
"""
137+
if grad is None:
138+
grad = torch.zeros_like(a)
144139
output_gradients.append(grad)
145140
else:
146141
assert grad is None

0 commit comments

Comments
 (0)