Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
indices_t* __restrict__ permuted_indices,
weights_t* __restrict__ permuted_weights) {
weights_t* __restrict__ permuted_weights,
int32_t weights_columns) {
auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
const auto stride = gridDim.x * blockDim.y;
for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
Expand All @@ -55,7 +56,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
for (int col = 0; col < weights_columns; ++col) {
permuted_weights[(output_start + i) * weights_columns + col] =
weights[(input_start + i) * weights_columns + col];
}
}
}
}
Expand Down Expand Up @@ -139,8 +143,16 @@ permute_1D_sparse_data_cuda(
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
int32_t weights_columns = 1;
if (weights_value.dense_dim() > 1) {
weights_columns = weights_value.size(1);
permuted_weights = at::empty(
{permuted_indices_size, weights_columns},
weights_value.options());
} else {
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
}
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
weights_value.scalar_type(),
"permute_1D_data_kernel_3",
Expand All @@ -164,7 +176,8 @@ permute_1D_sparse_data_cuda(
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
permuted_weights.data_ptr<weights_t>(),
weights_columns);
}); // for each weights_t
} else {
FBGEMM_LAUNCH_KERNEL(
Expand All @@ -185,7 +198,8 @@ permute_1D_sparse_data_cuda(
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
nullptr,
0);
}
}); // for each indices_t
}); // for each offsets_t
Expand Down
138 changes: 137 additions & 1 deletion fbgemm_gpu/test/sparse/permute_sparse_features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def permute_sparse_features_ref_(
)
@settings(max_examples=20, deadline=None)
def test_permute_sparse_features(
self, B: int, T: int, L: int, long_index: bool, has_weight: bool
self,
B: int,
T: int,
L: int,
long_index: bool,
has_weight: bool,
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
Expand Down Expand Up @@ -193,6 +198,137 @@ def test_permute_sparse_features_with_repeats(
assert permuted_weights_cpu is None


class Permute1DSparseFeaturesTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
@given(
T=st.integers(min_value=1, max_value=20),
L=st.integers(min_value=2, max_value=20),
long_index=st.booleans(),
has_weight=st.booleans(),
weight_columns=st.integers(min_value=1, max_value=20),
)
@settings(
max_examples=20,
deadline=None,
)
def test_permute_1D_sparse_data(
self,
T: int,
L: int,
long_index: bool,
has_weight: bool,
weight_columns: int,
) -> None:
# Setup: Choose index data type based on test parameter
index_dtype = torch.int64 if long_index else torch.int32

# Create 1D lengths tensor representing sparse feature counts for T features
lengths = torch.randint(
low=1,
high=L,
size=(T,), # 1D tensor with T elements
device=torch.accelerator.current_accelerator(),
).type(index_dtype)

# Create optional 2D weights tensor with dimensions [total_indices, weight_columns]
weights = (
torch.rand(
int(lengths.sum().item()),
weight_columns,
device=torch.accelerator.current_accelerator(),
).float()
if has_weight
else None
)

# Create indices tensor containing sparse feature indices
indices = torch.randint(
low=1,
high=int(1e5),
size=cast(tuple[int, ...], (lengths.sum().item(),)),
device=torch.accelerator.current_accelerator(),
).type(index_dtype)

# Create random permutation for shuffling features
permute_list = list(range(T))
random.shuffle(permute_list)
permute = torch.IntTensor(permute_list).to(
device=torch.accelerator.current_accelerator()
)
# Execute: Call the permute_1D_sparse_data operation
(
lengths_actual,
values_actual,
weights_actual,
) = torch.ops.fbgemm.permute_1D_sparse_data(
permute, lengths, indices, weights, indices.numel()
)

# Assert: Verify that the lengths were correctly permuted
# The permuted lengths should match the original lengths indexed by the permutation
self.assertTrue(
torch.equal(
lengths_actual, torch.index_select(lengths, dim=0, index=permute)
)
)

# Track the current position in the permuted output for validation
permuted_cumulated_index = 0

# Compute cumulative offsets to locate each feature's data in the original arrays
# Prepend a zero to get offsets: [0, lengths[0], lengths[0]+lengths[1], ...]
cumulative_indices = torch.cumsum(
torch.cat(
(
torch.zeros((1,), dtype=index_dtype, device=lengths.device),
lengths,
)
),
dim=0,
)

# Verify each feature's data was correctly permuted
for i in range(T):
# Get the original feature index that should appear at position i in the permuted output
permuted_index = permute[i]

# Assert: Verify that the indices for this feature were correctly copied
# Compare the segment in the permuted output against the original segment
self.assertTrue(
torch.equal(
values_actual[
permuted_cumulated_index : permuted_cumulated_index
+ lengths[permuted_index]
],
indices[
cumulative_indices[permuted_index] : lengths[permuted_index]
+ cumulative_indices[permuted_index]
],
)
)

# Assert: If weights are present, verify they were also correctly permuted
if has_weight and weights is not None:
self.assertTrue(
torch.equal(
weights_actual[
permuted_cumulated_index : permuted_cumulated_index
+ lengths[permuted_index]
],
weights[
cumulative_indices[permuted_index] : lengths[permuted_index]
+ cumulative_indices[permuted_index]
],
)
)
else:
# Assert: If no weights were provided, ensure the output also has no weights
assert weights_actual is None

# Move to the next segment in the permuted output
permuted_cumulated_index += lengths[permuted_index]


class Permute2DSparseFeaturesTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
def test_permute_2D_sparse_data(self) -> None:
Expand Down
Loading