Skip to content

Commit 56b1761

Browse files
committed
more fixes
1 parent 67212c1 commit 56b1761

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
import os
1919
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
2020
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
21-
from functools import partial
2221
from maxdiffusion import pyconfig, max_logging, max_utils
2322
from absl import app
24-
from absl import flags
2523
from maxdiffusion.utils import export_to_video
2624
from google.cloud import storage
2725
import flax
@@ -127,7 +125,7 @@ def run(config, pipeline=None, filename_prefix=""):
127125
# Using global_batch_size_to_train_on so not to create more config variables
128126
prompt = [config.prompt] * config.global_batch_size_to_train_on
129127
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
130-
128+
131129
max_logging.log(
132130
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
133131
)
@@ -162,4 +160,4 @@ def main(argv: Sequence[str]) -> None:
162160

163161

164162
if __name__ == "__main__":
165-
app.run(main)
163+
app.run(main)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC, abstractmethod
15+
from abc import abstractmethod
1616
from typing import List, Union, Optional, Type
1717
from functools import partial
1818
import numpy as np
@@ -466,7 +466,7 @@ def _denormalize_latents(self, latents: jax.Array) -> jax.Array:
466466
latents = latents / latents_std + latents_mean
467467
latents = latents.astype(jnp.float32)
468468
return latents
469-
469+
470470
def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray:
471471
"""Decodes latents to video frames and postprocesses."""
472472
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -508,7 +508,7 @@ def _get_subclass(cls, model_key: str) -> Type['WanPipeline']:
508508
f"Supported keys are: {list(cls._SUBCLASS_MAP.keys())}"
509509
)
510510
return subclass
511-
511+
512512
@classmethod
513513
def from_checkpoint(cls, model_key: str, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True):
514514
subclass = cls._get_subclass(model_key)
@@ -708,7 +708,6 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
708708
common_components = cls._create_common_components(config, vae_only)
709709
low_noise_transformer, high_noise_transformer = None, None
710710
if not vae_only and load_transformer:
711-
rngs = nnx.Rngs(jax.random.key(config.seed))
712711
low_noise_transformer = super().load_transformer(
713712
devices_array=common_components["devices_array"],
714713
mesh=common_components["mesh"],

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@
1414
import unittest
1515
from unittest.mock import patch, MagicMock
1616

17-
from maxdiffusion.checkpointing.wan_checkpointer import (
18-
WanCheckpointer2_1,
19-
WanCheckpointer2_2,
20-
WAN_CHECKPOINT
21-
)
22-
17+
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2
2318

2419
class WanCheckpointer2_1Test(unittest.TestCase):
2520
"""Tests for WAN 2.1 checkpointer."""
@@ -240,15 +235,14 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m
240235
self.assertEqual(step, 1)
241236

242237

243-
class WanCheckpointerFactoryTest(unittest.TestCase):
244-
"""Tests for checkpointer factory/selection logic."""
238+
class WanCheckpointerEdgeCasesTest(unittest.TestCase):
239+
"""Tests for edge cases and error handling."""
245240

246241
def setUp(self):
247242
self.config = MagicMock()
248-
self.config.checkpoint_dir = "/tmp/wan_checkpoint_factory_test"
243+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_edge_test"
249244
self.config.dataset_type = "test_dataset"
250245

251-
252246
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
253247
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1")
254248
def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager):

0 commit comments

Comments
 (0)