feat: add intra-card context parallelism for B=1 long sequences#9
Open
yyq0210 wants to merge 1 commit into
Open
feat: add intra-card context parallelism for B=1 long sequences#9yyq0210 wants to merge 1 commit into
yyq0210 wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
flash_kda/cp.pyimplementing intra-card context parallelism (CP)fwd_cpinterface inflash_kda/__init__.pytests/test_cp.pyfor correctness verificationMotivation
FlashKDA underutilizes the GPU at B=1 with long sequences (kernel2 grid has only H blocks), making it slower than Triton baselines. By splitting long sequences into sub-segments processed in parallel, we can significantly improve SM occupancy.
Approach
Simplified two-pass method inspired by FlashQLA's CP strategy:
_calc_cp_seqs(): Automatic sub-segment partitioningmax_local_chunks = 2^round(log2(sqrt(H * total_chunks / SM_COUNT) * 3))Be * H <= SM_COUNT // 4_estimate_warmup_converges(): Analytic gate-decay checkmin_seg_len * max(A_log) < thresholdfwd_cp(): Two-pass forwardcp_h0[i+1] = ht_buffer[i]Key simplification vs FlashQLA: no transition matrix
mtcomputation. When gate decay is sufficient (the precondition for enabling CP), the initial state contribution decays to zero within each sub-segment, makinght[i]the exact correct state — no correction needed.Benchmark (H20, B=1, H=16, D=128)
CP automatically disables when
Be * H > SM_COUNT // 4(e.g., H>=32 on H20), adding zero overhead to existing workloads.Correctness
When the gate-decay convergence condition is met, CP output is bit-identical to standard
fwd()(max_diff = 0).Test Plan
test_cp_disabled:auto_cp=Falseproduces output identical tofwd()test_cp_correctness_single_seq: Correctness at T=8ktest_cp_correctness_long_seq: Correctness at T=64ktest_cp_final_state:final_statecorrectly extracted from last sub-segmenttest_cp_with_initial_state:initial_statecorrectly propagated to first sub-segment