From ec56b2f3993813c33219585d4f6bc2eafbe85e1a Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Wed, 12 Nov 2025 10:01:02 -0700 Subject: [PATCH 1/2] add more templates for complex MulAdd and MulAddAssign traits --- src/lib.rs | 237 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 217 insertions(+), 20 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 661b67b..4fcd004 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -808,6 +808,67 @@ impl<'a, 'b, T: Clone + Num + MulAdd> MulAdd<&'b Complex> for &'a } } +// (a + i b) * (c + i 0) + (e + i f) == (a*c + e) + i (b*c + f) +impl> MulAdd> for Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: T, add: Complex) -> Complex { + let re = self.re.mul_add(other.clone(), add.re); + let im = self.im.mul_add(other, add.im); + Complex::new(re, im) + } +} +impl<'a, 'b, T: Clone + Num + MulAdd> MulAdd<&'b T, &'b Complex> for &'a Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: &T, add: &Complex) -> Complex { + self.clone().mul_add(other.clone(), add.clone()) + } +} + +// (a + i b) * (c + i d) + (e + i 0) == ((a*c + e) - b*d) + i (a*d + b*c) +impl> MulAdd, T> for Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: Complex, add: T) -> Complex { + let re = self.re.clone().mul_add(other.re.clone(), add) + - (self.im.clone() * other.im.clone()); // FIXME: use mulsub when available in rust + let im = self.re.mul_add(other.im, self.im * other.re); + Complex::new(re, im) + } +} +impl<'a, 'b, T: Clone + Num + MulAdd> MulAdd<&'b Complex, &'b T> for &'a Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: &Complex, add: &T) -> Complex { + self.clone().mul_add(other.clone(), add.clone()) + } +} + +// (a + i b) * (c + i 0) + (e + i 0) == (a*c + e) + i (b*c) +impl> MulAdd for Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: T, add: T) -> Complex { + let re = self.re.mul_add(other.clone(), add); + let im = self.im * other; + Complex::new(re, im) + } +} +impl<'a, 'b, T: Clone + Num + MulAdd> MulAdd<&'b T, &'b T> for &'a Complex { + type Output = Complex; + + #[inline] + fn mul_add(self, other: &T, add: &T) -> Complex { + self.clone().mul_add(other.clone(), add.clone()) + } +} + forward_all_binop!(impl Div, div); // (a + i b) / (c + i d) == [(a + i b) * (c - i d)] / (c*c + d*d) @@ -902,6 +963,59 @@ mod opassign { } } + // (a + i b) * (c + i 0) + (e + i f) == (a*c + e) + i (b*c + f) + impl MulAddAssign> for Complex { + fn mul_add_assign(&mut self, other: T, add: Complex) { + self.re.mul_add_assign(other.clone(), add.re); // (a*c + e) + self.im.mul_add_assign(other, add.im); // (b*c + f) + } + } + + impl<'a, 'b, T: Clone + NumAssign + MulAddAssign> MulAddAssign<&'a T, &'b Complex> + for Complex + { + fn mul_add_assign(&mut self, other: &T, add: &Complex) { + self.mul_add_assign(other.clone(), add.clone()); + } + } + + // (a + i b) * (c + i d) + (e + i 0) == ((a*c + e) - b*d) + i (a*d + b*c) + impl MulAddAssign, T> for Complex { + fn mul_add_assign(&mut self, other: Complex, add: T) { + let a = self.re.clone(); + + self.re.mul_add_assign(other.re.clone(), add); // (a*c + e) + self.re -= self.im.clone() * other.im.clone(); // ((a*c + e) - b*d) + + self.im.mul_add_assign(other.re, a * other.im); // (b*c + a*d) + } + } + + impl<'a, 'b, T: Clone + NumAssign + MulAddAssign> MulAddAssign<&'a Complex, &'b T> + for Complex + { + fn mul_add_assign(&mut self, other: &Complex, add: &T) { + self.mul_add_assign(other.clone(), add.clone()); + } + } + + // (a + i b) * (c + i 0) + (e + i 0) == (a*c + e) + i (b*c) + impl MulAddAssign for Complex { + fn mul_add_assign(&mut self, other: T, add: T) { + + self.re.mul_add_assign(other.clone(), add); // (a*c + e) + self.im *= other; // b * c + } + } + + impl<'a, 'b, T: Clone + NumAssign + MulAddAssign> MulAddAssign<&'a T, &'b T> + for Complex + { + fn mul_add_assign(&mut self, other: &T, add: &T) { + self.mul_add_assign(other.clone(), add.clone()); + } + } + // (a + i b) / (c + i d) == [(a + i b) * (c - i d)] / (c*c + d*d) // == [(a*c + b*d) / (c*c + d*d)] + i [(b*c - a*d) / (c*c + d*d)] impl DivAssign for Complex { @@ -1672,6 +1786,15 @@ pub(crate) mod test { pub const _nan_neg1i: Complex64 = Complex::new(f64::NAN, -1.0); pub const _nan_nani: Complex64 = Complex::new(f64::NAN, f64::NAN); + + // Common integer contants + pub const _0_0i_i32: Complex = Complex { re: 0, im: 0 }; + pub const _1_0i_i32: Complex = Complex { re: 1, im: 0 }; + pub const _1_1i_i32: Complex = Complex { re: 1, im: 1 }; + pub const _0_1i_i32: Complex = Complex { re: 0, im: 1 }; + pub const _neg1_1i_i32: Complex = Complex { re: -1, im: 1 }; + pub const all_consts_i32: [Complex; 5] = [_0_0i_i32, _1_0i_i32, _1_1i_i32, _0_1i_i32, _neg1_1i_i32]; + #[test] fn test_consts() { // check our constants are what Complex::new creates @@ -2503,6 +2626,7 @@ pub(crate) mod test { mod complex_arithmetic { use super::{_05_05i, _0_0i, _0_1i, _1_0i, _1_1i, _4_2i, _neg1_1i, all_consts}; + use super::{_0_0i_i32, _0_1i_i32, _1_0i_i32, _1_1i_i32, _neg1_1i_i32, all_consts_i32}; use num_traits::{MulAdd, MulAddAssign, Zero}; #[test] @@ -2571,28 +2695,101 @@ pub(crate) mod test { } #[test] - fn test_mul_add() { - use super::Complex; - const _0_0i: Complex = Complex { re: 0, im: 0 }; - const _1_0i: Complex = Complex { re: 1, im: 0 }; - const _1_1i: Complex = Complex { re: 1, im: 1 }; - const _0_1i: Complex = Complex { re: 0, im: 1 }; - const _neg1_1i: Complex = Complex { re: -1, im: 1 }; - const all_consts: [Complex; 5] = [_0_0i, _1_0i, _1_1i, _0_1i, _neg1_1i]; - - assert_eq!(_1_0i.mul_add(_1_0i, _0_0i), _1_0i * _1_0i + _0_0i); - assert_eq!(_1_0i * _1_0i + _0_0i, _1_0i.mul_add(_1_0i, _0_0i)); - assert_eq!(_0_1i.mul_add(_0_1i, _0_1i), _neg1_1i); - assert_eq!(_1_0i.mul_add(_1_0i, _1_0i), _1_0i * _1_0i + _1_0i); - assert_eq!(_1_0i * _1_0i + _1_0i, _1_0i.mul_add(_1_0i, _1_0i)); + fn test_mul_add_complex_complex() { + assert_eq!(_1_0i_i32.mul_add(_1_0i_i32, _0_0i_i32), _1_0i_i32 * _1_0i_i32 + _0_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _0_0i_i32, _1_0i_i32.mul_add(_1_0i_i32, _0_0i_i32)); + assert_eq!(_0_1i_i32.mul_add(_0_1i_i32, _0_1i_i32), _neg1_1i_i32); + assert_eq!(_1_0i_i32.mul_add(_1_0i_i32, _1_0i_i32), _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(_1_0i_i32, _1_0i_i32)); + + let mut x = _1_0i_i32; + x.mul_add_assign(_1_0i_i32, _1_0i_i32); + assert_eq!(x, _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + + for &a in &all_consts_i32 { + for &b in &all_consts_i32 { + for &c in &all_consts_i32 { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } - let mut x = _1_0i; - x.mul_add_assign(_1_0i, _1_0i); - assert_eq!(x, _1_0i * _1_0i + _1_0i); + #[test] + fn test_mul_add_real_complex() { + const real_consts: [i32; 3] = [-1, 0, 1]; + + assert_eq!(_1_0i_i32.mul_add(1, _0_0i_i32), _1_0i_i32 * _1_0i_i32 + _0_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _0_0i_i32, _1_0i_i32.mul_add(1, _0_0i_i32)); + assert_eq!(_1_0i_i32.mul_add(1, _1_0i_i32), _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(1, _1_0i_i32)); + assert_eq!(_1_0i_i32.mul_add(0, _1_0i_i32), _1_0i_i32 * _0_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _0_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(0, _1_0i_i32)); + + let mut x = _1_0i_i32; + x.mul_add_assign(1, _1_0i_i32); + assert_eq!(x, _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + + for &a in &all_consts_i32 { + for &b in &[-1, 0, 1] { + for &c in &all_consts_i32 { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } - for &a in &all_consts { - for &b in &all_consts { - for &c in &all_consts { + #[test] + fn test_mul_add_complex_real() { + assert_eq!(_1_0i_i32.mul_add(_1_0i_i32, 0), _1_0i_i32 * _1_0i_i32 + _0_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _0_0i_i32, _1_0i_i32.mul_add(_1_0i_i32, 0)); + assert_eq!(_1_0i_i32.mul_add(_1_0i_i32, 1), _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(_1_0i_i32, 1)); + assert_eq!(_1_0i_i32.mul_add(_0_0i_i32, 1), _1_0i_i32 * _0_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _0_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(_0_0i_i32, 1)); + + let mut x = _1_0i_i32; + x.mul_add_assign(_1_0i_i32, 1); + assert_eq!(x, _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + + for &a in &all_consts_i32 { + for &b in &all_consts_i32 { + for &c in &[-1, 0, 1] { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } + + #[test] + fn test_mul_add_real_real() { + + assert_eq!(_1_0i_i32.mul_add(1, 0), _1_0i_i32 * _1_0i_i32 + _0_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _0_0i_i32, _1_0i_i32.mul_add(1, 0)); + assert_eq!(_1_0i_i32.mul_add(1, 1), _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _1_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(1, 1)); + assert_eq!(_1_0i_i32.mul_add(0, 1), _1_0i_i32 * _0_0i_i32 + _1_0i_i32); + assert_eq!(_1_0i_i32 * _0_0i_i32 + _1_0i_i32, _1_0i_i32.mul_add(0, 1)); + + let mut x = _1_0i_i32; + x.mul_add_assign(1, 1); + assert_eq!(x, _1_0i_i32 * _1_0i_i32 + _1_0i_i32); + + for &a in &all_consts_i32 { + for &b in &[-1, 0, 1] { + for &c in &[-1, 0, 1] { let abc = a * b + c; assert_eq!(a.mul_add(b, c), abc); let mut x = a; From 14f3f346b03b56e2eac1cff75d8fa5b7eb2bed6d Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Wed, 12 Nov 2025 10:10:42 -0700 Subject: [PATCH 2/2] add tests for floats --- src/lib.rs | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 4fcd004..4c43983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2670,7 +2670,7 @@ pub(crate) mod test { #[test] #[cfg(any(feature = "std", feature = "libm"))] - fn test_mul_add_float() { + fn test_mul_add_float_complex_complex() { assert_eq!(_05_05i.mul_add(_05_05i, _0_0i), _05_05i * _05_05i + _0_0i); assert_eq!(_05_05i * _05_05i + _0_0i, _05_05i.mul_add(_05_05i, _0_0i)); assert_eq!(_0_1i.mul_add(_0_1i, _0_1i), _neg1_1i); @@ -2694,6 +2694,81 @@ pub(crate) mod test { } } + #[test] + #[cfg(any(feature = "std", feature = "libm"))] + fn test_mul_add_float_real_complex() { + assert_eq!(_05_05i.mul_add(0.5, _0_0i), _05_05i * 0.5 + _0_0i); + assert_eq!(_05_05i * 0.5 + _0_0i, _05_05i.mul_add(0.5, _0_0i)); + assert_eq!(_1_0i.mul_add(1.0, _1_0i), _1_0i * 1.0 + _1_0i); + assert_eq!(_1_0i * 1.0 + _1_0i, _1_0i.mul_add(1.0, _1_0i)); + + let mut x = _1_0i; + x.mul_add_assign(1.0, _1_0i); + assert_eq!(x, _1_0i * 1.0 + _1_0i); + + for &a in &all_consts { + for &b in &[-1.0, -0.5, 0.0, 0.5, 1.0] { + for &c in &all_consts { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } + + #[test] + #[cfg(any(feature = "std", feature = "libm"))] + fn test_mul_add_float_complex_real() { + assert_eq!(_05_05i.mul_add(_05_05i, 0.0), _05_05i * _05_05i + 0.0); + assert_eq!(_05_05i * _05_05i + 0.0, _05_05i.mul_add(_05_05i, 0.0)); + assert_eq!(_1_0i.mul_add(_1_0i, 1.0), _1_0i * _1_0i + 1.0); + assert_eq!(_1_0i * _1_0i + 1.0, _1_0i.mul_add(_1_0i, 1.0)); + + let mut x = _1_0i; + x.mul_add_assign(_1_0i, 1.0); + assert_eq!(x, _1_0i * _1_0i + 1.0); + + for &a in &all_consts { + for &b in &all_consts { + for &c in &[-1.0, -0.5, 0.0, 0.5, 1.0] { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } + + #[test] + #[cfg(any(feature = "std", feature = "libm"))] + fn test_mul_add_float_real_real() { + assert_eq!(_05_05i.mul_add(0.5, 0.0), _05_05i * 0.5 + 0.0); + assert_eq!(_05_05i * 0.5 + 0.0, _05_05i.mul_add(0.5, 0.0)); + assert_eq!(_1_0i.mul_add(1.0, 1.0), _1_0i * 1.0 + 1.0); + assert_eq!(_1_0i * 1.0 + 1.0, _1_0i.mul_add(1.0, 1.0)); + + let mut x = _1_0i; + x.mul_add_assign(1.0, 1.0); + assert_eq!(x, _1_0i * 1.0 + 1.0); + + for &a in &all_consts { + for &b in &[-1.0, -0.5, 0.0, 0.5, 1.0] { + for &c in &[-1.0, -0.5, 0.0, 0.5, 1.0] { + let abc = a * b + c; + assert_eq!(a.mul_add(b, c), abc); + let mut x = a; + x.mul_add_assign(b, c); + assert_eq!(x, abc); + } + } + } + } + #[test] fn test_mul_add_complex_complex() { assert_eq!(_1_0i_i32.mul_add(_1_0i_i32, _0_0i_i32), _1_0i_i32 * _1_0i_i32 + _0_0i_i32);