Skip to content

No flash and fused attention for Gemma4 26B-A4B #3006

@pavelgein

Description

@pavelgein

Describe the bug

I am trying to fine-tune Gemma4 26B-A4B on long sequences (32K) via Megatron-Bridge, and I have the following logs

2026-05-18 12:51:26.3238 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:26.3640 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:26.4042 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:26.4441 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:26.5250 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:26.5250 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:26.5653 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:26.6055 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:26.6452 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:26.7339 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:27.0039 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.0438 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.1246 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.1246 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.1649 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.2051 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 256, 'head_dim_v': 256, 'attn_mask_type': 'padding_causal', 'window_size': (1023, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:27.2450 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.3217 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.3249 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.3651 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.4053 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.4451 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.5297 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.5297 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.5700 DEBUG:DotProductAttention:Available backends = {FlashAttention=True (3.0.0), FusedAttention=True (sub-backend 1), UnfusedDotProductAttention=True}
2026-05-18 12:51:27.6102 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.6499 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.7236 DEBUG:DotProductAttention:Disabling FlashAttention to give FusedAttention preference on Hopper+ for performance reasons
2026-05-18 12:51:27.7298 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:27.7700 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:27.8102 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:27.8500 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:27.9127 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:27.9299 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:27.9701 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:28.0103 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:28.0500 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:28.0905 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:28.1304 DEBUG:DotProductAttention:Selected backend = FusedAttention (sub-backend 1)
2026-05-18 12:51:28.2189 INFO:DotProductAttention:Running with FusedAttention backend (sub-backend 1)
2026-05-18 12:51:30.5500 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:30.5501 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:30.5501 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:30.5909 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:30.6309 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:30.6863 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:30.7105 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:30.7502 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:30.8237 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:30.8310 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:30.8701 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:30.9103 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:30.9501 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:31.0199 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:31.0311 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:31.0703 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:31.1105 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:31.1504 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:31.2196 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:31.2309 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:31.2711 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:31.3102 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:31.3510 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:31.4214 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:31.4306 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:31.4708 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:31.5109 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:31.5507 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:31.6182 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:31.6304 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:31.6705 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:31.7107 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:31.7505 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:31.8279 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:31.8301 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:31.8703 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:31.9105 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:31.9511 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:32.0245 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:32.0308 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:32.0710 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:32.1101 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:32.1509 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:32.2111 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:32.2311 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:32.3216 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:32.3216 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:32.3618 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:32.4020 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend
2026-05-18 12:51:32.4417 DEBUG:DotProductAttention:Running with config={'transformer_engine_version': '2.14.0+71bbefbf', 'compute_capability': 'sm90', 'flash_attn_version': '2.7.4.post1+nv26.2.44259020', 'flash_attn_3_version': '3.0.0', 'cudnn_version': '9.20.0', 'qkv_type': <class 'torch.Tensor'>, 'qkv_dtype': torch.bfloat16, 'qkv_layout': 'thd_thd_thd', 'batch_size': 4, 'num_heads': 2, 'num_gqa_groups': 1, 'max_seqlen_q': tensor(5743, dtype=torch.int32), 'max_seqlen_kv': tensor(5743, dtype=torch.int32), 'head_dim_qk': 512, 'head_dim_v': 512, 'attn_mask_type': 'padding_causal', 'window_size': (-1, 0), 'bottom_right_diagonal': False, 'alibi_slopes_shape': None, 'core_attention_bias_type': 'no_bias', 'core_attention_bias_shape': None, 'core_attention_bias_requires_grad': False, 'pad_between_seqs': False, 'attention_dropout': 0.0, 'context_parallel': False, 'cp_comm_type': 'p2p', 'deterministic': False, 'is_training': True, 'fp8': False, 'fp8_meta': {'fp8_checkpoint': False, 'fp8_group': None}, 'inference_params': None, 'softmax_type': 'vanilla', 'return_max_logit': False, 'cuda_graph': False, 'num_splits': 1}
2026-05-18 12:51:32.5064 DEBUG:DotProductAttention:Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. Supported: head_dim_qk = head_dim_v, head_dim_qk %8 = 0, head_dim_qk <= 256 (>192 requires sm80/90/100+). Found: head_dim_qk = 512, head_dim_v = 512, on sm9.0.
2026-05-18 12:51:32.5224 DEBUG:DotProductAttention:Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, head_dim_qk, head_dim_v or qkv_dtype. Supported: head_dim_qk <= 256, and num_heads % num_gqa_groups = 0, and if head_dim_qk is different from head_dim_v, then (head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or (head_dim_qk <= 64 and head_dim_v <= 512), and if head_dim_qk is different from head_dim_v and head_dim_v > 256, then qkv_dtype requires fp16 and bf16 data type. Found: num_heads = 2, num_gqa_groups = 1, head_dim_qk = 512, head_dim_v = 512 and qkv_dtype = torch.bfloat16.
2026-05-18 12:51:32.6181 DEBUG:DotProductAttention:Disabling FusedAttention as no backend supports the provided input
2026-05-18 12:51:32.6181 DEBUG:DotProductAttention:Available backends = {FlashAttention=False, FusedAttention=False, UnfusedDotProductAttention=True}
2026-05-18 12:51:32.6590 DEBUG:DotProductAttention:Selected backend = UnfusedDotProductAttention
2026-05-18 12:51:32.6992 INFO:DotProductAttention:Running with UnfusedDotProductAttention backend 

As a result, the TE code tries to create a full attention mask, which leads to OOM

2026-05-18 12:52:23.2256 [rank6]:   File "/opt/venv/lib/python3.12/site-packages/transformer_engine/pytorch/attention/dot_product_attention/utils.py", line 1359, in get_full_mask
2026-05-18 12:52:23.2658 [rank6]:     attention_mask = torch.logical_or(
2026-05-18 12:52:23.3060 [rank6]:                      ^^^^^^^^^^^^^^^^^
2026-05-18 12:52:23.3458 [rank6]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 4096.00 GiB. GPU 6 has a total capacity of 79.18 GiB of which 54.50 GiB is free. Process 1720765 has 24.15 GiB memory in use. Process 1726543 has 522.00 MiB memory in use. Of the allocated memory 17.54 GiB is allocated by PyTorch, and 170.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Steps/Code to reproduce bug

Expected behavior

Environment overview (please complete the following information)

  • Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)]
  • Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install.
  • If method of install is [Docker], provide docker pull & docker run commands used

I am running a docker container with nvidia/nemo:26.04 image

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version
  • PyTorch version
  • Python version
  • Transformer Engine version
  • CUDA version
  • CUDNN version

Device details

  • GPU model

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions