@@ -38,8 +38,24 @@ struct ProblemSize {
3838 std::vector<int64_t > dilation;
3939 bool operator ==(const ProblemSize& ps) const {
4040 return activation_shape == ps.activation_shape &&
41- filter_shape == ps.filter_shape && padding == ps.padding &&
42- stride == ps.stride && dilation == ps.dilation ;
41+ filter_shape == ps.filter_shape ;
42+ }
43+ void print () const {
44+ // clang-format off
45+ std::cout << " actv: " // [N, D, H, W, C]
46+ << activation_shape[0 ] << " ,"
47+ << activation_shape[1 ] << " ,"
48+ << activation_shape[2 ] << " ,"
49+ << activation_shape[3 ] << " ,"
50+ << activation_shape[4 ] << " ,"
51+ << " filter: " // [K, T, R, S, C]
52+ << filter_shape[0 ] << " ,"
53+ << filter_shape[1 ] << " ,"
54+ << filter_shape[2 ] << " ,"
55+ << filter_shape[3 ] << " ,"
56+ << filter_shape[4 ] << " ,"
57+ << std::endl;
58+ // clang-format on
4359 }
4460};
4561
@@ -59,42 +75,43 @@ struct ProblemSizeHash {
5975 };
6076 hash_combine (seed, vec_hash (ps.activation_shape ));
6177 hash_combine (seed, vec_hash (ps.filter_shape ));
62- hash_combine (seed, vec_hash (ps.padding ));
63- hash_combine (seed, vec_hash (ps.stride ));
64- hash_combine (seed, vec_hash (ps.dilation ));
78+ // hash_combine(seed, vec_hash(ps.padding));
79+ // hash_combine(seed, vec_hash(ps.stride));
80+ // hash_combine(seed, vec_hash(ps.dilation));
6581 return seed;
6682 }
6783};
6884
6985// clang-format off
7086std::unordered_map<ProblemSize, Kernel_f8f8bf16_conv, ProblemSizeHash> kernel_map = {
71- {{{1 ,6 ,32 ,48 ,48 }, {48 ,1 ,1 ,1 ,48 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x128x128_1x1x1},
72- {{{1 ,3 ,34 ,50 ,48 }, {1024 ,3 ,3 ,3 ,48 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x1x1},
73- {{{1 ,3 ,34 ,50 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
74- {{{1 ,3 ,66 ,98 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x1x1},
75- {{{1 ,3 ,130 ,194 ,1024 }, {512 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
76- {{{1 ,3 ,130 ,194 ,512 }, {512 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x512x128_2x2x1},
77- {{{1 ,1 ,128 ,192 ,1024 }, {512 ,1 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
78- {{{1 ,3 ,258 ,386 ,512 }, {256 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
79- {{{1 ,3 ,258 ,386 ,256 }, {256 ,3 ,3 ,3 ,256 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
80- {{{1 ,1 ,256 ,384 ,512 }, {256 ,1 ,1 ,1 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
81- // {{{1,3,258,386,256}, {12,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_invalid},
82- {{{1 ,3 ,32 ,48 ,1024 }, {2048 ,3 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x2x1},
83- {{{1 ,4 ,66 ,98 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x512x128_2x2x1},
84- {{{1 ,4 ,64 ,96 ,1024 }, {2048 ,3 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
85- {{{1 ,6 ,130 ,194 ,1024 }, {512 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
86- {{{1 ,6 ,130 ,194 ,512 }, {512 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
87- {{{1 ,4 ,128 ,192 ,1024 }, {512 ,1 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
88- {{{1 ,6 ,258 ,386 ,512 }, {256 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
89- {{{1 ,6 ,258 ,386 ,256 }, {256 ,3 ,3 ,3 ,256 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
90- {{{1 ,4 ,256 ,384 ,512 }, {256 ,1 ,1 ,1 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
91- // {{{1,6,258,386,256}, {12,3,3,3,256}, {0, 0, 0}, {1, 1, 1}, {1, 1, 1}}, f8f8bf16_conv_invalid},
92- {{{1 ,1 ,64 ,96 ,1024 }, {1024 ,1 ,3 ,3 ,1024 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
93- {{{1 ,1 ,128 ,192 ,1024 }, {1024 ,1 ,3 ,3 ,1024 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
94- {{{1 ,1 ,256 ,384 ,512 }, {512 ,1 ,3 ,3 ,512 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
95- {{{2 ,1 ,64 ,96 ,1024 }, {1024 ,1 ,3 ,3 ,1024 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
96- {{{4 ,1 ,128 ,192 ,1024 }, {1024 ,1 ,3 ,3 ,1024 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
97- {{{4 ,1 ,256 ,384 ,512 }, {512 ,1 ,3 ,3 ,512 }, {1 , 1 , 1 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1}
87+ {{{1 ,1 ,192 ,128 ,1024 }, {512 ,1 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
88+ {{{1 ,1 ,192 ,128 ,160 }, {320 ,1 ,1 ,1 ,160 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
89+ {{{1 ,1 ,384 ,256 ,512 }, {256 ,1 ,1 ,1 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
90+ {{{1 ,1 ,96 ,64 ,320 }, {640 ,1 ,1 ,1 ,320 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_1x2x1},
91+ {{{1 ,3 ,194 ,130 ,1024 }, {512 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
92+ {{{1 ,3 ,194 ,130 ,160 }, {320 ,3 ,3 ,3 ,160 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
93+ {{{1 ,3 ,194 ,130 ,320 }, {320 ,3 ,3 ,3 ,320 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x128x128_4x1x1},
94+ {{{1 ,3 ,194 ,130 ,512 }, {512 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
95+ {{{1 ,3 ,386 ,258 ,160 }, {160 ,3 ,3 ,3 ,160 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
96+ {{{1 ,3 ,386 ,258 ,256 }, {256 ,3 ,3 ,3 ,256 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
97+ {{{1 ,3 ,386 ,258 ,512 }, {256 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
98+ {{{1 ,3 ,48 ,32 ,1024 }, {2048 ,3 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x128x128_4x1x1},
99+ {{{1 ,3 ,50 ,34 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
100+ {{{1 ,3 ,50 ,34 ,48 }, {1024 ,3 ,3 ,3 ,48 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x1x1},
101+ {{{1 ,3 ,50 ,34 ,640 }, {640 ,3 ,3 ,3 ,640 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x128x128_4x1x1},
102+ {{{1 ,3 ,50 ,34 ,640 }, {96 ,3 ,3 ,3 ,640 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x2x1},
103+ {{{1 ,3 ,98 ,66 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x1x1},
104+ {{{1 ,3 ,98 ,66 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_4x1x1},
105+ {{{1 ,3 ,98 ,66 ,320 }, {640 ,3 ,3 ,3 ,320 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
106+ {{{1 ,3 ,98 ,66 ,640 }, {640 ,3 ,3 ,3 ,640 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
107+ {{{1 ,4 ,192 ,128 ,1024 }, {512 ,1 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_128x256x128_2x1x1},
108+ {{{1 ,4 ,384 ,256 ,512 }, {256 ,1 ,1 ,1 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
109+ {{{1 ,4 ,96 ,64 ,1024 }, {2048 ,3 ,1 ,1 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
110+ {{{1 ,4 ,98 ,66 ,1024 }, {1024 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_512x256x128_4x1x1},
111+ {{{1 ,6 ,194 ,130 ,1024 }, {512 ,3 ,3 ,3 ,1024 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
112+ {{{1 ,6 ,194 ,130 ,512 }, {512 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
113+ {{{1 ,6 ,386 ,258 ,256 }, {256 ,3 ,3 ,3 ,256 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
114+ {{{1 ,6 ,386 ,258 ,512 }, {256 ,3 ,3 ,3 ,512 }, {0 , 0 , 0 }, {1 , 1 , 1 }, {1 , 1 , 1 }}, f8f8bf16_conv_256x256x128_2x1x1},
98115};
99116// clang-format on
100117
@@ -114,6 +131,9 @@ Kernel_f8f8bf16_conv get_kernel_via_heuristic(
114131 auto it = kernel_map.find (ps);
115132 if (it != kernel_map.end ()) {
116133 return it->second ;
134+ } else {
135+ std::cout << " warning: not found" ;
136+ ps.print ();
117137 }
118138 // Fallback kernel
119139 return f8f8bf16_conv_256x256x128_2x1x1;
0 commit comments