Skip to content

Commit b6e070d

Browse files
bobrenjc93pobin6
authored andcommitted
Specialize symfloats when getting fake value involves complex args (pytorch#140832)
Fixed `PYTORCH_TEST_WITH_DYNAMO=1 tlp python test/test_sparse_csr.py TestSparseCSRCPU.test_sampled_addmm_cpu_complex64` when `specialize_float=False` Pull Request resolved: pytorch#140832 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#140830
1 parent db9052f commit b6e070d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/_dynamo/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2217,7 +2217,9 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
22172217
# no matter it's lazy module or not, we should copy to fake mode.
22182218
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
22192219

2220-
if node.name in ["interpolate", "is_integer", "wrapped_gradient"]:
2220+
if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any(
2221+
isinstance(a, complex) for a in args
2222+
):
22212223
# We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo.
22222224
args = tuple(
22232225
float(arg)

0 commit comments

Comments
 (0)