From 6eaef81c65b93427d90f5c38a58990b804df57f5 Mon Sep 17 00:00:00 2001 From: aviveise Date: Wed, 3 Sep 2025 23:05:38 +0300 Subject: [PATCH 01/23] making units not mendatory if models not available --- diffsynth/pipelines/wan_video_new.py | 88 ++++++++++++++++++---------- diffsynth/utils/__init__.py | 14 ++++- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 660a38e7..89663801 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -48,22 +48,6 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ - WanVideoUnit_ShapeChecker(), - WanVideoUnit_NoiseInitializer(), - WanVideoUnit_PromptEmbedder(), - WanVideoUnit_S2V(), - WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_ImageEmbedderVAE(), - WanVideoUnit_ImageEmbedderCLIP(), - WanVideoUnit_ImageEmbedderFused(), - WanVideoUnit_FunControl(), - WanVideoUnit_FunReference(), - WanVideoUnit_FunCameraControl(), - WanVideoUnit_SpeedControl(), - WanVideoUnit_VACE(), - WanVideoUnit_UnifiedSequenceParallel(), - WanVideoUnit_TeaCache(), - WanVideoUnit_CfgMerger(), ] self.post_units = [ WanVideoPostUnit_S2V(), @@ -367,9 +351,10 @@ def from_pretrained( pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer - tokenizer_config.download_if_necessary(use_usp=use_usp) - pipe.prompter.fetch_models(pipe.text_encoder) - pipe.prompter.fetch_tokenizer(tokenizer_config.path) + if pipe.text_encoder is not None: + tokenizer_config.download_if_necessary(use_usp=use_usp) + pipe.prompter.fetch_models(pipe.text_encoder) + pipe.prompter.fetch_tokenizer(tokenizer_config.path) if audio_processor_config is not None: audio_processor_config.download_if_necessary(use_usp=use_usp) @@ -377,8 +362,30 @@ def from_pretrained( pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) # Unified Sequence Parallel if use_usp: pipe.enable_usp() + + pipe.initalize_units() + return pipe + def initalize_units(self): + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger() + ] @torch.no_grad() def __call__( @@ -519,7 +526,8 @@ def __call__( class WanVideoUnit_ShapeChecker(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width", "num_frames")) + super().__init__(input_params=("height", "width", "num_frames"), + output_params=("height", "width", "num_frames")) def process(self, pipe: WanVideoPipeline, height, width, num_frames): height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) @@ -529,7 +537,8 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames): class WanVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) + super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + output_params=("noise")) def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): length = (num_frames - 1) // 4 + 1 @@ -547,6 +556,7 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + output_params=("latents"), onload_model_names=("vae",) ) @@ -574,6 +584,7 @@ def __init__(self): seperate_cfg=True, input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, + output_params=("context"), onload_model_names=("text_encoder",) ) @@ -582,8 +593,6 @@ def process(self, pipe: WanVideoPipeline, prompt, positive) -> dict: prompt_emb = pipe.prompter.encode_prompt(prompt, positive=positive, device=pipe.device) return {"context": prompt_emb} - - class WanVideoUnit_ImageEmbedder(PipelineUnit): """ Deprecated @@ -591,6 +600,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("clip_context", "y"), onload_model_names=("image_encoder", "vae") ) @@ -629,6 +639,7 @@ class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "height", "width"), + output_params=("clip_feature"), onload_model_names=("image_encoder",) ) @@ -651,6 +662,7 @@ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("y"), onload_model_names=("vae",) ) @@ -688,7 +700,9 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae",) + output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), + onload_model_names=("vae",), + mendatory=False ) def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): @@ -706,7 +720,9 @@ class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), - onload_model_names=("vae",) + output_params=("clip_feature", "y"), + onload_model_names=("vae",), + mendatory=False ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): @@ -731,7 +747,9 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae",) + output_params=("reference_latents", "clip_feature"), + onload_model_names=("vae",), + mendatory=False ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -753,7 +771,9 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae",) + output_params=("control_camera_latents_input", "y"), + onload_model_names=("vae",), + mendatory=False ) def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): @@ -801,7 +821,8 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont class WanVideoUnit_SpeedControl(PipelineUnit): def __init__(self): - super().__init__(input_params=("motion_bucket_id",)) + super().__init__(input_params=("motion_bucket_id",), + mendatory=False) def process(self, pipe: WanVideoPipeline, motion_bucket_id): if motion_bucket_id is None: @@ -815,7 +836,9 @@ class WanVideoUnit_VACE(PipelineUnit): def __init__(self): super().__init__( input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae",) + output_params=("vace_context", "vace_scale"), + onload_model_names=("vae",), + mendatory=False ) def process( @@ -914,7 +937,9 @@ class WanVideoUnit_S2V(PipelineUnit): def __init__(self): super().__init__( take_over=True, - onload_model_names=("audio_encoder", "vae",) + onload_model_names=("audio_encoder", "vae",), + output_params=("audio_embeds", "s2v_pose_latents", "motion_latents"), + mendatory=False ) def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): @@ -994,7 +1019,8 @@ def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sam class WanVideoPostUnit_S2V(PipelineUnit): def __init__(self): - super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"), + output_params=("latents")) def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index ec3c7270..9ff2f05c 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -227,14 +227,18 @@ def __init__( input_params: tuple[str] = None, input_params_posi: dict[str, str] = None, input_params_nega: dict[str, str] = None, - onload_model_names: tuple[str] = None + output_params: tuple[str] = None, + onload_model_names: tuple[str] = None, + mendatory: bool = True ): self.seperate_cfg = seperate_cfg self.take_over = take_over self.input_params = input_params self.input_params_posi = input_params_posi self.input_params_nega = input_params_nega + self.output_params = output_params self.onload_model_names = onload_model_names + self.mendatory = mendatory def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict: @@ -247,6 +251,14 @@ def __init__(self): pass def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: + + + # Skip if any of the onload_model_names is not in the pipe + if any(pipe.getattr(model) is None for model in unit.onload_model_names): + if unit.mendatory and any(output_param not in inputs_shared.keys() for output_param in unit.output_params): + raise ValueError(f"The output parameters {unit.output_params} are not in the inputs_shared. Please check the pipeline unit {unit}.") + return inputs_shared, inputs_posi, inputs_nega + if unit.take_over: # Let the pipeline unit take over this function. inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega) From 587906351e60d78b4c3a07333784bdf6808e8c9d Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 4 Sep 2025 13:08:37 +0300 Subject: [PATCH 02/23] adding encoding only flag --- diffsynth/pipelines/wan_video_new.py | 8 +++++--- diffsynth/utils/__init__.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 89663801..ae02f3aa 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -74,7 +74,7 @@ def training_loss(self, **inputs): loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * self.scheduler.training_weight(timestep) - return loss + return loss, noise_pred def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5): @@ -560,7 +560,7 @@ def __init__(self): onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, pre_encoding=False): if input_video is None: return {"latents": noise} pipe.load_models_to_device(["vae"]) @@ -570,7 +570,9 @@ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) - if pipe.scheduler.training: + if pre_encoding: + return {"input_latents": input_latents} + elif pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index 9ff2f05c..bc4357f0 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -254,7 +254,7 @@ def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, # Skip if any of the onload_model_names is not in the pipe - if any(pipe.getattr(model) is None for model in unit.onload_model_names): + if any(pipe.getattr(model) is None for model in unit.onload_model_names) or any(input_param not in inputs_shared.keys() for input_param in unit.input_params): if unit.mendatory and any(output_param not in inputs_shared.keys() for output_param in unit.output_params): raise ValueError(f"The output parameters {unit.output_params} are not in the inputs_shared. Please check the pipeline unit {unit}.") return inputs_shared, inputs_posi, inputs_nega From bb800bba59128437660560f74a8693264ca448ef Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 4 Sep 2025 21:43:04 +0300 Subject: [PATCH 03/23] fixing --- diffsynth/pipelines/wan_video_new.py | 124 +++++++++++++++------------ diffsynth/utils/__init__.py | 11 --- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index ae02f3aa..5b2bd27a 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -30,7 +30,7 @@ class WanVideoPipeline(BasePipeline): - def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None, offline_preprocessing=False): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 @@ -47,8 +47,9 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non self.in_iteration_models = ("dit", "motion_controller", "vace") self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() - self.units = [ - ] + + self.initalize_units(offline_preprocessing=offline_preprocessing) + self.post_units = [ WanVideoPostUnit_S2V(), ] @@ -363,30 +364,43 @@ def from_pretrained( # Unified Sequence Parallel if use_usp: pipe.enable_usp() - pipe.initalize_units() - return pipe - def initalize_units(self): - self.units = [ - WanVideoUnit_ShapeChecker(), - WanVideoUnit_NoiseInitializer(), - WanVideoUnit_PromptEmbedder(), - WanVideoUnit_S2V(), - WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_ImageEmbedderVAE(), - WanVideoUnit_ImageEmbedderCLIP(), - WanVideoUnit_ImageEmbedderFused(), - WanVideoUnit_FunControl(), - WanVideoUnit_FunReference(), - WanVideoUnit_FunCameraControl(), - WanVideoUnit_SpeedControl(), - WanVideoUnit_VACE(), - WanVideoUnit_UnifiedSequenceParallel(), - WanVideoUnit_TeaCache(), - WanVideoUnit_CfgMerger() + def initalize_units(self, offline_preprocessing=False): + + if not offline_preprocessing: + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), + WanVideoUnit_S2V(), + WanVideoUnit_InputVideoEmbedder(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger() + ] + else: + self.units = [ + WanVideoUnit_ShapeChecker(), + WanVideoUnit_NoiseInitializer(), + WanVideoUnit_ImageEmbedderFusingOnly(), + WanVideoUnit_FunControl(), + WanVideoUnit_FunReference(), + WanVideoUnit_FunCameraControl(), + WanVideoUnit_SpeedControl(), + WanVideoUnit_VACE(), + WanVideoUnit_UnifiedSequenceParallel(), + WanVideoUnit_TeaCache(), + WanVideoUnit_CfgMerger() ] - @torch.no_grad() def __call__( self, @@ -526,8 +540,7 @@ def __call__( class WanVideoUnit_ShapeChecker(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width", "num_frames"), - output_params=("height", "width", "num_frames")) + super().__init__(input_params=("height", "width", "num_frames")) def process(self, pipe: WanVideoPipeline, height, width, num_frames): height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) @@ -537,8 +550,7 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames): class WanVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), - output_params=("noise")) + super().__init__(input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image")) def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): length = (num_frames - 1) // 4 + 1 @@ -556,11 +568,10 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), - output_params=("latents"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, pre_encoding=False): + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): if input_video is None: return {"latents": noise} pipe.load_models_to_device(["vae"]) @@ -570,8 +581,6 @@ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) - if pre_encoding: - return {"input_latents": input_latents} elif pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: @@ -586,7 +595,6 @@ def __init__(self): seperate_cfg=True, input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, - output_params=("context"), onload_model_names=("text_encoder",) ) @@ -602,7 +610,6 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - output_params=("clip_context", "y"), onload_model_names=("image_encoder", "vae") ) @@ -641,7 +648,6 @@ class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "height", "width"), - output_params=("clip_feature"), onload_model_names=("image_encoder",) ) @@ -664,7 +670,6 @@ class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - output_params=("y"), onload_model_names=("vae",) ) @@ -702,29 +707,42 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), - output_params=("latents", "fuse_vae_embedding_in_latents", "first_frame_latents"), - onload_model_names=("vae",), - mendatory=False + onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride, pre_encoding=False): if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) - latents[:, :, 0: 1] = z - return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + if pre_encoding: + return {"first_frame_latents": z} + else: + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z} + +class WanVideoUnit_ImageEmbedderFusingOnly(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("latents","first_frame_latents") + ) + def process(self, pipe: WanVideoPipeline, latents,first_frame_latents): + if first_frame_latents is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + latents[:, :, 0: 1] = first_frame_latents + return {"latents": latents, "fuse_vae_embedding_in_latents": True} class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), - output_params=("clip_feature", "y"), - onload_model_names=("vae",), - mendatory=False + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): @@ -749,9 +767,7 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - output_params=("reference_latents", "clip_feature"), - onload_model_names=("vae",), - mendatory=False + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -773,9 +789,7 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), - output_params=("control_camera_latents_input", "y"), - onload_model_names=("vae",), - mendatory=False + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): @@ -838,9 +852,7 @@ class WanVideoUnit_VACE(PipelineUnit): def __init__(self): super().__init__( input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), - output_params=("vace_context", "vace_scale"), - onload_model_names=("vae",), - mendatory=False + onload_model_names=("vae",) ) def process( @@ -941,7 +953,6 @@ def __init__(self): take_over=True, onload_model_names=("audio_encoder", "vae",), output_params=("audio_embeds", "s2v_pose_latents", "motion_latents"), - mendatory=False ) def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): @@ -1021,8 +1032,7 @@ def pre_calculate_audio_pose(pipe: WanVideoPipeline, input_audio=None, audio_sam class WanVideoPostUnit_S2V(PipelineUnit): def __init__(self): - super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames"), - output_params=("latents")) + super().__init__(input_params=("latents", "motion_latents", "drop_motion_frames")) def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_frames): if pipe.audio_encoder is None or motion_latents is None or drop_motion_frames: diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index bc4357f0..94603bc6 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -227,18 +227,14 @@ def __init__( input_params: tuple[str] = None, input_params_posi: dict[str, str] = None, input_params_nega: dict[str, str] = None, - output_params: tuple[str] = None, onload_model_names: tuple[str] = None, - mendatory: bool = True ): self.seperate_cfg = seperate_cfg self.take_over = take_over self.input_params = input_params self.input_params_posi = input_params_posi self.input_params_nega = input_params_nega - self.output_params = output_params self.onload_model_names = onload_model_names - self.mendatory = mendatory def process(self, pipe: BasePipeline, inputs: dict, positive=True, **kwargs) -> dict: @@ -251,13 +247,6 @@ def __init__(self): pass def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]: - - - # Skip if any of the onload_model_names is not in the pipe - if any(pipe.getattr(model) is None for model in unit.onload_model_names) or any(input_param not in inputs_shared.keys() for input_param in unit.input_params): - if unit.mendatory and any(output_param not in inputs_shared.keys() for output_param in unit.output_params): - raise ValueError(f"The output parameters {unit.output_params} are not in the inputs_shared. Please check the pipeline unit {unit}.") - return inputs_shared, inputs_posi, inputs_nega if unit.take_over: # Let the pipeline unit take over this function. From 9af74729e93d858322ffe5120a7691ce70ac34e9 Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 4 Sep 2025 21:53:55 +0300 Subject: [PATCH 04/23] fix in method --- diffsynth/pipelines/wan_video_new.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 5b2bd27a..85cfed02 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -951,8 +951,7 @@ class WanVideoUnit_S2V(PipelineUnit): def __init__(self): super().__init__( take_over=True, - onload_model_names=("audio_encoder", "vae",), - output_params=("audio_embeds", "s2v_pose_latents", "motion_latents"), + onload_model_names=("audio_encoder", "vae",) ) def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): From b6ecbd9b54d05f493435a2164fc6e69ba0800660 Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 4 Sep 2025 21:56:12 +0300 Subject: [PATCH 05/23] fix in method --- diffsynth/pipelines/wan_video_new.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 85cfed02..f46a6560 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -837,8 +837,7 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont class WanVideoUnit_SpeedControl(PipelineUnit): def __init__(self): - super().__init__(input_params=("motion_bucket_id",), - mendatory=False) + super().__init__(input_params=("motion_bucket_id",)) def process(self, pipe: WanVideoPipeline, motion_bucket_id): if motion_bucket_id is None: From c4ce0482c2293ecae5c4a55fc61dbc7f3995d1a8 Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 7 Sep 2025 13:53:21 +0300 Subject: [PATCH 06/23] adding offline preprocessing to from_pretrained method --- diffsynth/pipelines/wan_video_new.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f46a6560..7ff6c80e 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -304,6 +304,7 @@ def from_pretrained( audio_processor_config: ModelConfig = None, redirect_common_files: bool = True, use_usp=False, + offline_preprocessing=False ): # Redirect model path if redirect_common_files: @@ -320,7 +321,7 @@ def from_pretrained( model_config.model_id = redirect_dict[model_config.origin_file_pattern] # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, offline_preprocessing=offline_preprocessing) if use_usp: pipe.initialize_usp() # Download and load models From df68cbbe3f2ad40925daf331fa724e669a4b34e3 Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 7 Sep 2025 21:14:53 +0300 Subject: [PATCH 07/23] adding passive input video unit --- diffsynth/pipelines/wan_video_new.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 7ff6c80e..ca56366f 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -392,6 +392,7 @@ def initalize_units(self, offline_preprocessing=False): self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), + WanVideoUnit_InputVideoEmbedderPassThrough(), WanVideoUnit_ImageEmbedderFusingOnly(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), @@ -588,6 +589,18 @@ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents} +class WanVideoUnit_InputVideoEmbedderPassThrough(PipelineUnit): + def __init__(self): + super().__init__(input_params=("input_latents", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image")) + + def process(self, pipe: WanVideoPipeline, input_latents, noise, tiled, tile_size, tile_stride, vace_reference_image): + if input_latents is None: + return {"latents": noise} + elif pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents} class WanVideoUnit_PromptEmbedder(PipelineUnit): From 9187e9d6441f30ad79c48125cbcfaa4dd667153d Mon Sep 17 00:00:00 2001 From: aviveise Date: Tue, 9 Sep 2025 13:44:52 +0300 Subject: [PATCH 08/23] adding tea cache to wan 2.2 with wan 2.1 coefficients --- diffsynth/pipelines/wan_video_new.py | 54 +++++++++++-------- .../model_inference/Wan2.2-S2V-14B.py | 2 + 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index ca56366f..f2d7e9f6 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1360,6 +1360,7 @@ def model_fn_wans2v( use_gradient_checkpointing_offload=False, use_gradient_checkpointing=False, use_unified_sequence_parallel=False, + tea_cache=None, ): if use_unified_sequence_parallel: import torch.distributed as dist @@ -1398,6 +1399,11 @@ def model_fn_wans2v( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" @@ -1406,14 +1412,28 @@ def model_fn_wans2v( seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] seq_len_x = seq_len_x_list[sp_rank] - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward + if tea_cache_update: + x = tea_cache.update(x) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward - for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, seq_len_x, pre_compute_freqs[0], + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, seq_len_x, pre_compute_freqs[0], @@ -1424,20 +1444,12 @@ def custom_forward(*inputs): x, use_reentrant=False, ) - elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs[0], - use_reentrant=False, - ) - x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, - use_reentrant=False, - ) - else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) + + if tea_cache is not None: + tea_cache.store(x) if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py index bb93871a..426276f8 100644 --- a/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-S2V-14B.py @@ -45,6 +45,8 @@ audio_sample_rate=sample_rate, input_audio=input_audio, num_inference_steps=40, + tea_cache_l1_thresh=0.05, + tea_cache_model_id="Wan2.1-I2V-14B-480P", ) save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5) From 64b3508eb572d66241ac1cf46787e34253337ec7 Mon Sep 17 00:00:00 2001 From: aviveise Date: Tue, 9 Sep 2025 13:52:04 +0300 Subject: [PATCH 09/23] fix --- diffsynth/pipelines/wan_video_new.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f2d7e9f6..35c155d7 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1226,6 +1226,7 @@ def model_fn_wan_video( use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, use_unified_sequence_parallel=use_unified_sequence_parallel, + tea_cache=tea_cache, ) if use_unified_sequence_parallel: @@ -1447,7 +1448,7 @@ def custom_forward(*inputs): else: x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) - + if tea_cache is not None: tea_cache.store(x) From 440ce5ba01869fc0de80dcd010dc65bcee0c9a5b Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 11 Sep 2025 11:17:35 +0300 Subject: [PATCH 10/23] renaming dit inputs --- diffsynth/models/wan_video_dit.py | 52 ++++++++++---------- diffsynth/models/wan_video_dit_s2v.py | 70 +++++++++++++-------------- 2 files changed, 61 insertions(+), 61 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1a54728f..1aea7ef8 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -211,7 +211,7 @@ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.gate = GateModule() - def forward(self, x, context, t_mod, freqs): + def forward(self, hidden_states, encoder_hidden_states, t_mod, freqs): has_seq = len(t_mod.shape) == 4 chunk_dim = 2 if has_seq else 1 # msa: multi-head self-attention mlp: multi-layer perceptron @@ -222,12 +222,12 @@ def forward(self, x, context, t_mod, freqs): shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), ) - input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) - x = x + self.cross_attn(self.norm3(x), context) - input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) - x = self.gate(x, gate_mlp, self.ffn(input_x)) - return x + input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa) + hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs)) + hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states) + input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp) + hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x)) + return hidden_states class MLP(torch.nn.Module): @@ -244,10 +244,10 @@ def __init__(self, in_dim, out_dim, has_pos_emb=False): if has_pos_emb: self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) - def forward(self, x): + def forward(self, hidden_states): if self.has_pos_emb: - x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) - return self.proj(x) + hidden_states = hidden_states + self.emb_pos.to(dtype=hidden_states.dtype, device=hidden_states.device) + return self.proj(hidden_states) class Head(nn.Module): @@ -259,14 +259,14 @@ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - def forward(self, x, t_mod): + def forward(self, hidden_states, t_mod): if len(t_mod.shape) == 3: shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2) - x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2))) + hidden_states = (self.head(self.norm(hidden_states) * (1 + scale.squeeze(2)) + shift.squeeze(2))) else: shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + scale) + shift)) - return x + hidden_states = (self.head(self.norm(hidden_states) * (1 + scale) + shift)) + return hidden_states class WanModel(torch.nn.Module): @@ -354,9 +354,9 @@ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): ) def forward(self, - x: torch.Tensor, + hidden_states: torch.Tensor, timestep: torch.Tensor, - context: torch.Tensor, + encoder_hidden_states: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, @@ -366,20 +366,20 @@ def forward(self, t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) - context = self.text_embedding(context) + context = self.text_embedding(encoder_hidden_states) if self.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) + hidden_states, (f, h, w) = self.patchify(hidden_states) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + ], dim=-1).reshape(f * h * w, 1, -1).to(hidden_states.device) def create_custom_forward(module): def custom_forward(*inputs): @@ -390,23 +390,23 @@ def custom_forward(*inputs): if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: - x = block(x, context, t_mod, freqs) + hidden_states = block(hidden_states, context, t_mod, freqs) - x = self.head(x, t) - x = self.unpatchify(x, (f, h, w)) - return x + hidden_states = self.head(hidden_states, t) + hidden_states = self.unpatchify(hidden_states, (f, h, w)) + return hidden_states @staticmethod def state_dict_converter(): diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 70881e6d..556276bd 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -341,7 +341,7 @@ def forward(self, features): class WanS2VDiTBlock(DiTBlock): - def forward(self, x, context, t_mod, seq_len_x, freqs): + def forward(self, hidden_states, encoder_hidden_states, t_mod, seq_len_x, freqs): t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. t_mod = [ @@ -349,12 +349,12 @@ def forward(self, x, context, t_mod, seq_len_x, freqs): for element in t_mod ] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod - input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) - x = x + self.cross_attn(self.norm3(x), context) - input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) - x = self.gate(x, gate_mlp, self.ffn(input_x)) - return x + input_x = modulate(self.norm1(hidden_states), shift_msa, scale_msa) + hidden_states = self.gate(hidden_states, gate_msa, self.self_attn(input_x, freqs)) + hidden_states = hidden_states + self.cross_attn(self.norm3(hidden_states), encoder_hidden_states) + input_x = modulate(self.norm2(hidden_states), shift_mlp, scale_mlp) + hidden_states = self.gate(hidden_states, gate_mlp, self.ffn(input_x)) + return hidden_states class WanS2VModel(torch.nn.Module): @@ -505,7 +505,7 @@ def forward( self, latents, timestep, - context, + encoder_hidden_states, audio_input, motion_latents, pose_cond, @@ -513,33 +513,33 @@ def forward( use_gradient_checkpointing=False ): origin_ref_latents = latents[:, :, 0:1] - x = latents[:, :, 1:] + hidden_states = latents[:, :, 1:] # context embedding - context = self.text_embedding(context) + encoder_hidden_states = self.text_embedding(encoder_hidden_states) # audio encode audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) # x and pose_cond - pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) - seq_len_x = x.shape[1] + pose_cond = torch.zeros_like(hidden_states) if pose_cond is None else pose_cond + hidden_states, (f, h, w) = self.patchify(self.patch_embedding(hidden_states) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = hidden_states.shape[1] # reference image ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) - x = torch.cat([x, ref_latents], dim=1) + hidden_states = torch.cat([hidden_states, ref_latents], dim=1) # mask - mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(hidden_states.device) # freqs pre_compute_freqs = rope_precompute( - x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + hidden_states.detach().view(1, hidden_states.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None ) # motion - x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + hidden_states, pre_compute_freqs, mask = self.inject_motion(hidden_states, pre_compute_freqs, mask, motion_latents, add_last_motion=2) - x = x + self.trainable_cond_mask(mask).to(x.dtype) + hidden_states = hidden_states + self.trainable_cond_mask(mask).to(hidden_states.dtype) # t_mod timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) @@ -554,45 +554,45 @@ def custom_forward(*inputs): for block_id, block in enumerate(self.blocks): if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, - context, + hidden_states, + encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, + hidden_states, use_reentrant=False, ) elif use_gradient_checkpointing: - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, - context, + hidden_states, + encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) - x = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - x, + hidden_states, use_reentrant=False, ) else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) - x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) + hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x) - x = x[:, :seq_len_x] - x = self.head(x, t[:-1]) - x = self.unpatchify(x, (f, h, w)) + hidden_states = hidden_states[:, :seq_len_x] + hidden_states = self.head(hidden_states, t[:-1]) + hidden_states = self.unpatchify(hidden_states, (f, h, w)) # make compatible with wan video - x = torch.cat([origin_ref_latents, x], dim=2) - return x + hidden_states = torch.cat([origin_ref_latents, hidden_states], dim=2) + return hidden_states @staticmethod def state_dict_converter(): From 98ffebfa336b338987e0b0227804158ec0299159 Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 11 Sep 2025 11:22:19 +0300 Subject: [PATCH 11/23] renaming dit inputs --- diffsynth/models/wan_video_dit_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 556276bd..ff75ec32 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -345,7 +345,7 @@ def forward(self, hidden_states, encoder_hidden_states, t_mod, seq_len_x, freqs) t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. t_mod = [ - torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + torch.cat([element[:, :, 0].expand(1, seq_len_x, hidden_states.shape[-1]), element[:, :, 1].expand(1, hidden_states.shape[1] - seq_len_x, hidden_states.shape[-1])], dim=1) for element in t_mod ] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod From cfcf4dacfe44ba38b7c23db22b41c72dd53badb9 Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 11 Sep 2025 12:09:08 +0300 Subject: [PATCH 12/23] using dit forward --- diffsynth/models/wan_video_dit_s2v.py | 85 ++++++++++++++++++--------- diffsynth/pipelines/wan_video_new.py | 14 ++--- 2 files changed, 65 insertions(+), 34 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index ff75ec32..ac412429 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -510,8 +510,17 @@ def forward( motion_latents, pose_cond, use_gradient_checkpointing_offload=False, - use_gradient_checkpointing=False + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, + tea_cache=None, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + + origin_ref_latents = latents[:, :, 0:1] hidden_states = latents[:, :, 1:] @@ -546,14 +555,45 @@ def forward( t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - return custom_forward - - for block_id, block in enumerate(self.blocks): - if use_gradient_checkpointing_offload: - with torch.autograd.graph.save_on_cpu(): + if tea_cache is not None: + tea_cache_update = tea_cache.check(self, hidden_states, t_mod) + else: + tea_cache_update = False + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert hidden_states.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {hidden_states.shape[1]} and {get_sequence_parallel_world_size()}" + hidden_states = torch.chunk(hidden_states, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([hidden_states.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), hidden_states.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + + if tea_cache_update: + hidden_states = tea_cache.update(hidden_states) + else: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + t_mod, + seq_len_x, + pre_compute_freqs[0], + use_reentrant=False, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + hidden_states, + use_reentrant=False, + ) + elif use_gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, @@ -568,24 +608,15 @@ def custom_forward(*inputs): hidden_states, use_reentrant=False, ) - elif use_gradient_checkpointing: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - t_mod, - seq_len_x, - pre_compute_freqs[0], - use_reentrant=False, - ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), - hidden_states, - use_reentrant=False, - ) - else: - hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) - hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x) + else: + hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) + hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x) + + if tea_cache is not None: + tea_cache.store(hidden_states) + + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) hidden_states = hidden_states[:, :seq_len_x] hidden_states = self.head(hidden_states, t[:-1]) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 35c155d7..0473b358 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1061,7 +1061,7 @@ def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.previous_modulated_input = None self.rel_l1_thresh = rel_l1_thresh self.previous_residual = None - self.previous_hidden_states = None + self.previous_hidden_states = None self.coefficients_dict = { "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], @@ -1212,17 +1212,17 @@ def model_fn_wan_video( tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) + # wan2.2 s2v if audio_embeds is not None: - return model_fn_wans2v( - dit=dit, + return dit( latents=latents, timestep=timestep, - context=context, - audio_embeds=audio_embeds, + encoder_hidden_states=context, + audio_input=audio_embeds, motion_latents=motion_latents, - s2v_pose_latents=s2v_pose_latents, - drop_motion_frames=drop_motion_frames, + pose_cond=s2v_pose_latents, + motion_latents=drop_motion_frames, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, use_unified_sequence_parallel=use_unified_sequence_parallel, From 38d594ab3302c87877d581a3684c1c6436bdfa3f Mon Sep 17 00:00:00 2001 From: aviveise Date: Thu, 11 Sep 2025 12:18:31 +0300 Subject: [PATCH 13/23] using dit forward --- diffsynth/pipelines/wan_video_new.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 0473b358..4d381c16 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1212,17 +1212,16 @@ def model_fn_wan_video( tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) - + # wan2.2 s2v if audio_embeds is not None: return dit( latents=latents, timestep=timestep, encoder_hidden_states=context, - audio_input=audio_embeds, - motion_latents=motion_latents, + audio_input=audio_embeds, pose_cond=s2v_pose_latents, - motion_latents=drop_motion_frames, + motion_latents=motion_latents, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, use_unified_sequence_parallel=use_unified_sequence_parallel, From 7aedfd98ddd724428615a74eb722567a53992da6 Mon Sep 17 00:00:00 2001 From: aviveise Date: Fri, 19 Sep 2025 23:46:17 +0300 Subject: [PATCH 14/23] fixing usp --- .../distributed/xdit_context_parallel.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 4887e2f1..07351a6e 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -40,9 +40,9 @@ def rope_apply(x, freqs, num_heads): return x_out.to(x.dtype) def usp_dit_forward(self, - x: torch.Tensor, + latents: torch.Tensor, timestep: torch.Tensor, - context: torch.Tensor, + encoder_hidden_states: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, @@ -52,20 +52,20 @@ def usp_dit_forward(self, t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) - context = self.text_embedding(context) + encoder_hidden_states = self.text_embedding(encoder_hidden_states) if self.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + latents = torch.cat([latents, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) - context = torch.cat([clip_embdding, context], dim=1) + encoder_hidden_states = torch.cat([clip_embdding, encoder_hidden_states], dim=1) - x, (f, h, w) = self.patchify(x) + latents, (f, h, w) = self.patchify(latents) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + ], dim=-1).reshape(f * h * w, 1, -1).to(latents.device) def create_custom_forward(module): def custom_forward(*inputs): @@ -73,44 +73,44 @@ def custom_forward(*inputs): return custom_forward # Context Parallel - chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1) + chunks = torch.chunk(latents, get_sequence_parallel_world_size(), dim=1) pad_shape = chunks[0].shape[1] - chunks[-1].shape[1] chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks] - x = chunks[get_sequence_parallel_rank()] + latents = chunks[get_sequence_parallel_rank()] for block in self.blocks: if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): - x = torch.utils.checkpoint.checkpoint( + latents = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, freqs, + latents, encoder_hidden_states, t_mod, freqs, use_reentrant=False, ) else: - x = torch.utils.checkpoint.checkpoint( + latents = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, freqs, + latents, encoder_hidden_states, t_mod, freqs, use_reentrant=False, ) else: - x = block(x, context, t_mod, freqs) + latents = block(latents, encoder_hidden_states, t_mod, freqs) - x = self.head(x, t) + latents = self.head(latents, t) # Context Parallel - x = get_sp_group().all_gather(x, dim=1) - x = x[:, :-pad_shape] if pad_shape > 0 else x + latents = get_sp_group().all_gather(latents, dim=1) + latents = latents[:, :-pad_shape] if pad_shape > 0 else latents # unpatchify - x = self.unpatchify(x, (f, h, w)) - return x + latents = self.unpatchify(latents, (f, h, w)) + return latents -def usp_attn_forward(self, x, freqs): - q = self.norm_q(self.q(x)) - k = self.norm_k(self.k(x)) - v = self.v(x) +def usp_attn_forward(self, latents, freqs): + q = self.norm_q(self.q(latents)) + k = self.norm_k(self.k(latents)) + v = self.v(latents) q = rope_apply(q, freqs, self.num_heads) k = rope_apply(k, freqs, self.num_heads) @@ -118,14 +118,14 @@ def usp_attn_forward(self, x, freqs): k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) - x = xFuserLongContextAttention()( + latents = xFuserLongContextAttention()( None, query=q, key=k, value=v, ) - x = x.flatten(2) + latents = latents.flatten(2) del q, k, v torch.cuda.empty_cache() - return self.o(x) \ No newline at end of file + return self.o(latents) \ No newline at end of file From b1a62a3ee2e384e1e9178678e447fa2a5254e6d0 Mon Sep 17 00:00:00 2001 From: aviveise Date: Fri, 19 Sep 2025 23:52:48 +0300 Subject: [PATCH 15/23] disabling usp dit method overide --- diffsynth/pipelines/wan_video_new.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 4d381c16..8acf9c09 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -286,7 +286,7 @@ def enable_usp(self): for block in self.dit.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) - self.dit.forward = types.MethodType(usp_dit_forward, self.dit) + #self.dit.forward = types.MethodType(usp_dit_forward, self.dit) if self.dit2 is not None: for block in self.dit2.blocks: block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) From 5aa73634339c736f1cd26e77d32f2c41044dc9ed Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 09:59:12 +0300 Subject: [PATCH 16/23] adding prints for debug --- diffsynth/distributed/xdit_context_parallel.py | 2 ++ diffsynth/models/wan_video_dit_s2v.py | 3 +++ diffsynth/pipelines/wan_video_new.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 07351a6e..9b0e6aaf 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -49,6 +49,7 @@ def usp_dit_forward(self, use_gradient_checkpointing_offload: bool = False, **kwargs, ): + print("usp_dit_forward") t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) @@ -108,6 +109,7 @@ def custom_forward(*inputs): def usp_attn_forward(self, latents, freqs): + print("usp_attn_forward") q = self.norm_q(self.q(latents)) k = self.norm_k(self.k(latents)) v = self.v(latents) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index ac412429..db6f1c62 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -519,6 +519,7 @@ def forward( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) + print("use_unified_sequence_parallel") origin_ref_latents = latents[:, :, 0:1] @@ -567,6 +568,7 @@ def forward( seg_idxs = [0] + list(torch.cumsum(torch.tensor([hidden_states.shape[1]] * world_size), dim=0).cpu().numpy()) seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), hidden_states.shape[1]) for i in range(len(seg_idxs)-1)] seq_len_x = seq_len_x_list[sp_rank] + print("use_unified_sequence_parallel seq_len_x", sp_rank, world_size) if tea_cache_update: hidden_states = tea_cache.update(hidden_states) @@ -617,6 +619,7 @@ def custom_forward(*inputs): if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + print("use_unified_sequence_parallel all_gather", sp_rank, world_size) hidden_states = hidden_states[:, :seq_len_x] hidden_states = self.head(hidden_states, t[:-1]) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 8acf9c09..8f335c22 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -919,7 +919,9 @@ def __init__(self): def process(self, pipe: WanVideoPipeline): if hasattr(pipe, "use_unified_sequence_parallel"): if pipe.use_unified_sequence_parallel: + print("use_unified_sequence_parallel true") return {"use_unified_sequence_parallel": True} + return {} From 6cb86d75e0682d981d9e15d201805bcad0540c87 Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 10:06:55 +0300 Subject: [PATCH 17/23] print --- diffsynth/models/wan_video_dit_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index db6f1c62..2254897d 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -519,7 +519,7 @@ def forward( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - print("use_unified_sequence_parallel") + print("use_unified_sequence_parallel", dist.is_initialized(), dist.get_world_size()") origin_ref_latents = latents[:, :, 0:1] From defe06d34877b828af24f049e33b0289da46ca1d Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 10:07:50 +0300 Subject: [PATCH 18/23] print --- diffsynth/models/wan_video_dit_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 2254897d..9ee2fc9f 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -519,7 +519,7 @@ def forward( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - print("use_unified_sequence_parallel", dist.is_initialized(), dist.get_world_size()") + print("use_unified_sequence_parallel", dist.is_initialized(), dist.get_world_size()) origin_ref_latents = latents[:, :, 0:1] From 982f81044f18d27e2b1062dbfd5dc5c36cbe6fa9 Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 10:24:34 +0300 Subject: [PATCH 19/23] fix in seq_len_x_global --- diffsynth/models/wan_video_dit_s2v.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 9ee2fc9f..8ee3e425 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -534,7 +534,7 @@ def forward( # x and pose_cond pose_cond = torch.zeros_like(hidden_states) if pose_cond is None else pose_cond hidden_states, (f, h, w) = self.patchify(self.patch_embedding(hidden_states) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) - seq_len_x = hidden_states.shape[1] + seq_len_x = seq_len_x_global = hidden_states.shape[1] # reference image ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) @@ -612,7 +612,7 @@ def custom_forward(*inputs): ) else: hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) - hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x) + hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x_global) if tea_cache is not None: tea_cache.store(hidden_states) @@ -621,7 +621,7 @@ def custom_forward(*inputs): hidden_states = get_sp_group().all_gather(hidden_states, dim=1) print("use_unified_sequence_parallel all_gather", sp_rank, world_size) - hidden_states = hidden_states[:, :seq_len_x] + hidden_states = hidden_states[:, :seq_len_x_global] hidden_states = self.head(hidden_states, t[:-1]) hidden_states = self.unpatchify(hidden_states, (f, h, w)) # make compatible with wan video From b1d9717237783f14aa1ecbf61029af98602d9dff Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 16:52:58 +0300 Subject: [PATCH 20/23] fix in dit model --- diffsynth/models/wan_video_dit_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 8ee3e425..88219426 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -612,7 +612,7 @@ def custom_forward(*inputs): ) else: hidden_states = block(hidden_states, encoder_hidden_states, t_mod, seq_len_x, pre_compute_freqs[0]) - hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x_global) + hidden_states = self.after_transformer_block(block_id, hidden_states, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel=use_unified_sequence_parallel) if tea_cache is not None: tea_cache.store(hidden_states) From 144a475f181cec035787e1dac62c761b2e981e1f Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 16:57:36 +0300 Subject: [PATCH 21/23] removing prints --- diffsynth/models/wan_video_dit_s2v.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 88219426..6e4c5ee9 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -519,7 +519,6 @@ def forward( from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) - print("use_unified_sequence_parallel", dist.is_initialized(), dist.get_world_size()) origin_ref_latents = latents[:, :, 0:1] @@ -568,7 +567,6 @@ def forward( seg_idxs = [0] + list(torch.cumsum(torch.tensor([hidden_states.shape[1]] * world_size), dim=0).cpu().numpy()) seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), hidden_states.shape[1]) for i in range(len(seg_idxs)-1)] seq_len_x = seq_len_x_list[sp_rank] - print("use_unified_sequence_parallel seq_len_x", sp_rank, world_size) if tea_cache_update: hidden_states = tea_cache.update(hidden_states) @@ -619,7 +617,6 @@ def custom_forward(*inputs): if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: hidden_states = get_sp_group().all_gather(hidden_states, dim=1) - print("use_unified_sequence_parallel all_gather", sp_rank, world_size) hidden_states = hidden_states[:, :seq_len_x_global] hidden_states = self.head(hidden_states, t[:-1]) From 0e6e3e28fd417e202e4d6c21ae73b6a9443e517c Mon Sep 17 00:00:00 2001 From: aviveise Date: Sun, 21 Sep 2025 17:00:11 +0300 Subject: [PATCH 22/23] removing prints --- diffsynth/distributed/xdit_context_parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py index 9b0e6aaf..07351a6e 100644 --- a/diffsynth/distributed/xdit_context_parallel.py +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -49,7 +49,6 @@ def usp_dit_forward(self, use_gradient_checkpointing_offload: bool = False, **kwargs, ): - print("usp_dit_forward") t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) @@ -109,7 +108,6 @@ def custom_forward(*inputs): def usp_attn_forward(self, latents, freqs): - print("usp_attn_forward") q = self.norm_q(self.q(latents)) k = self.norm_k(self.k(latents)) v = self.v(latents) From 8b4c85973e906ffe289857bf57b8f364e76056c3 Mon Sep 17 00:00:00 2001 From: aviveise Date: Wed, 29 Oct 2025 16:33:54 +0200 Subject: [PATCH 23/23] adding fps to predict --- diffsynth/pipelines/wan_video_new.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 8f335c22..f9266d50 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -270,7 +270,10 @@ def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=No def initialize_usp(self): import torch.distributed as dist from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment - dist.init_process_group(backend="nccl", init_method="env://") + + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://") + init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) initialize_model_parallel( sequence_parallel_degree=dist.get_world_size(), @@ -464,6 +467,7 @@ def __call__( tea_cache_model_id: Optional[str] = "", # progress_bar progress_bar_cmd=tqdm, + fps: Optional[int] = 16, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) @@ -492,6 +496,7 @@ def __call__( "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "fps": fps, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -969,7 +974,7 @@ def __init__(self): onload_model_names=("audio_encoder", "vae",) ) - def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps=16, audio_embeds=None, return_all=False): + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames, fps, audio_embeds=None, return_all=False): if audio_embeds is not None: return {"audio_embeds": audio_embeds} pipe.load_models_to_device(["audio_encoder"]) @@ -1023,8 +1028,8 @@ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_neg num_frames, height, width, tiled, tile_size, tile_stride = inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width"), inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") input_audio, audio_embeds, audio_sample_rate = inputs_shared.pop("input_audio"), inputs_shared.pop("audio_embeds"), inputs_shared.get("audio_sample_rate") s2v_pose_video, s2v_pose_latents, motion_video = inputs_shared.pop("s2v_pose_video"), inputs_shared.pop("s2v_pose_latents"), inputs_shared.pop("motion_video") - - audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, audio_embeds=audio_embeds) + fps = inputs_shared.get("fps") + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames, fps, audio_embeds=audio_embeds) inputs_posi.update(audio_input_positive) inputs_nega.update({"audio_embeds": 0.0 * audio_input_positive["audio_embeds"]})