-
Notifications
You must be signed in to change notification settings - Fork 270
Refactor vector type to reduce build times #3641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c1a5bba
8b307fd
c3e573d
b731dc1
13a5177
f518830
cc0f42f
cffc2f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,9 +34,48 @@ using f4_t = unsigned _BitInt(4); | |
| using f6_t = _BitInt(6); // e2m3 format | ||
| using bf6_t = unsigned _BitInt(6); // e3m2 format | ||
|
|
||
| // scalar_type | ||
| template <typename TV> | ||
| struct scalar_type; | ||
| // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, | ||
| // native types: bool | ||
| template <typename T> | ||
| inline constexpr bool is_native_type() | ||
| { | ||
| return is_same_v<T, double> || is_same_v<T, float> || is_same_v<T, half_t> || | ||
| is_same_v<T, bhalf_t> || is_same_v<T, int32_t> || is_same_v<T, uint32_t> || | ||
| is_same_v<T, int8_t> || is_same_v<T, uint8_t> || is_same_v<T, _BitInt(8)> || | ||
| is_same_v<T, unsigned _BitInt(8)> || is_same_v<T, bool>; | ||
| } | ||
|
|
||
| /** | ||
| * @brief Wrapper for native vector type | ||
| * @tparam T The element type of the vector | ||
| * @tparam Rank The number of elements in the vector | ||
| */ | ||
| template <typename T, index_t Rank> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using this attribute requires T to be a builtin. maybe it's possible to use some concepts here, conditionally if they are available
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, once we have ability to use concepts, we can make some requirements about attribute support. For now we assume Clang supports it. |
||
| using NativeVectorT = T __attribute__((ext_vector_type(Rank))); | ||
|
|
||
| /** | ||
| * @brief Mapping of incoming type to local native vector storage type and vector size | ||
| * @tparam T Incoming data type | ||
| */ | ||
| template <typename T> | ||
| struct scalar_type | ||
| { | ||
| // Basic data type mapping to unsigned _BitInt of appropriate size | ||
| using type = unsigned _BitInt(8 * sizeof(T)); | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| /** | ||
| * @brief scalar_type trait override for NativeVectorT | ||
| * @tparam T The vector type | ||
| * @tparam Rank The number of elements in the vector | ||
| */ | ||
| template <typename T, index_t Rank> | ||
| struct scalar_type<NativeVectorT<T, Rank>> | ||
| { | ||
| using type = T; | ||
| static constexpr index_t vector_size = Rank; | ||
| }; | ||
|
|
||
| struct f4x2_pk_t | ||
| { | ||
|
|
@@ -74,6 +113,39 @@ struct f4x2_pk_t | |
| } | ||
| }; | ||
|
|
||
| // TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written | ||
| // in the following way: | ||
| // template<typename T, index_t Rank> | ||
| // struct scalar_type<T __attribute__((__vector_size__(sizeof(T) * Rank)))> | ||
| // { | ||
| // using type = T; | ||
| // static constexpr index_t vector_size = Rank; | ||
| // }; | ||
| // The compiler errors out with "partial specialization is not allowed for this type", | ||
| // claiming that the Rank is not a deducible parameter. This might be a compiler bug. | ||
| // Note the above type is classified differently from the NativeVectorT<T, Rank> alias, | ||
| // even though they are functionally equivalent and are trivially constructibe from each other. | ||
| // This is unfortunate, but we have to work around it because some LLVM builtins for some | ||
| // operations (e.g., mma) may return the former type. | ||
| // For now we have to explicitly specialize for each vector size we need. These are used | ||
| // in f6_pk_t below. | ||
|
|
||
| /// @brief scalar_type trait override for uint32_t vector of size 3 | ||
| template <> | ||
| struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 3)))> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the difference between
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Compiler treats them separately. I think I've discovered some compiler bugs that I'll report later. You can play with those here: https://godbolt.org/z/v87KEfx7d
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each one is trivially constructible from the other, however some of the llvm __builtins return types use the vector_size notation while we mostly use the ext_vector_type due to the ability to alias them properly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an interesting observation! I liked this doc regarding vector_size https://gcc.gnu.org/onlinedocs/gcc/Vector-Extensions.html, it explains some details. In godbolt I tried to experiment by defining vectors as: template<typename T, index_t Rank> And this way we can deduce template argument and test is passing. I'm not sure if it brings us back to the previous design though XD Regarding the NativeVectorT2, I would agree with compiler, we allocate a vector of floats 4 byte long, so effectively it is a vector of 1 float, so the test is failing. |
||
| { | ||
| using type = uint32_t; | ||
| static constexpr index_t vector_size = 3; | ||
| }; | ||
|
|
||
| /// @brief scalar_type trait override for uint32_t vector of size 6 | ||
| template <> | ||
| struct scalar_type<uint32_t __attribute__((__vector_size__(sizeof(uint32_t) * 6)))> | ||
| { | ||
| using type = uint32_t; | ||
| static constexpr index_t vector_size = 6; | ||
| }; | ||
|
|
||
| template <typename BitType, index_t pk_size> | ||
| struct f6_pk_t | ||
| { | ||
|
|
@@ -89,28 +161,48 @@ struct f6_pk_t | |
| static constexpr index_t vector_size = | ||
| (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units | ||
|
|
||
| using storage_type = element_type __attribute__((ext_vector_type(vector_size))); | ||
| using storage_type = NativeVectorT<element_type, vector_size>; | ||
| storage_type data_{storage_type(0)}; // packed data | ||
|
|
||
| using type = f6_pk_t<BitType, packed_size>; | ||
|
|
||
| /** This class may trivially constructed by the following vector type alias | ||
| * for example from a result of an mma operation. This is primarily for internal use. | ||
| * @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from | ||
| * uint32_t vectors of size 3 and 6 respectively for example from mma operation results. | ||
| * Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a | ||
| * NativeVectorT<uint32_t, 6> is NOT the same as __attribute__((__vector_size__(6 * | ||
| * sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being | ||
| * functionally equivalent. This class may be trivially constructed from both, so we can steer | ||
| * the templated ctor below to only consider incoming vectors types other than our two storage | ||
| * types of interest. | ||
| */ | ||
| using storage_type_alias = | ||
| element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size))); | ||
|
|
||
| __host__ __device__ constexpr f6_pk_t() {} | ||
| __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} | ||
| { | ||
| // TODO: consider removing initialization similar to vector_type<T, 256> | ||
| } | ||
|
|
||
| // Initialize from a vector type with the same size as packed_size | ||
| template <typename T, typename = enable_if_t<scalar_type<T>::vector_size == packed_size>> | ||
| // Initialize from a vector type with the same size as packed_size. | ||
| // Exclude storage_type and storage_type_alias because these are trivially constructible. | ||
| template < | ||
| typename T, | ||
| typename = enable_if_t<!is_same_v<T, storage_type> && !is_same_v<T, storage_type_alias> && | ||
| scalar_type<T>::vector_size == packed_size>> | ||
| __host__ __device__ f6_pk_t(const T& v) | ||
| { | ||
| static_assert(scalar_type<T>::vector_size == packed_size, | ||
| "Input vector size must match packed_size."); | ||
| static_for<0, packed_size, 1>{}( | ||
| [&](auto i) { pack(v[static_cast<index_t>(i)], static_cast<index_t>(i)); }); | ||
| } | ||
|
|
||
| // Broadcast single initialization value to all packed elements | ||
| __host__ __device__ f6_pk_t(const int8_t v) | ||
| : f6_pk_t(static_cast<int8_t __attribute__((ext_vector_type(packed_size)))>(v)) | ||
| : f6_pk_t(static_cast<NativeVectorT<int8_t, packed_size>>(v)) | ||
| { | ||
| // TODO: consider removing initialization similar to vector_type<T, 256> | ||
| } | ||
|
|
@@ -191,27 +283,6 @@ struct pk_i4_t | |
| __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} | ||
| }; | ||
|
|
||
| inline constexpr auto next_pow2(uint32_t x) | ||
| { | ||
| // Precondition: x > 1. | ||
| return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; | ||
| } | ||
|
|
||
| // native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, | ||
| // native types: bool | ||
| template <typename T> | ||
| inline constexpr bool is_native_type() | ||
| { | ||
| return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value || | ||
| is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || | ||
| is_same<T, uint32_t>::value || is_same<T, int8_t>::value || is_same<T, uint8_t>::value || | ||
| is_same_v<T, _BitInt(8)> || is_same_v<T, unsigned _BitInt(8)> || is_same<T, bool>::value; | ||
| } | ||
|
|
||
| // scalar_type | ||
| template <typename TV> | ||
| struct scalar_type; | ||
|
|
||
| // is_scalar_type | ||
| template <typename TV> | ||
| struct is_scalar_type | ||
|
|
@@ -224,14 +295,13 @@ template <typename X, typename Y> | |
| using has_same_scalar_type = is_same<typename scalar_type<remove_cvref_t<X>>::type, | ||
| typename scalar_type<remove_cvref_t<Y>>::type>; | ||
|
|
||
| template <typename T, index_t N> | ||
| struct scalar_type<T __attribute__((ext_vector_type(N)))> | ||
| template <> | ||
| struct scalar_type<bool> | ||
| { | ||
| using type = T; | ||
| static constexpr index_t vector_size = N; | ||
| using type = bool; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| // | ||
| template <> | ||
| struct scalar_type<double> | ||
| { | ||
|
|
@@ -293,86 +363,79 @@ struct scalar_type<int4_t> | |
| template <> | ||
| struct scalar_type<pk_i4_t> | ||
| { | ||
| using type = pk_i4_t; | ||
| using type = typename pk_i4_t::type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<f8_fnuz_t> | ||
| { | ||
| using type = f8_fnuz_t::data_type; | ||
| using type = typename f8_fnuz_t::data_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<bf8_fnuz_t> | ||
| { | ||
| using type = bf8_fnuz_t::data_type; | ||
| using type = typename bf8_fnuz_t::data_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<f8_ocp_t> | ||
| { | ||
| using type = f8_ocp_t::data_type; | ||
| using type = typename f8_ocp_t::data_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<bf8_ocp_t> | ||
| { | ||
| using type = bf8_ocp_t::data_type; | ||
| using type = typename bf8_ocp_t::data_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| #ifndef CK_CODE_GEN_RTC | ||
| template <> | ||
| struct scalar_type<e8m0_bexp_t> | ||
| { | ||
| using type = e8m0_bexp_t::type; | ||
| using type = typename e8m0_bexp_t::type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
| #endif | ||
|
|
||
| template <> | ||
| struct scalar_type<f4x2_pk_t> | ||
| { | ||
| using type = f4x2_pk_t::type; | ||
| using type = typename f4x2_pk_t::type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<f6x32_pk_t> | ||
| { | ||
| using type = f6x32_pk_t::storage_type; | ||
| using type = typename f6x32_pk_t::storage_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<bf6x32_pk_t> | ||
| { | ||
| using type = bf6x32_pk_t::storage_type; | ||
| using type = typename bf6x32_pk_t::storage_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<f6x16_pk_t> | ||
| { | ||
| using type = f6x16_pk_t::storage_type; | ||
| using type = typename f6x16_pk_t::storage_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<bf6x16_pk_t> | ||
| { | ||
| using type = bf6x16_pk_t::storage_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
| template <> | ||
| struct scalar_type<bool> | ||
| { | ||
| using type = bool; | ||
| using type = typename bf6x16_pk_t::storage_type; | ||
| static constexpr index_t vector_size = 1; | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference to std::is_arithmetic_v?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::is_arithmetic_v == false for _Float16 and _BitInt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was original code, didn't want to change it :)