Skip to content

Commit 1dbcdfe

Browse files
committed
Fix grouped_mm stride issue
1 parent 334a830 commit 1dbcdfe

File tree

2 files changed

+12
-15
lines changed

2 files changed

+12
-15
lines changed

autoparallel/compute_estimation.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@ def _get_sharded_shape_stride(spec):
170170
if placement.is_shard():
171171
dim = placement.dim
172172
new_tensor_shape[dim] = (new_tensor_shape[dim] + mesh_size - 1) // mesh_size
173-
new_tensor_stride[dim] = (
174-
new_tensor_stride[dim] + mesh_size - 1
175-
) // mesh_size
173+
if dim - 1 > 0:
174+
new_tensor_stride[dim - 1] = (
175+
new_tensor_stride[dim - 1] + mesh_size - 1
176+
) // mesh_size
176177
return new_tensor_shape, new_tensor_stride
177178

178179

@@ -213,16 +214,10 @@ def estimate_strategy_runtime_cost(node, strategy):
213214

214215
# TODO: maybe cache the flop_counter to avoid recreating it
215216
# all the time
216-
try:
217-
with FlopCounterMode(display=False) as flop_counter:
218-
node.target(*args, **kwargs)
219-
220-
flops = flop_counter.get_total_flops()
221-
except RuntimeError as exc:
222-
if node.target == torch.ops.aten._grouped_mm.default:
223-
flops = float("inf")
224-
else:
225-
raise exc
217+
with FlopCounterMode(display=False) as flop_counter:
218+
node.target(*args, **kwargs)
219+
220+
flops = flop_counter.get_total_flops()
226221

227222
# TODO: fix this
228223
dtype = strategy.input_specs[0].tensor_meta.dtype

autoparallel/propagation_rules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,11 @@ def native_layer_norm_backward_rule(mesh, op_schema):
508508

509509
@register_opschema_rule(torch.ops.prims.convert_element_type.default)
510510
def convert_element_type_rule(mesh, op_schema):
511-
from torch.distributed.tensor._ops._tensor_ops import default_strategy
511+
from torch.distributed.tensor._ops._tensor_ops import (
512+
propagate_single_input_strategy,
513+
)
512514

513-
out_strat = default_strategy(op_schema)
515+
out_strat = propagate_single_input_strategy(op_schema)
514516
return out_strat
515517

516518

0 commit comments

Comments
 (0)