Skip to content

feat: opt fa3 and flashinfer#1367

Open
blueswhen wants to merge 3 commits into
mainfrom
opt_flashinfer
Open

feat: opt fa3 and flashinfer#1367
blueswhen wants to merge 3 commits into
mainfrom
opt_flashinfer

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations for FlashInfer attention decoding, including a fast host-based planning path (_fast_plan_tensor_core_decode) and conditional skipping of wrapper initialization when CUDA graph capture is not needed. However, the review highlights two important issues: first, removing the call to super().copy_for_decode_cuda_graph prevents updating captured tensors during CUDA graph replay, which can lead to incorrect attention outputs or crashes; second, if b_seq_len is already on CUDA, b_seq_len_cpu remains unpopulated, preventing the fast planning path from being utilized.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +242 to +261
def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
self.decode_wrapper.plan(
new_state.kv_starts,
new_state.kv_indices,
new_state.kv_last_page_len_buffer,
if new_state.kv_seq_lens_host is not None:
# FlashInfer tensor-core decode updates its split-kv plan at 128-token
# boundaries for this path. page_size is 1 here, so pages == tokens.
skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist())
if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key:
return

_fast_plan_tensor_core_decode(
self.decode_wrapper,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
q_data_type=new_state.backend.q_data_type,
kv_data_type=new_state.backend.kv_data_type,
non_blocking=True,
new_state.kv_starts_host,
new_state.kv_seq_lens_host,
new_state.infer_state.max_kv_seq_len,
)
if new_state.kv_seq_lens_host is not None:
self.decode_wrapper._skip_plan_key = skip_plan_key

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to super().copy_for_decode_cuda_graph(new_state) was removed, which prevents the captured tensors (kv_starts, kv_indices, etc.) from being updated with the new values from new_state during CUDA graph replay. This will cause the attention mechanism to run with stale or incorrect KV cache indices and starts, leading to incorrect attention outputs or crashes.

Additionally, if new_state.kv_seq_lens_host is None, calling _fast_plan_tensor_core_decode will crash with a TypeError because kv_lens_arr_host is expected to be a list/tensor. We should restore the tensor copying and add a fallback to the standard plan method when host sequence lengths are not available.

    def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
        super().copy_for_decode_cuda_graph(new_state)
        if new_state.kv_seq_lens_host is not None:
            # FlashInfer tensor-core decode updates its split-kv plan at 128-token
            # boundaries for this path. page_size is 1 here, so pages == tokens.
            skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist())
            if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key:
                return

            _fast_plan_tensor_core_decode(
                self.decode_wrapper,
                new_state.backend.tp_q_head_num,
                new_state.backend.tp_kv_head_num,
                new_state.backend.head_dim,
                1,
                new_state.kv_starts_host,
                new_state.kv_seq_lens_host,
                new_state.infer_state.max_kv_seq_len,
            )
            self.decode_wrapper._skip_plan_key = skip_plan_key
        else:
            self.decode_wrapper.plan(
                new_state.kv_starts,
                new_state.kv_indices,
                new_state.kv_last_page_len_buffer,
                new_state.backend.tp_q_head_num,
                new_state.backend.tp_kv_head_num,
                new_state.backend.head_dim,
                1,
                q_data_type=new_state.backend.q_data_type,
                kv_data_type=new_state.backend.kv_data_type,
                non_blocking=True,
            )

Comment on lines +68 to 70
if not self.b_seq_len.is_cuda:
self.b_seq_len_cpu = self.b_seq_len
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If self.b_seq_len is already on CUDA when to_cuda is called, self.b_seq_len_cpu will remain None. This prevents the fast host-based planning path from being used. We can copy the tensor to CPU if it is already on CUDA to ensure self.b_seq_len_cpu is always populated.

        if not self.b_seq_len.is_cuda:
            self.b_seq_len_cpu = self.b_seq_len
        else:
            self.b_seq_len_cpu = self.b_seq_len.cpu()
        self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)

@blueswhen blueswhen force-pushed the opt_flashinfer branch 2 times, most recently from 2f40fc4 to 3377848 Compare June 24, 2026 03:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant