feat: opt fa3 and flashinfer#1367
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces optimizations for FlashInfer attention decoding, including a fast host-based planning path (_fast_plan_tensor_core_decode) and conditional skipping of wrapper initialization when CUDA graph capture is not needed. However, the review highlights two important issues: first, removing the call to super().copy_for_decode_cuda_graph prevents updating captured tensors during CUDA graph replay, which can lead to incorrect attention outputs or crashes; second, if b_seq_len is already on CUDA, b_seq_len_cpu remains unpopulated, preventing the fast planning path from being utilized.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): | ||
| super().copy_for_decode_cuda_graph(new_state) | ||
| self.decode_wrapper.plan( | ||
| new_state.kv_starts, | ||
| new_state.kv_indices, | ||
| new_state.kv_last_page_len_buffer, | ||
| if new_state.kv_seq_lens_host is not None: | ||
| # FlashInfer tensor-core decode updates its split-kv plan at 128-token | ||
| # boundaries for this path. page_size is 1 here, so pages == tokens. | ||
| skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist()) | ||
| if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key: | ||
| return | ||
|
|
||
| _fast_plan_tensor_core_decode( | ||
| self.decode_wrapper, | ||
| new_state.backend.tp_q_head_num, | ||
| new_state.backend.tp_kv_head_num, | ||
| new_state.backend.head_dim, | ||
| 1, | ||
| q_data_type=new_state.backend.q_data_type, | ||
| kv_data_type=new_state.backend.kv_data_type, | ||
| non_blocking=True, | ||
| new_state.kv_starts_host, | ||
| new_state.kv_seq_lens_host, | ||
| new_state.infer_state.max_kv_seq_len, | ||
| ) | ||
| if new_state.kv_seq_lens_host is not None: | ||
| self.decode_wrapper._skip_plan_key = skip_plan_key |
There was a problem hiding this comment.
The call to super().copy_for_decode_cuda_graph(new_state) was removed, which prevents the captured tensors (kv_starts, kv_indices, etc.) from being updated with the new values from new_state during CUDA graph replay. This will cause the attention mechanism to run with stale or incorrect KV cache indices and starts, leading to incorrect attention outputs or crashes.
Additionally, if new_state.kv_seq_lens_host is None, calling _fast_plan_tensor_core_decode will crash with a TypeError because kv_lens_arr_host is expected to be a list/tensor. We should restore the tensor copying and add a fallback to the standard plan method when host sequence lengths are not available.
def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"):
super().copy_for_decode_cuda_graph(new_state)
if new_state.kv_seq_lens_host is not None:
# FlashInfer tensor-core decode updates its split-kv plan at 128-token
# boundaries for this path. page_size is 1 here, so pages == tokens.
skip_plan_key = tuple((seq_len + 127) // 128 for seq_len in new_state.kv_seq_lens_host.tolist())
if getattr(self.decode_wrapper, "_skip_plan_key", None) == skip_plan_key:
return
_fast_plan_tensor_core_decode(
self.decode_wrapper,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
new_state.kv_starts_host,
new_state.kv_seq_lens_host,
new_state.infer_state.max_kv_seq_len,
)
self.decode_wrapper._skip_plan_key = skip_plan_key
else:
self.decode_wrapper.plan(
new_state.kv_starts,
new_state.kv_indices,
new_state.kv_last_page_len_buffer,
new_state.backend.tp_q_head_num,
new_state.backend.tp_kv_head_num,
new_state.backend.head_dim,
1,
q_data_type=new_state.backend.q_data_type,
kv_data_type=new_state.backend.kv_data_type,
non_blocking=True,
)| if not self.b_seq_len.is_cuda: | ||
| self.b_seq_len_cpu = self.b_seq_len | ||
| self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) |
There was a problem hiding this comment.
If self.b_seq_len is already on CUDA when to_cuda is called, self.b_seq_len_cpu will remain None. This prevents the fast host-based planning path from being used. We can copy the tensor to CPU if it is already on CUDA to ensure self.b_seq_len_cpu is always populated.
if not self.b_seq_len.is_cuda:
self.b_seq_len_cpu = self.b_seq_len
else:
self.b_seq_len_cpu = self.b_seq_len.cpu()
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)2f40fc4 to
3377848
Compare
No description provided.