Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
raise e
exit(-1)

SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError as e:
pass

FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head)
return out

@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads

if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape

try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)

if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)

return out

try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
Expand Down Expand Up @@ -650,6 +744,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
Expand Down