Skip to content

Commit 2d462e2

Browse files
authored
Catch missing cudnn error (#873)
1 parent 83f6212 commit 2d462e2

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

helion/_testing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,11 @@ def run_example(
458458
atol: Absolute tolerance for correctness check (default: 1e-1)
459459
bwd: Whether to also test backward pass (default: False)
460460
"""
461-
torch.backends.cuda.matmul.fp32_precision = "tf32"
462-
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[reportAttributeAccessIssue]
461+
try:
462+
torch.backends.cuda.matmul.fp32_precision = "tf32"
463+
torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[reportAttributeAccessIssue]
464+
except AttributeError: # No cudnn available
465+
torch.set_float32_matmul_precision("high") # older deprecated API
463466

464467
# Normalize to dict format
465468
kernels = kernel_fn if isinstance(kernel_fn, dict) else {kernel_name: kernel_fn}

0 commit comments

Comments
 (0)