@@ -61,6 +61,115 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
6161 }
6262}
6363
64+ // Vectorized kernel for permuting the indices and weights. Used for permutation
65+ // of sparse data. Uses vec4 loads for improved memory bandwidth.
66+ template <
67+ bool has_weight,
68+ typename offsets_t ,
69+ typename indices_t ,
70+ typename weights_t >
71+ __global__ __launch_bounds__ (kMaxThreads ) void permute_1D_data_kernel_vec(
72+ int32_t permuted_indices_size,
73+ int32_t permuted_lengths_size,
74+ const indices_t * __restrict__ indices,
75+ const weights_t * __restrict__ weights,
76+ const int32_t * __restrict__ permute,
77+ const offsets_t * __restrict__ input_offsets,
78+ const offsets_t * __restrict__ output_offsets,
79+ indices_t * __restrict__ permuted_indices,
80+ weights_t * __restrict__ permuted_weights) {
81+ // Select vector types based on element size (vec4 for 4× bandwidth)
82+ using indices_vec4_t =
83+ typename std::conditional<sizeof (indices_t ) == 8 , long4 , float4 >::type;
84+ using weights_vec4_t =
85+ typename std::conditional<sizeof (weights_t ) == 8 , long4 , float4 >::type;
86+
87+ const auto b_t_start = blockIdx .x * blockDim .y + threadIdx .y ;
88+ const auto stride = gridDim .x * blockDim .y ;
89+
90+ for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
91+ // Read offsets once - use int32_t for segment_length as it fits in 32 bits
92+ const offsets_t output_start = output_offsets[b_t ];
93+ const offsets_t output_end = (b_t == permuted_lengths_size - 1 )
94+ ? permuted_indices_size
95+ : output_offsets[b_t + 1 ];
96+ const int32_t segment_length =
97+ static_cast <int32_t >(output_end - output_start);
98+ const offsets_t input_start = input_offsets[permute[b_t ]];
99+
100+ // Compute pointers
101+ indices_t * __restrict__ indices_dst_ptr = permuted_indices + output_start;
102+ const indices_t * __restrict__ indices_src_ptr = indices + input_start;
103+ weights_t * __restrict__ weights_dst_ptr =
104+ has_weight ? permuted_weights + output_start : nullptr ;
105+ const weights_t * __restrict__ weights_src_ptr =
106+ has_weight ? weights + input_start : nullptr ;
107+
108+ // Check alignment once per segment
109+ const bool indices_vec4_aligned =
110+ (sizeof (indices_t ) == 4 || sizeof (indices_t ) == 8 ) &&
111+ (reinterpret_cast <uintptr_t >(indices_dst_ptr) &
112+ (alignof (indices_vec4_t ) - 1 )) == 0 &&
113+ (reinterpret_cast <uintptr_t >(indices_src_ptr) &
114+ (alignof (indices_vec4_t ) - 1 )) == 0 ;
115+
116+ const bool weights_vec4_aligned = !has_weight ||
117+ ((reinterpret_cast <uintptr_t >(weights_dst_ptr) &
118+ (alignof (weights_vec4_t ) - 1 )) == 0 &&
119+ (reinterpret_cast <uintptr_t >(weights_src_ptr) &
120+ (alignof (weights_vec4_t ) - 1 )) == 0 );
121+
122+ if (indices_vec4_aligned && weights_vec4_aligned) {
123+ // Vectorized path - process both indices and weights together
124+ const int32_t vec4_count = segment_length / 4 ;
125+ const int32_t remainder = segment_length & 3 ; // segment_length % 4
126+
127+ auto indices_dst = reinterpret_cast <indices_vec4_t *>(indices_dst_ptr);
128+ auto indices_src =
129+ reinterpret_cast <const indices_vec4_t *>(indices_src_ptr);
130+
131+ if (has_weight) {
132+ auto weights_dst = reinterpret_cast <weights_vec4_t *>(weights_dst_ptr);
133+ auto weights_src =
134+ reinterpret_cast <const weights_vec4_t *>(weights_src_ptr);
135+
136+ // copy both indices and weights
137+ #pragma unroll
138+ for (auto i = threadIdx .x ; i < vec4_count; i += blockDim .x ) {
139+ indices_dst[i] = indices_src[i];
140+ weights_dst[i] = weights_src[i];
141+ }
142+ // Handle remainder elements (0-3 elements)
143+ if (threadIdx .x < remainder) {
144+ const auto offset = vec4_count * 4 + threadIdx .x ;
145+ indices_dst_ptr[offset] = indices_src_ptr[offset];
146+ weights_dst_ptr[offset] = weights_src_ptr[offset];
147+ }
148+ } else {
149+ // copy only indices
150+ #pragma unroll
151+ for (auto i = threadIdx .x ; i < vec4_count; i += blockDim .x ) {
152+ indices_dst[i] = indices_src[i];
153+ }
154+
155+ // Handle remainder elements (0-3 elements)
156+ if (threadIdx .x < remainder) {
157+ const auto offset = vec4_count * 4 + threadIdx .x ;
158+ indices_dst_ptr[offset] = indices_src_ptr[offset];
159+ }
160+ }
161+ } else {
162+ // Scalar fallback path
163+ for (auto i = threadIdx .x ; i < segment_length; i += blockDim .x ) {
164+ indices_dst_ptr[i] = indices_src_ptr[i];
165+ if (has_weight) {
166+ weights_dst_ptr[i] = weights_src_ptr[i];
167+ }
168+ }
169+ }
170+ }
171+ }
172+
64173DLL_PUBLIC std::tuple<Tensor, Tensor, std::optional<Tensor>>
65174permute_1D_sparse_data_cuda (
66175 const Tensor& permute,
@@ -124,17 +233,17 @@ permute_1D_sparse_data_cuda(
124233 permuted_indices_size = output_offsets[-1 ].item <int64_t >();
125234 }
126235
127- constexpr int32_t BT_blocks = 32 ;
128- dim3 threads_2 (32 , BT_blocks);
236+ constexpr int32_t BT_blocks = 16 ;
237+ dim3 threads_2 (64 , BT_blocks);
129238 const auto blocks_2 =
130239 cuda_calc_xblock_count (permuted_lengths_size, BT_blocks);
131240 permuted_indices = at::empty (permuted_indices_size, indices.options ());
132241
133242 AT_DISPATCH_INDEX_TYPES (
134- input_offsets.scalar_type (), " permute_1D_data_kernel_1 " , [&] {
243+ input_offsets.scalar_type (), " permute_1D_data_kernel_vec_1 " , [&] {
135244 using offsets_t = index_t ;
136245 FBGEMM_DISPATCH_ALL_TYPES (
137- indices.scalar_type (), " permute_1D_data_kernel_2 " , [&] {
246+ indices.scalar_type (), " permute_1D_data_kernel_vec_2 " , [&] {
138247 using indices_t = scalar_t ;
139248 if (weights.has_value ()) {
140249 const Tensor weights_value = weights.value ();
@@ -143,11 +252,11 @@ permute_1D_sparse_data_cuda(
143252 at::empty (permuted_indices_size, weights_value.options ());
144253 FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE (
145254 weights_value.scalar_type (),
146- " permute_1D_data_kernel_3 " ,
255+ " permute_1D_data_kernel_vec_3 " ,
147256 [&] {
148257 using weights_t = scalar_t ;
149258 FBGEMM_LAUNCH_KERNEL (
150- (permute_1D_data_kernel <
259+ (permute_1D_data_kernel_vec <
151260 true ,
152261 offsets_t ,
153262 indices_t ,
@@ -168,7 +277,7 @@ permute_1D_sparse_data_cuda(
168277 }); // for each weights_t
169278 } else {
170279 FBGEMM_LAUNCH_KERNEL (
171- (permute_1D_data_kernel <
280+ (permute_1D_data_kernel_vec <
172281 false ,
173282 offsets_t ,
174283 indices_t ,
0 commit comments