diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 22c51189380..f58e4997615 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -32,7 +32,7 @@ use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_norm::L2Norm; -use crate::utils::extension_element_ptype; +use crate::utils::tensor_element_ptype; /// Cosine similarity between two columns. /// @@ -126,7 +126,7 @@ impl ScalarFnVTable for CosineSimilarity { "CosineSimilarity inputs must be an `AnyTensor`, got {lhs}" ); - let ptype = extension_element_ptype(lhs_ext)?; + let ptype = tensor_element_ptype(lhs_ext)?; vortex_ensure!( ptype.is_float(), "CosineSimilarity element dtype must be a float primitive, got {ptype}" diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index d142649600d..dbc9a7e52ad 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -31,9 +31,9 @@ use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; -use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; use crate::utils::extract_flat_elements; +use crate::utils::tensor_element_ptype; +use crate::utils::tensor_list_size; /// Inner product (dot product) between two columns. /// @@ -125,7 +125,7 @@ impl ScalarFnVTable for InnerProduct { "InnerProduct inputs must be an `AnyTensor`, got {lhs}" ); - let ptype = extension_element_ptype(lhs_ext)?; + let ptype = tensor_element_ptype(lhs_ext)?; vortex_ensure!( ptype.is_float(), "InnerProduct element dtype must be a float primitive, got {ptype}" @@ -153,7 +153,7 @@ impl ScalarFnVTable for InnerProduct { // Get list size from the dtype. Both sides have the same dtype (validated by // `return_dtype`). let ext = lhs.dtype().as_extension(); - let list_size = extension_list_size(ext)? as usize; + let list_size = tensor_list_size(ext)? as usize; // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index ed29cc776b7..40b7758acbb 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -30,9 +30,9 @@ use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; -use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; use crate::utils::extract_flat_elements; +use crate::utils::tensor_element_ptype; +use crate::utils::tensor_list_size; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -107,7 +107,7 @@ impl ScalarFnVTable for L2Norm { "L2Norm input must be an `AnyTensor`, got {input_dtype}" ); - let ptype = extension_element_ptype(ext)?; + let ptype = tensor_element_ptype(ext)?; vortex_ensure!( ptype.is_float(), "L2Norm element dtype must be a float primitive, got {ptype}" @@ -130,7 +130,7 @@ impl ScalarFnVTable for L2Norm { // Get list size (dimensions) from the dtype (validated by `return_dtype`). let ext = input.dtype().as_extension(); - let list_size = extension_list_size(ext)? as usize; + let list_size = tensor_list_size(ext)? as usize; let storage = input.data().storage_array(); let flat = extract_flat_elements(storage, list_size, ctx)?; diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 6a84f8bbc7f..e3f231be83b 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -20,7 +20,7 @@ use vortex_error::vortex_err; /// Extracts the list size from a tensor-like extension dtype. /// /// The storage dtype must be a `FixedSizeList`. -pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { +pub fn tensor_list_size(ext: &ExtDTypeRef) -> VortexResult { let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { vortex_bail!( "expected FixedSizeList storage dtype, got {}", @@ -34,7 +34,7 @@ pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { /// Extracts the float element [`PType`] from a tensor-like extension dtype. /// /// The storage dtype must be a `FixedSizeList` of non-nullable primitives. -pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { +pub fn tensor_element_ptype(ext: &ExtDTypeRef) -> VortexResult { let element_dtype = ext .storage_dtype() .as_fixed_size_list_element_opt()