File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed
dpctl/tensor/libtensor/include/kernels/linalg_functions Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -1490,8 +1490,8 @@ template <typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
14901490template <typename T>
14911491struct GemmBatchFunctorThreadNM_vecm_HyperParameters <std::complex <T>>
14921492{
1493- static constexpr std::uint32_t wi_delta_n = 4 ;
1494- static constexpr std::uint32_t wi_delta_m_vecs = 4 ;
1493+ static constexpr std::uint32_t wi_delta_n = 2 ;
1494+ static constexpr std::uint32_t wi_delta_m_vecs = 2 ;
14951495 static constexpr std::uint32_t m_vec_size = 1 ;
14961496};
14971497
@@ -1527,7 +1527,7 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
15271527 ? 64
15281528 : 32 * static_cast <std::uint32_t >(slm_max_rows / 32 );
15291529
1530- if ( !wi_delta_k) {
1530+ for (std:: uint32_t it = 0 ; !wi_delta_k && (it < 4 ); ++it ) {
15311531 wg_delta_m /= 2 ;
15321532
15331533 const size_t slm_max_rows =
@@ -1539,7 +1539,9 @@ get_wg_delta_m_and_wi_delta_k(const size_t slm_byte_size,
15391539 ? 64
15401540 : ((slm_max_rows >= 32 )
15411541 ? 32
1542- : 16 * static_cast <std::uint32_t >(slm_max_rows / 16 ));
1542+ : (slm_max_rows >= 16 ? 16
1543+ : 8 * static_cast <std::uint32_t >(
1544+ slm_max_rows / 8 )));
15431545 }
15441546
15451547 if (!wi_delta_k) {
You can’t perform that action at this time.
0 commit comments