Skip to content

Commit b2569b4

Browse files
zjing14meta-codesync[bot]
authored andcommitted
Update fpgemm fp8 conv heuristic (#5118)
Summary: Pull Request resolved: #5118 X-link: https://github.com/facebookresearch/FBGEMM/pull/2124 - Update fp8 conv heuristic for D86440061 Reviewed By: jwfromm Differential Revision: D86558446 fbshipit-source-id: 6b8b4fff190b2f98181ac018576bdbbbe940d256
1 parent bfa83ec commit b2569b4

File tree

5 files changed

+172
-32
lines changed

5 files changed

+172
-32
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv.cu

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,24 @@ struct ProblemSize {
3838
std::vector<int64_t> dilation;
3939
bool operator==(const ProblemSize& ps) const {
4040
return activation_shape == ps.activation_shape &&
41-
filter_shape == ps.filter_shape && padding == ps.padding &&
42-
stride == ps.stride && dilation == ps.dilation;
41+
filter_shape == ps.filter_shape;
42+
}
43+
void print() const {
44+
// clang-format off
45+
std::cout << "actv: " // [N, D, H, W, C]
46+
<< activation_shape[0] << ","
47+
<< activation_shape[1] << ","
48+
<< activation_shape[2] << ","
49+
<< activation_shape[3] << ","
50+
<< activation_shape[4] << ","
51+
<< "filter: " // [K, T, R, S, C]
52+
<< filter_shape[0] << ","
53+
<< filter_shape[1] << ","
54+
<< filter_shape[2] << ","
55+
<< filter_shape[3] << ","
56+
<< filter_shape[4] << ","
57+
<< std::endl;
58+
// clang-format on
4359
}
4460
};
4561

@@ -59,42 +75,43 @@ struct ProblemSizeHash {
5975
};
6076
hash_combine(seed, vec_hash(ps.activation_shape));
6177
hash_combine(seed, vec_hash(ps.filter_shape));
62-
hash_combine(seed, vec_hash(ps.padding));
63-
hash_combine(seed, vec_hash(ps.stride));
64-
hash_combine(seed, vec_hash(ps.dilation));
78+
// hash_combine(seed, vec_hash(ps.padding));
79+
// hash_combine(seed, vec_hash(ps.stride));
80+
// hash_combine(seed, vec_hash(ps.dilation));
6581
return seed;
6682
}
6783
};
6884

6985
// clang-format off
7086
std::unordered_map<ProblemSize, Kernel_f8f8bf16_conv, ProblemSizeHash> kernel_map = {
71-
{{{1,6,32,48,48}, {48,1,1,1,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x128x128_1x1x1},
72-
{{{1,3,34,50,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
73-
{{{1,3,34,50,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
74-
{{{1,3,66,98,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
75-
{{{1,3,130,194,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
76-
{{{1,3,130,194,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
77-
{{{1,1,128,192,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
78-
{{{1,3,258,386,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
79-
{{{1,3,258,386,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
80-
{{{1,1,256,384,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
81-
//{{{1,3,258,386,256}, {12,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_invalid},
82-
{{{1,3,32,48,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x2x1},
83-
{{{1,4,66,98,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x512x128_2x2x1},
84-
{{{1,4,64,96,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
85-
{{{1,6,130,194,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
86-
{{{1,6,130,194,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
87-
{{{1,4,128,192,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
88-
{{{1,6,258,386,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
89-
{{{1,6,258,386,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
90-
{{{1,4,256,384,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
91-
//{{{1,6,258,386,256}, {12,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_invalid},
92-
{{{1,1,64,96,1024}, {1024,1,3,3,1024}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
93-
{{{1,1,128,192,1024}, {1024,1,3,3,1024}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
94-
{{{1,1,256,384,512}, {512,1,3,3,512}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
95-
{{{2,1,64,96,1024}, {1024,1,3,3,1024}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
96-
{{{4,1,128,192,1024}, {1024,1,3,3,1024}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
97-
{{{4,1,256,384,512}, {512,1,3,3,512}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1}
87+
{{{1,1,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
88+
{{{1,1,192,128,160}, {320,1,1,1,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
89+
{{{1,1,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
90+
{{{1,1,96,64,320}, {640,1,1,1,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_1x2x1},
91+
{{{1,3,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
92+
{{{1,3,194,130,160}, {320,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
93+
{{{1,3,194,130,320}, {320,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
94+
{{{1,3,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
95+
{{{1,3,386,258,160}, {160,3,3,3,160}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
96+
{{{1,3,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
97+
{{{1,3,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
98+
{{{1,3,48,32,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
99+
{{{1,3,50,34,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
100+
{{{1,3,50,34,48}, {1024,3,3,3,48}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
101+
{{{1,3,50,34,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x128x128_4x1x1},
102+
{{{1,3,50,34,640}, {96,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x2x1},
103+
{{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
104+
{{{1,3,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_4x1x1},
105+
{{{1,3,98,66,320}, {640,3,3,3,320}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
106+
{{{1,3,98,66,640}, {640,3,3,3,640}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
107+
{{{1,4,192,128,1024}, {512,1,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_128x256x128_2x1x1},
108+
{{{1,4,384,256,512}, {256,1,1,1,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
109+
{{{1,4,96,64,1024}, {2048,3,1,1,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
110+
{{{1,4,98,66,1024}, {1024,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_512x256x128_4x1x1},
111+
{{{1,6,194,130,1024}, {512,3,3,3,1024}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
112+
{{{1,6,194,130,512}, {512,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
113+
{{{1,6,386,258,256}, {256,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
114+
{{{1,6,386,258,512}, {256,3,3,3,512}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_256x256x128_2x1x1},
98115
};
99116
// clang-format on
100117

@@ -114,6 +131,9 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic(
114131
auto it = kernel_map.find(ps);
115132
if (it != kernel_map.end()) {
116133
return it->second;
134+
} else {
135+
std::cout << "warning: not found";
136+
ps.print();
117137
}
118138
// Fallback kernel
119139
return f8f8bf16_conv_256x256x128_2x1x1;
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_conv_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_conv_128x256x128_1x2x1(
14+
at::Tensor activation, // FP8 - NDHWC layout
15+
at::Tensor filter, // FP8 - KTRSC layout
16+
at::Tensor scale,
17+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
18+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
19+
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]
20+
21+
return f8f8bf16_conv_impl<
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100>(
29+
activation, filter, scale, padding, stride, dilation);
30+
}
31+
32+
} // namespace fbgemm_gpu
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_conv_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_conv_256x128x128_4x1x1(
14+
at::Tensor activation, // FP8 - NDHWC layout
15+
at::Tensor filter, // FP8 - KTRSC layout
16+
at::Tensor scale,
17+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
18+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
19+
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]
20+
21+
return f8f8bf16_conv_impl<
22+
128,
23+
128,
24+
128,
25+
4,
26+
1,
27+
1,
28+
cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>(
29+
activation, filter, scale, padding, stride, dilation);
30+
}
31+
32+
} // namespace fbgemm_gpu
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_conv_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_conv_256x256x128_4x2x1(
14+
at::Tensor activation, // FP8 - NDHWC layout
15+
at::Tensor filter, // FP8 - KTRSC layout
16+
at::Tensor scale,
17+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
18+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
19+
std::vector<int64_t> dilation) { // [dilation_d, dilation_h, dilation_w]
20+
21+
return f8f8bf16_conv_impl<
22+
128,
23+
128,
24+
128,
25+
4,
26+
2,
27+
1,
28+
cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100>(
29+
activation, filter, scale, padding, stride, dilation);
30+
}
31+
32+
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_manifest.cuh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ at::Tensor f8f8bf16_conv_128x128x128_1x1x1(
2626
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
2727
std::vector<int64_t> dilation);
2828

29+
at::Tensor f8f8bf16_conv_128x256x128_1x2x1(
30+
at::Tensor activation, // FP8 - NDHWC layout
31+
at::Tensor filter, // FP8 - KTRSC layout
32+
at::Tensor scale,
33+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
34+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
35+
std::vector<int64_t> dilation);
36+
37+
at::Tensor f8f8bf16_conv_256x128x128_4x1x1(
38+
at::Tensor activation, // FP8 - NDHWC layout
39+
at::Tensor filter, // FP8 - KTRSC layout
40+
at::Tensor scale,
41+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
42+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
43+
std::vector<int64_t> dilation);
44+
2945
at::Tensor f8f8bf16_conv_128x256x128_2x1x1(
3046
at::Tensor activation, // FP8 - NDHWC layout
3147
at::Tensor filter, // FP8 - KTRSC layout
@@ -58,6 +74,14 @@ at::Tensor f8f8bf16_conv_256x256x128_4x1x1(
5874
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
5975
std::vector<int64_t> dilation);
6076

77+
at::Tensor f8f8bf16_conv_256x256x128_4x2x1(
78+
at::Tensor activation, // FP8 - NDHWC layout
79+
at::Tensor filter, // FP8 - KTRSC layout
80+
at::Tensor scale,
81+
std::vector<int64_t> padding, // [pad_d, pad_h, pad_w]
82+
std::vector<int64_t> stride, // [stride_d, stride_h, stride_w]
83+
std::vector<int64_t> dilation);
84+
6185
at::Tensor f8f8bf16_conv_256x512x128_2x2x1(
6286
at::Tensor activation, // FP8 - NDHWC layout
6387
at::Tensor filter, // FP8 - KTRSC layout

0 commit comments

Comments
 (0)