Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::l2_norm::L2Norm;
use crate::utils::extension_element_ptype;
use crate::utils::tensor_element_ptype;

/// Cosine similarity between two columns.
///
Expand Down Expand Up @@ -126,7 +126,7 @@ impl ScalarFnVTable for CosineSimilarity {
"CosineSimilarity inputs must be an `AnyTensor`, got {lhs}"
);

let ptype = extension_element_ptype(lhs_ext)?;
let ptype = tensor_element_ptype(lhs_ext)?;
vortex_ensure!(
ptype.is_float(),
"CosineSimilarity element dtype must be a float primitive, got {ptype}"
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/scalar_fns/inner_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use vortex_error::vortex_err;

use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extract_flat_elements;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// Inner product (dot product) between two columns.
///
Expand Down Expand Up @@ -125,7 +125,7 @@ impl ScalarFnVTable for InnerProduct {
"InnerProduct inputs must be an `AnyTensor`, got {lhs}"
);

let ptype = extension_element_ptype(lhs_ext)?;
let ptype = tensor_element_ptype(lhs_ext)?;
vortex_ensure!(
ptype.is_float(),
"InnerProduct element dtype must be a float primitive, got {ptype}"
Expand Down Expand Up @@ -153,7 +153,7 @@ impl ScalarFnVTable for InnerProduct {
// Get list size from the dtype. Both sides have the same dtype (validated by
// `return_dtype`).
let ext = lhs.dtype().as_extension();
let list_size = extension_list_size(ext)? as usize;
let list_size = tensor_list_size(ext)? as usize;

// Extract the storage array from each extension input. We pass the storage (FSL) rather
// than the extension array to avoid canonicalizing the extension wrapper.
Expand Down
8 changes: 4 additions & 4 deletions vortex-tensor/src/scalar_fns/l2_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ use vortex_error::vortex_err;

use crate::matcher::AnyTensor;
use crate::scalar_fns::ApproxOptions;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extract_flat_elements;
use crate::utils::tensor_element_ptype;
use crate::utils::tensor_list_size;

/// L2 norm (Euclidean norm) of a tensor or vector column.
///
Expand Down Expand Up @@ -107,7 +107,7 @@ impl ScalarFnVTable for L2Norm {
"L2Norm input must be an `AnyTensor`, got {input_dtype}"
);

let ptype = extension_element_ptype(ext)?;
let ptype = tensor_element_ptype(ext)?;
vortex_ensure!(
ptype.is_float(),
"L2Norm element dtype must be a float primitive, got {ptype}"
Expand All @@ -130,7 +130,7 @@ impl ScalarFnVTable for L2Norm {

// Get list size (dimensions) from the dtype (validated by `return_dtype`).
let ext = input.dtype().as_extension();
let list_size = extension_list_size(ext)? as usize;
let list_size = tensor_list_size(ext)? as usize;

let storage = input.data().storage_array();
let flat = extract_flat_elements(storage, list_size, ctx)?;
Expand Down
4 changes: 2 additions & 2 deletions vortex-tensor/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use vortex_error::vortex_err;
/// Extracts the list size from a tensor-like extension dtype.
///
/// The storage dtype must be a `FixedSizeList`.
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
pub fn tensor_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
vortex_bail!(
"expected FixedSizeList storage dtype, got {}",
Expand All @@ -34,7 +34,7 @@ pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
/// Extracts the float element [`PType`] from a tensor-like extension dtype.
///
/// The storage dtype must be a `FixedSizeList` of non-nullable primitives.
pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
pub fn tensor_element_ptype(ext: &ExtDTypeRef) -> VortexResult<PType> {
let element_dtype = ext
.storage_dtype()
.as_fixed_size_list_element_opt()
Expand Down
Loading