Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger

import time
from ..utils.profiling.flops_profiler import (
print_model_profile,
get_flops,
profile_entire_model,
unprofile_entire_model,
)

def launch_training_task(
accelerator: Accelerator,
Expand All @@ -29,21 +35,63 @@ def launch_training_task(
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)


train_step = 0
profile_entire_model(model)

for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
progress = tqdm(
dataloader,
disable=not accelerator.is_main_process,
desc=f"Epoch {epoch_id + 1}/{num_epochs}",
)

for data in progress:
iter_start = time.time()
if data is None:
continue

with accelerator.accumulate(model):
optimizer.zero_grad()

if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)

t5_Tflops, wan_Tflops, vae_Tflops = get_flops(model)
accelerator.backward(loss)
optimizer.step()

model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()

torch.cuda.synchronize()
time_step = time.time() - iter_start
train_step += 1

total_flops = t5_Tflops + wan_Tflops + vae_Tflops
TFLOPS = total_flops * 3 / time_step

if accelerator.is_main_process:
postfix_dict = {
"Rank": f"{accelerator.process_index}",
"loss": f"{loss.item():.5f}",
"lr": f"{optimizer.param_groups[0]['lr']:.5e}",
"step/t": f"{time_step:.3f}",
"[t5] Tflops": f"{t5_Tflops:.3f}",
"[dit] Tflops": f"{wan_Tflops:.3f}",
"[vae] Tflops": f"{vae_Tflops:.3f}",
"TFLOPS": f"{TFLOPS:.3f}",
}
progress.set_postfix(postfix_dict)
log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items())
progress.write(log_msg)

if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)

unprofile_entire_model(model)
model_logger.on_training_end(accelerator, model, save_steps)


Expand Down
7 changes: 6 additions & 1 deletion diffsynth/models/wan_video_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def forward(self, x, context=None, mask=None, pos_bias=None):
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)

# For calculate flops
self.q_shape = q.shape
self.k_shape = k.shape
self.v_shape = v.shape

# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
Expand Down Expand Up @@ -327,4 +332,4 @@ def _clean(self, text):
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
return text
6 changes: 6 additions & 0 deletions diffsynth/utils/profiling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .flops_profiler import (
profile_entire_model,
unprofile_entire_model,
get_flops,
print_model_profile,
)
230 changes: 230 additions & 0 deletions diffsynth/utils/profiling/flops_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import torch
import torch.nn as nn
from functools import wraps
import time
from torch.utils.flop_counter import conv_flop_count

def get_flops(model):
def get_module_flops(module):
if not hasattr(module, "__flops__"):
module.__flops__ = 0

flops = module.__flops__
# iterate over immediate children modules
for child in module.children():
flops += get_module_flops(child)
return flops

t5_flops = 0
wan_flops = 0
vae_flops = 0
for module in model.modules():
if module.__class__.__name__ == 'WanTextEncoder':
t5_flops = get_module_flops(module)
if module.__class__.__name__ == 'WanModel':
wan_flops = get_module_flops(module)
if module.__class__.__name__ == 'WanVideoVAE38':
vae_flops = get_module_flops(module)
return t5_flops / 1e12, wan_flops / 1e12, vae_flops / 1e12

def print_model_profile(model):
def get_module_flops(module):
if not hasattr(module, "__flops__"):
module.__flops__ = 0

flops = module.__flops__
# iterate over immediate children modules
for child in module.children():
flops += get_module_flops(child)
return flops

def get_module_duration(module):
if not hasattr(module, "__duration__"):
module.__duration__ = 0

duration = module.__duration__
if duration == 0: # e.g. ModuleList
for m in module.children():
duration += get_module_duration(m)
return duration

def flops_repr(module):
flops = get_module_flops(module)
duration = get_module_duration(module) * 1000
items = [
"{:,} flops".format(flops),
"{:.3f} ms".format(duration),
]
original_extra_repr = module.original_extra_repr()
if original_extra_repr:
items.append(original_extra_repr)
return ", ".join(items)

def add_extra_repr(module):
flops_extra_repr = flops_repr.__get__(module)
if module.extra_repr != flops_extra_repr:
module.original_extra_repr = module.extra_repr
module.extra_repr = flops_extra_repr
assert module.extra_repr != module.original_extra_repr

def del_extra_repr(module):
if hasattr(module, "original_extra_repr"):
module.extra_repr = module.original_extra_repr
del module.original_extra_repr

model.apply(add_extra_repr)
print(model)
model.apply(del_extra_repr)

def calculate_module_flops(module, *args, result=None, **kwargs):
module_type = module.__class__.__name__
module_original_fwd = module._original_forward.__name__

if module_type == 'RMSNorm':
x = args[0]
return x.numel() * 4

elif module_type == 'RMS_norm':
x = args[0]
return x.numel() * 4

elif module_type == 'Dropout':
x = args[0]
return x.numel() * 2

elif module_type == 'LayerNorm':
x = args[0]
has_affine = module.weight is not None
return x.numel() * (5 if has_affine else 4)

elif module_type == 'Linear':
x = args[0]
return x.numel() * module.weight.size(0) * 2

elif module_type == 'ReLU':
x = args[0]
return x.numel()

elif module_type == 'GELU':
x = args[0]
return x.numel()

elif module_type == 'SiLU':
x = args[0]
return x.numel()

elif module_type == 'Conv3d' or module_type == 'CausalConv3d' or module_type == 'Conv2d':
x_shape = args[0].shape
weight = getattr(module, 'weight', None)
w_shape = weight.shape
out_shape = result.shape

flops = conv_flop_count(
x_shape=x_shape,
w_shape=w_shape,
out_shape=out_shape,
transposed=False
)
return flops

# AttentionModule input is 3D shape, USP input is 4D shape.
#
# 3D shape:
# q [batch, target_seq_len, Dim]
# k [batch, source_seq_len, Dim]
# v [batch, source_seq_len, Dim]
# flops = (batch * target_seq_len * source_seq_len) * Dim * 2
# + (batch * target_seq_len * Dim) * source_seq_len * 2
# = 4 * (batch * target_seq_len * source_seq_len * Dim)
#
# 4D shape:
# q [batch, target_seq_len, head, dim]
# k [batch, source_seq_len, head, dim]
# v [batch, source_seq_len, head, dim]
# flops = 4 * (batch * target_seq_len * source_seq_len * head * dim)
#
elif module_type == 'AttentionModule':
q = args[0]
k = args[1]
v = args[2]

b, ts, dq = q.shape
_, ss, _ = k.shape
_, _, dv = v.shape
flops = (b * ts * ss * dq) * 2 + (b * ts * ss * dv) * 2
return flops

elif module_original_fwd == 'usp_attn_forward' or module_type == 'T5Attention':
q_shape = module.q_shape
k_shape = module.k_shape
v_shape = module.v_shape

b, ts, n, dq = q_shape
_, ss, _, _ = k_shape
_, _, _, dv = v_shape
flops = (b * ts * ss * n * dq) * 2 + (b * ts * ss * n * dv) * 2
return flops

elif module_type == 'GateModule':
x = args[0]
return x.numel() * 2

elif module_type == 'T5LayerNorm':
x = args[0]
return x.numel() * 4

# The 10x factor is an estimate of the computational coefficient for torch.log.
# The search and move operations in position encoding do not involve flop operations.
elif module_type == 'T5RelativeEmbedding':
lq = args[0]
lk = args[1]
return lq * lk * 10

else:
return 0

def flops_counter():
def decorator(forward_func):
@wraps(forward_func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()

result = forward_func(self, *args, **kwargs)

self.__flops__ = calculate_module_flops(self, *args, result=result, **kwargs)

end_time = time.perf_counter()
self.__duration__ = (end_time - start_time)

return result
return wrapper
return decorator

def wrap_existing_module(module):
# save original fwd
module._original_forward = module.forward

@flops_counter()
def profiled_forward(self, *args, **kwargs):
return module._original_forward(*args, **kwargs)

module.forward = profiled_forward.__get__(module, type(module))
return module

def profile_entire_model(model):
for name, module in model.named_modules():
wrap_existing_module(module)
return model

def unwrap_existing_module(module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward

return module

def unprofile_entire_model(model):
for name, module in model.named_modules():
unwrap_existing_module(module)
return model

7 changes: 6 additions & 1 deletion diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

# For calculate flops
self.q_shape = q.shape
self.k_shape = k.shape
self.v_shape = v.shape

attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
Expand All @@ -143,4 +148,4 @@ def usp_attn_forward(self, x, freqs):

del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
return self.o(x)