Skip to content

Commit 59c4d17

Browse files
committed
Hack around missing dtypes in compute estimation and handle grouped_mm cases with invalid strides
The grouped_mm should be handled in the sharding propagation and those cases should just be removed I think
1 parent 9f9f00a commit 59c4d17

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

autoparallel/compute_estimation.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,13 @@ def _get_device_tflops(dtype):
147147
f"Unsupported device: {device_name}. Supported devices: {[limit.name for limit in DEVICE_LIMITS]}"
148148
)
149149

150-
if dtype not in device_limit.gemm_tflops:
151-
raise ValueError(
152-
f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
153-
)
150+
# TODO: add proper support for int64 etc
151+
# if dtype not in device_limit.gemm_tflops:
152+
# raise ValueError(
153+
# f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
154+
# )
154155

155-
return device_limit.gemm_tflops[dtype]
156+
return device_limit.gemm_tflops.get(dtype, 1)
156157

157158

158159
def _get_sharded_shape(spec):
@@ -205,10 +206,16 @@ def estimate_strategy_runtime_cost(node, strategy):
205206

206207
# TODO: maybe cache the flop_counter to avoid recreating it
207208
# all the time
208-
with FlopCounterMode(display=False) as flop_counter:
209-
node.target(*args, **kwargs)
210-
211-
flops = flop_counter.get_total_flops()
209+
try:
210+
with FlopCounterMode(display=False) as flop_counter:
211+
node.target(*args, **kwargs)
212+
213+
flops = flop_counter.get_total_flops()
214+
except RuntimeError as exc:
215+
if node.target == torch.ops.aten._grouped_mm.default:
216+
flops = float("inf")
217+
else:
218+
raise exc
212219

213220
# TODO: fix this
214221
dtype = strategy.input_specs[0].tensor_meta.dtype

0 commit comments

Comments
 (0)