Skip to content

feat: add intra-card context parallelism for B=1 long sequences#9

Open
yyq0210 wants to merge 1 commit into
MoonshotAI:masterfrom
yyq0210:yyq/cp_opt
Open

feat: add intra-card context parallelism for B=1 long sequences#9
yyq0210 wants to merge 1 commit into
MoonshotAI:masterfrom
yyq0210:yyq/cp_opt

Conversation

@yyq0210
Copy link
Copy Markdown

@yyq0210 yyq0210 commented May 29, 2026

Summary

  • Add flash_kda/cp.py implementing intra-card context parallelism (CP)
  • Export fwd_cp interface in flash_kda/__init__.py
  • Add tests/test_cp.py for correctness verification

Motivation

FlashKDA underutilizes the GPU at B=1 with long sequences (kernel2 grid has only H blocks), making it slower than Triton baselines. By splitting long sequences into sub-segments processed in parallel, we can significantly improve SM occupancy.

Approach

Simplified two-pass method inspired by FlashQLA's CP strategy:
_calc_cp_seqs(): Automatic sub-segment partitioning

  • max_local_chunks = 2^round(log2(sqrt(H * total_chunks / SM_COUNT) * 3))
  • Enable condition: Be * H <= SM_COUNT // 4

_estimate_warmup_converges(): Analytic gate-decay check

  • Uses A_log (per-head decay rate): min_seg_len * max(A_log) < threshold

fwd_cp(): Two-pass forward

  • Pass 1: Compute all sub-segments with h0=0, capture final_state (ht_buffer)
  • State chaining: cp_h0[i+1] = ht_buffer[i]
  • Pass 2: Recompute all sub-segments with corrected initial_state

Key simplification vs FlashQLA: no transition matrix mt computation. When gate decay is sufficient (the precondition for enabling CP), the initial state contribution decays to zero within each sub-segment, making ht[i] the exact correct state — no correction needed.

Benchmark (H20, B=1, H=16, D=128)

SeqLen fwd (ms) fwd_cp (ms) Speedup
8k 1.24 1.10 1.13×
16k 2.45 1.92 1.28×
64k 9.64 6.74 1.43×
256k 38.47 29.04 1.32×

CP automatically disables when Be * H > SM_COUNT // 4 (e.g., H>=32 on H20), adding zero overhead to existing workloads.

Correctness

When the gate-decay convergence condition is met, CP output is bit-identical to standard fwd() (max_diff = 0).

Test Plan

  • test_cp_disabled: auto_cp=False produces output identical to fwd()
  • test_cp_correctness_single_seq: Correctness at T=8k
  • test_cp_correctness_long_seq: Correctness at T=64k
  • test_cp_final_state: final_state correctly extracted from last sub-segment
  • test_cp_with_initial_state: initial_state correctly propagated to first sub-segment
python tests/test_cp.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant