Skip to content

Commit 8e99af5

Browse files
authored
make is_tf32_env() safer (#6839)
Following #6816 ### Description make `is_tf32_env()` safer. check `cuda` to prevent fallthrough case when `pynvml` is not found ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Qingpeng Li <qingpeng9802@gmail.com>
1 parent ca96867 commit 8e99af5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def is_tf32_env():
180180
global _tf32_enabled
181181
if _tf32_enabled is None:
182182
_tf32_enabled = False
183-
if detect_default_tf32() or torch.backends.cuda.matmul.allow_tf32:
183+
if torch.cuda.is_available() and (detect_default_tf32() or torch.backends.cuda.matmul.allow_tf32):
184184
try:
185185
# with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result
186186
g_gpu = torch.Generator(device="cuda")

0 commit comments

Comments
 (0)