diff --git a/clippy.toml b/clippy.toml index 592ede3e6ad..4b6fc10fe76 100644 --- a/clippy.toml +++ b/clippy.toml @@ -12,4 +12,5 @@ disallowed-types = [ disallowed-methods = [ { path = "itertools::Itertools::counts", reason = "It uses the default hasher which is slow for primitives. Just inline the loop for better performance.", allow-invalid = true }, + { path = "std::result::Result::and", reason = "This method is a footgun, especially when working with `Result`.", allow-invalid = true }, ] diff --git a/vortex-tensor/Cargo.toml b/vortex-tensor/Cargo.toml index c80127bc7b3..f0b6670cc51 100644 --- a/vortex-tensor/Cargo.toml +++ b/vortex-tensor/Cargo.toml @@ -18,6 +18,7 @@ workspace = true [dependencies] vortex-array = { workspace = true } +vortex-buffer = { workspace = true } vortex-error = { workspace = true } vortex-session = { workspace = true } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 4a3b0d79fa4..cea02b69e38 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -126,6 +126,12 @@ pub mod vortex_tensor::scalar_fns::cosine_similarity pub struct vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity +impl vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn + +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::clone(&self) -> vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity @@ -138,7 +144,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -152,10 +158,52 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::return_dt pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> +pub mod vortex_tensor::scalar_fns::inner_product + +pub struct vortex_tensor::scalar_fns::inner_product::InnerProduct + +impl vortex_tensor::scalar_fns::inner_product::InnerProduct + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, lhs: vortex_array::array::erased::ArrayRef, rhs: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_tensor::scalar_fns::inner_product::InnerProduct + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::clone(&self) -> vortex_tensor::scalar_fns::inner_product::InnerProduct + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_tensor::scalar_fns::inner_product::InnerProduct + +pub type vortex_tensor::scalar_fns::inner_product::InnerProduct::Options = vortex_tensor::scalar_fns::ApproxOptions + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::arity(&self, _options: &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::is_fallible(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::is_null_sensitive(&self, _options: &Self::Options) -> bool + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::return_dtype(&self, _options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_tensor::scalar_fns::inner_product::InnerProduct::validity(&self, _options: &Self::Options, expression: &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + pub mod vortex_tensor::scalar_fns::l2_norm pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm +impl vortex_tensor::scalar_fns::l2_norm::L2Norm + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::new(options: &vortex_tensor::scalar_fns::ApproxOptions) -> vortex_array::scalar_fn::typed::ScalarFn + +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::try_new_array(options: &vortex_tensor::scalar_fns::ApproxOptions, child: vortex_array::array::erased::ArrayRef, len: usize) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::scalar_fns::l2_norm::L2Norm pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::clone(&self) -> vortex_tensor::scalar_fns::l2_norm::L2Norm diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 515c2f373ca..6b55389d8c9 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -11,6 +11,7 @@ use vortex_session::VortexSession; use crate::fixed_shape::FixedShapeTensor; use crate::scalar_fns::cosine_similarity::CosineSimilarity; +use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_norm::L2Norm; use crate::vector::Vector; @@ -29,5 +30,6 @@ pub fn initialize(session: &VortexSession) { session.dtypes().register(Vector); session.dtypes().register(FixedShapeTensor); session.scalar_fns().register(CosineSimilarity); + session.scalar_fns().register(InnerProduct); session.scalar_fns().register(L2Norm); } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 539d4ca4a7c..22c51189380 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -1,19 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Cosine similarity expression for tensor-like extension arrays -//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and -//! [`Vector`](crate::vector::Vector)). +//! Cosine similarity expression for tensor-like types. use std::fmt::Formatter; -use num_traits::Float; +use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; use vortex_array::dtype::DType; -use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; use vortex_array::expr::and; @@ -21,18 +20,19 @@ use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::inner_product::InnerProduct; +use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extension_element_ptype; -use crate::utils::extension_list_size; -use crate::utils::extension_storage; -use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -48,6 +48,30 @@ use crate::utils::extract_flat_elements; #[derive(Clone)] pub struct CosineSimilarity; +impl CosineSimilarity { + /// Creates a new [`ScalarFn`] wrapping the cosine similarity operation with the given + /// [`ApproxOptions`] controlling approximation behavior. + pub fn new(options: &ApproxOptions) -> ScalarFn { + ScalarFn::new(CosineSimilarity, options.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily computes the cosine similarity between `lhs` and + /// `rhs`. + /// + /// # Errors + /// + /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype + /// mismatches). + pub fn try_new_array( + options: &ApproxOptions, + lhs: ArrayRef, + rhs: ArrayRef, + len: usize, + ) -> VortexResult { + ScalarFnArray::try_new(CosineSimilarity::new(options).erased(), vec![lhs, rhs], len) + } +} + impl ScalarFnVTable for CosineSimilarity { type Options = ApproxOptions; @@ -114,37 +138,49 @@ impl ScalarFnVTable for CosineSimilarity { fn execute( &self, - _options: &Self::Options, + options: &Self::Options, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let lhs = args.get(0)?; - let rhs = args.get(1)?; - let row_count = args.row_count(); - - // Get list size from the dtype. Both sides should have the same dtype. - let ext = lhs.dtype().as_extension_opt().ok_or_else(|| { - vortex_err!( - "cosine_similarity input must be an extension type, got {}", - lhs.dtype() - ) - })?; - let list_size = extension_list_size(ext)? as usize; - - // Extract the storage array from each extension input. We pass the storage (FSL) rather - // than the extension array to avoid canonicalizing the extension wrapper. - let lhs_storage = extension_storage(&lhs)?; - let rhs_storage = extension_storage(&rhs)?; - - let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; - let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; - - match_each_float_ptype!(lhs_flat.ptype(), |T| { - let result: PrimitiveArray = (0..row_count) - .map(|i| cosine_similarity_row(lhs_flat.row::(i), rhs_flat.row::(i))) + let lhs = args.get(0)?.execute::(ctx)?.into_array(); + let rhs = args.get(1)?.execute::(ctx)?.into_array(); + + let len = args.row_count(); + + // Compute combined validity. + let validity = lhs.validity()?.and(rhs.validity()?)?; + + // Compute inner product and norms as columnar operations, and propagate the options. + let norm_lhs_arr = L2Norm::try_new_array(options, lhs.clone(), len)?; + let norm_rhs_arr = L2Norm::try_new_array(options, rhs.clone(), len)?; + let dot_arr = InnerProduct::try_new_array(options, lhs, rhs, len)?; + + // Execute to get PrimitiveArrays. + let dot: PrimitiveArray = dot_arr.into_array().execute(ctx)?; + let norm_l: PrimitiveArray = norm_lhs_arr.into_array().execute(ctx)?; + let norm_r: PrimitiveArray = norm_rhs_arr.into_array().execute(ctx)?; + + // Divide element-wise, guarding against zero norms. + match_each_float_ptype!(dot.ptype(), |T| { + let dots = dot.as_slice::(); + let norms_l = norm_l.as_slice::(); + let norms_r = norm_r.as_slice::(); + let buffer: Buffer = (0..len) + .map(|i| { + // TODO(connor): Would it be better to make this a binary multiply? + // What happens when this overflows??? + let denom = norms_l[i] * norms_r[i]; + + if denom == T::zero() { + T::zero() + } else { + dots[i] / denom + } + }) .collect(); - Ok(result.into_array()) + // SAFETY: The buffer length equals `len`, which matches the source validity length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) }) } @@ -169,30 +205,16 @@ impl ScalarFnVTable for CosineSimilarity { } } -// TODO(connor): We should try to use a more performant library instead of doing this ourselves. -/// Computes cosine similarity between two equal-length float slices. -/// -/// Returns `dot(a, b) / (||a|| * ||b||)`. When either vector has zero norm, this naturally -/// produces `NaN` via `0.0 / 0.0`, matching standard floating-point semantics. -fn cosine_similarity_row(a: &[T], b: &[T]) -> T { - let mut dot = T::zero(); - let mut norm_a = T::zero(); - let mut norm_b = T::zero(); - for i in 0..a.len() { - dot = dot + a[i] * b[i]; - norm_a = norm_a + a[i] * a[i]; - norm_b = norm_b + b[i] * b[i]; - } - dot / (norm_a.sqrt() * norm_b.sqrt()) -} - #[cfg(test)] mod tests { use rstest::rstest; use vortex_array::ArrayRef; + use vortex_array::IntoArray; use vortex_array::ToCanonical; + use vortex_array::arrays::MaskedArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::scalar_fn::ScalarFn; + use vortex_array::validity::Validity; use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; @@ -239,8 +261,8 @@ mod tests { #[case::opposite(&[3], &[1.0, 0.0, 0.0], &[-1.0, 0.0, 0.0], &[-1.0])] // dot=24, both magnitudes=5 -> 24/25 = 0.96. #[case::non_unit(&[2], &[3.0, 4.0], &[4.0, 3.0], &[0.96])] - // Zero vector -> 0/0 -> NaN. - #[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[f64::NAN])] + // Zero vector -> guarded to 0.0. + #[case::zero_norm(&[2], &[0.0, 0.0], &[1.0, 0.0], &[0.0])] fn single_row( #[case] shape: &[usize], #[case] lhs_elems: &[f64], @@ -367,4 +389,22 @@ mod tests { ); Ok(()) } + + #[test] + fn null_input_row() -> VortexResult<()> { + // 2 rows of dim-2 vectors. Row 1 of rhs is masked as null. + let lhs = tensor_array(&[2], &[3.0, 4.0, 1.0, 0.0])?; + let rhs = tensor_array(&[2], &[3.0, 4.0, 0.0, 1.0])?; + let rhs = MaskedArray::try_new(rhs, Validity::from_iter([true, false]))?.into_array(); + + let scalar_fn = ScalarFn::new(CosineSimilarity, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; + let prim = result.as_array().to_primitive(); + + // Row 0: self-similarity = 1.0, row 1: null. + assert!(prim.is_valid(0)?); + assert!(!prim.is_valid(1)?); + assert_close(&[prim.as_slice::()[0]], &[1.0]); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs new file mode 100644 index 00000000000..d142649600d --- /dev/null +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -0,0 +1,335 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Inner product expression for tensor-like types. + +use std::fmt::Formatter; + +use num_traits::Float; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::expr::Expression; +use vortex_array::expr::and; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFn; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use crate::matcher::AnyTensor; +use crate::scalar_fns::ApproxOptions; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extract_flat_elements; + +/// Inner product (dot product) between two columns. +/// +/// Computes `sum(a_i * b_i)` over the flat backing buffer of each tensor or vector. For vectors +/// this is the standard dot product; for higher-rank ([`FixedShapeTensor`]) arrays this is the +/// Frobenius inner product. +/// +/// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the +/// same dtype and a float element type. The output is a float column of the same float type. +/// +/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor +/// [`Vector`]: crate::vector::Vector +#[derive(Clone)] +pub struct InnerProduct; + +impl InnerProduct { + /// Creates a new [`ScalarFn`] wrapping the inner product operation with the given + /// [`ApproxOptions`] controlling approximation behavior. + pub fn new(options: &ApproxOptions) -> ScalarFn { + ScalarFn::new(InnerProduct, options.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily computes the inner product between `lhs` and + /// `rhs`. + /// + /// # Errors + /// + /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype + /// mismatches). + pub fn try_new_array( + options: &ApproxOptions, + lhs: ArrayRef, + rhs: ArrayRef, + len: usize, + ) -> VortexResult { + ScalarFnArray::try_new(InnerProduct::new(options).erased(), vec![lhs, rhs], len) + } +} + +impl ScalarFnVTable for InnerProduct { + type Options = ApproxOptions; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new_ref("vortex.tensor.inner_product") + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(2) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("lhs"), + 1 => ChildName::from("rhs"), + _ => unreachable!("InnerProduct must have exactly two children"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "inner_product(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", ")?; + expr.child(1).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let lhs = &arg_dtypes[0]; + let rhs = &arg_dtypes[1]; + + // Both must have the same dtype (ignoring top-level nullability). + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "InnerProduct requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + + // Both inputs must be tensor-like extension types. + let lhs_ext = lhs + .as_extension_opt() + .ok_or_else(|| vortex_err!("InnerProduct lhs must be an extension type, got {lhs}"))?; + + vortex_ensure!( + lhs_ext.is::(), + "InnerProduct inputs must be an `AnyTensor`, got {lhs}" + ); + + let ptype = extension_element_ptype(lhs_ext)?; + vortex_ensure!( + ptype.is_float(), + "InnerProduct element dtype must be a float primitive, got {ptype}" + ); + + let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); + Ok(DType::Primitive(ptype, nullability)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let lhs: ExtensionArray = args.get(0)?.execute(ctx)?; + let rhs: ExtensionArray = args.get(1)?.execute(ctx)?; + + let row_count = args.row_count(); + + // Compute combined validity. + let rhs_validity = rhs.as_ref().validity()?; + let validity = lhs.as_ref().validity()?.and(rhs_validity)?; + + // Get list size from the dtype. Both sides have the same dtype (validated by + // `return_dtype`). + let ext = lhs.dtype().as_extension(); + let list_size = extension_list_size(ext)? as usize; + + // Extract the storage array from each extension input. We pass the storage (FSL) rather + // than the extension array to avoid canonicalizing the extension wrapper. + let lhs_storage = lhs.data().storage_array(); + let rhs_storage = rhs.data().storage_array(); + + let lhs_flat = extract_flat_elements(lhs_storage, list_size, ctx)?; + let rhs_flat = extract_flat_elements(rhs_storage, list_size, ctx)?; + + match_each_float_ptype!(lhs_flat.ptype(), |T| { + let buffer: Buffer = (0..row_count) + .map(|i| inner_product_row(lhs_flat.row::(i), rhs_flat.row::(i))) + .collect(); + + // SAFETY: The buffer length equals `row_count`, which matches the source validity + // length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) + }) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + // The result is null if either input tensor is null. + let lhs_validity = expression.child(0).validity()?; + let rhs_validity = expression.child(1).validity()?; + + Ok(Some(and(lhs_validity, rhs_validity))) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Computes the inner product (dot product) of two equal-length float slices. +/// +/// Returns `sum(a_i * b_i)`. +fn inner_product_row(a: &[T], b: &[T]) -> T { + a.iter() + .zip(b.iter()) + .map(|(&x, &y)| x * y) + .fold(T::zero(), |acc, v| acc + v) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::ToCanonical; + use vortex_array::arrays::MaskedArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::arrays::ScalarFnArray; + use vortex_array::scalar_fn::ScalarFn; + use vortex_array::validity::Validity; + use vortex_error::VortexResult; + + use crate::scalar_fns::ApproxOptions; + use crate::scalar_fns::inner_product::InnerProduct; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; + + /// Evaluates inner product between two tensor arrays and returns the result as `Vec`. + fn eval_inner_product(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { + let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; + let prim = result.as_array().to_primitive(); + Ok(prim.as_slice::().to_vec()) + } + + /// Single-row inner product for various vector pairs. + #[rstest] + // Orthogonal: [1, 0] . [0, 1] = 0. + #[case::orthogonal(&[2], &[1.0, 0.0], &[0.0, 1.0], &[0.0])] + // Parallel: [3, 4] . [3, 4] = 9 + 16 = 25. + #[case::parallel(&[2], &[3.0, 4.0], &[3.0, 4.0], &[25.0])] + // Antiparallel: [1, 2] . [-1, -2] = -1 + -4 = -5. + #[case::antiparallel(&[2], &[1.0, 2.0], &[-1.0, -2.0], &[-5.0])] + // Scaled: [2, 0] . [3, 0] = 6. + #[case::scaled(&[2], &[2.0, 0.0], &[3.0, 0.0], &[6.0])] + fn single_row( + #[case] shape: &[usize], + #[case] lhs_elems: &[f64], + #[case] rhs_elems: &[f64], + #[case] expected: &[f64], + ) -> VortexResult<()> { + let lhs = tensor_array(shape, lhs_elems)?; + let rhs = tensor_array(shape, rhs_elems)?; + assert_close(&eval_inner_product(lhs, rhs, 1)?, expected); + Ok(()) + } + + #[test] + fn multiple_rows() -> VortexResult<()> { + let lhs = tensor_array( + &[3], + &[ + 1.0, 0.0, 0.0, // tensor 0 + 3.0, 4.0, 0.0, // tensor 1 + 1.0, 1.0, 1.0, // tensor 2 + ], + )?; + let rhs = tensor_array( + &[3], + &[ + 0.0, 1.0, 0.0, // tensor 0: dot = 0 + 3.0, 4.0, 0.0, // tensor 1: dot = 25 + 2.0, 2.0, 2.0, // tensor 2: dot = 6 + ], + )?; + assert_close(&eval_inner_product(lhs, rhs, 3)?, &[0.0, 25.0, 6.0]); + Ok(()) + } + + #[test] + fn vector_inner_product() -> VortexResult<()> { + let lhs = vector_array( + 2, + &[ + 3.0, 4.0, // vector 0 + 1.0, 0.0, // vector 1 + ], + )?; + let rhs = vector_array( + 2, + &[ + 3.0, 4.0, // vector 0: dot = 25 + 0.0, 1.0, // vector 1: dot = 0 + ], + )?; + assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); + Ok(()) + } + + #[test] + fn null_input_row() -> VortexResult<()> { + // 3 rows of dim-2 vectors. Row 1 of lhs is masked as null. + let lhs = tensor_array(&[2], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])?; + let rhs = tensor_array(&[2], &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0])?; + let lhs = MaskedArray::try_new(lhs, Validity::from_iter([true, false, true]))?.into_array(); + + let scalar_fn = ScalarFn::new(InnerProduct, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 3)?; + let prim = result.as_array().to_primitive(); + + // Row 0: 1*7 + 2*8 = 23, row 1: null, row 2: 5*11 + 6*12 = 127. + assert!(prim.is_valid(0)?); + assert!(!prim.is_valid(1)?); + assert!(prim.is_valid(2)?); + assert_close(&[prim.as_slice::()[0]], &[23.0]); + assert_close(&[prim.as_slice::()[2]], &[127.0]); + Ok(()) + } + + #[test] + fn rejects_non_extension_dtype() { + let lhs = PrimitiveArray::from_iter([1.0_f64, 2.0]).into_array(); + let rhs = PrimitiveArray::from_iter([3.0_f64, 4.0]).into_array(); + let result = InnerProduct::try_new_array(&ApproxOptions::Exact, lhs, rhs, 2); + assert!(result.is_err()); + } + + #[test] + fn rejects_mismatched_dtypes() -> VortexResult<()> { + let lhs = tensor_array(&[2], &[1.0_f64, 2.0])?; + let rhs = vector_array(2, &[3.0_f64, 4.0])?; + let result = InnerProduct::try_new_array(&ApproxOptions::Exact, lhs, rhs, 1); + assert!(result.is_err()); + Ok(()) + } +} diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index f319a0df0bb..ed29cc776b7 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -1,9 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! L2 norm expression for tensor-like extension arrays -//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and -//! [`Vector`](crate::vector::Vector)). +//! L2 norm expression for tensor-like types. use std::fmt::Formatter; @@ -11,7 +9,9 @@ use num_traits::Float; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; @@ -20,8 +20,10 @@ use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFn; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; @@ -30,7 +32,6 @@ use crate::matcher::AnyTensor; use crate::scalar_fns::ApproxOptions; use crate::utils::extension_element_ptype; use crate::utils::extension_list_size; -use crate::utils::extension_storage; use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -42,6 +43,28 @@ use crate::utils::extract_flat_elements; #[derive(Clone)] pub struct L2Norm; +impl L2Norm { + /// Creates a new [`ScalarFn`] wrapping the L2 norm operation with the given [`ApproxOptions`] + /// controlling approximation behavior. + pub fn new(options: &ApproxOptions) -> ScalarFn { + ScalarFn::new(L2Norm, options.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily computes the L2 norm over `child`. + /// + /// # Errors + /// + /// Returns an error if the [`ScalarFnArray`] cannot be constructed (e.g. due to dtype + /// mismatches). + pub fn try_new_array( + options: &ApproxOptions, + child: ArrayRef, + len: usize, + ) -> VortexResult { + ScalarFnArray::try_new(L2Norm::new(options).erased(), vec![child], len) + } +} + impl ScalarFnVTable for L2Norm { type Options = ApproxOptions; @@ -100,27 +123,26 @@ impl ScalarFnVTable for L2Norm { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let input = args.get(0)?; + let input: ExtensionArray = args.get(0)?.execute(ctx)?; + let row_count = args.row_count(); + let validity = input.as_ref().validity()?; - // Get list size (dimensions) from the dtype. - let ext = input.dtype().as_extension_opt().ok_or_else(|| { - vortex_err!( - "l2_norm input must be an extension type, got {}", - input.dtype() - ) - })?; + // Get list size (dimensions) from the dtype (validated by `return_dtype`). + let ext = input.dtype().as_extension(); let list_size = extension_list_size(ext)? as usize; - let storage = extension_storage(&input)?; - let flat = extract_flat_elements(&storage, list_size, ctx)?; + let storage = input.data().storage_array(); + let flat = extract_flat_elements(storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { - let result: PrimitiveArray = (0..row_count) + let buffer: Buffer = (0..row_count) .map(|i| l2_norm_row(flat.row::(i))) .collect(); - Ok(result.into_array()) + // SAFETY: The buffer length equals `row_count`, which matches the source validity + // length. + Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) }) } @@ -157,9 +179,12 @@ fn l2_norm_row(v: &[T]) -> T { mod tests { use rstest::rstest; use vortex_array::ArrayRef; + use vortex_array::IntoArray; use vortex_array::ToCanonical; + use vortex_array::arrays::MaskedArray; use vortex_array::arrays::ScalarFnArray; use vortex_array::scalar_fn::ScalarFn; + use vortex_array::validity::Validity; use vortex_error::VortexResult; use crate::scalar_fns::ApproxOptions; @@ -217,4 +242,21 @@ mod tests { assert_close(&eval_l2_norm(arr, 2)?, &[1.0, 5.0]); Ok(()) } + + #[test] + fn null_input_row() -> VortexResult<()> { + // 2 rows of dim-2 vectors. Row 1 is masked as null. + let arr = tensor_array(&[2], &[3.0, 4.0, 0.0, 0.0])?; + let arr = MaskedArray::try_new(arr, Validity::from_iter([true, false]))?.into_array(); + + let scalar_fn = ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(); + let result = ScalarFnArray::try_new(scalar_fn, vec![arr], 2)?; + let prim = result.as_array().to_primitive(); + + // Row 0: norm = 5.0, row 1: null. + assert!(prim.is_valid(0)?); + assert!(!prim.is_valid(1)?); + assert_close(&[prim.as_slice::()[0]], &[5.0]); + Ok(()) + } } diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index c699b46cfca..b10fd335420 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -6,6 +6,7 @@ use std::fmt; pub mod cosine_similarity; +pub mod inner_product; pub mod l2_norm; /// Options for tensor-related expressions that might have error. diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index cd21cd65964..6a84f8bbc7f 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -6,7 +6,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Constant; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::Extension; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; @@ -54,15 +53,6 @@ pub fn extension_element_ptype(ext: &ExtDTypeRef) -> VortexResult { Ok(element_dtype.as_ptype()) } -/// Extracts the storage array from an extension array without canonicalizing. -pub fn extension_storage(array: &ArrayRef) -> VortexResult { - let ext = array - .as_opt::() - .ok_or_else(|| vortex_err!("scalar_fn input must be an extension array"))?; - - Ok(ext.storage_array().clone()) -} - /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a @@ -124,7 +114,6 @@ pub fn extract_flat_elements( #[cfg(test)] pub mod test_helpers { use vortex_array::ArrayRef; - use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; @@ -138,11 +127,7 @@ pub mod test_helpers { use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; - use vortex_error::vortex_err; - use super::extension_list_size; - use super::extension_storage; - use super::extract_flat_elements; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::vector::Vector; @@ -222,26 +207,6 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - #[expect(dead_code, reason = "TODO(connor): Use this!")] - /// Extracts the f64 rows from a [`Vector`] extension array. - /// - /// Returns a `Vec>` where each inner vec is one vector's elements. - pub fn extract_vector_rows( - array: &ArrayRef, - ctx: &mut ExecutionCtx, - ) -> VortexResult>> { - let ext = array - .dtype() - .as_extension_opt() - .ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?; - let list_size = extension_list_size(ext)? as usize; - let storage = extension_storage(array)?; - let flat = extract_flat_elements(&storage, list_size, ctx)?; - Ok((0..array.len()) - .map(|i| flat.row::(i).to_vec()) - .collect()) - } - /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` /// value, with support for NaN (NaN == NaN is considered equal). #[track_caller]