Skip to content

Commit 7337eea

Browse files
authored
[refactor] unnecessary lines in decode_latents in video pipelines (#6682)
* refactor decode latents in video pipelines * make fix-copies
1 parent f07899a commit 7337eea

File tree

5 files changed

+5
-60
lines changed

5 files changed

+5
-60
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,18 +469,7 @@ def decode_latents(self, latents):
469469
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
470470

471471
image = self.vae.decode(latents).sample
472-
video = (
473-
image[None, :]
474-
.reshape(
475-
(
476-
batch_size,
477-
num_frames,
478-
-1,
479-
)
480-
+ image.shape[2:]
481-
)
482-
.permute(0, 2, 1, 3, 4)
483-
)
472+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
484473
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
485474
video = video.float()
486475
return video

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -445,18 +445,7 @@ def decode_latents(self, latents):
445445
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
446446

447447
image = self.vae.decode(latents).sample
448-
video = (
449-
image[None, :]
450-
.reshape(
451-
(
452-
batch_size,
453-
num_frames,
454-
-1,
455-
)
456-
+ image.shape[2:]
457-
)
458-
.permute(0, 2, 1, 3, 4)
459-
)
448+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
460449
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
461450
video = video.float()
462451
return video

src/diffusers/pipelines/pia/pipeline_pia.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -494,18 +494,7 @@ def decode_latents(self, latents):
494494
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
495495

496496
image = self.vae.decode(latents).sample
497-
video = (
498-
image[None, :]
499-
.reshape(
500-
(
501-
batch_size,
502-
num_frames,
503-
-1,
504-
)
505-
+ image.shape[2:]
506-
)
507-
.permute(0, 2, 1, 3, 4)
508-
)
497+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
509498
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
510499
video = video.float()
511500
return video

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -384,18 +384,7 @@ def decode_latents(self, latents):
384384
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
385385

386386
image = self.vae.decode(latents).sample
387-
video = (
388-
image[None, :]
389-
.reshape(
390-
(
391-
batch_size,
392-
num_frames,
393-
-1,
394-
)
395-
+ image.shape[2:]
396-
)
397-
.permute(0, 2, 1, 3, 4)
398-
)
387+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
399388
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
400389
video = video.float()
401390
return video

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -461,18 +461,7 @@ def decode_latents(self, latents):
461461
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
462462

463463
image = self.vae.decode(latents).sample
464-
video = (
465-
image[None, :]
466-
.reshape(
467-
(
468-
batch_size,
469-
num_frames,
470-
-1,
471-
)
472-
+ image.shape[2:]
473-
)
474-
.permute(0, 2, 1, 3, 4)
475-
)
464+
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
476465
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
477466
video = video.float()
478467
return video

0 commit comments

Comments
 (0)