Skip to content

Commit 11d30fc

Browse files
committed
Support for WAN 2.2 added
1 parent 731b07b commit 11d30fc

File tree

7 files changed

+766
-325
lines changed

7 files changed

+766
-325
lines changed

README.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
2021
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
2122
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
2223
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
@@ -481,7 +482,23 @@ To generate images, run the following command:
481482

482483
```bash
483484
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
484-
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
485+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
486+
```
487+
## Wan2.2
488+
489+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
490+
491+
```bash
492+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
493+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
494+
```
495+
## Wan2.2
496+
497+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
498+
499+
```bash
500+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/
501+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445
485502
```
486503

487504
## Flux

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager(
6161
if checkpoint_type == FLUX_CHECKPOINT:
6262
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
6363
elif checkpoint_type == WAN_CHECKPOINT:
64-
item_names = ("wan_state", "wan_config")
64+
item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config")
6565
else:
6666
item_names = (
6767
"unet_config",

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 157 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,50 @@
1414
limitations under the License.
1515
"""
1616

17-
from abc import ABC
17+
from abc import ABC, abstractmethod
1818
import json
1919

2020
import jax
2121
import numpy as np
22-
from typing import Optional, Tuple
22+
from typing import Optional, Tuple, Type
2323
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
24-
from ..pipelines.wan.wan_pipeline import WanPipeline
24+
from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2
2525
from .. import max_logging, max_utils
2626
import orbax.checkpoint as ocp
2727
from etils import epath
2828

29+
2930
WAN_CHECKPOINT = "WAN_CHECKPOINT"
3031

3132

3233
class WanCheckpointer(ABC):
34+
_SUBCLASS_MAP: dict[str, Type['WanCheckpointer']] = {}
35+
36+
def __new__(cls, model_key: str, config, checkpoint_type: str = WAN_CHECKPOINT):
37+
if cls is WanCheckpointer:
38+
subclass = cls._SUBCLASS_MAP.get(model_key)
39+
if subclass is None:
40+
raise ValueError(
41+
f"Unknown model_key: '{model_key}'. "
42+
f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}"
43+
)
44+
return super().__new__(subclass)
45+
else:
46+
return super().__new__(cls)
3347

34-
def __init__(self, config, checkpoint_type):
48+
def __init__(self, model_key, config, checkpoint_type: str = WAN_CHECKPOINT):
3549
self.config = config
3650
self.checkpoint_type = checkpoint_type
3751
self.opt_state = None
38-
self.run_wan2_2 = config.run_wan2_2 if 'run_wan2_2' in self.config.__dict__ else False
39-
40-
self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
41-
self.config.checkpoint_dir,
42-
enable_checkpointing=True,
43-
save_interval_steps=1,
44-
checkpoint_type=checkpoint_type,
45-
dataset_type=config.dataset_type,
52+
53+
self.checkpoint_manager: ocp.CheckpointManager = (
54+
create_orbax_checkpoint_manager(
55+
self.config.checkpoint_dir,
56+
enable_checkpointing=True,
57+
save_interval_steps=1,
58+
checkpoint_type=checkpoint_type,
59+
dataset_type=config.dataset_type,
60+
)
4661
)
4762

4863
def _create_optimizer(self, model, config, learning_rate):
@@ -52,6 +67,25 @@ def _create_optimizer(self, model, config, learning_rate):
5267
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
5368
return tx, learning_rate_scheduler
5469

70+
@abstractmethod
71+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
72+
raise NotImplementedError
73+
74+
@abstractmethod
75+
def load_diffusers_checkpoint(self):
76+
raise NotImplementedError
77+
78+
@abstractmethod
79+
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
80+
raise NotImplementedError
81+
82+
@abstractmethod
83+
def save_checkpoint(self, train_step, pipeline, train_states: dict):
84+
raise NotImplementedError
85+
86+
87+
class WanCheckpointer2_1(WanCheckpointer):
88+
5589
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
5690
if step is None:
5791
step = self.checkpoint_manager.latest_step()
@@ -61,36 +95,23 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
6195
return None, None
6296
max_logging.log(f"Loading WAN checkpoint from step {step}")
6397
metadatas = self.checkpoint_manager.item_metadata(step)
64-
65-
restore_args = {}
66-
67-
low_state_metadata = metadatas.low_noise_transformer_state
68-
abstract_tree_structure_low_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_state_metadata)
69-
low_state_restore = ocp.args.PyTreeRestore(
98+
transformer_metadata = metadatas.wan_state
99+
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
100+
params_restore = ocp.args.PyTreeRestore(
70101
restore_args=jax.tree.map(
71102
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
72-
abstract_tree_structure_low_state,
103+
abstract_tree_structure_params,
73104
)
74105
)
75-
restore_args["low_noise_transformer_state"] = low_state_restore
76-
77-
if self.run_wan2_2:
78-
high_state_metadata = metadatas.high_noise_transformer_state
79-
abstract_tree_structure_high_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_state_metadata)
80-
high_state_restore = ocp.args.PyTreeRestore(
81-
restore_args=jax.tree.map(
82-
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
83-
abstract_tree_structure_high_state,
84-
)
85-
)
86-
restore_args["high_noise_transformer_state"] = high_state_restore
87-
88-
restore_args["wan_config"] = ocp.args.JsonRestore()
89106

90107
max_logging.log("Restoring WAN checkpoint")
91108
restored_checkpoint = self.checkpoint_manager.restore(
109+
directory=epath.Path(self.config.checkpoint_dir),
92110
step=step,
93-
args=ocp.args.Composite(**restore_args),
111+
args=ocp.args.Composite(
112+
wan_state=params_restore,
113+
wan_config=ocp.args.JsonRestore(),
114+
),
94115
)
95116
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
96117
max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}")
@@ -99,24 +120,113 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
99120
return restored_checkpoint, step
100121

101122
def load_diffusers_checkpoint(self):
102-
pipeline = WanPipeline.from_pretrained(self.config)
123+
pipeline = WanPipeline2_1.from_pretrained(self.config)
124+
return pipeline
125+
126+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]:
127+
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
128+
opt_state = None
129+
if restored_checkpoint:
130+
max_logging.log("Loading WAN pipeline from checkpoint")
131+
pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
132+
if "opt_state" in restored_checkpoint.wan_state.keys():
133+
opt_state = restored_checkpoint.wan_state["opt_state"]
134+
else:
135+
max_logging.log("No checkpoint found, loading default pipeline.")
136+
pipeline = self.load_diffusers_checkpoint()
137+
138+
return pipeline, opt_state, step
139+
140+
def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict):
141+
"""Saves the training state and model configurations."""
142+
143+
def config_to_json(model_or_config):
144+
return json.loads(model_or_config.to_json_string())
145+
146+
max_logging.log(f"Saving checkpoint for step {train_step}")
147+
items = {
148+
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)),
149+
}
150+
151+
items["wan_state"] = ocp.args.PyTreeSave(train_states)
152+
153+
# Save the checkpoint
154+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
155+
max_logging.log(f"Checkpoint for step {train_step} saved.")
156+
157+
158+
class WanCheckpointer2_2(WanCheckpointer):
159+
160+
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
161+
if step is None:
162+
step = self.checkpoint_manager.latest_step()
163+
max_logging.log(f"Latest WAN checkpoint step: {step}")
164+
if step is None:
165+
max_logging.log("No WAN checkpoint found.")
166+
return None, None
167+
max_logging.log(f"Loading WAN checkpoint from step {step}")
168+
metadatas = self.checkpoint_manager.item_metadata(step)
169+
170+
# Handle low_noise_transformer
171+
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
172+
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
173+
low_params_restore = ocp.args.PyTreeRestore(
174+
restore_args=jax.tree.map(
175+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
176+
abstract_tree_structure_low_params,
177+
)
178+
)
179+
180+
# Handle high_noise_transformer
181+
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
182+
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
183+
high_params_restore = ocp.args.PyTreeRestore(
184+
restore_args=jax.tree.map(
185+
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
186+
abstract_tree_structure_high_params,
187+
)
188+
)
189+
190+
max_logging.log("Restoring WAN 2.2 checkpoint")
191+
restored_checkpoint = self.checkpoint_manager.restore(
192+
directory=epath.Path(self.config.checkpoint_dir),
193+
step=step,
194+
args=ocp.args.Composite(
195+
low_noise_transformer_state=low_params_restore,
196+
high_noise_transformer_state=high_params_restore,
197+
wan_config=ocp.args.JsonRestore(),
198+
),
199+
)
200+
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
201+
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
202+
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
203+
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
204+
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
205+
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
206+
return restored_checkpoint, step
207+
208+
def load_diffusers_checkpoint(self):
209+
pipeline = WanPipeline2_2.from_pretrained(self.config)
103210
return pipeline
104211

105-
def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
212+
def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]:
106213
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
107214
opt_state = None
108215
if restored_checkpoint:
109216
max_logging.log("Loading WAN pipeline from checkpoint")
110-
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
111-
if "opt_state" in restored_checkpoint["wan_state"].keys():
112-
opt_state = restored_checkpoint["wan_state"]["opt_state"]
217+
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint)
218+
# Check for optimizer state in either transformer
219+
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
220+
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
221+
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
222+
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
113223
else:
114224
max_logging.log("No checkpoint found, loading default pipeline.")
115225
pipeline = self.load_diffusers_checkpoint()
116226

117227
return pipeline, opt_state, step
118228

119-
def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
229+
def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict):
120230
"""Saves the training state and model configurations."""
121231

122232
def config_to_json(model_or_config):
@@ -127,22 +237,17 @@ def config_to_json(model_or_config):
127237
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
128238
}
129239

130-
if "low_noise_transformer" in train_states:
131-
low_noise_state = train_states["low_noise_transformer"]
132-
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(low_noise_state)
240+
items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
241+
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])
133242

134-
if self.run_wan2_2:
135-
if "high_noise_transformer" in train_states:
136-
high_noise_state = train_states["high_noise_transformer"]
137-
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(high_noise_state)
138-
139243
# Save the checkpoint
140-
if len(items) > 1:
141-
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
142-
max_logging.log(f"Checkpoint for step {train_step} saved.")
244+
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
245+
max_logging.log(f"Checkpoint for step {train_step} saved.")
143246

247+
WanCheckpointer._SUBCLASS_MAP["wan2.1"] = WanCheckpointer2_1
248+
WanCheckpointer._SUBCLASS_MAP["wan2.2"] = WanCheckpointer2_2
144249

145-
def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
250+
def save_checkpoint_orig(self, train_step, pipeline, train_states: dict):
146251
"""Saves the training state and model configurations."""
147252

148253
def config_to_json(model_or_config):

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ width: 832
272272
num_frames: 81
273273
flow_shift: 3.0
274274

275-
guidance_scale_low: 5.0
276-
guidance_scale_high: 8.0
277-
boundary_timestep: 15
275+
guidance_scale_low: 3.0
276+
guidance_scale_high: 4.0
277+
boundary_timestep: 875
278278

279279
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
280280
guidance_rescale: 0.0

0 commit comments

Comments
 (0)