File tree Expand file tree Collapse file tree 1 file changed +16
-9
lines changed Expand file tree Collapse file tree 1 file changed +16
-9
lines changed Original file line number Diff line number Diff line change @@ -147,12 +147,13 @@ def _get_device_tflops(dtype):
147147 f"Unsupported device: { device_name } . Supported devices: { [limit .name for limit in DEVICE_LIMITS ]} "
148148 )
149149
150- if dtype not in device_limit .gemm_tflops :
151- raise ValueError (
152- f"Dtype { dtype } not supported on { device_limit .name } . Supported dtypes: { list (device_limit .gemm_tflops .keys ())} "
153- )
150+ # TODO: add proper support for int64 etc
151+ # if dtype not in device_limit.gemm_tflops:
152+ # raise ValueError(
153+ # f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}"
154+ # )
154155
155- return device_limit .gemm_tflops [ dtype ]
156+ return device_limit .gemm_tflops . get ( dtype , 1 )
156157
157158
158159def _get_sharded_shape (spec ):
@@ -205,10 +206,16 @@ def estimate_strategy_runtime_cost(node, strategy):
205206
206207 # TODO: maybe cache the flop_counter to avoid recreating it
207208 # all the time
208- with FlopCounterMode (display = False ) as flop_counter :
209- node .target (* args , ** kwargs )
210-
211- flops = flop_counter .get_total_flops ()
209+ try :
210+ with FlopCounterMode (display = False ) as flop_counter :
211+ node .target (* args , ** kwargs )
212+
213+ flops = flop_counter .get_total_flops ()
214+ except RuntimeError as exc :
215+ if node .target == torch .ops .aten ._grouped_mm .default :
216+ flops = float ("inf" )
217+ else :
218+ raise exc
212219
213220 # TODO: fix this
214221 dtype = strategy .input_specs [0 ].tensor_meta .dtype
You can’t perform that action at this time.
0 commit comments