Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 70 additions & 46 deletions src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,66 @@ def _denormalize_latents(
latents = latents * latents_std / scaling_factor + latents_mean
return latents

def _decode_latents_to_video(
self,
latents: jax.Array,
generator: jax.Array,
batch_size: int,
decode_timestep: Union[float, List[float]],
decode_noise_scale: Optional[Union[float, List[float]]],
output_type: str,
replicate_vae: bool,
):
if replicate_vae:
try:
mesh = latents.sharding.mesh
replicated_sharding = NamedSharding(mesh, P())
# Replicate VAE weights
graphdef, state = nnx.split(self.vae)
state = jax.tree_util.tree_map(
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
)
self.vae = nnx.merge(graphdef, state)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}")

t0_video_vae = time.perf_counter()
with jax.named_scope("video_vae_decode"):
if getattr(self.vae.config, "timestep_conditioning", False):
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)

if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size

timestep = jnp.array(decode_timestep, dtype=latents.dtype)
decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None]

latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise

latents = latents.astype(self.vae.dtype)
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
else:
latents = latents.astype(self.vae.dtype)
video = self.vae.decode(latents, return_dict=False)[0]

video = video.block_until_ready()
video_vae_time = time.perf_counter() - t0_video_vae
max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s")

# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
t0_video_post = time.perf_counter()
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
video_post_time = time.perf_counter() - t0_video_post
max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s")

return video, video_vae_time, video_post_time

@staticmethod
def _normalize_audio_latents(latents: jax.Array, latents_mean: jax.Array, latents_std: jax.Array):
latents_mean = latents_mean.astype(latents.dtype)
Expand Down Expand Up @@ -1849,55 +1909,18 @@ def convert_to_vel(lat, x0, sig):
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.log(f"[Tuning] Failed to apply replicate VAE latents sharding: {e}")

if replicate_vae:
try:
mesh = latents.sharding.mesh
replicated_sharding = NamedSharding(mesh, P())
# Replicate VAE weights
graphdef, state = nnx.split(self.vae)
state = jax.tree_util.tree_map(
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
)
self.vae = nnx.merge(graphdef, state)
except Exception as e: # pylint: disable=broad-exception-caught
max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}")

latent_processing_time += time.perf_counter() - t0_latent_processing
timings["Latent Processing"] = latent_processing_time

t0_video_vae = time.perf_counter()
with jax.named_scope("video_vae_decode"):
if getattr(self.vae.config, "timestep_conditioning", False):
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)

if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size

timestep = jnp.array(decode_timestep, dtype=latents.dtype)
decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None]

latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise

latents = latents.astype(self.vae.dtype)
video = self.vae.decode(latents, temb=timestep, return_dict=False)[0]
else:
latents = latents.astype(self.vae.dtype)
video = self.vae.decode(latents, return_dict=False)[0]

video = video.block_until_ready()
video_vae_time = time.perf_counter() - t0_video_vae
max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s")

# VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W)
t0_video_post = time.perf_counter()
video_np = np.array(video).transpose(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)
video_post_time = time.perf_counter() - t0_video_post
max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s")
video, video_vae_time, video_post_time = self._decode_latents_to_video(
latents=latents,
generator=generator,
batch_size=batch_size,
decode_timestep=decode_timestep,
decode_noise_scale=decode_noise_scale,
output_type=output_type,
replicate_vae=replicate_vae,
)

# Decode Audio
t0_audio_vae = time.perf_counter()
Expand All @@ -1917,6 +1940,7 @@ def convert_to_vel(lat, x0, sig):
audio = self._jitted_vocoder(self.vocoder, generated_mel_spectrograms)

# Convert audio to numpy
audio = jax.experimental.multihost_utils.process_allgather(audio, tiled=True)
audio = np.array(audio)
vocoder_time = time.perf_counter() - t0_vocoder
max_logging.log(f"Vocoder & Audio numpy time: {vocoder_time:.2f}s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp
import numpy as np
import torch
from flax import nnx
from flax.core.frozen_dict import FrozenDict
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -250,8 +251,9 @@ def __call__(
if video.dtype == jnp.bfloat16:
video = video.astype(jnp.float32)

video = np.transpose(np.array(video), (0, 4, 1, 2, 3))
video = self.video_processor.postprocess_video(video, output_type=output_type)
video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
video_np = np.transpose(np.array(video), (0, 4, 1, 2, 3))
video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type)

if not return_dict:
return (video,)
Expand Down
Loading