@@ -579,7 +579,7 @@ def random_matrix():
579579
580580
581581@pytest .mark .parametrize ("dtype" , _numeric_types )
582- def test_matmul_largish (dtype , random_matrix ):
582+ def test_matmul_largish_square (dtype , random_matrix ):
583583 q = get_queue_or_skip ()
584584 skip_if_dtype_not_supported (dtype , q )
585585
@@ -598,6 +598,7 @@ def test_matmul_largish(dtype, random_matrix):
598598 assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
599599 assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
600600
601+ # check stided input
601602 m_np = m_np [:- 1 , :- 1 ]
602603 x_np = np .matmul (m_np .T , m_np )
603604
@@ -610,6 +611,40 @@ def test_matmul_largish(dtype, random_matrix):
610611 assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
611612
612613
614+ @pytest .mark .parametrize ("dtype" , _numeric_types )
615+ def test_matmul_largish_rect (dtype , random_matrix ):
616+ q = get_queue_or_skip ()
617+ skip_if_dtype_not_supported (dtype , q )
618+
619+ m_np = random_matrix .astype (dtype )[:, :- 1 ]
620+ x_np = np .matmul (m_np .T [:- 2 , :], m_np )
621+
622+ m = dpt .asarray (m_np )
623+ mmT = m .mT [:- 2 , :]
624+ mT = dpt .asarray (mmT , copy = True , order = "C" )
625+ x1 = dpt .matmul (mmT , m )
626+ x2 = dpt .matmul (mT , m )
627+
628+ tol = 0
629+ if dpt .isdtype (x2 .dtype , ("real floating" , "complex floating" )):
630+ tol = 32 * dpt .finfo (x2 .dtype ).eps
631+
632+ assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
633+ assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
634+
635+ m_np = m_np [:- 1 , :- 1 ]
636+ x_np = np .matmul (m_np .T [:- 2 , :], m_np )
637+
638+ m = m [:- 1 , :- 1 ]
639+ mmT = m .mT [:- 2 , :]
640+ mT = dpt .asarray (mmT , copy = True , order = "C" )
641+ x1 = dpt .matmul (mmT , m )
642+ x2 = dpt .matmul (mT , m )
643+
644+ assert dpt .allclose (x1 , x2 , atol = tol , rtol = tol )
645+ assert dpt .allclose (x1 , dpt .asarray (x_np ), atol = tol , rtol = tol )
646+
647+
613648@pytest .mark .parametrize ("dtype" , _numeric_types )
614649def test_tensordot_outer (dtype ):
615650 q = get_queue_or_skip ()
0 commit comments