Skip to content

Commit f3a73d5

Browse files
committed
Make the entire crate compatible with the array reference type
1 parent 2b07b9e commit f3a73d5

File tree

15 files changed

+120
-226
lines changed

15 files changed

+120
-226
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"]
1616
categories = ["data-structures", "science"]
1717

1818
[dependencies]
19-
ndarray = "0.16.0"
19+
ndarray = "0.17.1"
2020
noisy_float = "0.2.0"
2121
num-integer = "0.1"
2222
num-traits = "0.2"
@@ -25,10 +25,10 @@ itertools = { version = "0.13", default-features = false }
2525
indexmap = "2.4"
2626

2727
[dev-dependencies]
28-
ndarray = { version = "0.16.1", features = ["approx"] }
28+
ndarray = { version = "0.17.1", features = ["approx"] }
2929
criterion = "0.3"
3030
quickcheck = { version = "0.9.2", default-features = false }
31-
ndarray-rand = "0.15.0"
31+
ndarray-rand = "0.16.0"
3232
approx = "0.5"
3333
quickcheck_macros = "1.0.0"
3434
num-bigint = "0.4.0"

benches/deviation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn sq_l2_dist(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let data2 = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let data2 = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717

1818
b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap()))
1919
});

benches/summary_statistics.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn weighted_std(c: &mut Criterion) {
1212
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
1313
for len in &lens {
1414
group.bench_with_input(format!("{}", len), len, |b, &len| {
15-
let data = Array::random(len, Uniform::new(0.0, 1.0));
16-
let mut weights = Array::random(len, Uniform::new(0.0, 1.0));
15+
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
16+
let mut weights = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
1717
weights /= weights.sum();
1818
b.iter_batched(
1919
|| data.clone(),

src/correlation.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use crate::errors::EmptyInput;
22
use ndarray::prelude::*;
3-
use ndarray::Data;
43
use num_traits::{Float, FromPrimitive};
54

6-
/// Extension trait for `ArrayBase` providing functions
5+
/// Extension trait for `ndarray` providing functions
76
/// to compute different correlation measures.
8-
pub trait CorrelationExt<A, S>
9-
where
10-
S: Data<Elem = A>,
11-
{
7+
pub trait CorrelationExt<A> {
128
/// Return the covariance matrix `C` for a 2-dimensional
139
/// array of observations `M`.
1410
///
@@ -125,10 +121,7 @@ where
125121
private_decl! {}
126122
}
127123

128-
impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
129-
where
130-
S: Data<Elem = A>,
131-
{
124+
impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
132125
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
133126
where
134127
A: Float + FromPrimitive,
@@ -147,7 +140,7 @@ where
147140
let mean = self.mean_axis(observation_axis);
148141
match mean {
149142
Some(mean) => {
150-
let denoised = self - &mean.insert_axis(observation_axis);
143+
let denoised = self - mean.insert_axis(observation_axis);
151144
let covariance = denoised.dot(&denoised.t());
152145
Ok(covariance.mapv_into(|x| x / dof))
153146
}
@@ -208,7 +201,7 @@ mod cov_tests {
208201
let n_observations = 4;
209202
let a = Array::random(
210203
(n_random_variables, n_observations),
211-
Uniform::new(-bound.abs(), bound.abs()),
204+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
212205
);
213206
let covariance = a.cov(1.).unwrap();
214207
abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
@@ -219,7 +212,10 @@ mod cov_tests {
219212
fn test_invalid_ddof() {
220213
let n_random_variables = 3;
221214
let n_observations = 4;
222-
let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
215+
let a = Array::random(
216+
(n_random_variables, n_observations),
217+
Uniform::new(0., 10.).unwrap(),
218+
);
223219
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
224220
let _ = a.cov(invalid_ddof);
225221
}
@@ -299,7 +295,7 @@ mod pearson_correlation_tests {
299295
let n_observations = 4;
300296
let a = Array::random(
301297
(n_random_variables, n_observations),
302-
Uniform::new(-bound.abs(), bound.abs()),
298+
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
303299
);
304300
let pearson_correlation = a.pearson_correlation().unwrap();
305301
abs_diff_eq!(

src/deviation.rs

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
use ndarray::{ArrayBase, Data, Dimension, Zip};
1+
use ndarray::{ArrayRef, Dimension, Zip};
22
use num_traits::{Signed, ToPrimitive};
33
use std::convert::Into;
44
use std::ops::AddAssign;
55

66
use crate::errors::MultiInputError;
77

8-
/// An extension trait for `ArrayBase` providing functions
8+
/// An extension trait for `ndarray` providing functions
99
/// to compute different deviation measures.
10-
pub trait DeviationExt<A, S, D>
10+
pub trait DeviationExt<A, D>
1111
where
12-
S: Data<Elem = A>,
1312
D: Dimension,
1413
{
1514
/// Counts the number of indices at which the elements of the arrays `self`
@@ -19,10 +18,9 @@ where
1918
///
2019
/// * `MultiInputError::EmptyInput` if `self` is empty
2120
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22-
fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
21+
fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
2322
where
24-
A: PartialEq,
25-
T: Data<Elem = A>;
23+
A: PartialEq;
2624

2725
/// Counts the number of indices at which the elements of the arrays `self`
2826
/// and `other` are not equal.
@@ -31,10 +29,9 @@ where
3129
///
3230
/// * `MultiInputError::EmptyInput` if `self` is empty
3331
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
34-
fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
32+
fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
3533
where
36-
A: PartialEq,
37-
T: Data<Elem = A>;
34+
A: PartialEq;
3835

3936
/// Computes the [squared L2 distance] between `self` and `other`.
4037
///
@@ -52,10 +49,9 @@ where
5249
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
5350
///
5451
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
55-
fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
52+
fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
5653
where
57-
A: AddAssign + Clone + Signed,
58-
T: Data<Elem = A>;
54+
A: AddAssign + Clone + Signed;
5955

6056
/// Computes the [L2 distance] between `self` and `other`.
6157
///
@@ -75,10 +71,9 @@ where
7571
/// **Panics** if the type cast from `A` to `f64` fails.
7672
///
7773
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
78-
fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
74+
fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
7975
where
80-
A: AddAssign + Clone + Signed + ToPrimitive,
81-
T: Data<Elem = A>;
76+
A: AddAssign + Clone + Signed + ToPrimitive;
8277

8378
/// Computes the [L1 distance] between `self` and `other`.
8479
///
@@ -96,10 +91,9 @@ where
9691
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
9792
///
9893
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
99-
fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
94+
fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
10095
where
101-
A: AddAssign + Clone + Signed,
102-
T: Data<Elem = A>;
96+
A: AddAssign + Clone + Signed;
10397

10498
/// Computes the [L∞ distance] between `self` and `other`.
10599
///
@@ -116,10 +110,9 @@ where
116110
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
117111
///
118112
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
119-
fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
113+
fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
120114
where
121-
A: Clone + PartialOrd + Signed,
122-
T: Data<Elem = A>;
115+
A: Clone + PartialOrd + Signed;
123116

124117
/// Computes the [mean absolute error] between `self` and `other`.
125118
///
@@ -139,10 +132,9 @@ where
139132
/// **Panics** if the type cast from `A` to `f64` fails.
140133
///
141134
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
142-
fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
135+
fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
143136
where
144-
A: AddAssign + Clone + Signed + ToPrimitive,
145-
T: Data<Elem = A>;
137+
A: AddAssign + Clone + Signed + ToPrimitive;
146138

147139
/// Computes the [mean squared error] between `self` and `other`.
148140
///
@@ -162,10 +154,9 @@ where
162154
/// **Panics** if the type cast from `A` to `f64` fails.
163155
///
164156
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
165-
fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
157+
fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
166158
where
167-
A: AddAssign + Clone + Signed + ToPrimitive,
168-
T: Data<Elem = A>;
159+
A: AddAssign + Clone + Signed + ToPrimitive;
169160

170161
/// Computes the unnormalized [root-mean-square error] between `self` and `other`.
171162
///
@@ -183,10 +174,9 @@ where
183174
/// **Panics** if the type cast from `A` to `f64` fails.
184175
///
185176
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
186-
fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
177+
fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
187178
where
188-
A: AddAssign + Clone + Signed + ToPrimitive,
189-
T: Data<Elem = A>;
179+
A: AddAssign + Clone + Signed + ToPrimitive;
190180

191181
/// Computes the [peak signal-to-noise ratio] between `self` and `other`.
192182
///
@@ -205,27 +195,24 @@ where
205195
/// **Panics** if the type cast from `A` to `f64` fails.
206196
///
207197
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
208-
fn peak_signal_to_noise_ratio<T>(
198+
fn peak_signal_to_noise_ratio(
209199
&self,
210-
other: &ArrayBase<T, D>,
200+
other: &ArrayRef<A, D>,
211201
maxv: A,
212202
) -> Result<f64, MultiInputError>
213203
where
214-
A: AddAssign + Clone + Signed + ToPrimitive,
215-
T: Data<Elem = A>;
204+
A: AddAssign + Clone + Signed + ToPrimitive;
216205

217206
private_decl! {}
218207
}
219208

220-
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
209+
impl<A, D> DeviationExt<A, D> for ArrayRef<A, D>
221210
where
222-
S: Data<Elem = A>,
223211
D: Dimension,
224212
{
225-
fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
213+
fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
226214
where
227215
A: PartialEq,
228-
T: Data<Elem = A>,
229216
{
230217
return_err_if_empty!(self);
231218
return_err_unless_same_shape!(self, other);
@@ -241,18 +228,16 @@ where
241228
Ok(count)
242229
}
243230

244-
fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
231+
fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
245232
where
246233
A: PartialEq,
247-
T: Data<Elem = A>,
248234
{
249235
self.count_eq(other).map(|n_eq| self.len() - n_eq)
250236
}
251237

252-
fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
238+
fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
253239
where
254240
A: AddAssign + Clone + Signed,
255-
T: Data<Elem = A>,
256241
{
257242
return_err_if_empty!(self);
258243
return_err_unless_same_shape!(self, other);
@@ -268,10 +253,9 @@ where
268253
Ok(result)
269254
}
270255

271-
fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
256+
fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
272257
where
273258
A: AddAssign + Clone + Signed + ToPrimitive,
274-
T: Data<Elem = A>,
275259
{
276260
let sq_l2_dist = self
277261
.sq_l2_dist(other)?
@@ -281,10 +265,9 @@ where
281265
Ok(sq_l2_dist.sqrt())
282266
}
283267

284-
fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
268+
fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
285269
where
286270
A: AddAssign + Clone + Signed,
287-
T: Data<Elem = A>,
288271
{
289272
return_err_if_empty!(self);
290273
return_err_unless_same_shape!(self, other);
@@ -299,10 +282,9 @@ where
299282
Ok(result)
300283
}
301284

302-
fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
285+
fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
303286
where
304287
A: Clone + PartialOrd + Signed,
305-
T: Data<Elem = A>,
306288
{
307289
return_err_if_empty!(self);
308290
return_err_unless_same_shape!(self, other);
@@ -320,10 +302,9 @@ where
320302
Ok(max)
321303
}
322304

323-
fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
305+
fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
324306
where
325307
A: AddAssign + Clone + Signed + ToPrimitive,
326-
T: Data<Elem = A>,
327308
{
328309
let l1_dist = self
329310
.l1_dist(other)?
@@ -334,10 +315,9 @@ where
334315
Ok(l1_dist / n)
335316
}
336317

337-
fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
318+
fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
338319
where
339320
A: AddAssign + Clone + Signed + ToPrimitive,
340-
T: Data<Elem = A>,
341321
{
342322
let sq_l2_dist = self
343323
.sq_l2_dist(other)?
@@ -348,23 +328,21 @@ where
348328
Ok(sq_l2_dist / n)
349329
}
350330

351-
fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
331+
fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
352332
where
353333
A: AddAssign + Clone + Signed + ToPrimitive,
354-
T: Data<Elem = A>,
355334
{
356335
let msd = self.mean_sq_err(other)?;
357336
Ok(msd.sqrt())
358337
}
359338

360-
fn peak_signal_to_noise_ratio<T>(
339+
fn peak_signal_to_noise_ratio(
361340
&self,
362-
other: &ArrayBase<T, D>,
341+
other: &ArrayRef<A, D>,
363342
maxv: A,
364343
) -> Result<f64, MultiInputError>
365344
where
366345
A: AddAssign + Clone + Signed + ToPrimitive,
367-
T: Data<Elem = A>,
368346
{
369347
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
370348
let msd = self.mean_sq_err(&other)?;

0 commit comments

Comments
 (0)