From 64df81441650782d839eb2b2235bc8096ef4ad6b Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Tue, 18 Nov 2025 11:53:13 -0800 Subject: [PATCH] Support 2D weights permute for strided keys (#5145) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2144 Support 2D weights in 1D length permute tensor kernel. This kernel is invoked on variable strides per rank. Similar to D80466458 https://github.com/meta-pytorch/torchrec/blob/main/torchrec/sparse/jagged_tensor.py#L3111-L3134 2D weights are needed for write dist where the weights are actually embedding values. Reviewed By: spcyppt, q10 Differential Revision: D87261479 --- .../src/sparse_ops/sparse_permute_1d.cu | 26 +++- .../sparse/permute_sparse_features_test.py | 138 +++++++++++++++++- 2 files changed, 157 insertions(+), 7 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu index 6e7ca51614..99ee97447f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu @@ -40,7 +40,8 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel( const offsets_t* __restrict__ input_offsets, const offsets_t* __restrict__ output_offsets, indices_t* __restrict__ permuted_indices, - weights_t* __restrict__ permuted_weights) { + weights_t* __restrict__ permuted_weights, + int32_t weights_columns) { auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y; const auto stride = gridDim.x * blockDim.y; for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) { @@ -55,7 +56,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel( for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) { permuted_indices[output_start + i] = indices[input_start + i]; if (has_weight) { - permuted_weights[output_start + i] = weights[input_start + i]; + for (int col = 0; col < weights_columns; ++col) { + permuted_weights[(output_start + i) * weights_columns + col] = + weights[(input_start + i) * weights_columns + col]; + } } } } @@ -139,8 +143,16 @@ permute_1D_sparse_data_cuda( if (weights.has_value()) { const Tensor weights_value = weights.value(); const auto weights_value_contig = weights_value.contiguous(); - permuted_weights = - at::empty(permuted_indices_size, weights_value.options()); + int32_t weights_columns = 1; + if (weights_value.dense_dim() > 1) { + weights_columns = weights_value.size(1); + permuted_weights = at::empty( + {permuted_indices_size, weights_columns}, + weights_value.options()); + } else { + permuted_weights = + at::empty(permuted_indices_size, weights_value.options()); + } FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( weights_value.scalar_type(), "permute_1D_data_kernel_3", @@ -164,7 +176,8 @@ permute_1D_sparse_data_cuda( input_offsets.data_ptr(), output_offsets.data_ptr(), permuted_indices.data_ptr(), - permuted_weights.data_ptr()); + permuted_weights.data_ptr(), + weights_columns); }); // for each weights_t } else { FBGEMM_LAUNCH_KERNEL( @@ -185,7 +198,8 @@ permute_1D_sparse_data_cuda( input_offsets.data_ptr(), output_offsets.data_ptr(), permuted_indices.data_ptr(), - nullptr); + nullptr, + 0); } }); // for each indices_t }); // for each offsets_t diff --git a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py index 38c9c9581d..75010d052e 100644 --- a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py +++ b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py @@ -73,7 +73,12 @@ def permute_sparse_features_ref_( ) @settings(max_examples=20, deadline=None) def test_permute_sparse_features( - self, B: int, T: int, L: int, long_index: bool, has_weight: bool + self, + B: int, + T: int, + L: int, + long_index: bool, + has_weight: bool, ) -> None: index_dtype = torch.int64 if long_index else torch.int32 lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype) @@ -193,6 +198,137 @@ def test_permute_sparse_features_with_repeats( assert permuted_weights_cpu is None +class Permute1DSparseFeaturesTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=20), + L=st.integers(min_value=2, max_value=20), + long_index=st.booleans(), + has_weight=st.booleans(), + weight_columns=st.integers(min_value=1, max_value=20), + ) + @settings( + max_examples=20, + deadline=None, + ) + def test_permute_1D_sparse_data( + self, + T: int, + L: int, + long_index: bool, + has_weight: bool, + weight_columns: int, + ) -> None: + # Setup: Choose index data type based on test parameter + index_dtype = torch.int64 if long_index else torch.int32 + + # Create 1D lengths tensor representing sparse feature counts for T features + lengths = torch.randint( + low=1, + high=L, + size=(T,), # 1D tensor with T elements + device=torch.accelerator.current_accelerator(), + ).type(index_dtype) + + # Create optional 2D weights tensor with dimensions [total_indices, weight_columns] + weights = ( + torch.rand( + int(lengths.sum().item()), + weight_columns, + device=torch.accelerator.current_accelerator(), + ).float() + if has_weight + else None + ) + + # Create indices tensor containing sparse feature indices + indices = torch.randint( + low=1, + high=int(1e5), + size=cast(tuple[int, ...], (lengths.sum().item(),)), + device=torch.accelerator.current_accelerator(), + ).type(index_dtype) + + # Create random permutation for shuffling features + permute_list = list(range(T)) + random.shuffle(permute_list) + permute = torch.IntTensor(permute_list).to( + device=torch.accelerator.current_accelerator() + ) + # Execute: Call the permute_1D_sparse_data operation + ( + lengths_actual, + values_actual, + weights_actual, + ) = torch.ops.fbgemm.permute_1D_sparse_data( + permute, lengths, indices, weights, indices.numel() + ) + + # Assert: Verify that the lengths were correctly permuted + # The permuted lengths should match the original lengths indexed by the permutation + self.assertTrue( + torch.equal( + lengths_actual, torch.index_select(lengths, dim=0, index=permute) + ) + ) + + # Track the current position in the permuted output for validation + permuted_cumulated_index = 0 + + # Compute cumulative offsets to locate each feature's data in the original arrays + # Prepend a zero to get offsets: [0, lengths[0], lengths[0]+lengths[1], ...] + cumulative_indices = torch.cumsum( + torch.cat( + ( + torch.zeros((1,), dtype=index_dtype, device=lengths.device), + lengths, + ) + ), + dim=0, + ) + + # Verify each feature's data was correctly permuted + for i in range(T): + # Get the original feature index that should appear at position i in the permuted output + permuted_index = permute[i] + + # Assert: Verify that the indices for this feature were correctly copied + # Compare the segment in the permuted output against the original segment + self.assertTrue( + torch.equal( + values_actual[ + permuted_cumulated_index : permuted_cumulated_index + + lengths[permuted_index] + ], + indices[ + cumulative_indices[permuted_index] : lengths[permuted_index] + + cumulative_indices[permuted_index] + ], + ) + ) + + # Assert: If weights are present, verify they were also correctly permuted + if has_weight and weights is not None: + self.assertTrue( + torch.equal( + weights_actual[ + permuted_cumulated_index : permuted_cumulated_index + + lengths[permuted_index] + ], + weights[ + cumulative_indices[permuted_index] : lengths[permuted_index] + + cumulative_indices[permuted_index] + ], + ) + ) + else: + # Assert: If no weights were provided, ensure the output also has no weights + assert weights_actual is None + + # Move to the next segment in the permuted output + permuted_cumulated_index += lengths[permuted_index] + + class Permute2DSparseFeaturesTest(unittest.TestCase): @unittest.skipIf(*gpu_unavailable) def test_permute_2D_sparse_data(self) -> None: