@@ -2296,6 +2296,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
22962296 const int bid = blockIdx .x , hid = blockIdx .y ;
22972297 __shared__ T smem[bdy * HEAD_DIM];
22982298 __shared__ float md_smem[bdy * 2 ];
2299+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2300+ cudaGridDependencySynchronize ();
2301+ #endif
22992302 const int start_token_idx = cu_seqlens_q[bid];
23002303 const int seq_len_q = seq_lens_q[bid];
23012304 if (seq_len_q == 0 ) return ;
@@ -2332,6 +2335,10 @@ __global__ void merge_multi_chunks_decoder_kernel(
23322335 } else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
23332336 m = -3 .0e+30f ;
23342337 }
2338+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2339+ cudaGridDependencySynchronize ();
2340+ #endif
2341+
23352342#pragma unroll 2
23362343 for (int i = ty; i < num_chunks_this_seq; i += bdy) {
23372344 uint32_t offset = (bid * num_chunks + i) * num_heads + hid;
@@ -2397,6 +2404,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
23972404 out_vec,
23982405 &out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]);
23992406 }
2407+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2408+ cudaTriggerProgrammaticLaunchCompletion ();
2409+ #endif
24002410}
24012411
24022412template <typename T,
@@ -2433,6 +2443,9 @@ __global__ void merge_multi_chunks_v2_kernel(
24332443 const int hid = blockIdx .y ;
24342444 __shared__ T smem[bdy * HEAD_DIM];
24352445 __shared__ float md_smem[bdy * 2 ];
2446+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2447+ cudaGridDependencySynchronize ();
2448+ #endif
24362449 for (int qid = blockIdx .x ; qid < token_num; qid += gridDim .x ) {
24372450 const uint32_t bid = batch_id_per_token[qid];
24382451 if (bid == -1 ) {
@@ -2569,4 +2582,7 @@ __global__ void merge_multi_chunks_v2_kernel(
25692582 }
25702583 __syncthreads ();
25712584 }
2585+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2586+ cudaTriggerProgrammaticLaunchCompletion ();
2587+ #endif
25722588}
0 commit comments