Skip to content

Commit 605c525

Browse files
kausvfacebook-github-bot
authored andcommitted
Support 2D weights permute for strided keys
Summary: X-link: facebookresearch/FBGEMM#2144 Support 2D weights in 1D length permute tensor kernel. This kernel is invoked on variable strides per rank https://www.internalfb.com/code/fbsource/[5e1d1c3734c75d0664ba817ee05ce9a91a1f02e4]/fbcode/torchrec/sparse/jagged_tensor.py?lines=3111-3134 2D weights are needed for write dist where the weights are actually embedding values. Differential Revision: D87261479
1 parent bc6d968 commit 605c525

File tree

2 files changed

+131
-8
lines changed

2 files changed

+131
-8
lines changed

fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
4040
const offsets_t* __restrict__ input_offsets,
4141
const offsets_t* __restrict__ output_offsets,
4242
indices_t* __restrict__ permuted_indices,
43-
weights_t* __restrict__ permuted_weights) {
43+
weights_t* __restrict__ permuted_weights,
44+
int32_t weights_columns) {
4445
auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
4546
const auto stride = gridDim.x * blockDim.y;
4647
for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
@@ -55,7 +56,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
5556
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
5657
permuted_indices[output_start + i] = indices[input_start + i];
5758
if (has_weight) {
58-
permuted_weights[output_start + i] = weights[input_start + i];
59+
for (int col = 0; col < weights_columns; ++col) {
60+
permuted_weights[(output_start + i) * weights_columns + col] =
61+
weights[(input_start + i) * weights_columns + col];
62+
}
5963
}
6064
}
6165
}
@@ -138,9 +142,14 @@ permute_1D_sparse_data_cuda(
138142
using indices_t = scalar_t;
139143
if (weights.has_value()) {
140144
const Tensor weights_value = weights.value();
145+
int32_t weights_columns = 1;
146+
if (weights_value.dense_dim() > 1) {
147+
weights_columns = weights_value.size(1);
148+
}
141149
const auto weights_value_contig = weights_value.contiguous();
142-
permuted_weights =
143-
at::empty(permuted_indices_size, weights_value.options());
150+
permuted_weights = at::empty(
151+
{permuted_indices_size, weights_columns},
152+
weights_value.options());
144153
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
145154
weights_value.scalar_type(),
146155
"permute_1D_data_kernel_3",
@@ -164,7 +173,8 @@ permute_1D_sparse_data_cuda(
164173
input_offsets.data_ptr<offsets_t>(),
165174
output_offsets.data_ptr<offsets_t>(),
166175
permuted_indices.data_ptr<indices_t>(),
167-
permuted_weights.data_ptr<weights_t>());
176+
permuted_weights.data_ptr<weights_t>(),
177+
weights_columns);
168178
}); // for each weights_t
169179
} else {
170180
FBGEMM_LAUNCH_KERNEL(
@@ -185,7 +195,8 @@ permute_1D_sparse_data_cuda(
185195
input_offsets.data_ptr<offsets_t>(),
186196
output_offsets.data_ptr<offsets_t>(),
187197
permuted_indices.data_ptr<indices_t>(),
188-
nullptr);
198+
nullptr,
199+
0);
189200
}
190201
}); // for each indices_t
191202
}); // for each offsets_t

fbgemm_gpu/test/sparse/permute_sparse_features_test.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import hypothesis.strategies as st
1717
import torch
1818

19-
from hypothesis import given, settings
19+
from hypothesis import example, given, Phase, settings, Verbosity
2020

2121
from .common import extend_test_class, open_source, permute_indices_ref_
2222

@@ -73,7 +73,12 @@ def permute_sparse_features_ref_(
7373
)
7474
@settings(max_examples=20, deadline=None)
7575
def test_permute_sparse_features(
76-
self, B: int, T: int, L: int, long_index: bool, has_weight: bool
76+
self,
77+
B: int,
78+
T: int,
79+
L: int,
80+
long_index: bool,
81+
has_weight: bool,
7782
) -> None:
7883
index_dtype = torch.int64 if long_index else torch.int32
7984
lengths = torch.randint(low=1, high=L, size=(T, B)).type(index_dtype)
@@ -193,6 +198,112 @@ def test_permute_sparse_features_with_repeats(
193198
assert permuted_weights_cpu is None
194199

195200

201+
class Permute1DSparseFeaturesTest(unittest.TestCase):
202+
@unittest.skipIf(*gpu_unavailable)
203+
@given(
204+
B=st.integers(min_value=1, max_value=20),
205+
T=st.integers(min_value=1, max_value=20),
206+
L=st.integers(min_value=2, max_value=20),
207+
long_index=st.booleans(),
208+
has_weight=st.booleans(),
209+
weight_columns=st.integers(min_value=1, max_value=20),
210+
)
211+
@settings(
212+
max_examples=20,
213+
deadline=None,
214+
)
215+
def test_permute_1D_sparse_data(
216+
self,
217+
B: int,
218+
T: int,
219+
L: int,
220+
long_index: bool,
221+
has_weight: bool,
222+
weight_columns: int,
223+
) -> None:
224+
index_dtype = torch.int64 if long_index else torch.int32
225+
lengths = torch.randint(
226+
low=1,
227+
high=L,
228+
size=(T,), # 1D
229+
device=torch.accelerator.current_accelerator(),
230+
).type(index_dtype)
231+
weights = (
232+
torch.rand(
233+
int(lengths.sum().item()),
234+
weight_columns,
235+
device=torch.accelerator.current_accelerator(),
236+
).float()
237+
if has_weight
238+
else None
239+
)
240+
indices = torch.randint(
241+
low=1,
242+
high=int(1e5),
243+
size=cast(tuple[int, ...], (lengths.sum().item(),)),
244+
device=torch.accelerator.current_accelerator(),
245+
).type(index_dtype)
246+
permute_list = list(range(T))
247+
random.shuffle(permute_list)
248+
permute = torch.IntTensor(permute_list).to(
249+
device=torch.accelerator.current_accelerator()
250+
)
251+
(
252+
lengths_actual,
253+
values_actual,
254+
weights_actual,
255+
) = torch.ops.fbgemm.permute_1D_sparse_data(
256+
permute, lengths, indices, weights, indices.numel()
257+
)
258+
259+
self.assertTrue(
260+
torch.equal(
261+
lengths_actual, torch.index_select(lengths, dim=0, index=permute)
262+
)
263+
)
264+
permuted_cumulated_index = 0
265+
cumulative_indices = torch.cumsum(
266+
torch.cat(
267+
(
268+
torch.zeros((1,), dtype=index_dtype, device=lengths.device),
269+
lengths,
270+
)
271+
),
272+
dim=0,
273+
)
274+
275+
for i in range(T):
276+
permuted_index = permute[i]
277+
self.assertTrue(
278+
torch.equal(
279+
values_actual[
280+
permuted_cumulated_index : permuted_cumulated_index
281+
+ lengths[permuted_index]
282+
],
283+
indices[
284+
cumulative_indices[permuted_index] : lengths[permuted_index]
285+
+ cumulative_indices[permuted_index]
286+
],
287+
)
288+
)
289+
if has_weight and weights is not None:
290+
self.assertTrue(
291+
torch.equal(
292+
weights_actual[
293+
permuted_cumulated_index : permuted_cumulated_index
294+
+ lengths[permuted_index]
295+
],
296+
weights[
297+
cumulative_indices[permuted_index] : lengths[permuted_index]
298+
+ cumulative_indices[permuted_index]
299+
],
300+
)
301+
)
302+
else:
303+
assert weights_actual is None
304+
permuted_cumulated_index += lengths[permuted_index]
305+
306+
196307
class Permute2DSparseFeaturesTest(unittest.TestCase):
197308
@unittest.skipIf(*gpu_unavailable)
198309
def test_permute_2D_sparse_data(self) -> None:
@@ -234,6 +345,7 @@ def test_permute_2D_sparse_data(self) -> None:
234345

235346

236347
extend_test_class(PermuteSparseFeaturesTest)
348+
# extend_test_class(Permute1DSparseFeaturesTest)
237349

238350
if __name__ == "__main__":
239351
unittest.main()

0 commit comments

Comments
 (0)