diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index dc33066bd3b..56e96488167 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -6,5 +6,7 @@ //! similarity. pub mod fixed_shape; +pub mod vector; +pub mod matcher; pub mod scalar_fns; diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs new file mode 100644 index 00000000000..bb79ad7447e --- /dev/null +++ b/vortex-tensor/src/matcher.rs @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Matcher for tensor-like extension types. + +use vortex::dtype::extension::ExtDTypeRef; +use vortex::dtype::extension::Matcher; + +use crate::fixed_shape::FixedShapeTensor; +use crate::fixed_shape::FixedShapeTensorMetadata; +use crate::vector::Vector; + +/// Matcher for any tensor-like extension type. +/// +/// Currently the different kinds of tensors that are available are: +/// +/// - `FixedShapeTensor` +/// - `Vector` +pub struct AnyTensor; + +/// The matched variant of a tensor-like extension type. +#[derive(Debug, PartialEq, Eq)] +pub enum TensorMatch<'a> { + /// A [`FixedShapeTensor`] extension type. + FixedShapeTensor(&'a FixedShapeTensorMetadata), + /// A [`Vector`] extension type. + Vector, +} + +impl Matcher for AnyTensor { + type Match<'a> = TensorMatch<'a>; + + fn try_match<'a>(item: &'a ExtDTypeRef) -> Option> { + if let Some(metadata) = item.metadata_opt::() { + return Some(TensorMatch::FixedShapeTensor(metadata)); + } + if item.metadata_opt::().is_some() { + return Some(TensorMatch::Vector); + } + None + } +} diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index f64ffc4cd11..d9c952b29bb 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -1,26 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) -//! arrays. +//! Cosine similarity expression for tensor-like extension arrays +//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and +//! [`Vector`](crate::vector::Vector)). use std::fmt::Formatter; use num_traits::Float; use vortex::array::ArrayRef; -use vortex::array::DynArray; use vortex::array::ExecutionCtx; use vortex::array::IntoArray; -use vortex::array::arrays::Constant; -use vortex::array::arrays::ConstantArray; -use vortex::array::arrays::Extension; use vortex::array::arrays::PrimitiveArray; use vortex::array::match_each_float_ptype; use vortex::dtype::DType; use vortex::dtype::NativePType; use vortex::dtype::Nullability; +use vortex::dtype::extension::Matcher; use vortex::error::VortexResult; -use vortex::error::vortex_bail; use vortex::error::vortex_ensure; use vortex::error::vortex_err; use vortex::expr::Expression; @@ -31,18 +28,23 @@ use vortex::scalar_fn::ExecutionArgs; use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; -// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors -// encoded in spherical coordinates. +use crate::matcher::AnyTensor; +use crate::scalar_fns::utils::extension_element_ptype; +use crate::scalar_fns::utils::extension_list_size; +use crate::scalar_fns::utils::extension_storage; +use crate::scalar_fns::utils::extract_flat_elements; + /// Cosine similarity between two columns. /// -/// For [`FixedShapeTensor`], computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of -/// each tensor. The shape and permutation do not affect the result because cosine similarity only -/// depends on the element values, not their logical arrangement. +/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector. +/// The shape and permutation do not affect the result because cosine similarity only depends on the +/// element values, not their logical arrangement. /// -/// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a -/// float element type. The output is a float column of the same float type. +/// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the +/// same dtype and a float element type. The output is a float column of the same float type. /// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor +/// [`Vector`]: crate::vector::Vector #[derive(Clone)] pub struct CosineSimilarity; @@ -92,33 +94,22 @@ impl ScalarFnVTable for CosineSimilarity { // We don't need to look at rhs anymore since we know lhs and rhs are equal. - // Both inputs must be extension types. + // Both inputs must be tensor-like extension types. let lhs_ext = lhs.as_extension_opt().ok_or_else(|| { vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}") })?; - // Extract the element dtype from the storage FixedSizeList. - let element_dtype = lhs_ext - .storage_dtype() - .as_fixed_size_list_element_opt() - .ok_or_else(|| { - vortex_err!( - "cosine_similarity storage dtype must be a FixedSizeList, got {}", - lhs_ext.storage_dtype() - ) - })?; - - // Element dtype must be a non-nullable float primitive. vortex_ensure!( - element_dtype.is_float(), - "cosine_similarity element dtype must be a float primitive, got {element_dtype}" + AnyTensor::matches(lhs_ext), + "cosine_similarity inputs must be an `AnyTensor`, got {lhs}" ); + + let ptype = extension_element_ptype(lhs_ext)?; vortex_ensure!( - !element_dtype.is_nullable(), - "cosine_similarity element dtype must be non-nullable" + ptype.is_float(), + "cosine_similarity element dtype must be a float primitive, got {ptype}" ); - let ptype = element_dtype.as_ptype(); let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -140,32 +131,19 @@ impl ScalarFnVTable for CosineSimilarity { lhs.dtype() ) })?; - let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { - vortex_bail!("expected FixedSizeList storage dtype"); - }; - let list_size = *list_size as usize; + let list_size = extension_list_size(ext)?; // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; - // Extract the flat primitive elements from each tensor column. When an input is a - // `ConstantArray` (e.g., a literal query vector), we materialize only a single row - // instead of expanding it to the full row count. - let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?; - let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?; - - match_each_float_ptype!(lhs_elems.ptype(), |T| { - let lhs_slice = lhs_elems.as_slice::(); - let rhs_slice = rhs_elems.as_slice::(); + let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?; + match_each_float_ptype!(lhs_flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) - .map(|i| { - let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size]; - let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size]; - cosine_similarity_row(a, b) - }) + .map(|i| cosine_similarity_row(lhs_flat.row::(i), rhs_flat.row::(i))) .collect(); Ok(result.into_array()) @@ -194,38 +172,6 @@ impl ScalarFnVTable for CosineSimilarity { } } -/// Extracts the storage array from an extension array without canonicalizing. -fn extension_storage(array: &ArrayRef) -> VortexResult { - let ext = array - .as_opt::() - .ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?; - Ok(ext.storage_array().clone()) -} - -/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). -/// -/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is -/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)` -/// where `stride` is `list_size` for a full array and `0` for a constant. -fn extract_flat_elements( - storage: &ArrayRef, - list_size: usize, -) -> VortexResult<(PrimitiveArray, usize)> { - if let Some(constant) = storage.as_opt::() { - // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a - // huge amount of data. - let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl = single.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); - Ok((elems, 0)) - } else { - // Otherwise we have to fully expand all of the data. - let fsl = storage.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); - Ok((elems, list_size)) - } -} - // TODO(connor): We should try to use a more performant library instead of doing this ourselves. /// Computes cosine similarity between two equal-length float slices. /// @@ -258,6 +204,7 @@ mod tests { use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; use vortex::scalar::Scalar; use vortex::scalar_fn::EmptyOptions; use vortex::scalar_fn::ScalarFn; @@ -265,6 +212,7 @@ mod tests { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::cosine_similarity::CosineSimilarity; + use crate::vector::Vector; /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. /// @@ -460,4 +408,95 @@ mod tests { ); Ok(()) } + + /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. + fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { + let row_count = elements.len() / dim as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + #[test] + fn vector_unit_vectors() -> VortexResult<()> { + let lhs = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 0.0, 1.0, 0.0, // vector 1 + ], + )?; + let rhs = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 1.0, 0.0, 0.0, // vector 1 + ], + )?; + + // Row 0: identical -> 1.0, row 1: orthogonal -> 0.0. + assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); + Ok(()) + } + + #[test] + fn vector_self_similarity() -> VortexResult<()> { + let arr = vector_array( + 4, + &[ + 1.0, 2.0, 3.0, 4.0, // vector 0 + 0.0, 1.0, 0.0, 0.0, // vector 1 + 5.0, 0.0, 5.0, 0.0, // vector 2 + ], + )?; + + assert_close( + &eval_cosine_similarity(arr.clone(), arr, 3)?, + &[1.0, 1.0, 1.0], + ); + Ok(()) + } + + /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`]. + fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { + let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable); + + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + + let storage = ConstantArray::new(storage_scalar, len).into_array(); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + #[test] + fn vector_constant_query() -> VortexResult<()> { + let data = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // vector 0 + 0.0, 1.0, 0.0, // vector 1 + 0.0, 0.0, 1.0, // vector 2 + 1.0, 0.0, 0.0, // vector 3 + ], + )?; + let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?; + + assert_close( + &eval_cosine_similarity(data, query, 4)?, + &[1.0, 0.0, 0.0, 1.0], + ); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs new file mode 100644 index 00000000000..bada964b7ef --- /dev/null +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! L2 norm expression for tensor-like extension arrays +//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and +//! [`Vector`](crate::vector::Vector)). + +use std::fmt::Formatter; + +use num_traits::Float; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::match_each_float_ptype; +use vortex::dtype::DType; +use vortex::dtype::NativePType; +use vortex::dtype::Nullability; +use vortex::dtype::extension::Matcher; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::scalar_fn::Arity; +use vortex::scalar_fn::ChildName; +use vortex::scalar_fn::EmptyOptions; +use vortex::scalar_fn::ExecutionArgs; +use vortex::scalar_fn::ScalarFnId; +use vortex::scalar_fn::ScalarFnVTable; + +use crate::matcher::AnyTensor; +use crate::scalar_fns::utils::extension_element_ptype; +use crate::scalar_fns::utils::extension_list_size; +use crate::scalar_fns::utils::extension_storage; +use crate::scalar_fns::utils::extract_flat_elements; + +/// L2 norm (Euclidean norm) of a tensor or vector column. +/// +/// Computes `||v|| = sqrt(sum(v_i^2))` over the flat backing buffer of each tensor-like type. +/// +/// The input must be a tensor-like extension array with a float element type. The output is a float +/// column of the same float type. +#[derive(Clone)] +pub struct L2Norm; + +impl ScalarFnVTable for L2Norm { + type Options = EmptyOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.l2_norm") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("input"), + _ => unreachable!("L2Norm must have exactly one child"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "l2_norm(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + debug_assert_eq!(arg_dtypes.len(), 1); + + let input_dtype = &arg_dtypes[0]; + + // Input must be a tensor-like extension type. + let ext = input_dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("l2_norm input must be an extension type, got {input_dtype}") + })?; + + vortex_ensure!( + AnyTensor::matches(ext), + "l2_norm input must be an `AnyTensor`, got {input_dtype}" + ); + + let ptype = extension_element_ptype(ext)?; + vortex_ensure!( + ptype.is_float(), + "l2_norm element dtype must be a float primitive, got {ptype}" + ); + + let nullability = Nullability::from(input_dtype.is_nullable()); + Ok(DType::Primitive(ptype, nullability)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + let input = args.get(0)?; + let row_count = args.row_count(); + + // Get list size (dimensions) from the dtype. + let ext = input.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "l2_norm input must be an extension type, got {}", + input.dtype() + ) + })?; + let list_size = extension_list_size(ext)?; + + let storage = extension_storage(&input)?; + let flat = extract_flat_elements(&storage, list_size)?; + + match_each_float_ptype!(flat.ptype(), |T| { + let result: PrimitiveArray = (0..row_count) + .map(|i| l2_norm_row(flat.row::(i))) + .collect(); + + Ok(result.into_array()) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // The result is null if the input tensor is null. + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + // Canonicalization of the storage array can fail. + true + } +} + +/// Computes the L2 norm (Euclidean norm) of a float slice. +/// +/// Returns `sqrt(sum(v_i^2))`. A zero-length or all-zero input produces `0.0`. +fn l2_norm_row(v: &[T]) -> T { + let mut sum_sq = T::zero(); + for &x in v { + sum_sq = sum_sq + x * x; + } + sum_sq.sqrt() +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex::array::ArrayRef; + use vortex::array::IntoArray; + use vortex::array::ToCanonical; + use vortex::array::arrays::ExtensionArray; + use vortex::array::arrays::FixedSizeListArray; + use vortex::array::arrays::ScalarFnArray; + use vortex::array::validity::Validity; + use vortex::buffer::Buffer; + use vortex::dtype::extension::ExtDType; + use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; + use vortex::scalar_fn::EmptyOptions; + use vortex::scalar_fn::ScalarFn; + + use crate::fixed_shape::FixedShapeTensor; + use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::scalar_fns::l2_norm::L2Norm; + use crate::vector::Vector; + + /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); + let row_count = elements.len() / list_size as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); + + let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); + let ext_dtype = + ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. + fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { + let row_count = elements.len() / dim as usize; + + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); + + let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + + Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + } + + /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. + fn eval_l2_norm(input: ArrayRef, len: usize) -> VortexResult> { + let scalar_fn = ScalarFn::new(L2Norm, EmptyOptions).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![input], len)?; + let prim = result.to_primitive(); + Ok(prim.as_slice::().to_vec()) + } + + #[track_caller] + fn assert_close(actual: &[f64], expected: &[f64]) { + assert_eq!( + actual.len(), + expected.len(), + "length mismatch: got {} elements, expected {}", + actual.len(), + expected.len() + ); + + for (i, (a, e)) in actual.iter().zip(expected).enumerate() { + assert!( + (a - e).abs() < 1e-10, + "element {i}: got {a}, expected {e} (diff = {})", + (a - e).abs() + ); + } + } + + #[test] + fn unit_vector_norm() -> VortexResult<()> { + let arr = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // unit x + 0.0, 1.0, 0.0, // unit y + 0.0, 0.0, 1.0, // unit z + ], + )?; + assert_close(&eval_l2_norm(arr, 3)?, &[1.0, 1.0, 1.0]); + Ok(()) + } + + #[rstest] + #[case::three_four_five(&[2], &[3.0, 4.0], &[5.0])] + #[case::zero_vector(&[3], &[0.0, 0.0, 0.0], &[0.0])] + #[case::single_element(&[1], &[7.0], &[7.0])] + #[case::negative_elements(&[2], &[-3.0, -4.0], &[5.0])] + fn known_norms( + #[case] shape: &[usize], + #[case] elements: &[f64], + #[case] expected: &[f64], + ) -> VortexResult<()> { + let arr = tensor_array(shape, elements)?; + assert_close(&eval_l2_norm(arr, 1)?, expected); + Ok(()) + } + + #[test] + fn multiple_rows() -> VortexResult<()> { + let arr = tensor_array( + &[3], + &[ + 3.0, 4.0, 0.0, // norm = 5.0 + 0.0, 0.0, 0.0, // norm = 0.0 + 1.0, 1.0, 1.0, // norm = sqrt(3) + ], + )?; + assert_close(&eval_l2_norm(arr, 3)?, &[5.0, 0.0, 3.0_f64.sqrt()]); + Ok(()) + } + + #[test] + fn vector_known_norm() -> VortexResult<()> { + let arr = vector_array(2, &[3.0, 4.0])?; + assert_close(&eval_l2_norm(arr, 1)?, &[5.0]); + Ok(()) + } + + #[test] + fn vector_multiple_rows() -> VortexResult<()> { + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 3.0, 4.0, 0.0, // norm = 5.0 + ], + )?; + assert_close(&eval_l2_norm(arr, 2)?, &[1.0, 5.0]); + Ok(()) + } +} diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2797589e03f..2597f1115c8 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -1,4 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Scalar function expressions defined on tensor and tensor-like extension types. + pub mod cosine_similarity; +pub mod l2_norm; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/scalar_fns/utils.rs new file mode 100644 index 00000000000..ca7ddb47b02 --- /dev/null +++ b/vortex-tensor/src/scalar_fns/utils.rs @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::IntoArray; +use vortex::array::arrays::Constant; +use vortex::array::arrays::ConstantArray; +use vortex::array::arrays::Extension; +use vortex::array::arrays::PrimitiveArray; +use vortex::dtype::DType; +use vortex::dtype::NativePType; +use vortex::dtype::PType; +use vortex::dtype::extension::ExtDTypeRef; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::error::vortex_err; + +/// Extracts the list size from a tensor-like extension dtype. +/// +/// The storage dtype must be a `FixedSizeList`. +pub(crate) fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { + let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { + vortex_bail!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ); + }; + + Ok(*list_size as usize) +} + +/// Extracts the float element [`PType`] from a tensor-like extension dtype. +/// +/// The storage dtype must be a `FixedSizeList` of non-nullable primitives. +pub(crate) fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { + let element_dtype = ext + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext.storage_dtype() + ) + })?; + + vortex_ensure!( + !element_dtype.is_nullable(), + "element dtype must be non-nullable" + ); + + Ok(element_dtype.as_ptype()) +} + +/// Extracts the storage array from an extension array without canonicalizing. +pub(crate) fn extension_storage(array: &ArrayRef) -> VortexResult { + let ext = array + .as_opt::() + .ok_or_else(|| vortex_err!("scalar_fn input must be an extension array"))?; + + Ok(ext.storage_array().clone()) +} + +/// The flat primitive elements of a tensor storage array, with typed row access. +/// +/// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a +/// constant input materializes only a single row (stride=0), while a full array uses +/// stride=list_size. +pub(crate) struct FlatElements { + elems: PrimitiveArray, + stride: usize, + list_size: usize, +} + +impl FlatElements { + /// Returns the [`PType`] of the underlying elements. + pub fn ptype(&self) -> PType { + self.elems.ptype() + } + + /// Returns the `i`-th row as a typed slice of length `list_size`. + pub fn row(&self, i: usize) -> &[T] { + let slice = self.elems.as_slice::(); + &slice[i * self.stride..i * self.stride + self.list_size] + } +} + +/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList). +/// +/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is +/// materialized to avoid expanding it to the full column length. +pub(crate) fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, +) -> VortexResult { + if let Some(constant) = storage.as_opt::() { + // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge + // amount of data. + let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); + let fsl = single.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + return Ok(FlatElements { + elems, + stride: 0, + list_size, + }); + } + + // Otherwise we have to fully expand all of the data. + let fsl = storage.to_canonical()?.into_fixed_size_list(); + let elems = fsl.elements().to_canonical()?.into_primitive(); + Ok(FlatElements { + elems, + stride: list_size, + list_size, + }) +} diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs new file mode 100644 index 00000000000..181e08c4e84 --- /dev/null +++ b/vortex-tensor/src/vector/mod.rs @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Vector extension type for fixed-length float vectors (e.g., embeddings). + +/// The VTable for the vector extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct Vector; + +mod vtable; diff --git a/vortex-tensor/src/vector/vtable.rs b/vortex-tensor/src/vector/vtable.rs new file mode 100644 index 00000000000..a3206c5150e --- /dev/null +++ b/vortex-tensor/src/vector/vtable.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::dtype::DType; +use vortex::dtype::extension::ExtDType; +use vortex::dtype::extension::ExtId; +use vortex::dtype::extension::ExtVTable; +use vortex::error::VortexResult; +use vortex::error::vortex_bail; +use vortex::error::vortex_ensure; +use vortex::extension::EmptyMetadata; +use vortex::scalar::ScalarValue; + +use crate::vector::Vector; + +impl ExtVTable for Vector { + type Metadata = EmptyMetadata; + + // TODO(connor): This is just a placeholder for now. + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new_ref("vortex.vector") + } + + fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult> { + Ok(Vec::new()) + } + + fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult { + Ok(EmptyMetadata) + } + + fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()> { + let storage_dtype = ext_dtype.storage_dtype(); + let DType::FixedSizeList(element_dtype, _list_size, _nullability) = storage_dtype else { + vortex_bail!("Vector storage dtype must be a FixedSizeList, got {storage_dtype}"); + }; + + vortex_ensure!( + element_dtype.is_float(), + "Vector element dtype must be a float, got {element_dtype}" + ); + vortex_ensure!( + !element_dtype.is_nullable(), + "Vector element dtype must be non-nullable" + ); + + Ok(()) + } + + fn unpack_native<'a>( + &self, + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok(storage_value) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rstest::rstest; + use vortex::dtype::DType; + use vortex::dtype::Nullability; + use vortex::dtype::PType; + use vortex::dtype::extension::ExtDType; + use vortex::dtype::extension::ExtVTable; + use vortex::error::VortexResult; + use vortex::extension::EmptyMetadata; + + use crate::vector::Vector; + + /// Constructs a `FixedSizeList` storage dtype with the given float [`PType`], list size, and + /// [`Nullability`]. + fn vector_storage_dtype(ptype: PType, size: u32, nullability: Nullability) -> DType { + DType::FixedSizeList( + Arc::new(DType::Primitive(ptype, Nullability::NonNullable)), + size, + nullability, + ) + } + + #[rstest] + #[case::f16(PType::F16)] + #[case::f32(PType::F32)] + #[case::f64(PType::F64)] + fn validate_accepts_float_types(#[case] ptype: PType) -> VortexResult<()> { + let storage = vector_storage_dtype(ptype, 128, Nullability::NonNullable); + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[rstest] + #[case::nullable(Nullability::Nullable)] + #[case::non_nullable(Nullability::NonNullable)] + fn validate_accepts_any_outer_nullability( + #[case] nullability: Nullability, + ) -> VortexResult<()> { + let storage = vector_storage_dtype(PType::F32, 128, nullability); + ExtDType::::try_new(EmptyMetadata, storage)?; + Ok(()) + } + + #[test] + fn validate_rejects_non_fsl() { + let storage = DType::Primitive(PType::F32, Nullability::NonNullable); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn validate_rejects_integer_elements() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U32, Nullability::NonNullable)), + 128, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn validate_rejects_nullable_elements() { + let storage = DType::FixedSizeList( + Arc::new(DType::Primitive(PType::F32, Nullability::Nullable)), + 128, + Nullability::NonNullable, + ); + assert!(ExtDType::::try_new(EmptyMetadata, storage).is_err()); + } + + #[test] + fn roundtrip_metadata() -> VortexResult<()> { + let vtable = Vector; + let bytes = vtable.serialize_metadata(&EmptyMetadata)?; + assert!(bytes.is_empty()); + let deserialized = vtable.deserialize_metadata(&bytes)?; + assert_eq!(deserialized, EmptyMetadata); + Ok(()) + } +}