Skip to content

Commit e5c9164

Browse files
StrongerXifacebook-github-bot
authored andcommitted
Always trace into tensor subclass __torch_function__ (#149792)
Summary: This patch effectively ignores traceable_tensor_subclasses, allowing Dynamo to always try tracing into the `__torch_function__` of tensor subclass. This helps us with 2 things: 1. allowing users to directly benefit from better compilation of tensor subclass, by just upgrading pytorch, without having to change legacy library code (see earlier patches in the stack for examples). 2. potentially exposing more issues in compiling tensor subclass, so we can get signals and improve them. As a consequence, it exposed and fixes 2 subtle bugs: 1. In `build_torch_function_fn`, we could get `torch._C._disabled_torch_function_impl` because we have a `Parameter` subclass without `__torch_function__` override or if we have a tensor subclass with `__torch_dispatch__` override. We graph break on this for now, and plan to add support -- the logic for simulating `torch._C._disabled_torch_function_impl` is already in `SuperVariable`, we just need to reuse it. 2. Sometimes we create `SyntheticLocalSource` and need to remove all the guards installed on it, but we only removed the ones whose source _is_ the created synthetic source `s`, but forgot about chained source like `s.foo`, this showed up as `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`. X-link: pytorch/pytorch#149792 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: #149482, #149483, #149484 Reviewed By: clee2000 Differential Revision: D72351386 fbshipit-source-id: 5018265b1e27035875d031fb635ee0b0deb326c2
1 parent 5186143 commit e5c9164

File tree

1 file changed

+1
-0
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+1
-0
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,7 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]:
14081408
"reorderable_logging_functions",
14091409
"ignore_logger_methods",
14101410
"traceable_tensor_subclasses",
1411+
"nontraceable_tensor_subclasses",
14111412
"_custom_ops_profile",
14121413
}
14131414

0 commit comments

Comments
 (0)