diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 34eb6991d56..413f81c9cd9 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -326,6 +326,7 @@ def extract_decoder_token_from_q( assert len(cu_seqlens_q.shape) == 1 assert len(seq_lens_encoder.shape) == 1 assert len(seq_lens_decoder.shape) == 1 + assert seq_lens_encoder.shape == seq_lens_decoder.shape max_bsz = seq_lens_decoder.shape[0] @@ -398,7 +399,7 @@ def insert_decoder_result_back( max_bsz = seq_lens_encoder.shape[0] hidden_dim = decoder_result.shape[-2] * decoder_result.shape[-1] - out = paddle.zeros([mixed_token_num, hidden_dim], dtype=decoder_result.dtype) + out = paddle.empty([mixed_token_num, hidden_dim], dtype=decoder_result.dtype) BLOCK_SIZE = triton.next_power_of_2(hidden_dim) @@ -525,6 +526,7 @@ def __init__( self.useless_tensor = paddle.randn([1]).cast("int32") prop = paddle.device.cuda.get_device_properties() cc = prop.major * 10 + prop.minor + self.prop = prop self.is_blackwell = cc >= 100 if self.flash_attn_func is None: @@ -813,7 +815,7 @@ def forward_mixed( self.max_seq_len, ) - if self.is_blackwell: + if self.prop.major == 10: # TODO support FA4 fmha_out = MLAAttentionBackend.mha_baseline( q, @@ -857,7 +859,7 @@ def forward_mixed( speculate_decoder, ) - if int(os.getenv("USE_FLASH_MLA", "0")) == 0: + if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: assert self.num_heads <= 64, "paddle mla attention support failed" if self.heads_need_padding: q = paddle.nn.functional.pad( @@ -910,17 +912,7 @@ def forward_mixed( return fmha_out else: - import flash_mla - - decoder_q, cache_seqlens = extract_decoder_token_from_q( - q, - forward_meta.cu_seqlens_q, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - ) - - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() - token_num = q.shape[0] + decoder_q = q decoder_q.reshape_([-1, 1, self.num_heads, 576]) if self.heads_need_padding: padded_q = paddle.zeros( @@ -933,22 +925,28 @@ def forward_mixed( assert new_cache_shape[1] == 1 new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] - if self.is_blackwell: + if self.prop.major == 10: + # blackwell decoder_res = MLAAttentionBackend.mla_blackwell( decoder_q, latent_cache, metadata.block_tables, - cache_seqlens, + forward_meta.cache_seqlens, attn_softmax_scale=self.attn_softmax_scale, ) else: + + import flash_mla + + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() + decoder_res, _ = flash_mla.flash_mla_with_kvcache( decoder_q, # 外面的开源仓库的kv cache存储格式和FD的不同 # 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多! latent_cache.view(new_cache_shape), metadata.block_tables, - cache_seqlens, + forward_meta.cache_seqlens, 512, # t.dv, tile_scheduler_metadata, num_splits, @@ -958,15 +956,7 @@ def forward_mixed( if self.heads_need_padding: decoder_res = decoder_res[:, :, : self.num_heads, :].contiguous() - final_res = insert_decoder_result_back( - decoder_res, - forward_meta.cu_seqlens_q, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - token_num, - ) - - return final_res + return decoder_res @staticmethod def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): @@ -1016,11 +1006,6 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft softmax_scale = attn_softmax_scale output_scale = 1.0 - import sys - - sys.path.insert( - 0, "/root/paddlejob/workspace/env_run/output/zkk/cutlass/examples/python/CuTeDSL/blackwell/mla" - ) from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 mla = BlackwellMultiHeadLatentAttentionForwardFP16( diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 0ac87fa6dfa..2b34849ef5d 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -17,6 +17,7 @@ from __future__ import annotations import math +import os import re from typing import Dict @@ -344,6 +345,9 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None self.prefix = prefix + prop = paddle.device.cuda.get_device_properties() + self.prop = prop + @staticmethod def yarn_get_mscale(scale=1, mscale=1): """ """ @@ -362,6 +366,8 @@ def forward( fused_read_cache_and_interleave, ) + q_total_token_num = hidden_states.shape[0] + attn_out = None if self.use_gated_attn: gate_out = self.gate(hidden_states) @@ -438,6 +444,36 @@ def forward( attn_out = fmha_out if need_do_decode: # max_dec_len_this_time + + if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: + pass + else: + from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + extract_decoder_token_from_q, + insert_decoder_result_back, + ) + + decoder_query_nope, cache_seqlens = extract_decoder_token_from_q( + query_nope.reshape([0, -1]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + ) + + decoder_query_pe, cache_seqlens = extract_decoder_token_from_q( + query_pe.reshape([0, -1]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + ) + assert decoder_query_nope.shape[0] == forward_meta.seq_lens_encoder.shape[0] + assert decoder_query_pe.shape[0] == forward_meta.seq_lens_encoder.shape[0] + + forward_meta.cache_seqlens = cache_seqlens + + query_nope = decoder_query_nope.reshape([0, -1, self.qk_nope_head_dim]) + query_pe = decoder_query_pe.reshape([0, -1, self.qk_rope_head_dim]) + q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) @@ -466,6 +502,17 @@ def forward( .reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) ) + if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: + pass + else: + fmqa_out = insert_decoder_result_back( + fmqa_out.reshape([0, 1, self.num_attention_heads_tp, self.v_head_dim]), + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + q_total_token_num, + ) + if need_do_prefill: merge_prefill_decode_output( attn_out, diff --git a/tests/operators/test_deepgemm_precision.py b/tests/operators/test_deepgemm_precision.py index 34643a7aa4b..2fa25c034fa 100644 --- a/tests/operators/test_deepgemm_precision.py +++ b/tests/operators/test_deepgemm_precision.py @@ -40,13 +40,15 @@ def __init__(self): self.acc_dtype = cutlass.Float32 self.num_acc_stage = 1 - self.use_2cta_instrs = False + self.use_2cta_instrs = True self.cluster_shape_mnk = (2, 1, 1) if self.use_2cta_instrs else (1, 1, 1) self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.mma_tiler = (128, 128, 64) + self.num_ab_stage = 1 + @cute.jit def __call__( self, @@ -54,6 +56,9 @@ def __call__( b: cute.Tensor, c: cute.Tensor, ): + M = a.shape[0] + N = b.shape[0] + tiled_mma = sm100_utils.make_trivial_tiled_mma( cutlass.BFloat16, tcgen05.OperandMajorMode.K, @@ -69,9 +74,15 @@ def __call__( (tiled_mma.thr_id.shape,), ) # ((2),1,1,1):((1),0,0,0) + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) - a_smem_layout_staged = sm100_utils.make_smem_layout_a(tiled_mma, self.mma_tiler, cutlass.BFloat16, 1) - b_smem_layout_staged = sm100_utils.make_smem_layout_b(tiled_mma, self.mma_tiler, cutlass.BFloat16, 1) + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, self.mma_tiler, cutlass.BFloat16, self.num_ab_stage + ) + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, self.mma_tiler, cutlass.BFloat16, self.num_ab_stage + ) a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) @@ -107,13 +118,15 @@ def __call__( a, b, c, + tma_tensor_a, + tma_tensor_b, a_smem_layout_staged, b_smem_layout_staged, tma_atom_a, tma_atom_b, self.cluster_layout_vmnk, ).launch( - grid=self.cluster_shape_mnk, + grid=[M // self.mma_tiler[0] * self.cluster_shape_mn[0], N // self.mma_tiler[1], 1], block=[128, 1, 1], cluster=self.cluster_shape_mnk, ) @@ -126,6 +139,8 @@ def kernel( a, b, c, + tma_tensor_a, + tma_tensor_b, a_smem_layout_staged, b_smem_layout_staged, tma_atom_a, @@ -139,6 +154,14 @@ def kernel( bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) if warp_idx == 0: cpasync.prefetch_descriptor(tma_atom_a) @@ -146,6 +169,7 @@ def kernel( @cute.struct class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 @@ -167,20 +191,57 @@ class SharedStorage: swizzle=b_smem_layout_staged.inner, ) - tmem_alloc_barrier = pipeline.NamedBarrier(barrier_id=0, num_threads=self.threads_per_cta) + gA = cute.local_tile(tma_tensor_a, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None)) + gB = cute.local_tile(tma_tensor_b, cute.slice_(self.mma_tiler, (0, None, None)), (None, None)) + # local_tile后是flatten的shape哦! - # Tensor memory dealloc barrier init - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=tmem_alloc_barrier, - is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar, + # k_tile_cnt 表示k这个方向需要迭代的次数! + k_tile_cnt = cute.size(gA, mode=[3]) + + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA) + tCgB = thr_mma.partition_B(gB) + + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), ) - # Alloc tensor memory buffer - tmem.allocate(self.num_tmem_alloc_cols) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # tensor<(0,0) o (((64,128),1),20,7):(((1@0,1@1),0),128@1,64@0)> + # tensor<(0,?{div=128}) o (((64,128),1),7):(((1@0,1@1),0),64@0)> + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None)] + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None)] + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_tma_producer) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -197,15 +258,20 @@ class SharedStorage: acc_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) acc_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) - for i in cutlass.range(tidx, cute.cosize(sA), self.threads_per_cta): - if self.use_2cta_instrs: - sA[i] = a[bidx * 64 + i % 64, i // 64] - sB[i] = b[bidx * 64 + i % 64, i // 64] - else: - sA[i] = a[i] - sB[i] = b[i] + tmem_alloc_barrier = pipeline.NamedBarrier(barrier_id=0, num_threads=self.threads_per_cta) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar, + ) - pipeline.sync(barrier_id=1) + # Alloc tensor memory buffer + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) tCrA = tiled_mma.make_fragment_A(sA) tCrB = tiled_mma.make_fragment_B(sB) @@ -213,14 +279,54 @@ class SharedStorage: tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) - if warp_idx == 0 and is_leader_cta: - blk_count = tCrA.shape[2] - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - for i in cutlass.range_constexpr(blk_count): - cute.gemm(tiled_mma, tCtAcc, tCrA[None, None, i, 0], tCrB[None, None, i, 0], tCtAcc) - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + if warp_idx == 0: + + for k_tile_idx in cutlass.range(k_tile_cnt, unroll=1): + + producer_handle = ab_producer.acquire_and_advance() + + a_full_mcast_mask = None + b_full_mcast_mask = None + + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # if is_leader_cta: + # cute.printf(b_full_mcast_mask) + + cute.copy( + tma_atom_a, + tAgA[(None, k_tile_idx)], + tAsA[(None, 0)], + tma_bar_ptr=producer_handle.barrier, + mcast_mask=a_full_mcast_mask, + ) + + cute.copy( + tma_atom_b, + tBgB[(None, k_tile_idx)], + tBsB[(None, 0)], + tma_bar_ptr=producer_handle.barrier, + mcast_mask=b_full_mcast_mask, + ) - acc_pipeline.producer_commit(acc_producer_state) + if is_leader_cta: + blk_count = tCrA.shape[2] + + consumer_handle = ab_consumer.wait_and_advance() + + for i in cutlass.range_constexpr(blk_count): + cute.gemm(tiled_mma, tCtAcc, tCrA[None, None, i, 0], tCrB[None, None, i, 0], tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + consumer_handle.release() + + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) acc_pipeline.consumer_wait(acc_consumer_state) @@ -244,15 +350,18 @@ class SharedStorage: if self.use_2cta_instrs: for i in cutlass.range_constexpr(64): - c[tidx % 64 + 64 * bidx, i + tidx // 64 * 64] = (cutlass.BFloat16)(tTR_rAcc[i]) + c[tidx % 64 + 64 * bidx, i + tidx // 64 * 64 + bidy * 128] = (cutlass.BFloat16)(tTR_rAcc[i]) else: for i in cutlass.range_constexpr(128): - c[tidx, i] = (cutlass.BFloat16)(tTR_rAcc[i]) + c[bidx * 128 + tidx, bidy * 128 + i] = (cutlass.BFloat16)(tTR_rAcc[i]) pipeline.sync(barrier_id=2) tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) + if warp_idx == 0: + ab_producer.tail() + class TestDeepDenseGemm(unittest.TestCase): def setUp(self): @@ -280,11 +389,12 @@ def two_invoke(self, M, N, K): my_res, options="--opt-level 2", ) - compiled_mm(my_a, my_b, my_res) + for i in range(100): + compiled_mm(my_a, my_b, my_res) print(my_tensor) - print(my_tensor - baseline_out) + print(baseline_out) assert (my_tensor - baseline_out).abs().max().item() == 0.0 def one_invoke(self, M, N, K): @@ -348,6 +458,7 @@ def test_main(self): if prop.major != 10: return # import paddle.profiler as profiler + # p = profiler.Profiler( # targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], # on_trace_ready=profiler.export_chrome_tracing("./profile_log"), @@ -355,10 +466,10 @@ def test_main(self): # p.start() # p.step() - self.one_invoke(128 * 20, 2048, 4096) - self.one_invoke(128 * 20, 2048, 2048) + # self.one_invoke(128 * 20, 2048, 4096) + # self.one_invoke(128 * 20, 2048, 2048) - self.two_invoke(128, 128, 64) + self.two_invoke(128 * 20, 128 * 20, 64 * 4) # p.stop() diff --git a/tests/operators/test_flashmla_precision.py b/tests/operators/test_flashmla_precision.py index e1e3a9a242f..5b74a23dc31 100644 --- a/tests/operators/test_flashmla_precision.py +++ b/tests/operators/test_flashmla_precision.py @@ -30,47 +30,66 @@ def setUp(self): pass def test_flashmla(self): + # import paddle.profiler as profiler + # p = profiler.Profiler( + # targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], + # on_trace_ready=profiler.export_chrome_tracing("./profile_log"), + # ) + # p.start() + # p.step() + bsz = 128 kv_len = 1000 + page_size = 64 decoder_q = paddle.randn([bsz, 1, 128, 576], dtype="bfloat16") cache_seqlens = paddle.zeros([bsz], dtype="int32") + kv_len - block_tables = paddle.arange((kv_len // 64 + 1) * bsz, dtype="int32").reshape([bsz, -1]) - latent_cache = paddle.randn([10000, 1, 64, 576], dtype="bfloat16") + block_tables = paddle.arange((kv_len // page_size + 1) * bsz, dtype="int32").reshape([bsz, -1]) + latent_cache = paddle.randn([10000, 1, page_size, 576], dtype="bfloat16") # copy from dsv3 attn_softmax_scale = 0.1352337788608801 - baseline_out = MLAAttentionBackend.flashmla_baseline( - decoder_q, latent_cache, block_tables, cache_seqlens, attn_softmax_scale - ) - - paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla - try: - import flash_mla - except ImportError: - print(100 * "Please install flash_mla first") - return - - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() - - new_cache_shape = latent_cache.shape - assert new_cache_shape[1] == 1 - new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] - - decoder_res, _ = flash_mla.flash_mla_with_kvcache( - decoder_q, - # 外面的开源仓库的kv cache存储格式和FD的不同 - # 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多! - latent_cache.view(new_cache_shape), - block_tables, - cache_seqlens, - 512, # t.dv, - tile_scheduler_metadata, - num_splits, - softmax_scale=attn_softmax_scale, - causal=True, - ) + for i in range(10): + baseline_out = MLAAttentionBackend.flashmla_baseline( + decoder_q, latent_cache, block_tables, cache_seqlens, attn_softmax_scale + ) + + prop = paddle.device.cuda.get_device_properties() + if prop.major == 10: + + for i in range(10): + decoder_res = MLAAttentionBackend.mla_blackwell( + decoder_q, latent_cache, block_tables, cache_seqlens, attn_softmax_scale + ) + elif prop.major == 9: + paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla + try: + import flash_mla + except ImportError: + print(100 * "Please install flash_mla first") + return + + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() + + new_cache_shape = latent_cache.shape + assert new_cache_shape[1] == 1 + new_cache_shape[1], new_cache_shape[2] = new_cache_shape[2], new_cache_shape[1] + + decoder_res, _ = flash_mla.flash_mla_with_kvcache( + decoder_q, + # 外面的开源仓库的kv cache存储格式和FD的不同 + # 幸好这里缓存的头是1,直接view即可,否则上上下下要改很多! + latent_cache.view(new_cache_shape), + block_tables, + cache_seqlens, + 512, # t.dv, + tile_scheduler_metadata, + num_splits, + softmax_scale=attn_softmax_scale, + causal=True, + ) max_diff = (decoder_res - baseline_out).abs().max().item() + print(decoder_res - baseline_out) self.assertLessEqual(max_diff, 0.1)