Skip to content

Commit 9585a54

Browse files
royren622facebook-github-bot
authored andcommitted
accelerate permute_1D_data_kernel
Summary: accelerate permute_1D_data_kernel using vectorization by 3x from 18.20ms ->6.1616ms (see benchmark below) Reviewed By: arsatis Differential Revision: D86492866
1 parent 922ca39 commit 9585a54

File tree

3 files changed

+379
-13
lines changed

3 files changed

+379
-13
lines changed

fbgemm_gpu/bench/sparse_ops_benchmark.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,111 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
993993
ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)
994994

995995

996+
@cli.command()
997+
@click.option("--num-segments", default=100)
998+
@click.option("--max-segment-length", default=10000)
999+
@click.option(
1000+
"--index-dtype", type=click.Choice(["int", "int64", "float"]), default="float"
1001+
)
1002+
@click.option("--has-weight", is_flag=True, default=False)
1003+
@click.option("--device", type=click.Choice(["cpu", "cuda"]), default="cuda")
1004+
def permute_1d_sparse_data_bench(
1005+
num_segments: int,
1006+
max_segment_length: int,
1007+
index_dtype: str,
1008+
has_weight: bool,
1009+
device: str,
1010+
) -> None:
1011+
"""Benchmark permute_1D_sparse_data operator.
1012+
1013+
This operator permutes sparse features (indices and optional weights) according
1014+
to a given permutation. Commonly used in recommendation systems to reorder
1015+
embedding tables.
1016+
"""
1017+
if index_dtype == "int":
1018+
index_dtype = torch.int32
1019+
elif index_dtype == "int64":
1020+
index_dtype = torch.int64
1021+
elif index_dtype == "float":
1022+
index_dtype = torch.float32
1023+
else:
1024+
raise RuntimeError(f"Does not support data type {index_dtype}")
1025+
1026+
# Generate variable-length segments to test vectorization
1027+
emb_dim = 256
1028+
lengths = (
1029+
torch.randint(
1030+
low=max_segment_length // 2,
1031+
high=max_segment_length,
1032+
size=(num_segments,),
1033+
dtype=torch.int32,
1034+
device=device,
1035+
)
1036+
* emb_dim
1037+
)
1038+
total_indices = int(lengths.sum().item())
1039+
# Generate indices
1040+
if index_dtype == torch.float32:
1041+
indices = torch.rand(total_indices, dtype=index_dtype, device=device)
1042+
else:
1043+
indices = torch.randint(
1044+
low=0,
1045+
high=2**31 - 1,
1046+
size=(total_indices,),
1047+
dtype=index_dtype,
1048+
device=device,
1049+
)
1050+
1051+
# Generate optional weights
1052+
weights = (
1053+
torch.rand(total_indices, dtype=torch.float32, device=device)
1054+
if has_weight
1055+
else None
1056+
)
1057+
# Generate random permutation
1058+
permute_list = list(range(num_segments))
1059+
random.shuffle(permute_list)
1060+
permute = torch.IntTensor(permute_list).to(device)
1061+
# Benchmark the operation
1062+
time, (permuted_lengths, permuted_indices, permuted_weights) = (
1063+
benchmark_torch_function(
1064+
torch.ops.fbgemm.permute_1D_sparse_data,
1065+
(permute, lengths, indices, weights, None),
1066+
num_warmups=100,
1067+
iters=1000,
1068+
)
1069+
)
1070+
1071+
# Calculate memory bandwidth
1072+
num_bytes = (
1073+
permute.numel() * permute.element_size()
1074+
+ lengths.numel() * lengths.element_size()
1075+
+ indices.numel() * indices.element_size()
1076+
+ permuted_lengths.numel() * permuted_lengths.element_size()
1077+
+ permuted_indices.numel() * permuted_indices.element_size()
1078+
)
1079+
if has_weight:
1080+
assert weights is not None
1081+
assert permuted_weights is not None
1082+
num_bytes += (
1083+
weights.numel() * weights.element_size() # pyre-ignore [16]
1084+
+ permuted_weights.numel() * permuted_weights.element_size()
1085+
)
1086+
1087+
logging.info(
1088+
f"permute_1D_sparse_data_bench ("
1089+
f"num_segments={num_segments}, "
1090+
f"max_segment_length={max_segment_length}, "
1091+
f"total_indices={total_indices}, "
1092+
f"dtype={index_dtype}, "
1093+
f"with_weights={has_weight}, "
1094+
f"device={device})"
1095+
)
1096+
logging.info(
1097+
f"fbgemm_gpu time: {time * 1000:.5f} ms ({num_bytes / time / 1e9:.5f} GB/s)"
1098+
)
1099+
1100+
9961101
@cli.command()
9971102
@click.option("--row-size", default=2560000)
9981103
@click.option("--batch-size", default=2048)

fbgemm_gpu/src/sparse_ops/sparse_permute_1d.cu

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,115 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel(
6161
}
6262
}
6363

64+
// Vectorized kernel for permuting the indices and weights. Used for permutation
65+
// of sparse data. Uses vec4 loads for improved memory bandwidth.
66+
template <
67+
bool has_weight,
68+
typename offsets_t,
69+
typename indices_t,
70+
typename weights_t>
71+
__global__ __launch_bounds__(kMaxThreads) void permute_1D_data_kernel_vec(
72+
int32_t permuted_indices_size,
73+
int32_t permuted_lengths_size,
74+
const indices_t* __restrict__ indices,
75+
const weights_t* __restrict__ weights,
76+
const int32_t* __restrict__ permute,
77+
const offsets_t* __restrict__ input_offsets,
78+
const offsets_t* __restrict__ output_offsets,
79+
indices_t* __restrict__ permuted_indices,
80+
weights_t* __restrict__ permuted_weights) {
81+
// Select vector types based on element size (vec4 for 4× bandwidth)
82+
using indices_vec4_t =
83+
typename std::conditional<sizeof(indices_t) == 8, long4, float4>::type;
84+
using weights_vec4_t =
85+
typename std::conditional<sizeof(weights_t) == 8, long4, float4>::type;
86+
87+
const auto b_t_start = blockIdx.x * blockDim.y + threadIdx.y;
88+
const auto stride = gridDim.x * blockDim.y;
89+
90+
for (int b_t = b_t_start; b_t < permuted_lengths_size; b_t += stride) {
91+
// Read offsets once - use int32_t for segment_length as it fits in 32 bits
92+
const offsets_t output_start = output_offsets[b_t];
93+
const offsets_t output_end = (b_t == permuted_lengths_size - 1)
94+
? permuted_indices_size
95+
: output_offsets[b_t + 1];
96+
const int32_t segment_length =
97+
static_cast<int32_t>(output_end - output_start);
98+
const offsets_t input_start = input_offsets[permute[b_t]];
99+
100+
// Compute pointers
101+
indices_t* __restrict__ indices_dst_ptr = permuted_indices + output_start;
102+
const indices_t* __restrict__ indices_src_ptr = indices + input_start;
103+
weights_t* __restrict__ weights_dst_ptr =
104+
has_weight ? permuted_weights + output_start : nullptr;
105+
const weights_t* __restrict__ weights_src_ptr =
106+
has_weight ? weights + input_start : nullptr;
107+
108+
// Check alignment once per segment
109+
const bool indices_vec4_aligned =
110+
(sizeof(indices_t) == 4 || sizeof(indices_t) == 8) &&
111+
(reinterpret_cast<uintptr_t>(indices_dst_ptr) &
112+
(alignof(indices_vec4_t) - 1)) == 0 &&
113+
(reinterpret_cast<uintptr_t>(indices_src_ptr) &
114+
(alignof(indices_vec4_t) - 1)) == 0;
115+
116+
const bool weights_vec4_aligned = !has_weight ||
117+
((reinterpret_cast<uintptr_t>(weights_dst_ptr) &
118+
(alignof(weights_vec4_t) - 1)) == 0 &&
119+
(reinterpret_cast<uintptr_t>(weights_src_ptr) &
120+
(alignof(weights_vec4_t) - 1)) == 0);
121+
122+
if (indices_vec4_aligned && weights_vec4_aligned) {
123+
// Vectorized path - process both indices and weights together
124+
const int32_t vec4_count = segment_length / 4;
125+
const int32_t remainder = segment_length & 3; // segment_length % 4
126+
127+
auto indices_dst = reinterpret_cast<indices_vec4_t*>(indices_dst_ptr);
128+
auto indices_src =
129+
reinterpret_cast<const indices_vec4_t*>(indices_src_ptr);
130+
131+
if (has_weight) {
132+
auto weights_dst = reinterpret_cast<weights_vec4_t*>(weights_dst_ptr);
133+
auto weights_src =
134+
reinterpret_cast<const weights_vec4_t*>(weights_src_ptr);
135+
136+
// copy both indices and weights
137+
#pragma unroll
138+
for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) {
139+
indices_dst[i] = indices_src[i];
140+
weights_dst[i] = weights_src[i];
141+
}
142+
// Handle remainder elements (0-3 elements)
143+
if (threadIdx.x < remainder) {
144+
const auto offset = vec4_count * 4 + threadIdx.x;
145+
indices_dst_ptr[offset] = indices_src_ptr[offset];
146+
weights_dst_ptr[offset] = weights_src_ptr[offset];
147+
}
148+
} else {
149+
// copy only indices
150+
#pragma unroll
151+
for (auto i = threadIdx.x; i < vec4_count; i += blockDim.x) {
152+
indices_dst[i] = indices_src[i];
153+
}
154+
155+
// Handle remainder elements (0-3 elements)
156+
if (threadIdx.x < remainder) {
157+
const auto offset = vec4_count * 4 + threadIdx.x;
158+
indices_dst_ptr[offset] = indices_src_ptr[offset];
159+
}
160+
}
161+
} else {
162+
// Scalar fallback path
163+
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
164+
indices_dst_ptr[i] = indices_src_ptr[i];
165+
if (has_weight) {
166+
weights_dst_ptr[i] = weights_src_ptr[i];
167+
}
168+
}
169+
}
170+
}
171+
}
172+
64173
DLL_PUBLIC std::tuple<Tensor, Tensor, std::optional<Tensor>>
65174
permute_1D_sparse_data_cuda(
66175
const Tensor& permute,
@@ -124,17 +233,17 @@ permute_1D_sparse_data_cuda(
124233
permuted_indices_size = output_offsets[-1].item<int64_t>();
125234
}
126235

127-
constexpr int32_t BT_blocks = 32;
128-
dim3 threads_2(32, BT_blocks);
236+
constexpr int32_t BT_blocks = 16;
237+
dim3 threads_2(64, BT_blocks);
129238
const auto blocks_2 =
130239
cuda_calc_xblock_count(permuted_lengths_size, BT_blocks);
131240
permuted_indices = at::empty(permuted_indices_size, indices.options());
132241

133242
AT_DISPATCH_INDEX_TYPES(
134-
input_offsets.scalar_type(), "permute_1D_data_kernel_1", [&] {
243+
input_offsets.scalar_type(), "permute_1D_data_kernel_vec_1", [&] {
135244
using offsets_t = index_t;
136245
FBGEMM_DISPATCH_ALL_TYPES(
137-
indices.scalar_type(), "permute_1D_data_kernel_2", [&] {
246+
indices.scalar_type(), "permute_1D_data_kernel_vec_2", [&] {
138247
using indices_t = scalar_t;
139248
if (weights.has_value()) {
140249
const Tensor weights_value = weights.value();
@@ -143,11 +252,11 @@ permute_1D_sparse_data_cuda(
143252
at::empty(permuted_indices_size, weights_value.options());
144253
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
145254
weights_value.scalar_type(),
146-
"permute_1D_data_kernel_3",
255+
"permute_1D_data_kernel_vec_3",
147256
[&] {
148257
using weights_t = scalar_t;
149258
FBGEMM_LAUNCH_KERNEL(
150-
(permute_1D_data_kernel<
259+
(permute_1D_data_kernel_vec<
151260
true,
152261
offsets_t,
153262
indices_t,
@@ -168,7 +277,7 @@ permute_1D_sparse_data_cuda(
168277
}); // for each weights_t
169278
} else {
170279
FBGEMM_LAUNCH_KERNEL(
171-
(permute_1D_data_kernel<
280+
(permute_1D_data_kernel_vec<
172281
false,
173282
offsets_t,
174283
indices_t,

0 commit comments

Comments
 (0)