diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 0b14b5bfae8..efd68d5da54 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -106,10 +106,18 @@ harness = false name = "filter_kernels" harness = false +[[bench]] +name = "math_kernels" +harness = false + [[bench]] name = "take_kernels" harness = false +[[bench]] +name = "trigonometry_kernels" +harness = false + [[bench]] name = "length_kernel" harness = false diff --git a/rust/arrow/benches/math_kernels.rs b/rust/arrow/benches/math_kernels.rs new file mode 100644 index 00000000000..c0c0e68b5a3 --- /dev/null +++ b/rust/arrow/benches/math_kernels.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::datatypes::Float32Type; +use arrow::util::bench_util::*; +use arrow::{array::*, compute::*}; + +fn bench_powf(array: &Float32Array, raise: f32) { + criterion::black_box(powf_scalar(array, raise)); +} + +fn bench_powi(array: &Float32Array, raise: i32) { + criterion::black_box(powi(array, raise)); +} + +fn add_benchmark(c: &mut Criterion) { + let array = create_primitive_array::(512, 0.0); + let array_nulls = create_primitive_array::(512, 0.5); + + c.bench_function("powf(2.0) 512", |b| b.iter(|| bench_powf(&array, 2.0))); + c.bench_function("powf(4.0) 512", |b| b.iter(|| bench_powf(&array, 4.0))); + c.bench_function("powf(32.0) 512", |b| b.iter(|| bench_powf(&array, 32.0))); + c.bench_function("powf(2.0) nulls 512", |b| { + b.iter(|| bench_powf(&array_nulls, 2.0)) + }); + c.bench_function("powf(4.0) nulls 512", |b| { + b.iter(|| bench_powf(&array_nulls, 4.0)) + }); + c.bench_function("powf(32.0) nulls 512", |b| { + b.iter(|| bench_powf(&array_nulls, 32.0)) + }); + + c.bench_function("powi(2) 512", |b| b.iter(|| bench_powi(&array, 2))); + c.bench_function("powi(4) 512", |b| b.iter(|| bench_powi(&array, 4))); + c.bench_function("powi(32) 512", |b| b.iter(|| bench_powi(&array, 32))); + c.bench_function("powi(2) nulls 512", |b| { + b.iter(|| bench_powi(&array_nulls, 2)) + }); + c.bench_function("powi(4) nulls 512", |b| { + b.iter(|| bench_powi(&array_nulls, 4)) + }); + c.bench_function("powi(32) nulls 512", |b| { + b.iter(|| bench_powi(&array_nulls, 32)) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/benches/trigonometry_kernels.rs b/rust/arrow/benches/trigonometry_kernels.rs new file mode 100644 index 00000000000..37f3ba96c6d --- /dev/null +++ b/rust/arrow/benches/trigonometry_kernels.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::{ + array::*, compute::*, datatypes::Float32Type, + util::bench_util::create_primitive_array, +}; + +fn bench_haversine( + lat_a: &Float32Array, + lng_a: &Float32Array, + lat_b: &Float32Array, + lng_b: &Float32Array, +) { + criterion::black_box(haversine(lat_a, lng_a, lat_b, lng_b, 6371000.0).unwrap()); +} +fn bench_sin(array: &Float32Array) { + criterion::black_box(sin(array)); +} + +fn bench_cos(array: &Float32Array) { + criterion::black_box(cos(array)); +} + +fn bench_tan(array: &Float32Array) { + criterion::black_box(tan(array)); +} + +fn add_benchmark(c: &mut Criterion) { + let lat_a = create_primitive_array(512, 0.0); + let lng_a = create_primitive_array(512, 0.0); + let lat_b = create_primitive_array(512, 0.0); + let lng_b = create_primitive_array(512, 0.0); + + c.bench_function("haversine_unary 512", |b| { + b.iter(|| bench_haversine(&lat_a, &lng_a, &lat_b, &lng_b)) + }); + + let lat_a = create_primitive_array(512, 0.5); + let lng_a = create_primitive_array(512, 0.5); + let lat_b = create_primitive_array(512, 0.5); + let lng_b = create_primitive_array(512, 0.5); + + c.bench_function("haversine_unary_nulls 512", |b| { + b.iter(|| bench_haversine(&lat_a, &lng_a, &lat_b, &lng_b)) + }); + + let array = create_primitive_array::(512, 0.0); + let array_nulls = create_primitive_array::(512, 0.5); + + c.bench_function("sin 512", |b| b.iter(|| bench_sin(&array))); + c.bench_function("cos 512", |b| b.iter(|| bench_cos(&array))); + c.bench_function("tan 512", |b| b.iter(|| bench_tan(&array))); + c.bench_function("sin nulls 512", |b| b.iter(|| bench_sin(&array_nulls))); + c.bench_function("cos nulls 512", |b| b.iter(|| bench_cos(&array_nulls))); + c.bench_function("tan nulls 512", |b| b.iter(|| bench_tan(&array_nulls))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/compute/kernels/arithmetic.rs b/rust/arrow/src/compute/kernels/arithmetic.rs index 067756662cf..cc508f4b6ed 100644 --- a/rust/arrow/src/compute/kernels/arithmetic.rs +++ b/rust/arrow/src/compute/kernels/arithmetic.rs @@ -22,7 +22,11 @@ //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +#[cfg(simd)] +use std::borrow::BorrowMut; use std::ops::{Add, Div, Mul, Neg, Sub}; +#[cfg(simd)] +use std::slice::{ChunksExact, ChunksExactMut}; use std::sync::Arc; use num::{One, Zero}; @@ -30,16 +34,13 @@ use num::{One, Zero}; use crate::buffer::Buffer; #[cfg(simd)] use crate::buffer::MutableBuffer; -use crate::compute::{kernels::arity::unary, util::combine_option_bitmap}; +#[cfg(not(simd))] +use crate::compute::kernels::arity::unary; +use crate::compute::util::combine_option_bitmap; use crate::datatypes; use crate::datatypes::ArrowNumericType; use crate::error::{ArrowError, Result}; use crate::{array::*, util::bit_util}; -use num::traits::Pow; -#[cfg(simd)] -use std::borrow::BorrowMut; -#[cfg(simd)] -use std::slice::{ChunksExact, ChunksExactMut}; /// SIMD vectorized version of `unary_math_op` above specialized for signed numerical values. #[cfg(simd)] @@ -90,55 +91,6 @@ where Ok(PrimitiveArray::::from(Arc::new(data))) } -#[cfg(simd)] -fn simd_float_unary_math_op( - array: &PrimitiveArray, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, -) -> Result> -where - T: datatypes::ArrowFloatNumericType, - SIMD_OP: Fn(T::Simd) -> T::Simd, - SCALAR_OP: Fn(T::Native) -> T::Native, -{ - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut array_chunks = array.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, input_slice)| { - let simd_input = T::load(input_slice); - let simd_result = T::unary_op(simd_input, &simd_op); - T::write(simd_result, result_slice); - }); - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder.into_iter().zip(array_remainder).for_each( - |(scalar_result, scalar_input)| { - *scalar_result = scalar_op(*scalar_input); - }, - ); - - let data = ArrayData::new( - T::DATA_TYPE, - array.len(), - None, - array.data_ref().null_buffer().cloned(), - 0, - vec![result.into()], - vec![], - ); - Ok(PrimitiveArray::::from(Arc::new(data))) -} - /// Helper function to perform math lambda function on values from two arrays. If either /// left or right value is null then the output value is also null, so `1 + null` is /// `null`. @@ -558,28 +510,6 @@ where return Ok(unary(array, |x| -x)); } -/// Raise array with floating point values to the power of a scalar. -pub fn powf_scalar( - array: &PrimitiveArray, - raise: T::Native, -) -> Result> -where - T: datatypes::ArrowFloatNumericType, - T::Native: Pow, -{ - #[cfg(simd)] - { - let raise_vector = T::init(raise); - return simd_float_unary_math_op( - array, - |x| T::pow(x, raise_vector), - |x| x.pow(raise), - ); - } - #[cfg(not(simd))] - return Ok(unary(array, |x| x.pow(raise))); -} - /// Perform `left * right` operation on two arrays. If either left or right value is null /// then the result is also null. pub fn multiply( @@ -849,16 +779,4 @@ mod tests { .collect(); assert_eq!(expected, actual); } - - #[test] - fn test_primitive_array_raise_power_scalar() { - let a = Float64Array::from(vec![1.0, 2.0, 3.0]); - let actual = powf_scalar(&a, 2.0).unwrap(); - let expected = Float64Array::from(vec![1.0, 4.0, 9.0]); - assert_eq!(expected, actual); - let a = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); - let actual = powf_scalar(&a, 2.0).unwrap(); - let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]); - assert_eq!(expected, actual); - } } diff --git a/rust/arrow/src/compute/kernels/arity.rs b/rust/arrow/src/compute/kernels/arity.rs index 11139f83270..9a62004c048 100644 --- a/rust/arrow/src/compute/kernels/arity.rs +++ b/rust/arrow/src/compute/kernels/arity.rs @@ -17,12 +17,21 @@ //! Defines kernels suitable to perform operations to primitive arrays. +#[cfg(simd)] +use std::sync::Arc; + use crate::array::{Array, ArrayData, PrimitiveArray}; use crate::buffer::Buffer; +#[cfg(simd)] +use crate::buffer::MutableBuffer; use crate::datatypes::ArrowPrimitiveType; +#[cfg(simd)] +use crate::datatypes::{ArrowFloatNumericType, ArrowNumericType}; +#[cfg(simd)] +use std::borrow::BorrowMut; #[inline] -fn into_primitive_array_data( +pub(super) fn into_primitive_array_data( array: &PrimitiveArray, buffer: Buffer, ) -> ArrayData { @@ -72,3 +81,146 @@ where let data = into_primitive_array_data::<_, O>(array, buffer); PrimitiveArray::::from(std::sync::Arc::new(data)) } + +#[cfg(simd)] +pub(crate) fn simd_unary_float( + array: &PrimitiveArray, + simd_op: SIMD_OP, + scalar_op: SCALAR_OP, +) -> PrimitiveArray +where + T: ArrowFloatNumericType, + SIMD_OP: Fn(T::Simd) -> T::Simd, + SCALAR_OP: Fn(T::Native) -> T::Native, +{ + let lanes = T::lanes(); + let buffer_size = array.len() * std::mem::size_of::(); + + let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); + + let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + let mut array_chunks = array.values().chunks_exact(lanes); + + result_chunks + .borrow_mut() + .zip(array_chunks.borrow_mut()) + .for_each(|(result_slice, input_slice)| { + let simd_input = T::load(input_slice); + let simd_result = T::unary_op(simd_input, &simd_op); + T::write(simd_result, result_slice); + }); + + let result_remainder = result_chunks.into_remainder(); + let array_remainder = array_chunks.remainder(); + + result_remainder.into_iter().zip(array_remainder).for_each( + |(scalar_result, scalar_input)| { + *scalar_result = scalar_op(*scalar_input); + }, + ); + + let data = ArrayData::new( + T::DATA_TYPE, + array.len(), + None, + array.data_ref().null_buffer().cloned(), + 0, + vec![result.into()], + vec![], + ); + PrimitiveArray::::from(Arc::new(data)) +} + +#[cfg(simd)] +pub(crate) fn simd_unary( + array: &PrimitiveArray, + simd_op: SIMD_OP, + scalar_op: SCALAR_OP, +) -> PrimitiveArray +where + T: ArrowNumericType, + SIMD_OP: Fn(T::Simd) -> T::Simd, + SCALAR_OP: Fn(T::Native) -> T::Native, +{ + let lanes = T::lanes(); + let buffer_size = array.len() * std::mem::size_of::(); + + let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); + + let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + let mut array_chunks = array.values().chunks_exact(lanes); + + result_chunks + .borrow_mut() + .zip(array_chunks.borrow_mut()) + .for_each(|(result_slice, input_slice)| { + let simd_input = T::load(input_slice); + let simd_result = T::unary_op(simd_input, &simd_op); + T::write(simd_result, result_slice); + }); + + let result_remainder = result_chunks.into_remainder(); + let array_remainder = array_chunks.remainder(); + + result_remainder.into_iter().zip(array_remainder).for_each( + |(scalar_result, scalar_input)| { + *scalar_result = scalar_op(*scalar_input); + }, + ); + + let data = ArrayData::new( + T::DATA_TYPE, + array.len(), + None, + array.data_ref().null_buffer().cloned(), + 0, + vec![result.into()], + vec![], + ); + PrimitiveArray::::from(Arc::new(data)) +} + +#[macro_export] +macro_rules! float_unary { + ($name:ident) => { + pub fn $name(array: &PrimitiveArray) -> PrimitiveArray + where + T: ArrowNumericType + ArrowFloatNumericType, + T::Native: num::traits::Float, + { + return unary(array, |a| a.$name()); + } + }; +} + +#[macro_export] +macro_rules! float_unary_simd { + ($name:ident) => { + /// Calculate the `$name` of a floating point array + pub fn $name(array: &PrimitiveArray) -> PrimitiveArray + where + T: ArrowNumericType + ArrowFloatNumericType, + T::Native: num::traits::Float, + { + #[cfg(simd)] + { + return simd_unary(array, |x| T::$name(x), |x| x.$name()); + } + #[cfg(not(simd))] + return unary(array, |x| x.$name()); + } + }; +} + +#[macro_export] +macro_rules! unary { + ($name:ident) => { + pub fn $name(array: &PrimitiveArray) -> PrimitiveArray + where + T: ArrowNumericType, + T::Native: num::traits::Float, + { + return unary(array, |a| a.$name()); + } + }; +} diff --git a/rust/arrow/src/compute/kernels/math.rs b/rust/arrow/src/compute/kernels/math.rs new file mode 100644 index 00000000000..a0f5135c415 --- /dev/null +++ b/rust/arrow/src/compute/kernels/math.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines basic math kernels for `PrimitiveArrays`. + +use num::{traits::Pow, Float}; + +use crate::datatypes::{ArrowFloatNumericType, ArrowNumericType}; +use crate::{array::*, float_unary_simd}; + +#[cfg(simd)] +use super::arity::simd_unary; +use super::arity::unary; + +float_unary_simd!(sqrt); + +/// Raise a floating point array to the power of a float scalar. +pub fn powf_scalar(array: &PrimitiveArray, raise: T::Native) -> PrimitiveArray +where + T: ArrowFloatNumericType, + T::Native: Pow, +{ + #[cfg(simd)] + { + let raise_vector = T::init(raise); + return simd_unary(array, |x| T::pow(x, raise_vector), |x| x.pow(raise)); + } + #[cfg(not(simd))] + return unary(array, |x| x.pow(raise)); +} + +/// Raise array with floating point values to the power of an integer scalar. +/// +/// This function currently has no SIMD equivalent, but is included because: +/// - `powi` is generally faster than `powf` +/// - If there is a SIMD implementation in future, +/// we might want to keep forwad compatibility. +/// +/// If using SIMD, it will be quicker to use [`powf_scalar`] instead. +pub fn powi(array: &PrimitiveArray, raise: i32) -> PrimitiveArray +where + T: ArrowFloatNumericType, + T::Native: Pow, +{ + // Note: packed_simd doesn't support `pow` or `powi` + unary(array, |x| x.pow(raise)) +} + +/// Raise numeric array to the power of an integer scalar. +/// +/// Due to the use of [num::traits::Pow], this function can take both integer +/// and floating point arrays as an input. +pub fn pow(array: &PrimitiveArray, power: isize) -> PrimitiveArray +where + T: ArrowNumericType, + T::Native: Pow, +{ + unary::<_, _, T>(array, |x| x.pow(power)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_primitive_array_raise_power_scalar() { + let a = Float64Array::from(vec![1.0, 2.0, 3.0]); + let actual = powf_scalar(&a, 2.0); + let expected = Float64Array::from(vec![1.0, 4.0, 9.0]); + assert_eq!(expected, actual); + let a = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + let actual = powf_scalar(&a, 2.0); + let expected = Float64Array::from(vec![Some(1.0), None, Some(9.0)]); + assert_eq!(expected, actual); + } +} diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index 62d3642f7e8..5a72ff82108 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -27,9 +27,11 @@ pub mod concat; pub mod filter; pub mod length; pub mod limit; +pub mod math; pub mod sort; pub mod substring; pub mod take; pub mod temporal; +pub mod trigonometry; pub mod window; pub mod zip; diff --git a/rust/arrow/src/compute/kernels/trigonometry.rs b/rust/arrow/src/compute/kernels/trigonometry.rs new file mode 100644 index 00000000000..dc6e3bc6259 --- /dev/null +++ b/rust/arrow/src/compute/kernels/trigonometry.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines trigonometry kernels for `PrimitiveArrays` that are +//! floats, restricted by the [num::Float] trait. + +use std::{iter::FromIterator, ops::Div}; + +use num::{traits::Pow, Float, Zero}; + +use super::math::*; + +use crate::array::*; +use crate::buffer::Buffer; +#[cfg(simd)] +use crate::compute::kernels::arity::simd_unary; +use crate::datatypes::*; +use crate::error::ArrowError; +use crate::error::Result; +use crate::{ + compute::{add, math_op, multiply, subtract}, + float_unary, float_unary_simd, +}; + +use super::arity::{into_primitive_array_data, unary}; + +float_unary!(to_degrees); +float_unary!(to_radians); +float_unary_simd!(sin); +float_unary_simd!(cos); +float_unary!(tan); +float_unary!(asin); +float_unary!(acos); +float_unary!(atan); +float_unary!(sinh); +float_unary!(cosh); +float_unary_simd!(tanh); +float_unary!(asinh); +float_unary!(acosh); +float_unary!(atanh); + +pub fn atan2( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: ArrowNumericType, + T::Native: num::traits::Float, +{ + math_op(left, right, |x, y| x.atan2(y)) +} + +/// Perform `left * right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn sin_cos(array: &PrimitiveArray) -> (PrimitiveArray, PrimitiveArray) +where + T: ArrowNumericType, + T::Native: num::traits::Float, +{ + let (sin, cos): (Vec, Vec) = + array.values().iter().map(|v| v.sin_cos()).unzip(); + // NOTE: due to `unzip` collecting and splitting the arrays, + // this might be slower than Buffer::from_trusted_len_iter + let sin_buffer = Buffer::from_iter(sin); + let cos_buffer = Buffer::from_iter(cos); + + let sin_data = into_primitive_array_data::<_, T>(array, sin_buffer); + let cos_data = into_primitive_array_data::<_, T>(array, cos_buffer); + ( + PrimitiveArray::::from(std::sync::Arc::new(sin_data)), + PrimitiveArray::::from(std::sync::Arc::new(cos_data)), + ) +} + +/// Calculate the Haversine distance between two geographic coordinates. +/// Based on the haversine crate. +/// +/// The distance returned is in meters, and the radius must be specified in meters. +pub fn haversine( + lat_a: &PrimitiveArray, + lng_a: &PrimitiveArray, + lat_b: &PrimitiveArray, + lng_b: &PrimitiveArray, + radius: impl num::traits::Float, +) -> Result> +where + T: ArrowPrimitiveType + ArrowNumericType + ArrowFloatNumericType, + T::Native: num::traits::Float + + Div + + Pow + + Pow + + Zero + + num::NumCast, +{ + // Check array lengths, must all equal + let len = lat_a.len(); + if lat_b.len() != len || lng_a.len() != len || lng_b.len() != len { + return Err(ArrowError::ComputeError( + "Cannot perform math operation on arrays of different length".to_string(), + )); + } + // These casts normally get optimized to f64 as f64 + let one = num::cast::<_, T::Native>(1.0).unwrap(); + let two = num::cast::<_, T::Native>(2.0).unwrap(); + let radius = num::cast::<_, T::Native>(radius).unwrap(); + let two_radius = two * radius; + + let lat_delta = to_radians(&subtract(lat_b, lat_a)?); + let lng_delta = to_radians(&subtract(lng_b, lng_a)?); + let lat_a_rad = to_radians(lat_a); + let lat_b_rad = to_radians(lat_b); + + let v1: PrimitiveArray = sin(&unary::<_, _, _>(&lat_delta, |x| x / two)); + let v2: PrimitiveArray = sin(&unary::<_, _, _>(&lng_delta, |x| x / two)); + + let a = add( + &powi(&v1, 2), + // This could be simplified if we had a ternary kernel that takes 3 args + // F(T::Native, T::Native, T::Native) -> T::Native + &multiply( + &powi(&v2, 2), + &math_op(&lat_a_rad, &lat_b_rad, |x, y| (x.cos() * y.cos()))?, + )?, + )?; + Ok(unary::<_, _, _>( + &atan2(&sqrt(&a), &sqrt(&unary::<_, _, _>(&a, |x| one - x)))?, + |x| x * two_radius, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_haversine_non_null() { + let lat_a = Float64Array::from(vec![38.898556]); + let lng_a = Float64Array::from(vec![-77.037852]); + let lat_b = Float64Array::from(vec![38.897147]); + let lng_b = Float64Array::from(vec![-77.043934]); + + let radius = 6371000.0; + + let result = haversine(&lat_a, &lng_a, &lat_b, &lng_b, radius).unwrap(); + assert_eq!(result.value(0), 549.1557912038084); + } + + #[test] + fn test_haversine_null() { + // If any of the array slots is null, the result should be null + let lat_a = Float64Array::from(vec![38.898556; 4]); + let lng_a = + Float64Array::from(vec![Some(-77.037852), None, None, Some(-77.037852)]); + let lat_b = Float64Array::from(vec![ + Some(38.897147), + Some(38.897147), + None, + Some(38.897147), + ]); + let lng_b = Float64Array::from(vec![ + None, + Some(-77.043934), + Some(-77.043934), + Some(-77.043934), + ]); + + let radius = 6371000.0; + + let result = haversine(&lat_a, &lng_a, &lat_b, &lng_b, radius).unwrap(); + let expected = + Float64Array::from(vec![None, None, None, Some(549.1557912038084)]); + assert_eq!(&result, &expected); + } +} diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs index 9de07388e9c..e7323d00a48 100644 --- a/rust/arrow/src/compute/mod.rs +++ b/rust/arrow/src/compute/mod.rs @@ -29,7 +29,9 @@ pub use self::kernels::comparison::*; pub use self::kernels::concat::*; pub use self::kernels::filter::*; pub use self::kernels::limit::*; +pub use self::kernels::math::*; pub use self::kernels::sort::*; pub use self::kernels::take::*; pub use self::kernels::temporal::*; +pub use self::kernels::trigonometry::*; pub use self::kernels::window::*; diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index 096a9305891..309d64dbdbc 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -922,6 +922,10 @@ make_signed_numeric_type!(Float64Type, f64x8); #[cfg(simd)] pub trait ArrowFloatNumericType: ArrowNumericType { fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd; + fn sqrt(value: Self::Simd) -> Self::Simd; + fn cos(value: Self::Simd) -> Self::Simd; + fn sin(value: Self::Simd) -> Self::Simd; + fn tanh(value: Self::Simd) -> Self::Simd; } #[cfg(not(simd))] @@ -935,6 +939,26 @@ macro_rules! make_float_numeric_type { fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd { base.powf(raise) } + + #[inline] + fn sqrt(value: Self::Simd) -> Self::Simd { + value.sqrt() + } + + #[inline] + fn cos(value: Self::Simd) -> Self::Simd { + value.cos() + } + + #[inline] + fn sin(value: Self::Simd) -> Self::Simd { + value.sin() + } + + #[inline] + fn tanh(value: Self::Simd) -> Self::Simd { + value.tanh() + } } #[cfg(not(simd))]