From e993763ae1f29f6510bfbe0fd69004de80466b19 Mon Sep 17 00:00:00 2001 From: Pieter-Jan Hoedt Date: Tue, 10 Jun 2025 10:22:53 +0200 Subject: [PATCH 1/4] =?UTF-8?q?hot-fix=20for=20recurrent=20backward=20bug?= =?UTF-8?q?=C3=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py index 98897fb..ae96ff2 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py @@ -105,9 +105,9 @@ def mlstm_chunkwise_bw( vecN_out=vecN_out, # (B, NH, S) matDeltaC_last=matDeltaC_last, # (B, NH, DHQK, DHHV) qk_scale=qk_scale, - chunk_size=kernel_chunk_params.chunk_size_inter, + chunk_size=kernel_chunk_params.chunk_size_intra, eps=eps, - save_states_every_nth_chunk=kernel_chunk_params.save_states_every_nth_chunk, + save_states_every_nth_chunk=1, num_stages=num_stages_inter, num_warps=num_warps_inter, ) From a9edba279300132d49ce21973825708cd39c8e3f Mon Sep 17 00:00:00 2001 From: Pieter-Jan Hoedt Date: Tue, 10 Jun 2025 10:28:15 +0200 Subject: [PATCH 2/4] use scaM_inter shape to infer chunk size --- .../torch/chunkwise/triton_xl_chunk/bw.py | 2 -- .../chunkwise/triton_xl_chunk/bw_recurrent.py | 23 ++++--------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py index ae96ff2..5b50492 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py @@ -105,9 +105,7 @@ def mlstm_chunkwise_bw( vecN_out=vecN_out, # (B, NH, S) matDeltaC_last=matDeltaC_last, # (B, NH, DHQK, DHHV) qk_scale=qk_scale, - chunk_size=kernel_chunk_params.chunk_size_intra, eps=eps, - save_states_every_nth_chunk=1, num_stages=num_stages_inter, num_warps=num_warps_inter, ) diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py index 50359a8..cab6fd9 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py @@ -19,8 +19,6 @@ def mlstm_chunkwise__recurrent_bw_dC( vecN_out: torch.Tensor, # (B, NH, S) matDeltaC_last: torch.Tensor = None, # (B, NH, DHQK, DHHV) qk_scale: float = None, - chunk_size: int = 64, - save_states_every_nth_chunk: int = 1, num_warps: int | None = None, num_stages: int | None = None, eps: float = 0.0, @@ -31,31 +29,18 @@ def mlstm_chunkwise__recurrent_bw_dC( """ B, NH, S, DHQK, DHHV = *matQ.shape, matDeltaH.shape[-1] _dtype, _device = matQ.dtype, matQ.device - L = chunk_size + NC = scaM_inter.shape[-1] - 1 + L = S // NC assert is_power_of_2(L), "Chunk size must be a power of 2." assert S % L == 0, "S must be divisible by chunk_size." - NC = S // L - - assert ( - save_states_every_nth_chunk > 0 - ), "save_states_every_nth_chunk must be positive." - assert ( - save_states_every_nth_chunk <= NC - ), "save_states_every_nth_chunk must be <= NC." - - assert is_power_of_2( - save_states_every_nth_chunk - ), f"save_states_every_nth_chunk must be a power of 2. Got {save_states_every_nth_chunk}." if qk_scale is None: qk_scale = DHQK**-0.5 USE_LAST_STATE = matDeltaC_last is not None - num_chunks_saved = NC // save_states_every_nth_chunk - matDeltaC_states = torch.empty( - (B, NH, (num_chunks_saved + 1) * DHQK, DHHV), + (B, NH, (NC + 1) * DHQK, DHHV), dtype=torch.float32, device=_device, ) @@ -109,7 +94,7 @@ def mlstm_chunkwise__recurrent_bw_dC( L=L, siz_b_DHQK=siz_b_DHQK, siz_b_DHHV=siz_b_DHHV, - save_states_every_nth_chunk=save_states_every_nth_chunk, + save_states_every_nth_chunk=1, USE_LAST_STATE=USE_LAST_STATE, DTYPE=torch2triton_dtype(_dtype), EPS=eps, From 29bb53d7781b31c4986366915584b32127beef74 Mon Sep 17 00:00:00 2001 From: Pieter-Jan Hoedt Date: Tue, 10 Jun 2025 10:31:41 +0200 Subject: [PATCH 3/4] remove 'save_every_nth_chunk'-logic from kernel --- .../chunkwise/triton_xl_chunk/bw_recurrent.py | 1 - .../chunkwise/xl_chunk/bw_kernel_recurrent.py | 36 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py index cab6fd9..2f38ab5 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py @@ -94,7 +94,6 @@ def mlstm_chunkwise__recurrent_bw_dC( L=L, siz_b_DHQK=siz_b_DHQK, siz_b_DHHV=siz_b_DHHV, - save_states_every_nth_chunk=1, USE_LAST_STATE=USE_LAST_STATE, DTYPE=torch2triton_dtype(_dtype), EPS=eps, diff --git a/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py b/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py index b600e20..d068797 100644 --- a/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py +++ b/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py @@ -50,7 +50,6 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel( L: tl.constexpr, siz_b_DHQK: tl.constexpr, siz_b_DHHV: tl.constexpr, - save_states_every_nth_chunk: tl.constexpr, USE_LAST_STATE: tl.constexpr, DTYPE: tl.constexpr = tl.float32, EPS: tl.constexpr = 1e-6, @@ -100,24 +99,23 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel( order=(1, 0), ) # ? end pointers - if k % save_states_every_nth_chunk == 0: - idx_k_save = k // save_states_every_nth_chunk - # * store matDeltaC_k_val from previous iteration in HBM - matDeltaCstates_k_ptr = tl.make_block_ptr( - base=matDeltaC_states - + idx_b_NH * str_matDeltaC_states_B_NH - + idx_k_save * DHQK * DHHV, - shape=(DHQK, DHHV), - strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV), - offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV), - block_shape=(siz_b_DHQK, siz_b_DHHV), - order=(1, 0), - ) - tl.store( - matDeltaCstates_k_ptr, - matDeltaC_k_val.to(tl.float32), - boundary_check=(0, 1), - ) + + # * store matDeltaC_k_val from previous iteration in HBM + matDeltaCstates_k_ptr = tl.make_block_ptr( + base=matDeltaC_states + + idx_b_NH * str_matDeltaC_states_B_NH + + k * DHQK * DHHV, + shape=(DHQK, DHHV), + strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV), + offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV), + block_shape=(siz_b_DHQK, siz_b_DHHV), + order=(1, 0), + ) + tl.store( + matDeltaCstates_k_ptr, + matDeltaC_k_val.to(tl.float32), + boundary_check=(0, 1), + ) # * compute matDeltaC_km1_val # load scaG_k, vecB_k, scaM_inter_km1, scaM_inter_k, vecM_combine_k From 027f051acf43edb8b8bd74ee7c19d116ca3960e2 Mon Sep 17 00:00:00 2001 From: Pieter-Jan Hoedt Date: Tue, 10 Jun 2025 11:03:55 +0200 Subject: [PATCH 4/4] add test to avoid future errors --- .../test_chunkwise_triton_xl_chunk.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py b/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py index a81ac69..47397c4 100644 --- a/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py +++ b/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py @@ -54,6 +54,46 @@ def test_triton_chunkwise_xl_chunk_vs_native_parallel_stablef_fp32( ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") +@pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) +def test_inter_vs_intra_chunks(S, B, NH, DHQK, DHHV): + torch.manual_seed(2025) + q = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True) + k = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True) + v = torch.randn(B, NH, S, DHHV, device="cuda", requires_grad=True) + i = torch.randn(B, NH, S, device="cuda", requires_grad=True) + f = torch.randn(B, NH, S, device="cuda", requires_grad=True) + + h_ref = mlstm_chunkwise__xl_chunk( + q, k, v, i, f, + chunk_size=128, chunk_size_inter=128, chunk_size_intra=128, + siz_b_L_parallel=64, siz_b_L_loop=64, + siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV, + ) + + dh = torch.randn_like(h_ref) + dq_ref, dk_ref, dv_ref, di_ref, df_ref = torch.autograd.grad( + [h_ref], [q, k, v, i, f], [dh] + ) + + h = mlstm_chunkwise__xl_chunk( + q, k, v, i, f, + chunk_size=128, chunk_size_inter=64, chunk_size_intra=128, + siz_b_L_parallel=64, siz_b_L_loop=64, + siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV + ) + dq, dk, dv, di, df = torch.autograd.grad( + [h], [q, k, v, i, f], [dh] + ) + + torch.testing.assert_close(h, h_ref) + torch.testing.assert_close(dq, dq_ref) + torch.testing.assert_close(dk, dk_ref) + torch.testing.assert_close(dv, dv_ref) + torch.testing.assert_close(di, di_ref) + torch.testing.assert_close(df, df_ref) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") def test_state_passing(mlstm_state_passing_test, state_passing_qkvif): num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64