Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vortex-array/src/arrays/decimal/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl DecimalArray {
decimal_dtype: DecimalDType,
validity: Validity,
) -> VortexResult<Self> {
Self::validate(&buffer, &validity)?;
Self::validate(&buffer, decimal_dtype, &validity)?;

// SAFETY: validate ensures all invariants are met.
Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) })
Expand All @@ -136,16 +136,18 @@ impl DecimalArray {
///
/// The caller must ensure all of the following invariants are satisfied:
///
/// - The storage type `T` must be compatible with the precision (i.e., able to represent all
/// values of the declared precision).
/// - All non-null values in `buffer` must be representable within the specified precision.
/// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
/// For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99].
/// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
pub unsafe fn new_unchecked<T: NativeDecimalType>(
buffer: Buffer<T>,
decimal_dtype: DecimalDType,
validity: Validity,
) -> Self {
#[cfg(debug_assertions)]
Self::validate(&buffer, &validity)
Self::validate(&buffer, decimal_dtype, &validity)
.vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters");

Self {
Expand All @@ -162,8 +164,18 @@ impl DecimalArray {
/// This function checks all the invariants required by [`DecimalArray::new_unchecked`].
pub fn validate<T: NativeDecimalType>(
buffer: &Buffer<T>,
// TODO(connor): The decimal array storage type should be able to represent the entire
// domain of the decimal type.
_decimal_dtype: DecimalDType,
validity: &Validity,
) -> VortexResult<()> {
// vortex_ensure!(
// T::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype),
// "Storage type {:?} cannot represent all values of precision {}",
// T::DECIMAL_TYPE,
// decimal_dtype.precision()
// );

if let Some(len) = validity.maybe_len() {
vortex_ensure!(
buffer.len() == len,
Expand Down
10 changes: 6 additions & 4 deletions vortex-array/src/arrays/decimal/compute/fill_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ impl FillNullKernel for DecimalVTable {
let is_invalid = is_valid.to_bool().bit_buffer().not();
match_each_decimal_value_type!(array.values_type(), |T| {
let mut buffer = array.buffer::<T>().into_mut();
let fill_value = fill_value
.as_decimal()
let decimal_scalar = fill_value.as_decimal();
let decimal_value = decimal_scalar
.decimal_value()
.and_then(|v| v.cast::<T>())
.vortex_expect("top-level fill_null ensure non-null fill value");
.vortex_expect("fill_null requires a non-null fill value");
let fill_value = decimal_value
.cast::<T>()
.vortex_expect("fill value does not fit in array's decimal storage type");
for invalid_index in is_invalid.set_indices() {
buffer[invalid_index] = fill_value;
}
Expand Down
40 changes: 40 additions & 0 deletions vortex-array/src/compute/fill_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use std::sync::LazyLock;

use arcref::ArcRef;
use vortex_dtype::DType;
use vortex_dtype::DecimalType;
use vortex_dtype::match_each_decimal_value_type;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
Expand All @@ -15,6 +18,7 @@ use crate::Array;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::ConstantArray;
use crate::arrays::DecimalVTable;
use crate::compute::ComputeFn;
use crate::compute::ComputeFnVTable;
use crate::compute::InvocationArgs;
Expand Down Expand Up @@ -59,6 +63,14 @@ pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult<ArrayRe
}

pub trait FillNullKernel: VTable {
/// Kernel for replacing null values in an array with a fill value.
///
/// TODO(connor): Actually enforce these constraints (so that casts do not fail).
///
/// Implementations can assume that:
/// - The array has at least one null value (not all valid, not all invalid).
/// - The fill value is non-null.
/// - For decimal arrays, the fill value can be successfully cast to the array's storage type.
fn fill_null(&self, array: &Self::Array, fill_value: &Scalar) -> VortexResult<ArrayRef>;
}

Expand Down Expand Up @@ -110,6 +122,34 @@ impl ComputeFnVTable for FillNull {
vortex_bail!("Cannot fill_null with a null value")
}

/*
// For decimal arrays, validate that the fill value fits in the storage type.
if let Some(decimal_dtype) = array.dtype().as_decimal_opt() {
// Try to get the actual storage type from a DecimalArray. Otherwise, use the smallest
// type that can represent the precision.
let storage_type = array
.as_opt::<DecimalVTable>()
.map(|arr| arr.values_type())
.unwrap_or_else(|| DecimalType::smallest_decimal_value_type(decimal_dtype));
let decimal_value = fill_value
.as_decimal()
.decimal_value()
.vortex_expect("fill_null checked is_null above");

let fits = match_each_decimal_value_type!(storage_type, |T| {
decimal_value.cast::<T>().is_some()
});

if !fits {
vortex_bail!(
"fill value {} does not fit in array's decimal storage type {:?}",
decimal_value,
storage_type
)
}
}
*/

for kernel in kernels {
if let Some(output) = kernel.invoke(args)? {
return Ok(output);
Expand Down
137 changes: 111 additions & 26 deletions vortex-array/src/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,37 +272,47 @@ impl Validity {
indices: &dyn Array,
patches: &Validity,
) -> Self {
use Validity::*;

match (&self, patches) {
(Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable,
(Validity::NonNullable, _) => {
vortex_panic!("Can't patch a non-nullable validity with nullable validity")
(NonNullable, NonNullable | AllValid) => {
return NonNullable;
}
(NonNullable, Array(_) | AllInvalid) => {
vortex_panic!("Can't patch a non-nullable validity with null values")
}
(_, Validity::NonNullable) => {
vortex_panic!("Can't patch a nullable validity with non-nullable validity")

(AllValid | Array(_) | AllInvalid, NonNullable) => {
vortex_panic!("Can't patch a nullable validity with a non-nullable validity")
}
(Validity::AllValid, Validity::AllValid) => return Validity::AllValid,
(Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid,
_ => {}

(AllValid, AllValid) => return AllValid,
(AllValid, Array(_) | AllInvalid) => {}

(AllInvalid, AllInvalid) => return AllInvalid,
(AllInvalid, AllValid | Array(_)) => {}

(Array(_), _) => {}
};

let own_nullability = if self == Validity::NonNullable {
let own_nullability = if self == NonNullable {
Nullability::NonNullable
} else {
Nullability::Nullable
};

let source = match self {
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)),
Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)),
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
Validity::Array(a) => a.to_bool(),
NonNullable => BoolArray::from(BitBuffer::new_set(len)),
AllValid => BoolArray::from(BitBuffer::new_set(len)),
AllInvalid => BoolArray::from(BitBuffer::new_unset(len)),
Array(a) => a.to_bool(),
};

let patch_values = match patches {
Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
Validity::Array(a) => a.to_bool(),
NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())),
AllValid => BoolArray::from(BitBuffer::new_set(indices.len())),
AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())),
Array(a) => a.to_bool(),
};

let patches = Patches::new(
Expand Down Expand Up @@ -513,21 +523,96 @@ mod tests {
use crate::validity::Validity;

#[rstest]
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::AllValid,
Validity::AllValid
)]
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::AllInvalid,
Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
)]
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
)]
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::AllValid,
Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
)]
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::AllInvalid,
Validity::AllInvalid
)]
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
)]
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::AllValid,
Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
)]
#[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::AllInvalid,
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
)]
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::Array(BoolArray::from_iter([true, false]).into_array()),
Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
)]
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)]
#[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array())
#[case(
Validity::NonNullable,
5,
&[2, 4],
Validity::AllValid,
Validity::NonNullable
)]
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
#[case(
Validity::AllValid,
5,
&[2, 4],
Validity::NonNullable,
Validity::AllValid
)]
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array())
#[case(
Validity::AllInvalid,
5,
&[2, 4],
Validity::NonNullable,
Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array())
)]
#[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array())
#[case(
Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()),
5,
&[2, 4],
Validity::NonNullable,
Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array())
)]
fn patch_validity(
#[case] validity: Validity,
Expand Down
Loading