Skip to content

Commit e1fc43e

Browse files
kausvmeta-codesync[bot]
authored andcommitted
Support 2D weights permute for strided keys (#5145)
Summary: Pull Request resolved: #5145 X-link: https://github.com/facebookresearch/FBGEMM/pull/2144 Support 2D weights in 1D length permute tensor kernel. This kernel is invoked on variable strides per rank. Similar to D80466458 https://github.com/meta-pytorch/torchrec/blob/main/torchrec/sparse/jagged_tensor.py#L3111-L3134 2D weights are needed for write dist where the weights are actually embedding values. Reviewed By: spcyppt, q10 Differential Revision: D87261479 fbshipit-source-id: ae56ff53c07f99f167abfec4b472aa57bd989ce8
1 parent 8189ad4 commit e1fc43e

File tree

2 files changed

+157
-7
lines changed

2 files changed

+157
-7
lines changed

fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
4040
const offsets_t* __restrict__ input_offsets,
4141
const offsets_t* __restrict__ output_offsets,
4242
indices_t* __restrict__ permuted_indices,
43-
weights_t* __restrict__ permuted_weights) {
43+
weights_t* __restrict__ permuted_weights,
44+
int32_t weights_columns) {
4445
auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
4546
const auto stride = gridDim.x * blockDim.y;
4647
for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
@@ -55,7 +56,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
5556
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
5657
permuted_indices[output_start + i] = indices[input_start + i];
5758
if (has_weight) {
58-
permuted_weights[output_start + i] = weights[input_start + i];
59+
for (int col = 0; col < weights_columns; ++col) {
60+
permuted_weights[(output_start + i) * weights_columns + col] =
61+
weights[(input_start + i) * weights_columns + col];
62+
}
5963
}
6064
}
6165
}
@@ -139,8 +143,16 @@ permute_1D_sparse_data_cuda(
139143
if (weights.has_value()) {
140144
const Tensor weights_value = weights.value();
141145
const auto weights_value_contig = weights_value.contiguous();
142-
permuted_weights =
143-
at::empty(permuted_indices_size, weights_value.options());
146+
int32_t weights_columns = 1;
147+
if (weights_value.dense_dim() > 1) {
148+
weights_columns = weights_value.size(1);
149+
permuted_weights = at::empty(
150+
{permuted_indices_size, weights_columns},
151+
weights_value.options());
152+
} else {
153+
permuted_weights =
154+
at::empty(permuted_indices_size, weights_value.options());
155+
}
144156
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
145157
weights_value.scalar_type(),
146158
"permute_1D_data_kernel_3",
@@ -164,7 +176,8 @@ permute_1D_sparse_data_cuda(
164176
input_offsets.data_ptr<offsets_t>(),
165177
output_offsets.data_ptr<offsets_t>(),
166178
permuted_indices.data_ptr<indices_t>(),
167-
permuted_weights.data_ptr<weights_t>());
179+
permuted_weights.data_ptr<weights_t>(),
180+
weights_columns);
168181
}); // for each weights_t
169182
} else {
170183
FBGEMM_LAUNCH_KERNEL(
@@ -185,7 +198,8 @@ permute_1D_sparse_data_cuda(
185198
input_offsets.data_ptr<offsets_t>(),
186199
output_offsets.data_ptr<offsets_t>(),
187200
permuted_indices.data_ptr<indices_t>(),
188-
nullptr);
201+
nullptr,
202+
0);
189203
}
190204
}); // for each indices_t
191205
}); // for each offsets_t

fbgemm_gpu/test/sparse/permute_sparse_features_test.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def permute_sparse_features_ref_(
7373
)
7474
@settings(max_examples=20, deadline=None)
7575
def test_permute_sparse_features(
76-
self, B: int, T: int, L: int, long_index: bool, has_weight: bool
76+
self,
77+
B: int,
78+
T: int,
79+
L: int,
80+
long_index: bool,
81+
has_weight: bool,
7782
) -> None:
7883
index_dtype = torch.int64 if long_index else torch.int32
7984
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
@@ -193,6 +198,137 @@ def test_permute_sparse_features_with_repeats(
193198
assert permuted_weights_cpu is None
194199

195200

201+
class Permute1DSparseFeaturesTest(unittest.TestCase):
202+
@unittest.skipIf(*gpu_unavailable)
203+
@given(
204+
T=st.integers(min_value=1, max_value=20),
205+
L=st.integers(min_value=2, max_value=20),
206+
long_index=st.booleans(),
207+
has_weight=st.booleans(),
208+
weight_columns=st.integers(min_value=1, max_value=20),
209+
)
210+
@settings(
211+
max_examples=20,
212+
deadline=None,
213+
)
214+
def test_permute_1D_sparse_data(
215+
self,
216+
T: int,
217+
L: int,
218+
long_index: bool,
219+
has_weight: bool,
220+
weight_columns: int,
221+
) -> None:
222+
# Setup: Choose index data type based on test parameter
223+
index_dtype = torch.int64 if long_index else torch.int32
224+
225+
# Create 1D lengths tensor representing sparse feature counts for T features
226+
lengths = torch.randint(
227+
low=1,
228+
high=L,
229+
size=(T,), # 1D tensor with T elements
230+
device=torch.accelerator.current_accelerator(),
231+
).type(index_dtype)
232+
233+
# Create optional 2D weights tensor with dimensions [total_indices, weight_columns]
234+
weights = (
235+
torch.rand(
236+
int(lengths.sum().item()),
237+
weight_columns,
238+
device=torch.accelerator.current_accelerator(),
239+
).float()
240+
if has_weight
241+
else None
242+
)
243+
244+
# Create indices tensor containing sparse feature indices
245+
indices = torch.randint(
246+
low=1,
247+
high=int(1e5),
248+
size=cast(tuple[int, ...], (lengths.sum().item(),)),
249+
device=torch.accelerator.current_accelerator(),
250+
).type(index_dtype)
251+
252+
# Create random permutation for shuffling features
253+
permute_list = list(range(T))
254+
random.shuffle(permute_list)
255+
permute = torch.IntTensor(permute_list).to(
256+
device=torch.accelerator.current_accelerator()
257+
)
258+
# Execute: Call the permute_1D_sparse_data operation
259+
(
260+
lengths_actual,
261+
values_actual,
262+
weights_actual,
263+
) = torch.ops.fbgemm.permute_1D_sparse_data(
264+
permute, lengths, indices, weights, indices.numel()
265+
)
266+
267+
# Assert: Verify that the lengths were correctly permuted
268+
# The permuted lengths should match the original lengths indexed by the permutation
269+
self.assertTrue(
270+
torch.equal(
271+
lengths_actual, torch.index_select(lengths, dim=0, index=permute)
272+
)
273+
)
274+
275+
# Track the current position in the permuted output for validation
276+
permuted_cumulated_index = 0
277+
278+
# Compute cumulative offsets to locate each feature's data in the original arrays
279+
# Prepend a zero to get offsets: [0, lengths[0], lengths[0]+lengths[1], ...]
280+
cumulative_indices = torch.cumsum(
281+
torch.cat(
282+
(
283+
torch.zeros((1,), dtype=index_dtype, device=lengths.device),
284+
lengths,
285+
)
286+
),
287+
dim=0,
288+
)
289+
290+
# Verify each feature's data was correctly permuted
291+
for i in range(T):
292+
# Get the original feature index that should appear at position i in the permuted output
293+
permuted_index = permute[i]
294+
295+
# Assert: Verify that the indices for this feature were correctly copied
296+
# Compare the segment in the permuted output against the original segment
297+
self.assertTrue(
298+
torch.equal(
299+
values_actual[
300+
permuted_cumulated_index : permuted_cumulated_index
301+
+ lengths[permuted_index]
302+
],
303+
indices[
304+
cumulative_indices[permuted_index] : lengths[permuted_index]
305+
+ cumulative_indices[permuted_index]
306+
],
307+
)
308+
)
309+
310+
# Assert: If weights are present, verify they were also correctly permuted
311+
if has_weight and weights is not None:
312+
self.assertTrue(
313+
torch.equal(
314+
weights_actual[
315+
permuted_cumulated_index : permuted_cumulated_index
316+
+ lengths[permuted_index]
317+
],
318+
weights[
319+
cumulative_indices[permuted_index] : lengths[permuted_index]
320+
+ cumulative_indices[permuted_index]
321+
],
322+
)
323+
)
324+
else:
325+
# Assert: If no weights were provided, ensure the output also has no weights
326+
assert weights_actual is None
327+
328+
# Move to the next segment in the permuted output
329+
permuted_cumulated_index += lengths[permuted_index]
330+
331+
196332
class Permute2DSparseFeaturesTest(unittest.TestCase):
197333
@unittest.skipIf(*gpu_unavailable)
198334
def test_permute_2D_sparse_data(self) -> None:

0 commit comments

Comments
 (0)