diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 3c2181abf..3d5db6801 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -469,7 +469,7 @@ def __init__( def forward(self, x, feat_cache=None, feat_idx=[0]): x_copy = x.clone() for module in self.downsamples: - x = module(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = module(x, feat_cache, feat_idx) return x + self.avg_shortcut(x_copy), feat_cache, feat_idx @@ -506,10 +506,10 @@ def __init__( def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x_main = x.clone() for module in self.upsamples: - x_main = module(x_main, feat_cache, feat_idx) + x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx) if self.avg_shortcut is not None: x_shortcut = self.avg_shortcut(x, first_chunk) - return x_main + x_shortcut + return x_main + x_shortcut, feat_cache, feat_idx else: return x_main, feat_cache, feat_idx