@@ -1480,19 +1480,74 @@ class GemmBatchFunctorThreadNM_vecm
14801480 }
14811481};
14821482
1483- template < typename resT> struct GemmBatchFunctorThreadNM_vecm_HyperParameters
1483+ struct GemmBatchFunctorThreadNM_vecm_HyperParameters
14841484{
1485- static constexpr std::uint32_t wi_delta_n = 4 ;
1486- static constexpr std::uint32_t wi_delta_m_vecs = 1 ;
1487- static constexpr std::uint32_t m_vec_size = 4 ;
1485+ private:
1486+ std::uint32_t wi_delta_n = 2 ;
1487+ std::uint32_t wi_delta_m_vecs = 4 ;
1488+ std::uint32_t m_vec_size = 1 ;
1489+
1490+ public:
1491+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters ();
1492+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters (
1493+ std::uint32_t wi_delta_n_,
1494+ std::uint32_t wi_delta_m_vecs_,
1495+ std::uint32_t m_vec_size_)
1496+ : wi_delta_n(wi_delta_n_), wi_delta_m_vecs(wi_delta_m_vecs_),
1497+ m_vec_size(m_vec_size_)
1498+ {
1499+ }
1500+
1501+ constexpr std::uint32_t get_wi_delta_n () const
1502+ {
1503+ return wi_delta_n;
1504+ }
1505+ constexpr std::uint32_t get_wi_delta_m_vecs () const
1506+ {
1507+ return wi_delta_m_vecs;
1508+ }
1509+ constexpr std::uint32_t get_m_vec_size () const
1510+ {
1511+ return m_vec_size;
1512+ }
14881513};
14891514
1490- template <typename T >
1491- struct GemmBatchFunctorThreadNM_vecm_HyperParameters <std:: complex <T>>
1515+ template <typename resT >
1516+ struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector
14921517{
1493- static constexpr std::uint32_t wi_delta_n = 2 ;
1494- static constexpr std::uint32_t wi_delta_m_vecs = 2 ;
1495- static constexpr std::uint32_t m_vec_size = 1 ;
1518+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector () {}
1519+
1520+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters get () const
1521+ {
1522+ if constexpr (sizeof (resT) == 1 ) {
1523+ // 1 * 8 * 2 * 4 == 64
1524+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (8 , 2 , 4 );
1525+ }
1526+ else if constexpr (sizeof (resT) == 2 ) {
1527+ // 2 * 4 * 2 * 4 == 64
1528+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (4 , 2 , 4 );
1529+ }
1530+ else if constexpr (sizeof (resT) == 4 ) {
1531+ // 4 * 4 * 1 * 4 == 64
1532+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (4 , 1 , 4 );
1533+ }
1534+ else if constexpr (sizeof (resT) == 8 ) {
1535+ // 8 * 2 * 1 * 4 == 64
1536+ if constexpr (std::is_same_v<resT, std::complex <float >>) {
1537+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 4 , 1 );
1538+ }
1539+ else {
1540+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 1 , 4 );
1541+ }
1542+ }
1543+ else if constexpr (std::is_same_v<resT, std::complex <double >>) {
1544+ // 16 * 2 * 2 * 1 == 64
1545+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 2 , 1 );
1546+ }
1547+ else {
1548+ return GemmBatchFunctorThreadNM_vecm_HyperParameters (2 , 2 , 1 );
1549+ }
1550+ }
14961551};
14971552
14981553template <typename T1,
@@ -1572,11 +1627,14 @@ sycl::event _gemm_batch_new_nm_impl(sycl::queue &exec_q,
15721627 const ResIndexerT &res_indexer,
15731628 std::vector<sycl::event> const &depends)
15741629{
1575- using parametersT = GemmBatchFunctorThreadNM_vecm_HyperParameters<resTy>;
1630+ constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector<resTy>
1631+ selector{};
1632+ constexpr auto hyper_params = selector.get ();
15761633
1577- constexpr std::uint32_t wi_delta_n = parametersT::wi_delta_n;
1578- constexpr std::uint32_t wi_delta_m_vecs = parametersT::wi_delta_m_vecs;
1579- constexpr std::uint32_t m_vec_size = parametersT::m_vec_size;
1634+ constexpr std::uint32_t wi_delta_n = hyper_params.get_wi_delta_n ();
1635+ constexpr std::uint32_t wi_delta_m_vecs =
1636+ hyper_params.get_wi_delta_m_vecs ();
1637+ constexpr std::uint32_t m_vec_size = hyper_params.get_m_vec_size ();
15801638
15811639 constexpr std::uint32_t wi_total_delta_m = wi_delta_m_vecs * m_vec_size;
15821640
@@ -3078,7 +3136,7 @@ gemm_batch_new_nm_impl(sycl::queue &exec_q,
30783136 sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
30793137 lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
30803138 OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3081- exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, m, k , batch_indexer,
3139+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m , batch_indexer,
30823140 lhs_indexer, rhs_indexer, res_indexer, depends);
30833141
30843142 return gemm_ev;
@@ -3643,41 +3701,67 @@ sycl::event gemm_new_nm_impl(sycl::queue &exec_q,
36433701 sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
36443702 lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
36453703 OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3646- exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k ,
3704+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m ,
36473705 batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
36483706
36493707 return gemm_ev;
36503708}
36513709
36523710template <typename lhsTy, typename rhsTy, typename resTy>
36533711sycl::event
3654- gemm_new_nm_contig_impl (sycl::queue &exec_q,
3655- const lhsTy *lhs_tp,
3656- const rhsTy *rhs_tp,
3657- resTy *res_tp,
3658- size_t n,
3659- size_t k,
3660- size_t m,
3661- std::vector<sycl::event> const &depends = {})
3712+ gemm_batch_new_nm_contig_impl (sycl::queue &exec_q,
3713+ const lhsTy *lhs_tp,
3714+ const rhsTy *rhs_tp,
3715+ resTy *res_tp,
3716+ const size_t batch_nelems,
3717+ const size_t n,
3718+ const size_t k,
3719+ const size_t m,
3720+ std::vector<sycl::event> const &depends = {})
36623721{
36633722 using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
36643723 constexpr OuterInnerDimsIndexerT lhs_indexer{};
36653724 constexpr OuterInnerDimsIndexerT rhs_indexer{};
36663725 constexpr OuterInnerDimsIndexerT res_indexer{};
36673726
3668- using BatchDimsIndexerT =
3669- dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3670- constexpr BatchDimsIndexerT batch_indexer{};
3671-
36723727 constexpr size_t single_batch_nelems = 1 ;
36733728
3674- sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3675- lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3676- OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3677- exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, m, k,
3678- batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3729+ if (batch_nelems == single_batch_nelems) {
3730+ using BatchDimsIndexerT =
3731+ dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer;
3732+ constexpr BatchDimsIndexerT batch_indexer{};
36793733
3680- return gemm_ev;
3734+ sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3735+ lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3736+ OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3737+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
3738+ batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3739+
3740+ return gemm_ev;
3741+ }
3742+ else {
3743+ using dpctl::tensor::offset_utils::Strided1DIndexer;
3744+ using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer;
3745+ using BatchDimsIndexerT =
3746+ ThreeOffsets_CombinedIndexer<Strided1DIndexer, Strided1DIndexer,
3747+ Strided1DIndexer>;
3748+
3749+ using dpctl::tensor::offset_utils::Strided1DIndexer;
3750+
3751+ const ssize_t ss_batch_nelems = static_cast <ssize_t >(batch_nelems);
3752+ const BatchDimsIndexerT batch_indexer (
3753+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(n * k)},
3754+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(k * m)},
3755+ Strided1DIndexer{0 , ss_batch_nelems, static_cast <ssize_t >(n * m)});
3756+
3757+ sycl::event gemm_ev = gemm_detail::_gemm_batch_new_nm_impl<
3758+ lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT,
3759+ OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>(
3760+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m,
3761+ batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends);
3762+
3763+ return gemm_ev;
3764+ }
36813765}
36823766
36833767template <typename lhsTy, typename rhsTy, typename resTy>
@@ -3705,8 +3789,8 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q,
37053789 const size_t max_nm = std::max (n, m);
37063790
37073791 if (min_nm > 0 && (max_nm >= ((64 * 1024 ) / min_nm))) {
3708- return gemm_new_nm_contig_impl <lhsTy, rhsTy, resTy>(
3709- exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
3792+ return gemm_batch_new_nm_contig_impl <lhsTy, rhsTy, resTy>(
3793+ exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends);
37103794 }
37113795
37123796 if (k == 0 ) {
@@ -4518,8 +4602,10 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q,
45184602 const size_t max_nm = std::max (n, m);
45194603
45204604 if (min_nm > 0 && (max_nm >= ((64 * 1024 ) / min_nm))) {
4521- return gemm_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4522- exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends);
4605+ constexpr size_t single_batch_nelems = 1 ;
4606+ return gemm_batch_new_nm_contig_impl<lhsTy, rhsTy, resTy>(
4607+ exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m,
4608+ depends);
45234609 }
45244610
45254611 if (k == 0 ) {
0 commit comments