@@ -422,25 +422,30 @@ def assert_fill(
422422
423423
424424def _real_float_strict_equals (out : Array , expected : Array ) -> bool :
425- assert hasattr (_xp , "signbit" ) # sanity check
426-
427425 nan_mask = xp .isnan (out )
428426 if not xp .all (nan_mask == xp .isnan (expected )):
429427 return False
428+ ignore_mask = nan_mask
429+
430+ # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
431+ # not that big of a deal for the perf costs.
432+ if api_version >= "2023.12" and hasattr (_xp , "signbit" ):
433+ out_zero_mask = out == 0
434+ out_sign_mask = xp .signbit (out )
435+ out_pos_zero_mask = out_zero_mask & out_sign_mask
436+ out_neg_zero_mask = out_zero_mask & ~ out_sign_mask
437+ expected_zero_mask = expected == 0
438+ expected_sign_mask = xp .signbit (expected )
439+ expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
440+ expected_neg_zero_mask = expected_zero_mask & ~ expected_sign_mask
441+ pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
442+ neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
443+ if not (xp .all (pos_zero_match ) and xp .all (neg_zero_match )):
444+ return False
445+ ignore_mask |= out_zero_mask
430446
431- out_zero_mask = out == 0
432- out_sign_mask = xp .signbit (out )
433- out_pos_zero_mask = out_zero_mask & out_sign_mask
434- out_neg_zero_mask = out_zero_mask & ~ out_sign_mask
435- expected_zero_mask = expected == 0
436- expected_sign_mask = xp .signbit (expected )
437- expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
438- expected_neg_zero_mask = expected_zero_mask & ~ expected_sign_mask
439- if not (xp .all (out_pos_zero_mask == expected_pos_zero_mask ) and xp .all (out_neg_zero_mask == expected_neg_zero_mask )):
440- return False
441-
442- ignore_mask = nan_mask | out_zero_mask
443447 replacement = xp .asarray (42 , dtype = out .dtype ) # i.e. an arbitrary non-zero value that equals itself
448+ assert replacement == replacement # sanity check
444449 match = xp .where (ignore_mask , replacement , out ) == xp .where (ignore_mask , replacement , expected )
445450 return xp .all (match )
446451
@@ -486,10 +491,10 @@ def assert_array_elements(
486491 f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
487492
488493 # First we try short-circuit for a successful assertion by using vectorised checks.
489- if out .dtype in dh .real_float_dtypes and api_version >= "2023.12" :
494+ if out .dtype in dh .real_float_dtypes :
490495 if _real_float_strict_equals (out , expected ):
491496 return
492- elif out .dtype in dh .complex_dtypes and api_version >= "2023.12" :
497+ elif out .dtype in dh .complex_dtypes :
493498 real_match = _real_float_strict_equals (out .real , expected .real )
494499 imag_match = _real_float_strict_equals (out .imag , expected .imag )
495500 if real_match and imag_match :
0 commit comments