diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py index 98897fb..5b50492 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw.py @@ -105,9 +105,7 @@ def mlstm_chunkwise_bw( vecN_out=vecN_out, # (B, NH, S) matDeltaC_last=matDeltaC_last, # (B, NH, DHQK, DHHV) qk_scale=qk_scale, - chunk_size=kernel_chunk_params.chunk_size_inter, eps=eps, - save_states_every_nth_chunk=kernel_chunk_params.save_states_every_nth_chunk, num_stages=num_stages_inter, num_warps=num_warps_inter, ) diff --git a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py index 50359a8..2f38ab5 100644 --- a/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py +++ b/mlstm_kernels/torch/chunkwise/triton_xl_chunk/bw_recurrent.py @@ -19,8 +19,6 @@ def mlstm_chunkwise__recurrent_bw_dC( vecN_out: torch.Tensor, # (B, NH, S) matDeltaC_last: torch.Tensor = None, # (B, NH, DHQK, DHHV) qk_scale: float = None, - chunk_size: int = 64, - save_states_every_nth_chunk: int = 1, num_warps: int | None = None, num_stages: int | None = None, eps: float = 0.0, @@ -31,31 +29,18 @@ def mlstm_chunkwise__recurrent_bw_dC( """ B, NH, S, DHQK, DHHV = *matQ.shape, matDeltaH.shape[-1] _dtype, _device = matQ.dtype, matQ.device - L = chunk_size + NC = scaM_inter.shape[-1] - 1 + L = S // NC assert is_power_of_2(L), "Chunk size must be a power of 2." assert S % L == 0, "S must be divisible by chunk_size." - NC = S // L - - assert ( - save_states_every_nth_chunk > 0 - ), "save_states_every_nth_chunk must be positive." - assert ( - save_states_every_nth_chunk <= NC - ), "save_states_every_nth_chunk must be <= NC." - - assert is_power_of_2( - save_states_every_nth_chunk - ), f"save_states_every_nth_chunk must be a power of 2. Got {save_states_every_nth_chunk}." if qk_scale is None: qk_scale = DHQK**-0.5 USE_LAST_STATE = matDeltaC_last is not None - num_chunks_saved = NC // save_states_every_nth_chunk - matDeltaC_states = torch.empty( - (B, NH, (num_chunks_saved + 1) * DHQK, DHHV), + (B, NH, (NC + 1) * DHQK, DHHV), dtype=torch.float32, device=_device, ) @@ -109,7 +94,6 @@ def mlstm_chunkwise__recurrent_bw_dC( L=L, siz_b_DHQK=siz_b_DHQK, siz_b_DHHV=siz_b_DHHV, - save_states_every_nth_chunk=save_states_every_nth_chunk, USE_LAST_STATE=USE_LAST_STATE, DTYPE=torch2triton_dtype(_dtype), EPS=eps, diff --git a/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py b/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py index b600e20..d068797 100644 --- a/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py +++ b/mlstm_kernels/triton/chunkwise/xl_chunk/bw_kernel_recurrent.py @@ -50,7 +50,6 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel( L: tl.constexpr, siz_b_DHQK: tl.constexpr, siz_b_DHHV: tl.constexpr, - save_states_every_nth_chunk: tl.constexpr, USE_LAST_STATE: tl.constexpr, DTYPE: tl.constexpr = tl.float32, EPS: tl.constexpr = 1e-6, @@ -100,24 +99,23 @@ def mlstm_chunkwise__recurrent_bw_dC_kernel( order=(1, 0), ) # ? end pointers - if k % save_states_every_nth_chunk == 0: - idx_k_save = k // save_states_every_nth_chunk - # * store matDeltaC_k_val from previous iteration in HBM - matDeltaCstates_k_ptr = tl.make_block_ptr( - base=matDeltaC_states - + idx_b_NH * str_matDeltaC_states_B_NH - + idx_k_save * DHQK * DHHV, - shape=(DHQK, DHHV), - strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV), - offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV), - block_shape=(siz_b_DHQK, siz_b_DHHV), - order=(1, 0), - ) - tl.store( - matDeltaCstates_k_ptr, - matDeltaC_k_val.to(tl.float32), - boundary_check=(0, 1), - ) + + # * store matDeltaC_k_val from previous iteration in HBM + matDeltaCstates_k_ptr = tl.make_block_ptr( + base=matDeltaC_states + + idx_b_NH * str_matDeltaC_states_B_NH + + k * DHQK * DHHV, + shape=(DHQK, DHHV), + strides=(str_matDeltaC_states_NCDHQK, str_matDeltaC_states_DHHV), + offsets=(idx_b_DHQK * siz_b_DHQK, idx_b_DHHV * siz_b_DHHV), + block_shape=(siz_b_DHQK, siz_b_DHHV), + order=(1, 0), + ) + tl.store( + matDeltaCstates_k_ptr, + matDeltaC_k_val.to(tl.float32), + boundary_check=(0, 1), + ) # * compute matDeltaC_km1_val # load scaG_k, vecB_k, scaM_inter_km1, scaM_inter_k, vecM_combine_k diff --git a/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py b/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py index a81ac69..47397c4 100644 --- a/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py +++ b/tests/torch/chunkwise/test_chunkwise_triton_xl_chunk.py @@ -54,6 +54,46 @@ def test_triton_chunkwise_xl_chunk_vs_native_parallel_stablef_fp32( ) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") +@pytest.mark.parametrize(["S", "B", "NH", "DHQK", "DHHV"], final_combinations) +def test_inter_vs_intra_chunks(S, B, NH, DHQK, DHHV): + torch.manual_seed(2025) + q = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True) + k = torch.randn(B, NH, S, DHQK, device="cuda", requires_grad=True) + v = torch.randn(B, NH, S, DHHV, device="cuda", requires_grad=True) + i = torch.randn(B, NH, S, device="cuda", requires_grad=True) + f = torch.randn(B, NH, S, device="cuda", requires_grad=True) + + h_ref = mlstm_chunkwise__xl_chunk( + q, k, v, i, f, + chunk_size=128, chunk_size_inter=128, chunk_size_intra=128, + siz_b_L_parallel=64, siz_b_L_loop=64, + siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV, + ) + + dh = torch.randn_like(h_ref) + dq_ref, dk_ref, dv_ref, di_ref, df_ref = torch.autograd.grad( + [h_ref], [q, k, v, i, f], [dh] + ) + + h = mlstm_chunkwise__xl_chunk( + q, k, v, i, f, + chunk_size=128, chunk_size_inter=64, chunk_size_intra=128, + siz_b_L_parallel=64, siz_b_L_loop=64, + siz_b_DH_parallel=DHHV, siz_b_DH_loop=DHHV + ) + dq, dk, dv, di, df = torch.autograd.grad( + [h], [q, k, v, i, f], [dh] + ) + + torch.testing.assert_close(h, h_ref) + torch.testing.assert_close(dq, dq_ref) + torch.testing.assert_close(dk, dk_ref) + torch.testing.assert_close(dv, dv_ref) + torch.testing.assert_close(di, di_ref) + torch.testing.assert_close(df, df_ref) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.") def test_state_passing(mlstm_state_passing_test, state_passing_qkvif): num_chunks = state_passing_qkvif[0].shape[2] // 64 # <- chunk size = 64