1- use ndarray:: { ArrayBase , Data , Dimension , Zip } ;
1+ use ndarray:: { ArrayRef , Dimension , Zip } ;
22use num_traits:: { Signed , ToPrimitive } ;
33use std:: convert:: Into ;
44use std:: ops:: AddAssign ;
55
66use crate :: errors:: MultiInputError ;
77
8- /// An extension trait for `ArrayBase ` providing functions
8+ /// An extension trait for `ndarray ` providing functions
99/// to compute different deviation measures.
10- pub trait DeviationExt < A , S , D >
10+ pub trait DeviationExt < A , D >
1111where
12- S : Data < Elem = A > ,
1312 D : Dimension ,
1413{
1514 /// Counts the number of indices at which the elements of the arrays `self`
1918 ///
2019 /// * `MultiInputError::EmptyInput` if `self` is empty
2120 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22- fn count_eq < T > ( & self , other : & ArrayBase < T , D > ) -> Result < usize , MultiInputError >
21+ fn count_eq ( & self , other : & ArrayRef < A , D > ) -> Result < usize , MultiInputError >
2322 where
24- A : PartialEq ,
25- T : Data < Elem = A > ;
23+ A : PartialEq ;
2624
2725 /// Counts the number of indices at which the elements of the arrays `self`
2826 /// and `other` are not equal.
3129 ///
3230 /// * `MultiInputError::EmptyInput` if `self` is empty
3331 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
34- fn count_neq < T > ( & self , other : & ArrayBase < T , D > ) -> Result < usize , MultiInputError >
32+ fn count_neq ( & self , other : & ArrayRef < A , D > ) -> Result < usize , MultiInputError >
3533 where
36- A : PartialEq ,
37- T : Data < Elem = A > ;
34+ A : PartialEq ;
3835
3936 /// Computes the [squared L2 distance] between `self` and `other`.
4037 ///
5249 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
5350 ///
5451 /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
55- fn sq_l2_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
52+ fn sq_l2_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
5653 where
57- A : AddAssign + Clone + Signed ,
58- T : Data < Elem = A > ;
54+ A : AddAssign + Clone + Signed ;
5955
6056 /// Computes the [L2 distance] between `self` and `other`.
6157 ///
7571 /// **Panics** if the type cast from `A` to `f64` fails.
7672 ///
7773 /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
78- fn l2_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
74+ fn l2_dist ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
7975 where
80- A : AddAssign + Clone + Signed + ToPrimitive ,
81- T : Data < Elem = A > ;
76+ A : AddAssign + Clone + Signed + ToPrimitive ;
8277
8378 /// Computes the [L1 distance] between `self` and `other`.
8479 ///
9691 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
9792 ///
9893 /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
99- fn l1_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
94+ fn l1_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
10095 where
101- A : AddAssign + Clone + Signed ,
102- T : Data < Elem = A > ;
96+ A : AddAssign + Clone + Signed ;
10397
10498 /// Computes the [L∞ distance] between `self` and `other`.
10599 ///
@@ -116,10 +110,9 @@ where
116110 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
117111 ///
118112 /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
119- fn linf_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
113+ fn linf_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
120114 where
121- A : Clone + PartialOrd + Signed ,
122- T : Data < Elem = A > ;
115+ A : Clone + PartialOrd + Signed ;
123116
124117 /// Computes the [mean absolute error] between `self` and `other`.
125118 ///
@@ -139,10 +132,9 @@ where
139132 /// **Panics** if the type cast from `A` to `f64` fails.
140133 ///
141134 /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
142- fn mean_abs_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
135+ fn mean_abs_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
143136 where
144- A : AddAssign + Clone + Signed + ToPrimitive ,
145- T : Data < Elem = A > ;
137+ A : AddAssign + Clone + Signed + ToPrimitive ;
146138
147139 /// Computes the [mean squared error] between `self` and `other`.
148140 ///
@@ -162,10 +154,9 @@ where
162154 /// **Panics** if the type cast from `A` to `f64` fails.
163155 ///
164156 /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
165- fn mean_sq_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
157+ fn mean_sq_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
166158 where
167- A : AddAssign + Clone + Signed + ToPrimitive ,
168- T : Data < Elem = A > ;
159+ A : AddAssign + Clone + Signed + ToPrimitive ;
169160
170161 /// Computes the unnormalized [root-mean-square error] between `self` and `other`.
171162 ///
@@ -183,10 +174,9 @@ where
183174 /// **Panics** if the type cast from `A` to `f64` fails.
184175 ///
185176 /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
186- fn root_mean_sq_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
177+ fn root_mean_sq_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
187178 where
188- A : AddAssign + Clone + Signed + ToPrimitive ,
189- T : Data < Elem = A > ;
179+ A : AddAssign + Clone + Signed + ToPrimitive ;
190180
191181 /// Computes the [peak signal-to-noise ratio] between `self` and `other`.
192182 ///
@@ -205,27 +195,24 @@ where
205195 /// **Panics** if the type cast from `A` to `f64` fails.
206196 ///
207197 /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
208- fn peak_signal_to_noise_ratio < T > (
198+ fn peak_signal_to_noise_ratio (
209199 & self ,
210- other : & ArrayBase < T , D > ,
200+ other : & ArrayRef < A , D > ,
211201 maxv : A ,
212202 ) -> Result < f64 , MultiInputError >
213203 where
214- A : AddAssign + Clone + Signed + ToPrimitive ,
215- T : Data < Elem = A > ;
204+ A : AddAssign + Clone + Signed + ToPrimitive ;
216205
217206 private_decl ! { }
218207}
219208
220- impl < A , S , D > DeviationExt < A , S , D > for ArrayBase < S , D >
209+ impl < A , D > DeviationExt < A , D > for ArrayRef < A , D >
221210where
222- S : Data < Elem = A > ,
223211 D : Dimension ,
224212{
225- fn count_eq < T > ( & self , other : & ArrayBase < T , D > ) -> Result < usize , MultiInputError >
213+ fn count_eq ( & self , other : & ArrayRef < A , D > ) -> Result < usize , MultiInputError >
226214 where
227215 A : PartialEq ,
228- T : Data < Elem = A > ,
229216 {
230217 return_err_if_empty ! ( self ) ;
231218 return_err_unless_same_shape ! ( self , other) ;
@@ -241,18 +228,16 @@ where
241228 Ok ( count)
242229 }
243230
244- fn count_neq < T > ( & self , other : & ArrayBase < T , D > ) -> Result < usize , MultiInputError >
231+ fn count_neq ( & self , other : & ArrayRef < A , D > ) -> Result < usize , MultiInputError >
245232 where
246233 A : PartialEq ,
247- T : Data < Elem = A > ,
248234 {
249235 self . count_eq ( other) . map ( |n_eq| self . len ( ) - n_eq)
250236 }
251237
252- fn sq_l2_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
238+ fn sq_l2_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
253239 where
254240 A : AddAssign + Clone + Signed ,
255- T : Data < Elem = A > ,
256241 {
257242 return_err_if_empty ! ( self ) ;
258243 return_err_unless_same_shape ! ( self , other) ;
@@ -268,10 +253,9 @@ where
268253 Ok ( result)
269254 }
270255
271- fn l2_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
256+ fn l2_dist ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
272257 where
273258 A : AddAssign + Clone + Signed + ToPrimitive ,
274- T : Data < Elem = A > ,
275259 {
276260 let sq_l2_dist = self
277261 . sq_l2_dist ( other) ?
@@ -281,10 +265,9 @@ where
281265 Ok ( sq_l2_dist. sqrt ( ) )
282266 }
283267
284- fn l1_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
268+ fn l1_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
285269 where
286270 A : AddAssign + Clone + Signed ,
287- T : Data < Elem = A > ,
288271 {
289272 return_err_if_empty ! ( self ) ;
290273 return_err_unless_same_shape ! ( self , other) ;
@@ -299,10 +282,9 @@ where
299282 Ok ( result)
300283 }
301284
302- fn linf_dist < T > ( & self , other : & ArrayBase < T , D > ) -> Result < A , MultiInputError >
285+ fn linf_dist ( & self , other : & ArrayRef < A , D > ) -> Result < A , MultiInputError >
303286 where
304287 A : Clone + PartialOrd + Signed ,
305- T : Data < Elem = A > ,
306288 {
307289 return_err_if_empty ! ( self ) ;
308290 return_err_unless_same_shape ! ( self , other) ;
@@ -320,10 +302,9 @@ where
320302 Ok ( max)
321303 }
322304
323- fn mean_abs_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
305+ fn mean_abs_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
324306 where
325307 A : AddAssign + Clone + Signed + ToPrimitive ,
326- T : Data < Elem = A > ,
327308 {
328309 let l1_dist = self
329310 . l1_dist ( other) ?
@@ -334,10 +315,9 @@ where
334315 Ok ( l1_dist / n)
335316 }
336317
337- fn mean_sq_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
318+ fn mean_sq_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
338319 where
339320 A : AddAssign + Clone + Signed + ToPrimitive ,
340- T : Data < Elem = A > ,
341321 {
342322 let sq_l2_dist = self
343323 . sq_l2_dist ( other) ?
@@ -348,23 +328,21 @@ where
348328 Ok ( sq_l2_dist / n)
349329 }
350330
351- fn root_mean_sq_err < T > ( & self , other : & ArrayBase < T , D > ) -> Result < f64 , MultiInputError >
331+ fn root_mean_sq_err ( & self , other : & ArrayRef < A , D > ) -> Result < f64 , MultiInputError >
352332 where
353333 A : AddAssign + Clone + Signed + ToPrimitive ,
354- T : Data < Elem = A > ,
355334 {
356335 let msd = self . mean_sq_err ( other) ?;
357336 Ok ( msd. sqrt ( ) )
358337 }
359338
360- fn peak_signal_to_noise_ratio < T > (
339+ fn peak_signal_to_noise_ratio (
361340 & self ,
362- other : & ArrayBase < T , D > ,
341+ other : & ArrayRef < A , D > ,
363342 maxv : A ,
364343 ) -> Result < f64 , MultiInputError >
365344 where
366345 A : AddAssign + Clone + Signed + ToPrimitive ,
367- T : Data < Elem = A > ,
368346 {
369347 let maxv_f = maxv. to_f64 ( ) . expect ( "failed cast from type A to f64" ) ;
370348 let msd = self . mean_sq_err ( & other) ?;
0 commit comments