1414 limitations under the License.
1515"""
1616
17- from abc import ABC
17+ from abc import ABC , abstractmethod
1818import json
1919
2020import jax
2121import numpy as np
22- from typing import Optional , Tuple
22+ from typing import Optional , Tuple , Type
2323from maxdiffusion .checkpointing .checkpointing_utils import (create_orbax_checkpoint_manager )
24- from ..pipelines .wan .wan_pipeline import WanPipeline
24+ from ..pipelines .wan .wan_pipeline import WanPipeline2_1 , WanPipeline2_2
2525from .. import max_logging , max_utils
2626import orbax .checkpoint as ocp
2727from etils import epath
2828
29+
2930WAN_CHECKPOINT = "WAN_CHECKPOINT"
3031
3132
3233class WanCheckpointer (ABC ):
34+ _SUBCLASS_MAP : dict [str , Type ['WanCheckpointer' ]] = {}
35+
36+ def __new__ (cls , model_key : str , config , checkpoint_type : str = WAN_CHECKPOINT ):
37+ if cls is WanCheckpointer :
38+ subclass = cls ._SUBCLASS_MAP .get (model_key )
39+ if subclass is None :
40+ raise ValueError (
41+ f"Unknown model_key: '{ model_key } '. "
42+ f"Supported keys are: { list (cls ._SUBCLASS_MAP .keys ())} "
43+ )
44+ return super ().__new__ (subclass )
45+ else :
46+ return super ().__new__ (cls )
3347
34- def __init__ (self , config , checkpoint_type ):
48+ def __init__ (self , model_key , config , checkpoint_type : str = WAN_CHECKPOINT ):
3549 self .config = config
3650 self .checkpoint_type = checkpoint_type
3751 self .opt_state = None
38- self .run_wan2_2 = config .run_wan2_2 if 'run_wan2_2' in self .config .__dict__ else False
39-
40- self .checkpoint_manager : ocp .CheckpointManager = create_orbax_checkpoint_manager (
41- self .config .checkpoint_dir ,
42- enable_checkpointing = True ,
43- save_interval_steps = 1 ,
44- checkpoint_type = checkpoint_type ,
45- dataset_type = config .dataset_type ,
52+
53+ self .checkpoint_manager : ocp .CheckpointManager = (
54+ create_orbax_checkpoint_manager (
55+ self .config .checkpoint_dir ,
56+ enable_checkpointing = True ,
57+ save_interval_steps = 1 ,
58+ checkpoint_type = checkpoint_type ,
59+ dataset_type = config .dataset_type ,
60+ )
4661 )
4762
4863 def _create_optimizer (self , model , config , learning_rate ):
@@ -52,6 +67,25 @@ def _create_optimizer(self, model, config, learning_rate):
5267 tx = max_utils .create_optimizer (config , learning_rate_scheduler )
5368 return tx , learning_rate_scheduler
5469
70+ @abstractmethod
71+ def load_wan_configs_from_orbax (self , step : Optional [int ]) -> Tuple [Optional [dict ], Optional [int ]]:
72+ raise NotImplementedError
73+
74+ @abstractmethod
75+ def load_diffusers_checkpoint (self ):
76+ raise NotImplementedError
77+
78+ @abstractmethod
79+ def load_checkpoint (self , step = None ) -> Tuple [Optional [WanPipeline2_1 | WanPipeline2_2 ], Optional [dict ], Optional [int ]]:
80+ raise NotImplementedError
81+
82+ @abstractmethod
83+ def save_checkpoint (self , train_step , pipeline , train_states : dict ):
84+ raise NotImplementedError
85+
86+
87+ class WanCheckpointer2_1 (WanCheckpointer ):
88+
5589 def load_wan_configs_from_orbax (self , step : Optional [int ]) -> Tuple [Optional [dict ], Optional [int ]]:
5690 if step is None :
5791 step = self .checkpoint_manager .latest_step ()
@@ -61,36 +95,23 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
6195 return None , None
6296 max_logging .log (f"Loading WAN checkpoint from step { step } " )
6397 metadatas = self .checkpoint_manager .item_metadata (step )
64-
65- restore_args = {}
66-
67- low_state_metadata = metadatas .low_noise_transformer_state
68- abstract_tree_structure_low_state = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , low_state_metadata )
69- low_state_restore = ocp .args .PyTreeRestore (
98+ transformer_metadata = metadatas .wan_state
99+ abstract_tree_structure_params = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , transformer_metadata )
100+ params_restore = ocp .args .PyTreeRestore (
70101 restore_args = jax .tree .map (
71102 lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
72- abstract_tree_structure_low_state ,
103+ abstract_tree_structure_params ,
73104 )
74105 )
75- restore_args ["low_noise_transformer_state" ] = low_state_restore
76-
77- if self .run_wan2_2 :
78- high_state_metadata = metadatas .high_noise_transformer_state
79- abstract_tree_structure_high_state = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , high_state_metadata )
80- high_state_restore = ocp .args .PyTreeRestore (
81- restore_args = jax .tree .map (
82- lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
83- abstract_tree_structure_high_state ,
84- )
85- )
86- restore_args ["high_noise_transformer_state" ] = high_state_restore
87-
88- restore_args ["wan_config" ] = ocp .args .JsonRestore ()
89106
90107 max_logging .log ("Restoring WAN checkpoint" )
91108 restored_checkpoint = self .checkpoint_manager .restore (
109+ directory = epath .Path (self .config .checkpoint_dir ),
92110 step = step ,
93- args = ocp .args .Composite (** restore_args ),
111+ args = ocp .args .Composite (
112+ wan_state = params_restore ,
113+ wan_config = ocp .args .JsonRestore (),
114+ ),
94115 )
95116 max_logging .log (f"restored checkpoint { restored_checkpoint .keys ()} " )
96117 max_logging .log (f"restored checkpoint wan_state { restored_checkpoint .wan_state .keys ()} " )
@@ -99,24 +120,113 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
99120 return restored_checkpoint , step
100121
101122 def load_diffusers_checkpoint (self ):
102- pipeline = WanPipeline .from_pretrained (self .config )
123+ pipeline = WanPipeline2_1 .from_pretrained (self .config )
124+ return pipeline
125+
126+ def load_checkpoint (self , step = None ) -> Tuple [WanPipeline2_1 , Optional [dict ], Optional [int ]]:
127+ restored_checkpoint , step = self .load_wan_configs_from_orbax (step )
128+ opt_state = None
129+ if restored_checkpoint :
130+ max_logging .log ("Loading WAN pipeline from checkpoint" )
131+ pipeline = WanPipeline2_1 .from_checkpoint (self .config , restored_checkpoint )
132+ if "opt_state" in restored_checkpoint .wan_state .keys ():
133+ opt_state = restored_checkpoint .wan_state ["opt_state" ]
134+ else :
135+ max_logging .log ("No checkpoint found, loading default pipeline." )
136+ pipeline = self .load_diffusers_checkpoint ()
137+
138+ return pipeline , opt_state , step
139+
140+ def save_checkpoint (self , train_step , pipeline : WanPipeline2_1 , train_states : dict ):
141+ """Saves the training state and model configurations."""
142+
143+ def config_to_json (model_or_config ):
144+ return json .loads (model_or_config .to_json_string ())
145+
146+ max_logging .log (f"Saving checkpoint for step { train_step } " )
147+ items = {
148+ "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .transformer )),
149+ }
150+
151+ items ["wan_state" ] = ocp .args .PyTreeSave (train_states )
152+
153+ # Save the checkpoint
154+ self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
155+ max_logging .log (f"Checkpoint for step { train_step } saved." )
156+
157+
158+ class WanCheckpointer2_2 (WanCheckpointer ):
159+
160+ def load_wan_configs_from_orbax (self , step : Optional [int ]) -> Tuple [Optional [dict ], Optional [int ]]:
161+ if step is None :
162+ step = self .checkpoint_manager .latest_step ()
163+ max_logging .log (f"Latest WAN checkpoint step: { step } " )
164+ if step is None :
165+ max_logging .log ("No WAN checkpoint found." )
166+ return None , None
167+ max_logging .log (f"Loading WAN checkpoint from step { step } " )
168+ metadatas = self .checkpoint_manager .item_metadata (step )
169+
170+ # Handle low_noise_transformer
171+ low_noise_transformer_metadata = metadatas .low_noise_transformer_state
172+ abstract_tree_structure_low_params = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , low_noise_transformer_metadata )
173+ low_params_restore = ocp .args .PyTreeRestore (
174+ restore_args = jax .tree .map (
175+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
176+ abstract_tree_structure_low_params ,
177+ )
178+ )
179+
180+ # Handle high_noise_transformer
181+ high_noise_transformer_metadata = metadatas .high_noise_transformer_state
182+ abstract_tree_structure_high_params = jax .tree_util .tree_map (ocp .utils .to_shape_dtype_struct , high_noise_transformer_metadata )
183+ high_params_restore = ocp .args .PyTreeRestore (
184+ restore_args = jax .tree .map (
185+ lambda _ : ocp .RestoreArgs (restore_type = np .ndarray ),
186+ abstract_tree_structure_high_params ,
187+ )
188+ )
189+
190+ max_logging .log ("Restoring WAN 2.2 checkpoint" )
191+ restored_checkpoint = self .checkpoint_manager .restore (
192+ directory = epath .Path (self .config .checkpoint_dir ),
193+ step = step ,
194+ args = ocp .args .Composite (
195+ low_noise_transformer_state = low_params_restore ,
196+ high_noise_transformer_state = high_params_restore ,
197+ wan_config = ocp .args .JsonRestore (),
198+ ),
199+ )
200+ max_logging .log (f"restored checkpoint { restored_checkpoint .keys ()} " )
201+ max_logging .log (f"restored checkpoint low_noise_transformer_state { restored_checkpoint .low_noise_transformer_state .keys ()} " )
202+ max_logging .log (f"restored checkpoint high_noise_transformer_state { restored_checkpoint .high_noise_transformer_state .keys ()} " )
203+ max_logging .log (f"optimizer found in low_noise checkpoint { 'opt_state' in restored_checkpoint .low_noise_transformer_state .keys ()} " )
204+ max_logging .log (f"optimizer found in high_noise checkpoint { 'opt_state' in restored_checkpoint .high_noise_transformer_state .keys ()} " )
205+ max_logging .log (f"optimizer state saved in attribute self.opt_state { self .opt_state } " )
206+ return restored_checkpoint , step
207+
208+ def load_diffusers_checkpoint (self ):
209+ pipeline = WanPipeline2_2 .from_pretrained (self .config )
103210 return pipeline
104211
105- def load_checkpoint (self , step = None ) -> Tuple [WanPipeline , Optional [dict ], Optional [int ]]:
212+ def load_checkpoint (self , step = None ) -> Tuple [WanPipeline2_2 , Optional [dict ], Optional [int ]]:
106213 restored_checkpoint , step = self .load_wan_configs_from_orbax (step )
107214 opt_state = None
108215 if restored_checkpoint :
109216 max_logging .log ("Loading WAN pipeline from checkpoint" )
110- pipeline = WanPipeline .from_checkpoint (self .config , restored_checkpoint )
111- if "opt_state" in restored_checkpoint ["wan_state" ].keys ():
112- opt_state = restored_checkpoint ["wan_state" ]["opt_state" ]
217+ pipeline = WanPipeline2_2 .from_checkpoint (self .config , restored_checkpoint )
218+ # Check for optimizer state in either transformer
219+ if "opt_state" in restored_checkpoint .low_noise_transformer_state .keys ():
220+ opt_state = restored_checkpoint .low_noise_transformer_state ["opt_state" ]
221+ elif "opt_state" in restored_checkpoint .high_noise_transformer_state .keys ():
222+ opt_state = restored_checkpoint .high_noise_transformer_state ["opt_state" ]
113223 else :
114224 max_logging .log ("No checkpoint found, loading default pipeline." )
115225 pipeline = self .load_diffusers_checkpoint ()
116226
117227 return pipeline , opt_state , step
118228
119- def save_checkpoint (self , train_step , pipeline : WanPipeline , train_states : dict ):
229+ def save_checkpoint (self , train_step , pipeline : WanPipeline2_2 , train_states : dict ):
120230 """Saves the training state and model configurations."""
121231
122232 def config_to_json (model_or_config ):
@@ -127,22 +237,17 @@ def config_to_json(model_or_config):
127237 "wan_config" : ocp .args .JsonSave (config_to_json (pipeline .low_noise_transformer )),
128238 }
129239
130- if "low_noise_transformer" in train_states :
131- low_noise_state = train_states ["low_noise_transformer" ]
132- items ["low_noise_transformer_state" ] = ocp .args .PyTreeSave (low_noise_state )
240+ items ["low_noise_transformer_state" ] = ocp .args .PyTreeSave (train_states ["low_noise_transformer" ])
241+ items ["high_noise_transformer_state" ] = ocp .args .PyTreeSave (train_states ["high_noise_transformer" ])
133242
134- if self .run_wan2_2 :
135- if "high_noise_transformer" in train_states :
136- high_noise_state = train_states ["high_noise_transformer" ]
137- items ["high_noise_transformer_state" ] = ocp .args .PyTreeSave (high_noise_state )
138-
139243 # Save the checkpoint
140- if len (items ) > 1 :
141- self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
142- max_logging .log (f"Checkpoint for step { train_step } saved." )
244+ self .checkpoint_manager .save (train_step , args = ocp .args .Composite (** items ))
245+ max_logging .log (f"Checkpoint for step { train_step } saved." )
143246
247+ WanCheckpointer ._SUBCLASS_MAP ["wan2.1" ] = WanCheckpointer2_1
248+ WanCheckpointer ._SUBCLASS_MAP ["wan2.2" ] = WanCheckpointer2_2
144249
145- def save_checkpoint_orig (self , train_step , pipeline : WanPipeline , train_states : dict ):
250+ def save_checkpoint_orig (self , train_step , pipeline , train_states : dict ):
146251 """Saves the training state and model configurations."""
147252
148253 def config_to_json (model_or_config ):
0 commit comments