From a7cc837a05b7944bf06c22795d390adf8821b681 Mon Sep 17 00:00:00 2001 From: SexyERIC0723 Date: Fri, 20 Mar 2026 17:32:14 +0000 Subject: [PATCH] fix: replace deprecated torch.cuda.amp with torch.amp `torch.cuda.amp.autocast`, `torch.cuda.amp.GradScaler` were deprecated in PyTorch 1.13 and will be removed in a future release. Replace with the device-explicit `torch.amp.autocast('cuda', ...)` and `torch.amp.GradScaler('cuda', ...)` equivalents. Files changed: - rfdiffusion/Track_module.py (decorator on Str2Str.forward) - env/SE3Transformer/se3_transformer/runtime/inference.py - env/SE3Transformer/se3_transformer/runtime/training.py (2 instances) --- env/SE3Transformer/se3_transformer/runtime/inference.py | 2 +- env/SE3Transformer/se3_transformer/runtime/training.py | 4 ++-- rfdiffusion/Track_module.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/env/SE3Transformer/se3_transformer/runtime/inference.py b/env/SE3Transformer/se3_transformer/runtime/inference.py index 21e9125b..4ecc47fb 100644 --- a/env/SE3Transformer/se3_transformer/runtime/inference.py +++ b/env/SE3Transformer/se3_transformer/runtime/inference.py @@ -49,7 +49,7 @@ def evaluate(model: nn.Module, for callback in callbacks: callback.on_batch_start() - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.amp.autocast('cuda', enabled=args.amp): pred = model(*input) for callback in callbacks: diff --git a/env/SE3Transformer/se3_transformer/runtime/training.py b/env/SE3Transformer/se3_transformer/runtime/training.py index 53122779..5931be16 100644 --- a/env/SE3Transformer/se3_transformer/runtime/training.py +++ b/env/SE3Transformer/se3_transformer/runtime/training.py @@ -90,7 +90,7 @@ def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimi for callback in callbacks: callback.on_batch_start() - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.amp.autocast('cuda', enabled=args.amp): pred = model(*inputs) loss = loss_fn(pred, target) / args.accumulate_grad_batches @@ -127,7 +127,7 @@ def train(model: nn.Module, model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model.train() - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) + grad_scaler = torch.amp.GradScaler('cuda', enabled=args.amp) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) diff --git a/rfdiffusion/Track_module.py b/rfdiffusion/Track_module.py index 27511e5d..727c0a8a 100644 --- a/rfdiffusion/Track_module.py +++ b/rfdiffusion/Track_module.py @@ -233,7 +233,7 @@ def reset_parameter(self): nn.init.zeros_(self.embed_e1.bias) nn.init.zeros_(self.embed_e2.bias) - @torch.cuda.amp.autocast(enabled=False) + @torch.amp.autocast('cuda', enabled=False) def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses=None, top_k=64, eps=1e-5): B, N, L = msa.shape[:3]