diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 2309c7b4..b1d4a167 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -1152,6 +1152,66 @@ def _denormalize_latents( latents = latents * latents_std / scaling_factor + latents_mean return latents + def _decode_latents_to_video( + self, + latents: jax.Array, + generator: jax.Array, + batch_size: int, + decode_timestep: Union[float, List[float]], + decode_noise_scale: Optional[Union[float, List[float]]], + output_type: str, + replicate_vae: bool, + ): + if replicate_vae: + try: + mesh = latents.sharding.mesh + replicated_sharding = NamedSharding(mesh, P()) + # Replicate VAE weights + graphdef, state = nnx.split(self.vae) + state = jax.tree_util.tree_map( + lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state + ) + self.vae = nnx.merge(graphdef, state) + except Exception as e: # pylint: disable=broad-exception-caught + max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}") + + t0_video_vae = time.perf_counter() + with jax.named_scope("video_vae_decode"): + if getattr(self.vae.config, "timestep_conditioning", False): + noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) + + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = jnp.array(decode_timestep, dtype=latents.dtype) + decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] + + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] + else: + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, return_dict=False)[0] + + video = video.block_until_ready() + video_vae_time = time.perf_counter() - t0_video_vae + max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s") + + # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W) + t0_video_post = time.perf_counter() + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video_np = np.array(video).transpose(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type) + video_post_time = time.perf_counter() - t0_video_post + max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s") + + return video, video_vae_time, video_post_time + @staticmethod def _normalize_audio_latents(latents: jax.Array, latents_mean: jax.Array, latents_std: jax.Array): latents_mean = latents_mean.astype(latents.dtype) @@ -1849,55 +1909,18 @@ def convert_to_vel(lat, x0, sig): except Exception as e: # pylint: disable=broad-exception-caught max_logging.log(f"[Tuning] Failed to apply replicate VAE latents sharding: {e}") - if replicate_vae: - try: - mesh = latents.sharding.mesh - replicated_sharding = NamedSharding(mesh, P()) - # Replicate VAE weights - graphdef, state = nnx.split(self.vae) - state = jax.tree_util.tree_map( - lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state - ) - self.vae = nnx.merge(graphdef, state) - except Exception as e: # pylint: disable=broad-exception-caught - max_logging.log(f"[Tuning] Failed to replicate VAE weights: {e}") - latent_processing_time += time.perf_counter() - t0_latent_processing timings["Latent Processing"] = latent_processing_time - t0_video_vae = time.perf_counter() - with jax.named_scope("video_vae_decode"): - if getattr(self.vae.config, "timestep_conditioning", False): - noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) - - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * batch_size - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [decode_noise_scale] * batch_size - - timestep = jnp.array(decode_timestep, dtype=latents.dtype) - decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] - - latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise - - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] - else: - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, return_dict=False)[0] - - video = video.block_until_ready() - video_vae_time = time.perf_counter() - t0_video_vae - max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s") - - # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W) - t0_video_post = time.perf_counter() - video_np = np.array(video).transpose(0, 4, 1, 2, 3) - video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type) - video_post_time = time.perf_counter() - t0_video_post - max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s") + video, video_vae_time, video_post_time = self._decode_latents_to_video( + latents=latents, + generator=generator, + batch_size=batch_size, + decode_timestep=decode_timestep, + decode_noise_scale=decode_noise_scale, + output_type=output_type, + replicate_vae=replicate_vae, + ) # Decode Audio t0_audio_vae = time.perf_counter() @@ -1917,6 +1940,7 @@ def convert_to_vel(lat, x0, sig): audio = self._jitted_vocoder(self.vocoder, generated_mel_spectrograms) # Convert audio to numpy + audio = jax.experimental.multihost_utils.process_allgather(audio, tiled=True) audio = np.array(audio) vocoder_time = time.perf_counter() - t0_vocoder max_logging.log(f"Vocoder & Audio numpy time: {vocoder_time:.2f}s") diff --git a/src/maxdiffusion/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/maxdiffusion/pipelines/ltx2/pipeline_ltx2_latent_upsample.py index 4a775218..669fc952 100644 --- a/src/maxdiffusion/pipelines/ltx2/pipeline_ltx2_latent_upsample.py +++ b/src/maxdiffusion/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp import numpy as np +import torch from flax import nnx from flax.core.frozen_dict import FrozenDict from typing import Dict, List, Optional, Union @@ -250,8 +251,9 @@ def __call__( if video.dtype == jnp.bfloat16: video = video.astype(jnp.float32) - video = np.transpose(np.array(video), (0, 4, 1, 2, 3)) - video = self.video_processor.postprocess_video(video, output_type=output_type) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video_np = np.transpose(np.array(video), (0, 4, 1, 2, 3)) + video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type) if not return_dict: return (video,)