From 996388ae2098ce4df1a33138cb18c3faedab32a6 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 10:44:35 +0530 Subject: [PATCH 01/10] Changes for WAN 2.2 --- .../checkpointing/wan_checkpointer.py | 49 ++- src/maxdiffusion/configs/base_wan_27b.yml | 332 ++++++++++++++++++ src/maxdiffusion/generate_wan.py | 52 ++- src/maxdiffusion/models/wan/wan_utils.py | 9 +- .../pipelines/wan/wan_pipeline.py | 164 ++++++--- 5 files changed, 532 insertions(+), 74 deletions(-) create mode 100644 src/maxdiffusion/configs/base_wan_27b.yml diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 0dd493a3..1f8db8f7 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -34,7 +34,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type - self.opt_state = None + self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( self.config.checkpoint_dir, @@ -60,23 +60,36 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - transformer_metadata = metadatas.wan_state - abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) - params_restore = ocp.args.PyTreeRestore( + + restore_args = {} + + low_state_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata) + low_state_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_params, + abstract_tree_structure_low_state, ) ) + restore_args["low_noise_transformer_state"] = low_state_restore + + if self.run_wan2_2: + high_state_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata) + high_state_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_state, + ) + ) + restore_args["high_noise_transformer_state"] = high_state_restore + + restore_args["wan_config"] = ocp.args.JsonRestore() max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), step=step, - args=ocp.args.Composite( - wan_state=params_restore, - wan_config=ocp.args.JsonRestore(), - ), + args=ocp.args.Composite(**restore_args), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") @@ -110,14 +123,22 @@ def config_to_json(model_or_config): max_logging.log(f"Saving checkpoint for step {train_step}") items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - items["wan_state"] = ocp.args.PyTreeSave(train_states) + if "low_noise_transformer" in train_states: + low_noise_state = train_states["low_noise_transformer"] + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state) + if self.run_wan2_2: + if "high_noise_transformer" in train_states: + high_noise_state = train_states["high_noise_transformer"] + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state) + # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") + if len(items) > 1: + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml new file mode 100644 index 00000000..81fc5914 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -0,0 +1,332 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' +run_wan2_2: True + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +dropout: 0.1 + +flash_block_sizes: { + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_length', 'fsdp'], + + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +flow_shift: 3.0 + +guidance_scale_low: 5.0 +guidance_scale_high: 8.0 +boundary_timestep: 15 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 501dbf32..46fca5e8 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -19,10 +19,23 @@ from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig, max_logging, max_utils from absl import app +from absl import flags from maxdiffusion.utils import export_to_video from google.cloud import storage import flax +_MODEL_NAME = flags.DEFINE_enum( + "model_name", + default="wan2.1", + enum_values=["wan2.1", "wan2.2"], + help="The model version to run (wan2.1 or wan2.2). This determines the base config file.", +) + +CONFIG_BASE_DIR = "src/maxdiffusion/configs" +MODEL_CONFIG_MAP = { + "wan2.1": "base_wan_14b.yml", + "wan2.2": "base_wan_27b.yml", +} def upload_video_to_gcs(output_dir: str, video_path: str): """ @@ -80,7 +93,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") @@ -107,6 +123,10 @@ def run(config, pipeline=None, filename_prefix=""): # Using global_batch_size_to_train_on so not to create more config variables prompt = [config.prompt] * config.global_batch_size_to_train_on negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on + guidance_scale = config.guidance_scale if 'guidance_scale' in config.__dict__ else 5 + guidance_scale_low = config.guidance_scale_low if 'guidance_scale_low' in config.__dict__ else 3 + guidance_scale_high = config.guidance_scale_high if 'guidance_scale_high' in config.__dict__ else 4 + boundary = config.boundary_timestep if 'boundary_timestep' in config.__dict__ else 875 max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" @@ -119,7 +139,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) print("compile time: ", (time.perf_counter() - s0)) @@ -139,7 +162,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) print("generation time: ", (time.perf_counter() - s0)) @@ -153,7 +179,10 @@ def run(config, pipeline=None, filename_prefix=""): width=config.width, num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, + guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, ) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) @@ -161,7 +190,20 @@ def run(config, pipeline=None, filename_prefix=""): def main(argv: Sequence[str]) -> None: - pyconfig.initialize(argv) + # Get the model name from the flag + model_key = _MODEL_NAME.value + config_filename = MODEL_CONFIG_MAP[model_key] + selected_yaml_path = os.path.join(CONFIG_BASE_DIR, config_filename) + + max_logging.log(f"Using model: {model_key}, loading base config: {selected_yaml_path}") + + # Construct argv for pyconfig.initialize + # argv[0] is the program name. + # Insert the selected YAML path at index 1. + # The rest of argv (argv[1:]) are the overrides. + argv_for_pyconfig = list(argv[:1]) + [selected_yaml_path] + list(argv[1:]) + + pyconfig.initialize(argv_for_pyconfig) flax.config.update("flax_always_shard_variable", False) run(pyconfig.config) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index ec97abd3..191d8b61 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -184,6 +184,7 @@ def load_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: @@ -192,7 +193,7 @@ def load_wan_transformer( return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) else: return load_base_wan_transformer( - pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers + pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder ) @@ -203,9 +204,9 @@ def load_base_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): device = jax.local_devices(backend=device)[0] - subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False if os.path.isdir(pretrained_model_name_or_path): @@ -236,7 +237,7 @@ def load_base_wan_transformer( else: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}") if ckpt_shard_path is not None: with safe_open(ckpt_shard_path, framework="pt") as f: @@ -281,7 +282,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") elif hf_download: ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) - max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}") with jax.default_device(device): if ckpt_path is not None: tensors = {} diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 55981be0..3f343845 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -89,7 +89,7 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" ): def create_model(rngs: nnx.Rngs, wan_config: dict): @@ -100,7 +100,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): if restored_checkpoint: wan_config = restored_checkpoint["wan_config"] else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -142,6 +142,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "cpu", num_layers=wan_config["num_layers"], scan_layers=config.scan_layers, + subfolder=subfolder, ) params = jax.tree_util.tree_map_with_path( @@ -191,7 +192,8 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanModel, + low_noise_transformer: WanModel, + high_noise_transformer: Optional[WanModel], vae: AutoencoderKLWan, vae_cache: AutoencoderKLWanCache, scheduler: FlaxUniPCMultistepScheduler, @@ -202,7 +204,8 @@ def __init__( ): self.tokenizer = tokenizer self.text_encoder = text_encoder - self.transformer = transformer + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer self.vae = vae self.vae_cache = vae_cache self.scheduler = scheduler @@ -210,6 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config + self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -353,11 +357,10 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None - ): + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder ) return wan_transformer @@ -376,7 +379,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - transformer = None + run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + low_noise_transformer = None + high_noise_transformer = None tokenizer = None scheduler = None scheduler_state = None @@ -384,9 +389,9 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint - ) + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") + if run_wan2_2: + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -399,7 +404,8 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ return WanPipeline( tokenizer=tokenizer, text_encoder=text_encoder, - transformer=transformer, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, vae=wan_vae, vae_cache=vae_cache, scheduler=scheduler, @@ -415,7 +421,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - transformer = None + run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + low_noise_transformer = None + high_noise_transformer = None tokenizer = None scheduler = None scheduler_state = None @@ -423,8 +431,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") + if run_wan2_2: + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -436,7 +445,8 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform pipeline = WanPipeline( tokenizer=tokenizer, text_encoder=text_encoder, - transformer=transformer, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, vae=wan_vae, vae_cache=vae_cache, scheduler=scheduler, @@ -446,7 +456,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform config=config, ) - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh) + pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) + if run_wan2_2: + pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) return pipeline def _get_t5_prompt_embeds( @@ -546,6 +558,9 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, num_videos_per_prompt: Optional[int] = 1, max_sequence_length: int = 512, latents: jax.Array = None, @@ -575,7 +590,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_channel_latents = self.transformer.config.in_channels + num_channel_latents = self.low_noise_transformer.config.in_channels if latents is None: latents = self.prepare_latents( batch_size=batch_size, @@ -600,22 +615,31 @@ def __call__( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None + if self.run_wan2_2: + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) p_run_inference = partial( run_inference, + run_wan2_2=self.run_wan2_2, guidance_scale=guidance_scale, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, - num_transformer_layers=self.transformer.config.num_layers, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -635,43 +659,74 @@ def __call__( return video -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +@partial(jax.jit, static_argnames=("run_wan2_2", "guidance_scale", "guidance_scale_low", "guidance_scale_high", "boundary", "do_classifier_free_guidance")) def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, timestep, prompt_embeds, - do_classifier_free_guidance, - guidance_scale, + run_wan2_2: bool, + guidance_scale: float, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + do_classifier_free_guidance: bool, + t: jnp.array, ): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] + low_noise_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + noise_pred_low = low_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred_low + current_guide_scale = guidance_scale + if run_wan2_2: + high_noise_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + noise_pred_high = high_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred = jax.lax.cond( + use_high_noise, + lambda: noise_pred_high, + lambda: noise_pred_low, + ) + current_guide_scale = jax.lax.cond( + use_high_noise, + lambda: guidance_scale_high, + lambda: guidance_scale_low, + ) - return noise_pred, latents + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] + return noise_pred, latents def run_inference( - graphdef, - sharded_state, - rest_of_state, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, + run_wan2_2: bool, guidance_scale: float, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, - num_transformer_layers: int, scheduler_state, ): do_classifier_free_guidance = guidance_scale > 1.0 + if run_wan2_2: + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): @@ -681,14 +736,21 @@ def run_inference( timestep = jnp.broadcast_to(t, latents.shape[0]) noise_pred, latents = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, timestep, prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale, + run_wan2_2, + guidance_scale, + guidance_scale_low, + guidance_scale_high, + boundary, + do_classifier_free_guidance, + t ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() From c094a73bfd0a88f8c1add6fe9b85118fadd79950 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 11:03:04 +0530 Subject: [PATCH 02/10] changes return type of checkpoint_loader to tuple --- src/maxdiffusion/generate_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 46fca5e8..c4b53b8d 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -115,7 +115,7 @@ def run(config, pipeline=None, filename_prefix=""): from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline = checkpoint_loader.load_checkpoint() + pipeline, opt_state, step = checkpoint_loader.load_checkpoint() if pipeline is None: pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() From 33bf49c328e9d92a7e50180b1435191a13c2b624 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 18:42:20 +0530 Subject: [PATCH 03/10] opt_state=None added --- src/maxdiffusion/checkpointing/wan_checkpointer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 1f8db8f7..4295e633 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -34,6 +34,7 @@ class WanCheckpointer(ABC): def __init__(self, config, checkpoint_type): self.config = config self.checkpoint_type = checkpoint_type + self.opt_state = None self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( From 8a752e78d711e5b871baeddc2c2700696834b707 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Sat, 25 Oct 2025 21:10:39 +0530 Subject: [PATCH 04/10] added model_name in config file --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 50e66964..8dea4e3a 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -28,6 +28,7 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' +model_name: wan2.1 # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 81fc5914..6d005bdd 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -28,7 +28,7 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' -run_wan2_2: True +model_name: wan2.2 # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3f343845..69833842 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -213,7 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config - self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False + self.run_wan2_2 = config.model_name == "wan2.2" self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + run_wan2_2 = config.model_name == "wan2.2" low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in config.__dict__ else False + run_wan2_2 = config.model_name == "wan2.2" low_noise_transformer = None high_noise_transformer = None tokenizer = None From 1be0361b80702bc7cdba4dc022b489ca94bf4595 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 27 Oct 2025 15:04:47 +0530 Subject: [PATCH 05/10] double noise computation fixed --- .../pipelines/wan/wan_pipeline.py | 118 ++++++++---------- 1 file changed, 51 insertions(+), 67 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 69833842..7e62a236 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -213,7 +213,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config - self.run_wan2_2 = config.model_name == "wan2.2" + self.model_name = config.model_name self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -379,7 +379,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.model_name == "wan2.2" + model_name = config.model_name low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -390,7 +390,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if load_transformer: with mesh: low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - if run_wan2_2: + if model_name == "wan2.2": high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) @@ -421,7 +421,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform mesh = Mesh(devices_array, config.mesh_axes) rng = jax.random.key(config.seed) rngs = nnx.Rngs(rng) - run_wan2_2 = config.model_name == "wan2.2" + model_name = config.model_name low_noise_transformer = None high_noise_transformer = None tokenizer = None @@ -432,7 +432,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if load_transformer: with mesh: low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - if run_wan2_2: + if model_name == "wan2.2": high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -457,7 +457,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform ) pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - if run_wan2_2: + if model_name == "wan2.2": pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) return pipeline @@ -617,12 +617,12 @@ def __call__( low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None - if self.run_wan2_2: + if self.model_name == "wan2.2": high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) p_run_inference = partial( run_inference, - run_wan2_2=self.run_wan2_2, + model_name=self.model_name, guidance_scale=guidance_scale, guidance_scale_low=guidance_scale_low, guidance_scale_high=guidance_scale_high, @@ -659,51 +659,27 @@ def __call__( return video -@partial(jax.jit, static_argnames=("run_wan2_2", "guidance_scale", "guidance_scale_low", "guidance_scale_high", "boundary", "do_classifier_free_guidance")) +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents, timestep, + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, prompt_embeds, - run_wan2_2: bool, - guidance_scale: float, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - do_classifier_free_guidance: bool, - t: jnp.array, + do_classifier_free_guidance, + guidance_scale, ): - low_noise_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) - noise_pred_low = low_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred_low - current_guide_scale = guidance_scale - if run_wan2_2: - high_noise_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) - noise_pred_high = high_noise_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - use_high_noise = jnp.greater_equal(t, boundary) - noise_pred = jax.lax.cond( - use_high_noise, - lambda: noise_pred_high, - lambda: noise_pred_low, - ) - current_guide_scale = jax.lax.cond( - use_high_noise, - lambda: guidance_scale_high, - lambda: guidance_scale_low, - ) - - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + current_guide_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] - return noise_pred, latents + return noise_pred, latents def run_inference( low_noise_graphdef, @@ -715,7 +691,7 @@ def run_inference( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, - run_wan2_2: bool, + model_name: str, guidance_scale: float, guidance_scale_low: float, guidance_scale_high: float, @@ -725,32 +701,40 @@ def run_inference( scheduler_state, ): do_classifier_free_guidance = guidance_scale > 1.0 - if run_wan2_2: + if model_name == "wan2.2": do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + def low_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_low + ) + + def high_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_high + ) + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if do_classifier_free_guidance: latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - noise_pred, latents = transformer_forward_pass( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents, timestep, - prompt_embeds, - run_wan2_2, - guidance_scale, - guidance_scale_low, - guidance_scale_high, - boundary, - do_classifier_free_guidance, - t + use_high_noise = jnp.greater_equal(t, boundary) + + noise_pred, latents = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latents, timestep, prompt_embeds) ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() From 731b07bb9e21e8ef5a949d2d1462ea1c6da43bab Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Mon, 27 Oct 2025 15:14:10 +0530 Subject: [PATCH 06/10] support for wan2.1 in run_inference added --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7e62a236..e6d3df1d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -728,6 +728,11 @@ def high_noise_branch(operands): latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) + if model_name == "wan2.1": + noise_pred, latents = low_noise_branch((latents, timestep, prompt_embeds)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + continue + use_high_noise = jnp.greater_equal(t, boundary) noise_pred, latents = jax.lax.cond( From 11d30fce2ba51a2d938e7848e601ec513ae63904 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 00:08:27 +0530 Subject: [PATCH 07/10] Support for WAN 2.2 added --- README.md | 19 +- .../checkpointing/checkpointing_utils.py | 2 +- .../checkpointing/wan_checkpointer.py | 209 +++++-- src/maxdiffusion/configs/base_wan_27b.yml | 6 +- src/maxdiffusion/generate_wan.py | 123 ++--- .../pipelines/wan/wan_pipeline.py | 508 ++++++++++++------ .../tests/wan_checkpointer_test.py | 224 +++++++- 7 files changed, 766 insertions(+), 325 deletions(-) diff --git a/README.md b/README.md index 7d26dd5d..2f33c680 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. - **`2025/10/14`**: NVIDIA DGX Spark Flux support. - **`2025/8/14`**: LTX-Video img2vid generation is now supported. @@ -481,7 +482,23 @@ To generate images, run the following command: ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Wan2.2 + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` + ## Wan2.2 + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` ## Flux diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 24c7b2ff..bbad3ad1 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager( if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") elif checkpoint_type == WAN_CHECKPOINT: - item_names = ("wan_state", "wan_config") + item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config") else: item_names = ( "unet_config", diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 4295e633..74710f4f 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -14,35 +14,50 @@ limitations under the License. """ -from abc import ABC +from abc import ABC, abstractmethod import json import jax import numpy as np -from typing import Optional, Tuple +from typing import Optional, Tuple, Type from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline import WanPipeline +from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2 from .. import max_logging, max_utils import orbax.checkpoint as ocp from etils import epath + WAN_CHECKPOINT = "WAN_CHECKPOINT" class WanCheckpointer(ABC): + _SUBCLASS_MAP: dict[str, Type['WanCheckpointer']] = {} + + def __new__(cls, model_key: str, config, checkpoint_type: str = WAN_CHECKPOINT): + if cls is WanCheckpointer: + subclass = cls._SUBCLASS_MAP.get(model_key) + if subclass is None: + raise ValueError( + f"Unknown model_key: '{model_key}'. " + f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" + ) + return super().__new__(subclass) + else: + return super().__new__(cls) - def __init__(self, config, checkpoint_type): + def __init__(self, model_key, config, checkpoint_type: str = WAN_CHECKPOINT): self.config = config self.checkpoint_type = checkpoint_type self.opt_state = None - self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False - - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, + + self.checkpoint_manager: ocp.CheckpointManager = ( + create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) ) def _create_optimizer(self, model, config, learning_rate): @@ -52,6 +67,25 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler + @abstractmethod + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def load_diffusers_checkpoint(self): + raise NotImplementedError + + @abstractmethod + def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def save_checkpoint(self, train_step, pipeline, train_states: dict): + raise NotImplementedError + + +class WanCheckpointer2_1(WanCheckpointer): + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: step = self.checkpoint_manager.latest_step() @@ -61,36 +95,23 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return None, None max_logging.log(f"Loading WAN checkpoint from step {step}") metadatas = self.checkpoint_manager.item_metadata(step) - - restore_args = {} - - low_state_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata) - low_state_restore = ocp.args.PyTreeRestore( + transformer_metadata = metadatas.wan_state + abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata) + params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_state, + abstract_tree_structure_params, ) ) - restore_args["low_noise_transformer_state"] = low_state_restore - - if self.run_wan2_2: - high_state_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata) - high_state_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_state, - ) - ) - restore_args["high_noise_transformer_state"] = high_state_restore - - restore_args["wan_config"] = ocp.args.JsonRestore() max_logging.log("Restoring WAN checkpoint") restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), step=step, - args=ocp.args.Composite(**restore_args), + args=ocp.args.Composite( + wan_state=params_restore, + wan_config=ocp.args.JsonRestore(), + ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") @@ -99,24 +120,113 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return restored_checkpoint, step def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) + pipeline = WanPipeline2_1.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items["wan_state"] = ocp.args.PyTreeSave(train_states) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + + +class WanCheckpointer2_2(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + # Handle low_noise_transformer + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + # Handle high_noise_transformer + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN 2.2 checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline2_2.from_pretrained(self.config) return pipeline - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint["wan_state"].keys(): - opt_state = restored_checkpoint["wan_state"]["opt_state"] + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() return pipeline, opt_state, step - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): @@ -127,22 +237,17 @@ def config_to_json(model_or_config): "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), } - if "low_noise_transformer" in train_states: - low_noise_state = train_states["low_noise_transformer"] - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state) + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - if self.run_wan2_2: - if "high_noise_transformer" in train_states: - high_noise_state = train_states["high_noise_transformer"] - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state) - # Save the checkpoint - if len(items) > 1: - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") +WanCheckpointer._SUBCLASS_MAP["wan2.1"] = WanCheckpointer2_1 +WanCheckpointer._SUBCLASS_MAP["wan2.2"] = WanCheckpointer2_2 -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): +def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 6d005bdd..323a1a51 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -272,9 +272,9 @@ width: 832 num_frames: 81 flow_shift: 3.0 -guidance_scale_low: 5.0 -guidance_scale_high: 8.0 -boundary_timestep: 15 +guidance_scale_low: 3.0 +guidance_scale_high: 4.0 +boundary_timestep: 875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index c4b53b8d..53a38ac4 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -17,25 +17,13 @@ import time import os from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer from maxdiffusion import pyconfig, max_logging, max_utils from absl import app -from absl import flags from maxdiffusion.utils import export_to_video from google.cloud import storage import flax -_MODEL_NAME = flags.DEFINE_enum( - "model_name", - default="wan2.1", - enum_values=["wan2.1", "wan2.2"], - help="The model version to run (wan2.1 or wan2.2). This determines the base config file.", -) - -CONFIG_BASE_DIR = "src/maxdiffusion/configs" -MODEL_CONFIG_MAP = { - "wan2.1": "base_wan_14b.yml", - "wan2.2": "base_wan_27b.yml", -} def upload_video_to_gcs(output_dir: str, video_path: str): """ @@ -76,6 +64,33 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) +def call_pipeline(config, pipeline, prompt, negative_prompt): + model_key = config.model_name + if model_key == "wan2.1": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + elif model_key == "wan2.2": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale_low=config.guidance_scale_low, + guidance_scale_high=config.guidance_scale_high, + boundary=config.boundary_timestep, + ) + else: + raise ValueError(f"Unsupported model_name in config: {model_key}") + def inference_generate_video(config, pipeline, filename_prefix=""): s0 = time.perf_counter() @@ -86,18 +101,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}" ) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") for i in range(len(videos)): @@ -112,38 +116,20 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) - from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer - - checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline, opt_state, step = checkpoint_loader.load_checkpoint() - if pipeline is None: - pipeline = WanPipeline.from_pretrained(config) + model_key = config.model_name + checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) + pipeline, _, _ = checkpoint_loader.load_checkpoint() + pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables prompt = [config.prompt] * config.global_batch_size_to_train_on negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on - guidance_scale = config.guidance_scale if 'guidance_scale' in config.__dict__ else 5 - guidance_scale_low = config.guidance_scale_low if 'guidance_scale_low' in config.__dict__ else 3 - guidance_scale_high = config.guidance_scale_high if 'guidance_scale_high' in config.__dict__ else 4 - boundary = config.boundary_timestep if 'boundary_timestep' in config.__dict__ else 875 max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("compile time: ", (time.perf_counter() - s0)) saved_video_path = [] @@ -155,55 +141,20 @@ def run(config, pipeline=None, filename_prefix=""): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) s0 = time.perf_counter() - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("generation time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) return saved_video_path def main(argv: Sequence[str]) -> None: - # Get the model name from the flag - model_key = _MODEL_NAME.value - config_filename = MODEL_CONFIG_MAP[model_key] - selected_yaml_path = os.path.join(CONFIG_BASE_DIR, config_filename) - - max_logging.log(f"Using model: {model_key}, loading base config: {selected_yaml_path}") - - # Construct argv for pyconfig.initialize - # argv[0] is the program name. - # Insert the selected YAML path at index 1. - # The rest of argv (argv[1:]) are the overrides. - argv_for_pyconfig = list(argv[:1]) + [selected_yaml_path] + list(argv[1:]) - - pyconfig.initialize(argv_for_pyconfig) + pyconfig.initialize(argv) flax.config.update("flax_always_shard_variable", False) run(pyconfig.config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e6d3df1d..e2bfd7e8 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union, Optional +from abc import abstractmethod +from typing import List, Union, Optional, Type from functools import partial import numpy as np import jax @@ -187,13 +188,11 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - + _SUBCLASS_MAP: dict[str, Type['WanPipeline']] = {} def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - low_noise_transformer: WanModel, - high_noise_transformer: Optional[WanModel], vae: AutoencoderKLWan, vae_cache: AutoencoderKLWanCache, scheduler: FlaxUniPCMultistepScheduler, @@ -204,8 +203,6 @@ def __init__( ): self.tokenizer = tokenizer self.text_encoder = text_encoder - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer self.vae = vae self.vae_cache = vae_cache self.scheduler = scheduler @@ -373,93 +370,6 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - model_name = config.model_name - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - if model_name == "wan2.2": - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - model_name = config.model_name - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - if model_name == "wan2.2": - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - if model_name == "wan2.2": - pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) - return pipeline def _get_t5_prompt_embeds( self, @@ -549,25 +459,86 @@ def prepare_latents( return latents - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): + def _denormalize_latents(self, latents: jax.Array) -> jax.Array: + """Denormalizes latents using VAE statistics.""" + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + return latents + + def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: + """Decodes latents to video frames and postprocesses.""" + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] + + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + return self.video_processor.postprocess_video(video, output_type="np") + + @classmethod + def _create_common_components(cls, config, vae_only=False): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + components = { + "vae": wan_vae, "vae_cache": vae_cache, + "devices_array": devices_array, "rngs": rngs, "mesh": mesh, + "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None + } + + if not vae_only: + components["tokenizer"] = cls.load_tokenizer(config=config) + components["text_encoder"] = cls.load_text_encoder(config=config) + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + return components + + @classmethod + def _get_subclass(cls, model_key: str) -> Type['WanPipeline']: + subclass = cls._SUBCLASS_MAP.get(model_key) + if subclass is None: + raise ValueError( + f"Unknown model_key for WanPipeline: '{model_key}'. " + f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}" + ) + return subclass + + @classmethod + def from_checkpoint(cls, model_key: str, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + subclass = cls._get_subclass(model_key) + return subclass.from_checkpoint(config, restored_checkpoint=restored_checkpoint, vae_only=vae_only, load_transformer=load_transformer) + + @classmethod + def from_pretrained(cls, model_key: str, config: HyperParameters, vae_only=False, load_transformer=True): + subclass = cls._get_subclass(model_key) + return subclass.from_pretrained(config, vae_only=vae_only, load_transformer=load_transformer) + + @abstractmethod + def _get_num_channel_latents(self) -> int: + """Returns the number of input channels for the transformer.""" + pass + + def _prepare_call_inputs( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -590,7 +561,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_channel_latents = self.low_noise_transformer.config.in_channels + num_channel_latents = self._get_num_channel_latents() if latents is None: latents = self.prepare_latents( batch_size=batch_size, @@ -615,49 +586,235 @@ def __call__( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = None, None, None - if self.model_name == "wan2.2": - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference, - model_name=self.model_name, - guidance_scale=guidance_scale, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) + return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, + @abstractmethod + def __call__(self, **kwargs): + """Runs the inference pipeline.""" + pass + +class WanPipeline2_1(WanPipeline): + """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.transformer = transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + transformer = None + if not vae_only: + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + + return pipeline, transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_1, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) + +class WanPipeline2_2(WanPipeline): + """Pipeline for WAN 2.2 with dual transformers.""" + def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + low_noise_transformer, high_noise_transformer = None, None + if not vae_only and load_transformer: + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + return pipeline, low_noise_transformer, high_noise_transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) + low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video + def _get_num_channel_latents(self) -> int: + return self.low_noise_transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_2, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( @@ -681,7 +838,42 @@ def transformer_forward_pass( return noise_pred, latents -def run_inference( +def run_inference_2_1( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale: float, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale, + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents + +def run_inference_2_2( low_noise_graphdef, low_noise_state, low_noise_rest, @@ -691,8 +883,6 @@ def run_inference( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, - model_name: str, - guidance_scale: float, guidance_scale_low: float, guidance_scale_high: float, boundary: int, @@ -700,9 +890,7 @@ def run_inference( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, ): - do_classifier_free_guidance = guidance_scale > 1.0 - if model_name == "wan2.2": - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) @@ -728,11 +916,6 @@ def high_noise_branch(operands): latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - if model_name == "wan2.1": - noise_pred, latents = low_noise_branch((latents, timestep, prompt_embeds)) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - continue - use_high_noise = jnp.greater_equal(t, boundary) noise_pred, latents = jax.lax.cond( @@ -744,3 +927,6 @@ def high_noise_branch(operands): latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents + +WanPipeline._SUBCLASS_MAP["wan2.1"] = WanPipeline2_1 +WanPipeline._SUBCLASS_MAP["wan2.2"] = WanPipeline2_2 diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index ab5b5ca3..554c8824 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -14,10 +14,10 @@ import unittest from unittest.mock import patch, MagicMock -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 - -class WanCheckpointerTest(unittest.TestCase): +class WanCheckpointer2_1Test(unittest.TestCase): + """Tests for WAN 2.1 checkpointer.""" def setUp(self): self.config = MagicMock() @@ -25,7 +25,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None @@ -34,7 +34,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -44,7 +44,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -57,12 +57,6 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) - - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -70,7 +64,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -80,7 +74,7 @@ def getitem_side_effect(key): self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -93,12 +87,102 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + + +class WanCheckpointer2_2Test(unittest.TestCase): + """Tests for WAN 2.2 checkpointer.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_2_2_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + """Test loading from pretrained when no checkpoint exists.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint without optimizer state.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in low_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -106,7 +190,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -116,6 +200,104 @@ def getitem_side_effect(key): self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in high_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.002) + self.assertEqual(step, 1) + + +class WanCheckpointerEdgeCasesTest(unittest.TestCase): + """Tests for edge cases and error handling.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_edge_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with explicit None step falls back to latest.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 5 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + self.assertEqual(step, 5) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint when both transformers have optimizer state (prioritize low_noise).""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(model_key=self.config.model_key, config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + # Should prioritize low_noise_transformer's optimizer state + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) From ce17ed0f41b9944af6928d53df8a34fb2a488703 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 00:12:50 +0530 Subject: [PATCH 08/10] Removed extra files --- .../checkpointing/wan_checkpointer2_2.py | 207 ----- .../pipelines/wan/wan_pipeline2_2.py | 725 ------------------ .../tests/wan_checkpointer2_2_test.py | 113 --- 3 files changed, 1045 deletions(-) delete mode 100644 src/maxdiffusion/checkpointing/wan_checkpointer2_2.py delete mode 100644 src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py delete mode 100644 src/maxdiffusion/tests/wan_checkpointer2_2_test.py diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py deleted file mode 100644 index de8bb35d..00000000 --- a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py +++ /dev/null @@ -1,207 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from abc import ABC -import json - -import jax -import numpy as np -from typing import Optional, Tuple -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline2_2 import WanPipeline -from .. import max_logging, max_utils -import orbax.checkpoint as ocp -from etils import epath - -WAN_CHECKPOINT = "WAN_CHECKPOINT" - - -class WanCheckpointer(ABC): - - def __init__(self, config, checkpoint_type): - self.config = config - self.checkpoint_type = checkpoint_type - self.opt_state = None - - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, - ) - - def _create_optimizer(self, model, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( - learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps - ) - tx = max_utils.create_optimizer(config, learning_rate_scheduler) - return tx, learning_rate_scheduler - - def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: - if step is None: - step = self.checkpoint_manager.latest_step() - max_logging.log(f"Latest WAN checkpoint step: {step}") - if step is None: - max_logging.log("No WAN checkpoint found.") - return None, None - max_logging.log(f"Loading WAN checkpoint from step {step}") - metadatas = self.checkpoint_manager.item_metadata(step) - - low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) - - high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) - - max_logging.log("Restoring WAN checkpoint") - restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), - step=step, - args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, - wan_config=ocp.args.JsonRestore(), - ), - ) - max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") - return restored_checkpoint, step - - def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step - - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - return json.loads(model_or_config.to_json_string()) - - max_logging.log(f"Saving checkpoint for step {train_step}") - items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), - } - - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") - - -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - """ - only save the config that is needed and can be serialized to JSON. - """ - if not hasattr(model_or_config, "config"): - return None - source_config = dict(model_or_config.config) - - # 1. configs that can be serialized to JSON - SAFE_KEYS = [ - "_class_name", - "_diffusers_version", - "model_type", - "patch_size", - "num_attention_heads", - "attention_head_dim", - "in_channels", - "out_channels", - "text_dim", - "freq_dim", - "ffn_dim", - "num_layers", - "cross_attn_norm", - "qk_norm", - "eps", - "image_dim", - "added_kv_proj_dim", - "rope_max_seq_len", - "pos_embed_seq_len", - "flash_min_seq_length", - "flash_block_sizes", - "attention", - "_use_default_values", - ] - - # 2. save the config that are in the SAFE_KEYS list - clean_config = {} - for key in SAFE_KEYS: - if key in source_config: - clean_config[key] = source_config[key] - - # 3. deal with special data type and precision - if "dtype" in source_config and hasattr(source_config["dtype"], "name"): - clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' - - if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): - clean_config["weights_dtype"] = source_config["weights_dtype"].name - - if "precision" in source_config and isinstance(source_config["precision"]): - clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' - - return clean_config - - items_to_save = { - "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), - } - - items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) - - # Create CompositeArgs for Orbax - save_args = ocp.args.Composite(**items_to_save) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=save_args) - max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py deleted file mode 100644 index 0645aeeb..00000000 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py +++ /dev/null @@ -1,725 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Union, Optional -from functools import partial -import numpy as np -import jax -import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -import flax -import flax.linen as nn -from flax import nnx -from flax.linen import partitioning as nn_partitioning -from ...pyconfig import HyperParameters -from ... import max_logging -from ... import max_utils -from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated -from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae -from ...models.wan.transformers.transformer_wan import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache -from maxdiffusion.video_processor import VideoProcessor -from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState -from transformers import AutoTokenizer, UMT5EncoderModel -from maxdiffusion.utils.import_utils import is_ftfy_available -from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs -import html -import re -import torch -import qwix - - -def cast_with_exclusion(path, x, dtype_to_cast): - """ - Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. - """ - - exclusion_keywords = [ - "norm", # For all LayerNorm/GroupNorm layers - "condition_embedder", # The entire time/text conditioning module - "scale_shift_table", # Catches both the final and the AdaLN tables - ] - - path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) - - if any(keyword in path_str.lower() for keyword in exclusion_keywords): - print("is_norm_path: ", path) - # Keep LayerNorm/GroupNorm weights and biases in full precision - return x.astype(jnp.float32) - else: - # Cast everything else to dtype_to_cast - return x.astype(dtype_to_cast) - - -def basic_clean(text): - if is_ftfy_available(): - import ftfy - - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: - vs.sharding_rules = logical_axis_rules - return vs - - -# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" -): - - def create_model(rngs: nnx.Rngs, wan_config: dict): - wan_transformer = WanModel(**wan_config, rngs=rngs) - return wan_transformer - - # 1. Load config. - if restored_checkpoint: - wan_config = restored_checkpoint["wan_config"] - else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) - wan_config["mesh"] = mesh - wan_config["dtype"] = config.activations_dtype - wan_config["weights_dtype"] = config.weights_dtype - wan_config["attention"] = config.attention - wan_config["precision"] = get_precision(config) - wan_config["flash_block_sizes"] = get_flash_block_sizes(config) - wan_config["remat_policy"] = config.remat_policy - wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved - wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded - wan_config["flash_min_seq_length"] = config.flash_min_seq_length - wan_config["dropout"] = config.dropout - wan_config["scan_layers"] = config.scan_layers - - # 2. eval_shape - will not use flops or create weights on device - # thus not using HBM memory. - p_model_factory = partial(create_model, wan_config=wan_config) - wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - - # 3. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - if restored_checkpoint: - if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer - params = restored_checkpoint["wan_state"]["params"] - else: # if not checkpointed with optimizer - params = restored_checkpoint["wan_state"] - else: - params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, - params, - "cpu", - num_layers=wan_config["num_layers"], - scan_layers=config.scan_layers, - subfolder=subfolder, - ) - - params = jax.tree_util.tree_map_with_path( - lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params - ) - for path, val in flax.traverse_util.flatten_dict(params).items(): - if restored_checkpoint: - path = path[:-1] - sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_transformer = nnx.merge(graphdef, state, rest_of_state) - return wan_transformer - - -@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) -def create_sharded_logical_model(model, logical_axis_rules): - graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) - p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) - state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) - pspecs = nnx.get_partition_spec(state) - sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - model = nnx.merge(graphdef, sharded_state, rest_of_state) - return model - - -class WanPipeline: - r""" - Pipeline for text-to-video generation using Wan. - - tokenizer ([`T5Tokenizer`]): - Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), - specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - text_encoder ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanModel`]): - Conditional Transformer to denoise the input latents. - scheduler ([`FlaxUniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - """ - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - low_noise_transformer: WanModel, - high_noise_transformer: WanModel, - vae: AutoencoderKLWan, - vae_cache: AutoencoderKLWanCache, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state: UniPCMultistepSchedulerState, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - ): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer - self.vae = vae - self.vae_cache = vae_cache - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.devices_array = devices_array - self.mesh = mesh - self.config = config - - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - self.p_run_inference = None - - @classmethod - def load_text_encoder(cls, config: HyperParameters): - text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - ) - return text_encoder - - @classmethod - def load_tokenizer(cls, config: HyperParameters): - tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer", - ) - return tokenizer - - @classmethod - def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - - def create_model(rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=jnp.float32, - weights_dtype=jnp.float32, - ) - return wan_vae - - # 1. eval shape - p_model_factory = partial(create_model, config=config) - wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state = nnx.split(wan_vae, nnx.Param) - - # 2. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - for path, val in flax.traverse_util.flatten_dict(params).items(): - sharding = logical_state_sharding[path].value - if config.replicate_vae: - sharding = NamedSharding(mesh, P()) - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_vae = nnx.merge(graphdef, state) - vae_cache = AutoencoderKLWanCache(wan_vae) - return wan_vae, vae_cache - - @classmethod - def get_basic_config(cls, dtype, config: HyperParameters): - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=dtype, - act_qtype=dtype, - op_names=("dot_general", "einsum", "conv_general_dilated"), - ) - ] - return rules - - @classmethod - def get_fp8_config(cls, config: HyperParameters): - """ - fp8 config rules with per-tensor calibration. - FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): - The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice. - """ - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("dot_general", "einsum"), - ), - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("conv_general_dilated"), - ), - ] - return rules - - @classmethod - def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: - """Get quantization rules based on the config.""" - if not getattr(config, "use_qwix_quantization", False): - return None - - match config.quantization: - case "int8": - return qwix.QtProvider(cls.get_basic_config(jnp.int8, config)) - case "fp8": - return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config)) - case "fp8_full": - return qwix.QtProvider(cls.get_fp8_config(config)) - return None - - @classmethod - def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): - """Quantizes the transformer model.""" - q_rules = cls.get_qt_provider(config) - if not q_rules: - return model - max_logging.log("Quantizing transformer with Qwix.") - - batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) - latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) - model_inputs = (latents, timesteps, prompt_embeds) - with mesh: - quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) - max_logging.log("Qwix Quantization complete.") - return quantized_model - - @classmethod - def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): - with mesh: - wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder - ) - return wan_transformer - - @classmethod - def load_scheduler(cls, config): - scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="scheduler", - flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p - ) - return scheduler, scheduler_state - - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) - return pipeline - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) - - if negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32) - - return prompt_embeds, negative_prompt_embeds - - def prepare_latents( - self, - batch_size: int, - vae_scale_factor_temporal: int, - vae_scale_factor_spatial: int, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_channels_latents: int = 16, - ): - rng = jax.random.key(self.config.seed) - num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // vae_scale_factor_spatial, - int(width) // vae_scale_factor_spatial, - ) - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) - - return latents - - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): - if not vae_only: - if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - prompt = [prompt] - - batch_size = len(prompt) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - num_channel_latents = self.low_noise_transformer.config.in_channels - if latents is None: - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents, - ) - - data_sharding = NamedSharding(self.mesh, P()) - # Using global_batch_size_to_train_on so not to create more config variables - if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - - latents = jax.device_put(latents, data_sharding) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) - - scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape - ) - - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] - - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video - - -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) -def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance, - guidance_scale, -): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] - - return noise_pred, latents - -def run_inference( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - num_inference_steps: int, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state, -): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - def low_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low - ) - - def high_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high - ) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - use_high_noise = jnp.greater_equal(t, boundary) - - noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) - ) - - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents diff --git a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py b/src/maxdiffusion/tests/wan_checkpointer2_2_test.py deleted file mode 100644 index 8e1fa0be..00000000 --- a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py +++ /dev/null @@ -1,113 +0,0 @@ -""" - Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import unittest -from unittest.mock import patch, MagicMock - -from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer, WAN_CHECKPOINT - - -class WanCheckpointerTest(unittest.TestCase): - - def setUp(self): - self.config = MagicMock() - self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" - self.config.dataset_type = "test_dataset" - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = None - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) - - mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertIsNone(step) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertEqual(step, 1) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNotNone(opt_state) - self.assertEqual(opt_state["learning_rate"], 0.001) - self.assertEqual(step, 1) - - -if __name__ == "__main__": - unittest.main() From b7aad0aaa8d8bd3f7e1964f657d983f0a9aa2add Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 12 Nov 2025 08:51:07 +0530 Subject: [PATCH 09/10] Updated README and generate_wan.py --- README.md | 18 +----------------- src/maxdiffusion/generate_wan.py | 3 ++- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 6a4d2048..2deb8ba9 100644 --- a/README.md +++ b/README.md @@ -482,23 +482,7 @@ To generate images, run the following command: ```bash HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - ## Wan2.2 - - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - ## Wan2.2 - - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` ## Wan2.2 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 53a38ac4..fc3a3626 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -119,7 +119,8 @@ def run(config, pipeline=None, filename_prefix=""): model_key = config.model_name checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() - pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) + if pipeline is None: + pipeline = WanPipeline.from_pretrained(model_key=model_key, config=config) s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables From 16d657ab18984b0bc73af39e3b341146d017c01b Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Tue, 18 Nov 2025 18:28:57 +0530 Subject: [PATCH 10/10] Added tensorboard logging for inference metrics --- requirements_with_jax_ai_image.txt | 1 + src/maxdiffusion/generate_wan.py | 37 ++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/requirements_with_jax_ai_image.txt b/requirements_with_jax_ai_image.txt index 2a2287d6..c279edb8 100644 --- a/requirements_with_jax_ai_image.txt +++ b/requirements_with_jax_ai_image.txt @@ -30,6 +30,7 @@ orbax-checkpoint tokenizers==0.21.0 huggingface_hub>=0.30.2 transformers==4.48.1 +tokamax einops==0.8.0 sentencepiece aqtp diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index fc3a3626..d75fd7ee 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -115,8 +115,10 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): - print("seed: ", config.seed) model_key = config.model_name + writer = max_utils.initialize_summary_writer(config) + if jax.process_index() == 0 and writer: + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") checkpoint_loader = WanCheckpointer(model_key=model_key, config=config) pipeline, _, _ = checkpoint_loader.load_checkpoint() if pipeline is None: @@ -132,7 +134,19 @@ def run(config, pipeline=None, filename_prefix=""): ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("compile time: ", (time.perf_counter() - s0)) + max_logging.log("===================== Model details =======================") + max_logging.log(f"model name: {config.model_name}") + max_logging.log(f"model path: {config.pretrained_model_name_or_path}") + max_logging.log("model type: t2v") + max_logging.log(f"hardware: {jax.devices()[0].platform}") + max_logging.log(f"number of devices: {jax.device_count()}") + max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") + max_logging.log("============================================================") + + compile_time = time.perf_counter() - s0 + max_logging.log(f"compile_time: {compile_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/compile_time", compile_time, global_step=0) saved_video_path = [] for i in range(len(videos)): video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" @@ -143,14 +157,27 @@ def run(config, pipeline=None, filename_prefix=""): s0 = time.perf_counter() videos = call_pipeline(config, pipeline, prompt, negative_prompt) - print("generation time: ", (time.perf_counter() - s0)) - + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + max_logging.log(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) - print("generation time: ", (time.perf_counter() - s0)) + generation_time_with_profiler = time.perf_counter() - s0 + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) return saved_video_path