Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@
//! similarity.

pub mod fixed_shape;
pub mod vector;

pub mod matcher;
pub mod scalar_fns;
42 changes: 42 additions & 0 deletions vortex-tensor/src/matcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Matcher for tensor-like extension types.

use vortex::dtype::extension::ExtDTypeRef;
use vortex::dtype::extension::Matcher;

use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
use crate::vector::Vector;

/// Matcher for any tensor-like extension type.
///
/// Currently the different kinds of tensors that are available are:
///
/// - `FixedShapeTensor`
/// - `Vector`
pub struct AnyTensor;

/// The matched variant of a tensor-like extension type.
#[derive(Debug, PartialEq, Eq)]
pub enum TensorMatch<'a> {
/// A [`FixedShapeTensor`] extension type.
FixedShapeTensor(&'a FixedShapeTensorMetadata),
/// A [`Vector`] extension type.
Vector,
}

impl Matcher for AnyTensor {
type Match<'a> = TensorMatch<'a>;

fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
if let Some(metadata) = item.metadata_opt::<FixedShapeTensor>() {
return Some(TensorMatch::FixedShapeTensor(metadata));
}
if item.metadata_opt::<Vector>().is_some() {
return Some(TensorMatch::Vector);
}
None
}
}
203 changes: 121 additions & 82 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Cosine similarity expression for [`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor)
//! arrays.
//! Cosine similarity expression for tensor-like extension arrays
//! ([`FixedShapeTensor`](crate::fixed_shape::FixedShapeTensor) and
//! [`Vector`](crate::vector::Vector)).

use std::fmt::Formatter;

use num_traits::Float;
use vortex::array::ArrayRef;
use vortex::array::DynArray;
use vortex::array::ExecutionCtx;
use vortex::array::IntoArray;
use vortex::array::arrays::Constant;
use vortex::array::arrays::ConstantArray;
use vortex::array::arrays::Extension;
use vortex::array::arrays::PrimitiveArray;
use vortex::array::match_each_float_ptype;
use vortex::dtype::DType;
use vortex::dtype::NativePType;
use vortex::dtype::Nullability;
use vortex::dtype::extension::Matcher;
use vortex::error::VortexResult;
use vortex::error::vortex_bail;
use vortex::error::vortex_ensure;
use vortex::error::vortex_err;
use vortex::expr::Expression;
Expand All @@ -31,18 +28,23 @@ use vortex::scalar_fn::ExecutionArgs;
use vortex::scalar_fn::ScalarFnId;
use vortex::scalar_fn::ScalarFnVTable;

// TODO(connor): We will want to add implementations for unit normalized vectors and also vectors
// encoded in spherical coordinates.
use crate::matcher::AnyTensor;
use crate::scalar_fns::utils::extension_element_ptype;
use crate::scalar_fns::utils::extension_list_size;
use crate::scalar_fns::utils::extension_storage;
use crate::scalar_fns::utils::extract_flat_elements;

/// Cosine similarity between two columns.
///
/// For [`FixedShapeTensor`], computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of
/// each tensor. The shape and permutation do not affect the result because cosine similarity only
/// depends on the element values, not their logical arrangement.
/// Computes `dot(a, b) / (||a|| * ||b||)` over the flat backing buffer of each tensor or vector.
/// The shape and permutation do not affect the result because cosine similarity only depends on the
/// element values, not their logical arrangement.
///
/// Right now, both inputs must be [`FixedShapeTensor`] extension arrays with the same dtype and a
/// float element type. The output is a float column of the same float type.
/// Both inputs must be tensor-like extension arrays ([`FixedShapeTensor`] or [`Vector`]) with the
/// same dtype and a float element type. The output is a float column of the same float type.
///
/// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor
/// [`Vector`]: crate::vector::Vector
#[derive(Clone)]
pub struct CosineSimilarity;

Expand Down Expand Up @@ -92,33 +94,22 @@ impl ScalarFnVTable for CosineSimilarity {

// We don't need to look at rhs anymore since we know lhs and rhs are equal.

// Both inputs must be extension types.
// Both inputs must be tensor-like extension types.
let lhs_ext = lhs.as_extension_opt().ok_or_else(|| {
vortex_err!("cosine_similarity lhs must be an extension type, got {lhs}")
})?;

// Extract the element dtype from the storage FixedSizeList.
let element_dtype = lhs_ext
.storage_dtype()
.as_fixed_size_list_element_opt()
.ok_or_else(|| {
vortex_err!(
"cosine_similarity storage dtype must be a FixedSizeList, got {}",
lhs_ext.storage_dtype()
)
})?;

// Element dtype must be a non-nullable float primitive.
vortex_ensure!(
element_dtype.is_float(),
"cosine_similarity element dtype must be a float primitive, got {element_dtype}"
AnyTensor::matches(lhs_ext),
"cosine_similarity inputs must be an `AnyTensor`, got {lhs}"
);

let ptype = extension_element_ptype(lhs_ext)?;
vortex_ensure!(
!element_dtype.is_nullable(),
"cosine_similarity element dtype must be non-nullable"
ptype.is_float(),
"cosine_similarity element dtype must be a float primitive, got {ptype}"
);

let ptype = element_dtype.as_ptype();
let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable());
Ok(DType::Primitive(ptype, nullability))
}
Expand All @@ -140,32 +131,19 @@ impl ScalarFnVTable for CosineSimilarity {
lhs.dtype()
)
})?;
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
vortex_bail!("expected FixedSizeList storage dtype");
};
let list_size = *list_size as usize;
let list_size = extension_list_size(ext)?;

// Extract the storage array from each extension input. We pass the storage (FSL) rather
// than the extension array to avoid canonicalizing the extension wrapper.
let lhs_storage = extension_storage(&lhs)?;
let rhs_storage = extension_storage(&rhs)?;

// Extract the flat primitive elements from each tensor column. When an input is a
// `ConstantArray` (e.g., a literal query vector), we materialize only a single row
// instead of expanding it to the full row count.
let (lhs_elems, lhs_stride) = extract_flat_elements(&lhs_storage, list_size)?;
let (rhs_elems, rhs_stride) = extract_flat_elements(&rhs_storage, list_size)?;

match_each_float_ptype!(lhs_elems.ptype(), |T| {
let lhs_slice = lhs_elems.as_slice::<T>();
let rhs_slice = rhs_elems.as_slice::<T>();
let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?;
let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?;

match_each_float_ptype!(lhs_flat.ptype(), |T| {
let result: PrimitiveArray = (0..row_count)
.map(|i| {
let a = &lhs_slice[i * lhs_stride..i * lhs_stride + list_size];
let b = &rhs_slice[i * rhs_stride..i * rhs_stride + list_size];
cosine_similarity_row(a, b)
})
.map(|i| cosine_similarity_row(lhs_flat.row::<T>(i), rhs_flat.row::<T>(i)))
.collect();

Ok(result.into_array())
Expand Down Expand Up @@ -194,38 +172,6 @@ impl ScalarFnVTable for CosineSimilarity {
}
}

/// Extracts the storage array from an extension array without canonicalizing.
fn extension_storage(array: &ArrayRef) -> VortexResult<ArrayRef> {
let ext = array
.as_opt::<Extension>()
.ok_or_else(|| vortex_err!("cosine_similarity input must be an extension array"))?;
Ok(ext.storage_array().clone())
}

/// Extracts the flat primitive elements from a tensor storage array (FixedSizeList).
///
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
/// materialized to avoid expanding it to the full column length. Returns `(elements, stride)`
/// where `stride` is `list_size` for a full array and `0` for a constant.
fn extract_flat_elements(
storage: &ArrayRef,
list_size: usize,
) -> VortexResult<(PrimitiveArray, usize)> {
if let Some(constant) = storage.as_opt::<Constant>() {
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a
// huge amount of data.
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
let fsl = single.to_canonical()?.into_fixed_size_list();
let elems = fsl.elements().to_canonical()?.into_primitive();
Ok((elems, 0))
} else {
// Otherwise we have to fully expand all of the data.
let fsl = storage.to_canonical()?.into_fixed_size_list();
let elems = fsl.elements().to_canonical()?.into_primitive();
Ok((elems, list_size))
}
}

// TODO(connor): We should try to use a more performant library instead of doing this ourselves.
/// Computes cosine similarity between two equal-length float slices.
///
Expand Down Expand Up @@ -258,13 +204,15 @@ mod tests {
use vortex::dtype::Nullability;
use vortex::dtype::extension::ExtDType;
use vortex::error::VortexResult;
use vortex::extension::EmptyMetadata;
use vortex::scalar::Scalar;
use vortex::scalar_fn::EmptyOptions;
use vortex::scalar_fn::ScalarFn;

use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::vector::Vector;

/// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape.
///
Expand Down Expand Up @@ -460,4 +408,95 @@ mod tests {
);
Ok(())
}

/// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size.
fn vector_array(dim: u32, elements: &[f64]) -> VortexResult<ArrayRef> {
let row_count = elements.len() / dim as usize;

let elems: ArrayRef = Buffer::copy_from(elements).into_array();
let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count);

let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();

Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
}

#[test]
fn vector_unit_vectors() -> VortexResult<()> {
let lhs = vector_array(
3,
&[
1.0, 0.0, 0.0, // vector 0
0.0, 1.0, 0.0, // vector 1
],
)?;
let rhs = vector_array(
3,
&[
1.0, 0.0, 0.0, // vector 0
1.0, 0.0, 0.0, // vector 1
],
)?;

// Row 0: identical -> 1.0, row 1: orthogonal -> 0.0.
assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]);
Ok(())
}

#[test]
fn vector_self_similarity() -> VortexResult<()> {
let arr = vector_array(
4,
&[
1.0, 2.0, 3.0, 4.0, // vector 0
0.0, 1.0, 0.0, 0.0, // vector 1
5.0, 0.0, 5.0, 0.0, // vector 2
],
)?;

assert_close(
&eval_cosine_similarity(arr.clone(), arr, 3)?,
&[1.0, 1.0, 1.0],
);
Ok(())
}

/// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`].
fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult<ArrayRef> {
let element_dtype = DType::Primitive(vortex::dtype::PType::F64, Nullability::NonNullable);

let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);

let storage = ConstantArray::new(storage_scalar, len).into_array();

let ext_dtype =
ExtDType::<Vector>::try_new(EmptyMetadata, storage.dtype().clone())?.erased();

Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}

#[test]
fn vector_constant_query() -> VortexResult<()> {
let data = vector_array(
3,
&[
1.0, 0.0, 0.0, // vector 0
0.0, 1.0, 0.0, // vector 1
0.0, 0.0, 1.0, // vector 2
1.0, 0.0, 0.0, // vector 3
],
)?;
let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?;

assert_close(
&eval_cosine_similarity(data, query, 4)?,
&[1.0, 0.0, 0.0, 1.0],
);
Ok(())
}
}
Loading
Loading