diff --git a/encodings/alp/public-api.lock b/encodings/alp/public-api.lock index 4105a593e53..8416610a039 100644 --- a/encodings/alp/public-api.lock +++ b/encodings/alp/public-api.lock @@ -256,10 +256,6 @@ impl vortex_array::compute::between::BetweenKernel for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::between(&self, array: &vortex_alp::ALPArray, lower: &dyn vortex_array::array::Array, upper: &dyn vortex_array::array::Array, options: &vortex_array::compute::between::BetweenOptions) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_alp::ALPVTable - -pub fn vortex_alp::ALPVTable::compare(&self, lhs: &vortex_alp::ALPArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::mask::MaskKernel for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::mask(&self, array: &vortex_alp::ALPArray, filter_mask: &vortex_mask::Mask) -> vortex_error::VortexResult @@ -268,6 +264,10 @@ impl vortex_array::compute::nan_count::NaNCountKernel for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::nan_count(&self, array: &vortex_alp::ALPArray) -> vortex_error::VortexResult +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_alp::ALPVTable + +pub fn vortex_alp::ALPVTable::compare(lhs: &vortex_alp::ALPArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_alp::ALPVTable pub fn vortex_alp::ALPVTable::cast(array: &vortex_alp::ALPArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/alp/src/alp/compute/compare.rs b/encodings/alp/src/alp/compute/compare.rs index 6f04b58841b..e90aa3cd81a 100644 --- a/encodings/alp/src/alp/compute/compare.rs +++ b/encodings/alp/src/alp/compute/compare.rs @@ -5,13 +5,12 @@ use std::fmt::Debug; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::compare; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_dtype::NativePType; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -27,10 +26,10 @@ use crate::match_each_alp_float_ptype; impl CompareKernel for ALPVTable { fn compare( - &self, lhs: &ALPArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { if lhs.patches().is_some() { // TODO(joe): support patches @@ -66,8 +65,6 @@ impl CompareKernel for ALPVTable { } } -register_kernel!(CompareKernelAdapter(ALPVTable).lift()); - /// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing /// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't /// encode into the ALP domain. diff --git a/encodings/alp/src/alp/rules.rs b/encodings/alp/src/alp/rules.rs index f90739c1733..57d177f99f8 100644 --- a/encodings/alp/src/alp/rules.rs +++ b/encodings/alp/src/alp/rules.rs @@ -5,12 +5,14 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use vortex_array::optimizer::rules::ParentRuleSet; use crate::ALPVTable; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&SliceExecuteAdaptor(ALPVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)), diff --git a/encodings/datetime-parts/public-api.lock b/encodings/datetime-parts/public-api.lock index b7938753428..b43cbafad7a 100644 --- a/encodings/datetime-parts/public-api.lock +++ b/encodings/datetime-parts/public-api.lock @@ -114,10 +114,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_datetime_parts::DateTim pub fn vortex_datetime_parts::DateTimePartsVTable::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_datetime_parts::DateTimePartsVTable - -pub fn vortex_datetime_parts::DateTimePartsVTable::compare(&self, lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::is_constant::IsConstantKernel for vortex_datetime_parts::DateTimePartsVTable pub fn vortex_datetime_parts::DateTimePartsVTable::is_constant(&self, array: &vortex_datetime_parts::DateTimePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> @@ -126,6 +122,10 @@ impl vortex_array::compute::mask::MaskKernel for vortex_datetime_parts::DateTime pub fn vortex_datetime_parts::DateTimePartsVTable::mask(&self, array: &vortex_datetime_parts::DateTimePartsArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_datetime_parts::DateTimePartsVTable + +pub fn vortex_datetime_parts::DateTimePartsVTable::compare(lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_datetime_parts::DateTimePartsVTable pub fn vortex_datetime_parts::DateTimePartsVTable::cast(array: &vortex_datetime_parts::DateTimePartsArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/datetime-parts/src/compute/compare.rs b/encodings/datetime-parts/src/compute/compare.rs index 2885b35de26..77188e6515e 100644 --- a/encodings/datetime-parts/src/compute/compare.rs +++ b/encodings/datetime-parts/src/compute/compare.rs @@ -3,16 +3,15 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::and_kleene; use vortex_array::compute::compare; use vortex_array::compute::or_kleene; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::datetime::Timestamp; @@ -24,13 +23,11 @@ use crate::array::DateTimePartsVTable; use crate::timestamp; impl CompareKernel for DateTimePartsVTable { - /// Compares two arrays and returns a new boolean array with the result of the comparison. - /// Or, returns None if comparison is not supported. fn compare( - &self, lhs: &DateTimePartsArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { let Some(rhs_const) = rhs.as_constant() else { return Ok(None); @@ -71,8 +68,6 @@ impl CompareKernel for DateTimePartsVTable { } } -register_kernel!(CompareKernelAdapter(DateTimePartsVTable).lift()); - fn compare_eq( lhs: &DateTimePartsArray, ts_parts: ×tamp::TimestampParts, diff --git a/encodings/datetime-parts/src/compute/kernel.rs b/encodings/datetime-parts/src/compute/kernel.rs index 9c95c3439ca..301d8580340 100644 --- a/encodings/datetime-parts/src/compute/kernel.rs +++ b/encodings/datetime-parts/src/compute/kernel.rs @@ -2,11 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::DateTimePartsVTable; -pub(crate) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( - DateTimePartsVTable, - ))]); +pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(DateTimePartsVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(DateTimePartsVTable)), +]); diff --git a/encodings/decimal-byte-parts/public-api.lock b/encodings/decimal-byte-parts/public-api.lock index 42212cf004b..a2117734a75 100644 --- a/encodings/decimal-byte-parts/public-api.lock +++ b/encodings/decimal-byte-parts/public-api.lock @@ -64,10 +64,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_decimal_byte_parts::Dec pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::slice(array: &vortex_decimal_byte_parts::DecimalBytePartsArray, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable - -pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::compare(&self, lhs: &Self::Array, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::is_constant::IsConstantKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::is_constant(&self, array: &vortex_decimal_byte_parts::DecimalBytePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> @@ -76,6 +72,10 @@ impl vortex_array::compute::mask::MaskKernel for vortex_decimal_byte_parts::Deci pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::mask(&self, array: &vortex_decimal_byte_parts::DecimalBytePartsArray, mask_array: &vortex_mask::Mask) -> vortex_error::VortexResult +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_decimal_byte_parts::DecimalBytePartsVTable + +pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::compare(lhs: &Self::Array, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_decimal_byte_parts::DecimalBytePartsVTable pub fn vortex_decimal_byte_parts::DecimalBytePartsVTable::cast(array: &vortex_decimal_byte_parts::DecimalBytePartsArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs index 2ff776c36b3..0aa77420b5f 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs @@ -5,12 +5,11 @@ use Sign::Negative; use num_traits::NumCast; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::compare; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_dtype::IntegerPType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -28,10 +27,10 @@ use crate::decimal_byte_parts::compute::compare::Sign::Positive; impl CompareKernel for DecimalBytePartsVTable { fn compare( - &self, lhs: &Self::Array, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { let Some(rhs_const) = rhs.as_constant() else { return Ok(None); @@ -132,8 +131,6 @@ where }) } -register_kernel!(CompareKernelAdapter(DecimalBytePartsVTable).lift()); - #[cfg(test)] mod tests { use vortex_array::Array; diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs index a802fad5db1..b53d03359b5 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/kernel.rs @@ -2,11 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::DecimalBytePartsVTable; -pub(crate) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor( - DecimalBytePartsVTable, - ))]); +pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(DecimalBytePartsVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(DecimalBytePartsVTable)), +]); diff --git a/encodings/fastlanes/public-api.lock b/encodings/fastlanes/public-api.lock index d1fcdeef9db..7206974aa4b 100644 --- a/encodings/fastlanes/public-api.lock +++ b/encodings/fastlanes/public-api.lock @@ -464,10 +464,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_fastlanes::FoRVTable pub fn vortex_fastlanes::FoRVTable::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_fastlanes::FoRVTable - -pub fn vortex_fastlanes::FoRVTable::compare(&self, lhs: &vortex_fastlanes::FoRArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::is_constant::IsConstantKernel for vortex_fastlanes::FoRVTable pub fn vortex_fastlanes::FoRVTable::is_constant(&self, array: &vortex_fastlanes::FoRArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> @@ -478,6 +474,10 @@ pub fn vortex_fastlanes::FoRVTable::is_sorted(&self, array: &vortex_fastlanes::F pub fn vortex_fastlanes::FoRVTable::is_strict_sorted(&self, array: &vortex_fastlanes::FoRArray) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_fastlanes::FoRVTable + +pub fn vortex_fastlanes::FoRVTable::compare(lhs: &vortex_fastlanes::FoRArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_fastlanes::FoRVTable pub fn vortex_fastlanes::FoRVTable::cast(array: &vortex_fastlanes::FoRArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/fastlanes/src/for/compute/compare.rs b/encodings/fastlanes/src/for/compute/compare.rs index 6ae51867ad8..cccdc6daceb 100644 --- a/encodings/fastlanes/src/for/compute/compare.rs +++ b/encodings/fastlanes/src/for/compute/compare.rs @@ -6,12 +6,11 @@ use std::ops::Shr; use num_traits::WrappingSub; use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::compare; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_dtype::NativePType; use vortex_dtype::Nullability; use vortex_dtype::match_each_integer_ptype; @@ -26,10 +25,10 @@ use crate::FoRVTable; impl CompareKernel for FoRVTable { fn compare( - &self, lhs: &FoRArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { if let Some(constant) = rhs.as_constant() && let Some(constant) = constant.as_primitive_opt() @@ -39,7 +38,7 @@ impl CompareKernel for FoRVTable { lhs, constant .typed_value::() - .vortex_expect("null scalar handled in top-level"), + .vortex_expect("null scalar handled in adaptor"), rhs.dtype().nullability(), operator, ); @@ -50,8 +49,6 @@ impl CompareKernel for FoRVTable { } } -register_kernel!(CompareKernelAdapter(FoRVTable).lift()); - fn compare_constant( lhs: &FoRArray, mut rhs: T, diff --git a/encodings/fastlanes/src/for/vtable/kernels.rs b/encodings/fastlanes/src/for/vtable/kernels.rs index 60a009afe15..96e1f010178 100644 --- a/encodings/fastlanes/src/for/vtable/kernels.rs +++ b/encodings/fastlanes/src/for/vtable/kernels.rs @@ -2,9 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::FoRVTable; -pub(crate) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(FoRVTable))]); +pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(FoRVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(FoRVTable)), +]); diff --git a/encodings/fsst/benches/fsst_compress.rs b/encodings/fsst/benches/fsst_compress.rs index c0dd99db5f0..d5f8660be8c 100644 --- a/encodings/fsst/benches/fsst_compress.rs +++ b/encodings/fsst/benches/fsst_compress.rs @@ -10,6 +10,8 @@ use rand::Rng; use rand::SeedableRng; use rand::rngs::StdRng; use vortex_array::IntoArray; +use vortex_array::LEGACY_SESSION; +use vortex_array::RecursiveCanonical; use vortex_array::VortexSessionExecute; use vortex_array::arrays::ChunkedArray; use vortex_array::arrays::ConstantArray; @@ -87,9 +89,18 @@ fn pushdown_compare(bencher: Bencher, (string_count, avg_len, unique_chars): (us let constant = ConstantArray::new(Scalar::from(&b"const"[..]), array.len()); bencher - .with_inputs(|| (&fsst_array, &constant)) - .bench_refs(|(fsst_array, constant)| { - compare(fsst_array.as_ref(), constant.as_ref(), Operator::Eq).unwrap(); + .with_inputs(|| { + ( + &fsst_array, + &constant, + LEGACY_SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(fsst_array, constant, ctx)| { + compare(fsst_array.as_ref(), constant.as_ref(), Operator::Eq) + .unwrap() + .execute::(ctx) + .unwrap(); }) } @@ -104,14 +115,22 @@ fn canonicalize_compare( let constant = ConstantArray::new(Scalar::from(&b"const"[..]), array.len()); bencher - .with_inputs(|| (&fsst_array, &constant)) - .bench_refs(|(fsst_array, constant)| { + .with_inputs(|| { + ( + &fsst_array, + &constant, + LEGACY_SESSION.create_execution_ctx(), + ) + }) + .bench_refs(|(fsst_array, constant, ctx)| { compare( fsst_array.to_canonical().unwrap().as_ref(), constant.as_ref(), Operator::Eq, ) .unwrap() + .execute::(ctx) + .unwrap(); }); } @@ -135,14 +154,16 @@ fn chunked_canonicalize_into( ) { let array = generate_chunked_test_data(chunk_size, string_count, avg_len, unique_chars); - bencher.with_inputs(|| &array).bench_refs(|array| { - let mut builder = - VarBinViewBuilder::with_capacity(DType::Binary(Nullability::NonNullable), array.len()); - array - .append_to_builder(&mut builder, &mut SESSION.create_execution_ctx()) - .unwrap(); - builder.finish() - }); + bencher + .with_inputs(|| (&array, SESSION.create_execution_ctx())) + .bench_refs(|(array, ctx)| { + let mut builder = VarBinViewBuilder::with_capacity( + DType::Binary(Nullability::NonNullable), + array.len(), + ); + array.append_to_builder(&mut builder, ctx).unwrap(); + builder.finish() + }); } #[divan::bench(args = CHUNKED_BENCH_ARGS)] diff --git a/encodings/fsst/public-api.lock b/encodings/fsst/public-api.lock index dfefa5d92ec..00d030a5a23 100644 --- a/encodings/fsst/public-api.lock +++ b/encodings/fsst/public-api.lock @@ -104,9 +104,9 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_fsst::FSSTVTable pub fn vortex_fsst::FSSTVTable::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_fsst::FSSTVTable +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_fsst::FSSTVTable -pub fn vortex_fsst::FSSTVTable::compare(&self, lhs: &vortex_fsst::FSSTArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> +pub fn vortex_fsst::FSSTVTable::compare(lhs: &vortex_fsst::FSSTArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_fsst::FSSTVTable diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 24ff9e9f3a1..7dc0c79f1de 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -3,16 +3,15 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::compare; use vortex_array::compute::compare_lengths_to_empty; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_array::validity::Validity; use vortex_buffer::BitBuffer; use vortex_buffer::ByteBuffer; @@ -28,10 +27,10 @@ use crate::FSSTVTable; impl CompareKernel for FSSTVTable { fn compare( - &self, lhs: &FSSTArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { match rhs.as_constant() { Some(constant) => compare_fsst_constant(lhs, &constant, operator), @@ -41,8 +40,6 @@ impl CompareKernel for FSSTVTable { } } -register_kernel!(CompareKernelAdapter(FSSTVTable).lift()); - /// Specialized compare function implementation used when performing against a constant fn compare_fsst_constant( left: &FSSTArray, diff --git a/encodings/fsst/src/kernel.rs b/encodings/fsst/src/kernel.rs index e3e11a1ed5e..d304e5dc653 100644 --- a/encodings/fsst/src/kernel.rs +++ b/encodings/fsst/src/kernel.rs @@ -3,11 +3,13 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; use crate::FSSTVTable; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(FSSTVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(FSSTVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(FSSTVTable)), ]); diff --git a/encodings/runend/public-api.lock b/encodings/runend/public-api.lock index ea9bd0e14f6..8b146026548 100644 --- a/encodings/runend/public-api.lock +++ b/encodings/runend/public-api.lock @@ -122,10 +122,6 @@ impl vortex_array::arrays::filter::kernel::FilterKernel for vortex_runend::RunEn pub fn vortex_runend::RunEndVTable::filter(array: &vortex_runend::RunEndArray, mask: &vortex_mask::Mask, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_runend::RunEndVTable - -pub fn vortex_runend::RunEndVTable::compare(&self, lhs: &vortex_runend::RunEndArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::is_constant::IsConstantKernel for vortex_runend::RunEndVTable pub fn vortex_runend::RunEndVTable::is_constant(&self, array: &Self::Array, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult> @@ -140,6 +136,10 @@ impl vortex_array::compute::min_max::MinMaxKernel for vortex_runend::RunEndVTabl pub fn vortex_runend::RunEndVTable::min_max(&self, array: &vortex_runend::RunEndArray) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_runend::RunEndVTable + +pub fn vortex_runend::RunEndVTable::compare(lhs: &vortex_runend::RunEndArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_runend::RunEndVTable pub fn vortex_runend::RunEndVTable::cast(array: &vortex_runend::RunEndArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/runend/src/compute/compare.rs b/encodings/runend/src/compute/compare.rs index 4161898af61..1d95bf11fa5 100644 --- a/encodings/runend/src/compute/compare.rs +++ b/encodings/runend/src/compute/compare.rs @@ -3,14 +3,13 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; -use vortex_array::compute::CompareKernelAdapter; use vortex_array::compute::Operator; use vortex_array::compute::compare; -use vortex_array::register_kernel; +use vortex_array::expr::CompareKernel; use vortex_error::VortexResult; use crate::RunEndArray; @@ -19,10 +18,10 @@ use crate::compress::runend_decode_bools; impl CompareKernel for RunEndVTable { fn compare( - &self, lhs: &RunEndArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { // If the RHS is constant, then we just need to compare against our encoded values. if let Some(const_scalar) = rhs.as_constant() { @@ -45,8 +44,6 @@ impl CompareKernel for RunEndVTable { } } -register_kernel!(CompareKernelAdapter(RunEndVTable).lift()); - #[cfg(test)] mod test { use vortex_array::IntoArray; diff --git a/encodings/runend/src/kernel.rs b/encodings/runend/src/kernel.rs index 4873d9e15b1..e74d8338f1a 100644 --- a/encodings/runend/src/kernel.rs +++ b/encodings/runend/src/kernel.rs @@ -11,6 +11,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceArray; use vortex_array::arrays::SliceVTable; use vortex_array::arrays::TakeExecuteAdaptor; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ExecuteParentKernel; use vortex_array::kernel::ParentKernelSet; use vortex_error::VortexResult; @@ -20,6 +21,7 @@ use crate::RunEndVTable; use crate::compute::take_from::RunEndVTableTakeFrom; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(RunEndVTable)), ParentKernelSet::lift(&RunEndSliceKernel), ParentKernelSet::lift(&FilterExecuteAdaptor(RunEndVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(RunEndVTable)), diff --git a/encodings/sequence/public-api.lock b/encodings/sequence/public-api.lock index 3a8034b3fad..b939289103e 100644 --- a/encodings/sequence/public-api.lock +++ b/encodings/sequence/public-api.lock @@ -78,10 +78,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_sequence::SequenceVTabl pub fn vortex_sequence::SequenceVTable::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> -impl vortex_array::compute::compare::CompareKernel for vortex_sequence::SequenceVTable - -pub fn vortex_sequence::SequenceVTable::compare(&self, lhs: &vortex_sequence::SequenceArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::is_sorted::IsSortedKernel for vortex_sequence::SequenceVTable pub fn vortex_sequence::SequenceVTable::is_sorted(&self, array: &vortex_sequence::SequenceArray) -> vortex_error::VortexResult> @@ -96,6 +92,10 @@ impl vortex_array::compute::min_max::MinMaxKernel for vortex_sequence::SequenceV pub fn vortex_sequence::SequenceVTable::min_max(&self, array: &vortex_sequence::SequenceArray) -> vortex_error::VortexResult> +impl vortex_array::expr::exprs::binary::compare::CompareKernel for vortex_sequence::SequenceVTable + +pub fn vortex_sequence::SequenceVTable::compare(lhs: &vortex_sequence::SequenceArray, rhs: &dyn vortex_array::array::Array, operator: vortex_array::compute::compare::Operator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::exprs::cast::kernel::CastReduce for vortex_sequence::SequenceVTable pub fn vortex_sequence::SequenceVTable::cast(array: &vortex_sequence::SequenceArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index fc652216a15..a18b3fd6f7d 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -3,11 +3,11 @@ use vortex_array::Array; use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; use vortex_array::arrays::BoolArray; use vortex_array::arrays::ConstantArray; -use vortex_array::compute::CompareKernel; use vortex_array::compute::Operator; -use vortex_array::validity::Validity; +use vortex_array::expr::CompareKernel; use vortex_buffer::BitBuffer; use vortex_dtype::NativePType; use vortex_dtype::Nullability; @@ -22,14 +22,15 @@ use crate::array::SequenceVTable; impl CompareKernel for SequenceVTable { fn compare( - &self, lhs: &SequenceArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { + // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space. if operator != Operator::Eq { return Ok(None); - }; + } let Some(constant) = rhs.as_constant() else { return Ok(None); @@ -43,13 +44,13 @@ impl CompareKernel for SequenceVTable { constant .as_primitive() .pvalue() - .vortex_expect("non-null constant"), + .vortex_expect("null constant handled in adaptor"), ); let nullability = lhs.dtype().nullability() | rhs.dtype().nullability(); let validity = match nullability { - Nullability::NonNullable => Validity::NonNullable, - Nullability::Nullable => Validity::AllValid, + Nullability::NonNullable => vortex_array::validity::Validity::NonNullable, + Nullability::Nullable => vortex_array::validity::Validity::AllValid, }; if let Some(set_idx) = set_idx { diff --git a/encodings/sequence/src/kernel.rs b/encodings/sequence/src/kernel.rs index 6dacc4d01ca..4217d3b1823 100644 --- a/encodings/sequence/src/kernel.rs +++ b/encodings/sequence/src/kernel.rs @@ -1,568 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::BoolArray; -use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::ConstantVTable; -use vortex_array::arrays::ExactScalarFn; use vortex_array::arrays::FilterExecuteAdaptor; -use vortex_array::arrays::ScalarFnArrayView; -use vortex_array::arrays::ScalarFnVTable; use vortex_array::arrays::TakeExecuteAdaptor; -use vortex_array::compute::Operator; -use vortex_array::expr::Binary; -use vortex_array::kernel::ExecuteParentKernel; +use vortex_array::expr::CompareExecuteAdaptor; use vortex_array::kernel::ParentKernelSet; -use vortex_buffer::bitbuffer; -use vortex_buffer::buffer; -use vortex_dtype::DType; -use vortex_dtype::NativePType; -use vortex_dtype::Nullability; -use vortex_dtype::match_each_integer_ptype; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_runend::RunEndArray; -use vortex_scalar::PValue; -use vortex_scalar::Scalar; -use crate::SequenceArray; use crate::SequenceVTable; -use crate::compute::compare::find_intersection_scalar; pub(crate) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ - ParentKernelSet::lift(&SequenceCompareKernel), + ParentKernelSet::lift(&CompareExecuteAdaptor(SequenceVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(SequenceVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(SequenceVTable)), ]); - -/// Kernel to execute comparison operations directly on a sequence array. -#[derive(Debug)] -struct SequenceCompareKernel; - -impl ExecuteParentKernel for SequenceCompareKernel { - type Parent = ExactScalarFn; - - fn execute_parent( - &self, - array: &SequenceArray, - parent: ScalarFnArrayView<'_, Binary>, - child_idx: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - // Only handle comparison operators - let Some(cmp_op) = parent.options.maybe_cmp_operator() else { - return Ok(None); - }; - - // Get the ScalarFnArray to access children - let Some(scalar_fn_array) = parent.as_opt::() else { - return Ok(None); - }; - let children = scalar_fn_array.children(); - - // Determine which operand is the constant and which is the sequence - let (cmp_op, constant) = match child_idx { - 0 => { - // sequence is lhs, check if rhs is constant - let rhs = &children[1]; - let Some(constant) = rhs.as_opt::() else { - return Ok(None); - }; - (cmp_op, constant) - } - 1 => { - // sequence is rhs, swap the operator and check if lhs is constant - let lhs = &children[0]; - let Some(constant) = lhs.as_opt::() else { - return Ok(None); - }; - // Swap the operator since we're reversing operand order - (cmp_op.swap(), constant) - } - _ => return Ok(None), - }; - - let constant_pvalue = constant.scalar().as_primitive().pvalue(); - let Some(constant_pvalue) = constant_pvalue else { - // Constant is null - result is all null for comparisons - let nullability = array.dtype().nullability() | constant.dtype().nullability(); - let result_array = - ConstantArray::new(Scalar::null(DType::Bool(nullability)), array.len).to_array(); - return Ok(Some(result_array)); - }; - - let nullability = array.dtype().nullability() | constant.dtype().nullability(); - - // For Eq and NotEq, use specialized logic - if cmp_op == Operator::Eq { - return compare_eq_neq(array, constant_pvalue, nullability, false, ctx); - } - if cmp_op == Operator::NotEq { - return compare_eq_neq(array, constant_pvalue, nullability, true, ctx); - } - - // For ordering comparisons, find the transition point - compare_ordering(array, constant_pvalue, cmp_op, nullability, ctx) - } -} - -/// Compare sequence to constant for equality/inequality. -/// When `negate` is false, returns true where sequence == constant. -/// When `negate` is true, returns true where sequence != constant. -fn compare_eq_neq( - array: &SequenceArray, - constant: PValue, - nullability: Nullability, - negate: bool, - _ctx: &mut ExecutionCtx, -) -> VortexResult> { - // For Eq: match_val=true, default_val=false - // For NotEq: match_val=false, default_val=true - let match_val = !negate; - let not_match_val = negate; - - // Check if there exists an integer solution to const = base + idx * multiplier - let Some(set_idx) = - find_intersection_scalar(array.base(), array.multiplier(), array.len, constant) - else { - return Ok(Some( - ConstantArray::new(Scalar::bool(not_match_val, nullability), array.len).into_array(), - )); - }; - let idx = set_idx as u64; - let len = array.len as u64; - - if len == 1 && set_idx == 0 { - let result_array = - ConstantArray::new(Scalar::bool(match_val, nullability), array.len).to_array(); - return Ok(Some(result_array)); - } - - let (ends, values) = if idx == 0 { - let ends = buffer![1u64, len].into_array(); - let values = - BoolArray::new(bitbuffer![match_val, not_match_val], nullability.into()).into_array(); - (ends, values) - } else if idx == len - 1 { - let ends = buffer![idx, len].into_array(); - let values = - BoolArray::new(bitbuffer![not_match_val, match_val], nullability.into()).into_array(); - (ends, values) - } else { - let ends = buffer![idx, idx + 1, len].into_array(); - let values = BoolArray::new( - bitbuffer![not_match_val, match_val, not_match_val], - nullability.into(), - ) - .into_array(); - (ends, values) - }; - Ok(Some(RunEndArray::try_new(ends, values)?.into_array())) -} - -fn compare_ordering( - array: &SequenceArray, - constant: PValue, - operator: Operator, - nullability: Nullability, - _ctx: &mut ExecutionCtx, -) -> VortexResult> { - let transition = find_transition_point( - array.base(), - array.multiplier(), - array.len, - constant, - operator, - ); - - let result_array = match transition { - Transition::AllTrue => { - ConstantArray::new(Scalar::bool(true, nullability), array.len).to_array() - } - Transition::AllFalse => { - ConstantArray::new(Scalar::bool(false, nullability), array.len).to_array() - } - Transition::FalseToTrue(idx) => { - // [0..idx) is false, [idx..len) is true - let ends = buffer![idx as u64, array.len as u64].into_array(); - let values = BoolArray::new(bitbuffer![false, true], nullability.into()).into_array(); - RunEndArray::try_new(ends, values)?.into_array() - } - Transition::TrueToFalse(idx) => { - // [0..idx) is true, [idx..len) is false - let ends = buffer![idx as u64, array.len as u64].into_array(); - let values = BoolArray::new(bitbuffer![true, false], nullability.into()).into_array(); - RunEndArray::try_new(ends, values)?.into_array() - } - }; - - Ok(Some(result_array)) -} - -enum Transition { - AllTrue, - AllFalse, - FalseToTrue(usize), - TrueToFalse(usize), -} - -fn find_transition_point( - base: PValue, - multiplier: PValue, - len: usize, - constant: PValue, - operator: Operator, -) -> Transition { - match_each_integer_ptype!(base.ptype(), |P| { - find_transition_point_typed::

( - base.cast::

(), - multiplier.cast::

(), - len, - constant.cast::

(), - operator, - ) - }) -} - -fn find_transition_point_typed( - base: P, - multiplier: P, - len: usize, - constant: P, - operator: Operator, -) -> Transition { - if len == 0 { - return Transition::AllFalse; - } - - let last_idx = P::from_usize(len - 1).vortex_expect("len must fit into type"); - let first_value = base; - let last_value = base + multiplier * last_idx; - - let first_result = eval_comparison(first_value, constant, operator); - let last_result = eval_comparison(last_value, constant, operator); - - if first_result && last_result { - return Transition::AllTrue; - } - if !first_result && !last_result { - return Transition::AllFalse; - } - - // There's a transition point - find it using binary search - let transition_idx = binary_search_transition(base, multiplier, len, constant, operator); - - if first_result { - Transition::TrueToFalse(transition_idx) - } else { - Transition::FalseToTrue(transition_idx) - } -} - -fn eval_comparison(lhs: P, rhs: P, operator: Operator) -> bool { - match operator { - Operator::Lt => lhs.is_lt(rhs), - Operator::Lte => lhs.is_le(rhs), - Operator::Gt => lhs.is_gt(rhs), - Operator::Gte => lhs.is_ge(rhs), - Operator::Eq => lhs.is_eq(rhs), - Operator::NotEq => !lhs.is_eq(rhs), - } -} - -fn binary_search_transition( - base: P, - multiplier: P, - len: usize, - constant: P, - operator: Operator, -) -> usize { - let first_result = eval_comparison(base, constant, operator); - - let mut lo = 0usize; - let mut hi = len; - - while lo < hi { - let mid = lo + (hi - lo) / 2; - let mid_p = P::from_usize(mid).vortex_expect("idx must fit into type"); - let value = base + multiplier * mid_p; - let result = eval_comparison(value, constant, operator); - - if result == first_result { - lo = mid + 1; - } else { - hi = mid; - } - } - - lo -} - -#[cfg(test)] -mod tests { - use vortex_array::ToCanonical; - use vortex_array::arrays::BoolArray; - use vortex_array::arrays::ConstantArray; - use vortex_array::arrays::ScalarFnArrayExt; - use vortex_array::assert_arrays_eq; - use vortex_array::expr::Binary; - use vortex_array::expr::Operator as ExprOperator; - use vortex_array::validity::Validity; - use vortex_buffer::BitBuffer; - use vortex_buffer::bitbuffer; - use vortex_dtype::DType; - use vortex_dtype::Nullability; - use vortex_dtype::Nullability::NonNullable; - use vortex_dtype::PType; - use vortex_error::VortexResult; - use vortex_scalar::Scalar; - - use crate::SequenceArray; - - #[test] - fn test_sequence_eq_neq_constant() -> VortexResult<()> { - let len = 1; - let seq = SequenceArray::typed_new(5i64, 1, NonNullable, len)?.to_array(); - let constant = ConstantArray::new(5i64, len).to_array(); - - let compare_array = - Binary.try_new_array(len, ExprOperator::NotEq, [seq.clone(), constant.clone()])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![false]); - assert_eq!(bool_result.to_bit_buffer(), expected); - - let compare_array = Binary.try_new_array(len, ExprOperator::Eq, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![true]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_gte_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new( - Scalar::try_new( - DType::Primitive(PType::I64, Nullability::Nullable), - Some(5i64.into()), - ) - .unwrap(), - 10, - ) - .to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Gte, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BoolArray::new( - bitbuffer![ - false, false, false, false, false, true, true, true, true, true, - ], - Validity::AllValid, - ); - assert_arrays_eq!(bool_result, expected); - Ok(()) - } - - #[test] - fn test_sequence_lt_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(5i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Lt, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![ - true, true, true, true, true, false, false, false, false, false, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_lte_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(5i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Lte, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // [0,1,2,3,4,5,6,7,8,9] <= 5 - let expected = BitBuffer::from(vec![ - true, true, true, true, true, true, false, false, false, false, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_gt_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(5i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Gt, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // [0,1,2,3,4,5,6,7,8,9] > 5 - let expected = BitBuffer::from(vec![ - false, false, false, false, false, false, true, true, true, true, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_constant_gte_sequence() -> VortexResult<()> { - // Test when constant is on the left side - let constant = ConstantArray::new(5i64, 10).to_array(); - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Gte, [constant, seq])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // 5 >= [0,1,2,3,4,5,6,7,8,9] - let expected = BitBuffer::from(vec![ - true, true, true, true, true, true, false, false, false, false, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_eq_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(5i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Eq, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![ - false, false, false, false, false, true, false, false, false, false, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_not_eq_constant() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(5i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::NotEq, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![ - true, true, true, true, true, false, true, true, true, true, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_all_true() -> VortexResult<()> { - let seq = SequenceArray::typed_new(10i64, 1, NonNullable, 5)?.to_array(); - let constant = ConstantArray::new(5i64, 5).to_array(); - - let compare_array = Binary.try_new_array(5, ExprOperator::Gt, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![true, true, true, true, true]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_all_false() -> VortexResult<()> { - let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 5)?.to_array(); - let constant = ConstantArray::new(100i64, 5).to_array(); - - let compare_array = Binary.try_new_array(5, ExprOperator::Gt, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - let expected = BitBuffer::from(vec![false, false, false, false, false]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_multiplier_2_gte() -> VortexResult<()> { - // Sequence: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] - let seq = SequenceArray::typed_new(0i64, 2, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(10i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Gte, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >= 10 - let expected = BitBuffer::from(vec![ - false, false, false, false, false, true, true, true, true, true, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_multiplier_3_eq() -> VortexResult<()> { - // Sequence: [5, 8, 11, 14, 17, 20, 23, 26] - let seq = SequenceArray::typed_new(5i64, 3, NonNullable, 8)?.to_array(); - let constant = ConstantArray::new(14i64, 8).to_array(); - - let compare_array = Binary.try_new_array(8, ExprOperator::Eq, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // 14 is at index 3: (14 - 5) / 3 = 3 - let expected = BitBuffer::from(vec![false, false, false, true, false, false, false, false]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } - - #[test] - fn test_sequence_negative_multiplier_lt() -> VortexResult<()> { - // Sequence: [100, 90, 80, 70, 60, 50, 40, 30, 20, 10] - let seq = SequenceArray::typed_new(100i64, -10, NonNullable, 10)?.to_array(); - let constant = ConstantArray::new(50i64, 10).to_array(); - - let compare_array = Binary.try_new_array(10, ExprOperator::Lt, [seq, constant])?; - - let result = compare_array; - let bool_result = result.to_bool(); - - // [100, 90, 80, 70, 60, 50, 40, 30, 20, 10] < 50 - let expected = BitBuffer::from(vec![ - false, false, false, false, false, false, true, true, true, true, - ]); - assert_eq!(bool_result.to_bit_buffer(), expected); - Ok(()) - } -} diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs index b6830ebf9e2..7c20cff6f39 100644 --- a/vortex-array/benches/compare.rs +++ b/vortex-array/benches/compare.rs @@ -8,11 +8,14 @@ use rand::Rng; use rand::SeedableRng; use rand::distr::Uniform; use rand::prelude::StdRng; +use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; use vortex_array::compute::Operator; use vortex_array::compute::compare; use vortex_buffer::Buffer; +use vortex_session::VortexSession; fn main() { divan::main(); @@ -27,10 +30,15 @@ fn compare_bool(bencher: Bencher) { let arr1 = BoolArray::from_iter((0..ARRAY_SIZE).map(|_| rng.sample(range) == 0)).into_array(); let arr2 = BoolArray::from_iter((0..ARRAY_SIZE).map(|_| rng.sample(range) == 0)).into_array(); + let session = VortexSession::empty(); bencher - .with_inputs(|| (&arr1, &arr2)) - .bench_refs(|(arr1, arr2)| compare(*arr1, *arr2, Operator::Gte).unwrap()); + .with_inputs(|| (&arr1, &arr2, session.create_execution_ctx())) + .bench_refs(|input| { + compare(input.0, input.1, Operator::Gte) + .unwrap() + .execute::(&mut input.2) + }); } #[divan::bench] @@ -47,8 +55,13 @@ fn compare_int(bencher: Bencher) { .map(|_| rng.sample(range)) .collect::>() .into_array(); + let session = VortexSession::empty(); bencher - .with_inputs(|| (&arr1, &arr2)) - .bench_refs(|(arr1, arr2)| compare(*arr1, *arr2, Operator::Gte).unwrap()); + .with_inputs(|| (&arr1, &arr2, session.create_execution_ctx())) + .bench_refs(|input| { + compare(input.0, input.1, Operator::Gte) + .unwrap() + .execute::(&mut input.2) + }); } diff --git a/vortex-array/benches/dict_compare.rs b/vortex-array/benches/dict_compare.rs index d4412967ea6..e5bdcede59c 100644 --- a/vortex-array/benches/dict_compare.rs +++ b/vortex-array/benches/dict_compare.rs @@ -6,6 +6,7 @@ use std::str::from_utf8; use vortex_array::Canonical; +use vortex_array::RecursiveCanonical; use vortex_array::VortexSessionExecute; use vortex_array::accessor::ArrayAccessor; use vortex_array::arrays::ConstantArray; @@ -50,15 +51,20 @@ fn bench_compare_primitive(bencher: divan::Bencher, (len, uniqueness): (usize, u let primitive_arr = gen_primitive_for_dict::(len, uniqueness); let dict = dict_encode(primitive_arr.as_ref()).unwrap(); let value = primitive_arr.as_slice::()[0]; + let session = VortexSession::empty(); - bencher.with_inputs(|| &dict).bench_refs(|dict| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - }) + bencher + .with_inputs(|| (&dict, session.create_execution_ctx())) + .bench_refs(|(dict, ctx)| { + compare( + dict.as_ref(), + ConstantArray::new(value, len).as_ref(), + Operator::Eq, + ) + .unwrap() + .execute::(ctx) + .unwrap() + }) } #[divan::bench(args = LENGTH_AND_UNIQUE_VALUES)] @@ -67,15 +73,20 @@ fn bench_compare_varbin(bencher: divan::Bencher, (len, uniqueness): (usize, usiz let dict = dict_encode(varbin_arr.as_ref()).unwrap(); let bytes = varbin_arr.with_iterator(|i| i.next().unwrap().unwrap().to_vec()); let value = from_utf8(bytes.as_slice()).unwrap(); + let session = VortexSession::empty(); - bencher.with_inputs(|| &dict).bench_refs(|dict| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - }) + bencher + .with_inputs(|| (&dict, session.create_execution_ctx())) + .bench_refs(|(dict, ctx)| { + compare( + dict.as_ref(), + ConstantArray::new(value, len).as_ref(), + Operator::Eq, + ) + .unwrap() + .execute::(ctx) + .unwrap() + }) } #[divan::bench(args = LENGTH_AND_UNIQUE_VALUES)] @@ -84,14 +95,20 @@ fn bench_compare_varbinview(bencher: divan::Bencher, (len, uniqueness): (usize, let dict = dict_encode(varbinview_arr.as_ref()).unwrap(); let bytes = varbinview_arr.with_iterator(|i| i.next().unwrap().unwrap().to_vec()); let value = from_utf8(bytes.as_slice()).unwrap(); - bencher.with_inputs(|| &dict).bench_refs(|dict| { - compare( - dict.as_ref(), - ConstantArray::new(value, len).as_ref(), - Operator::Eq, - ) - .unwrap() - }) + let session = VortexSession::empty(); + + bencher + .with_inputs(|| (&dict, session.create_execution_ctx())) + .bench_refs(|(dict, ctx)| { + compare( + dict.as_ref(), + ConstantArray::new(value, len).as_ref(), + Operator::Eq, + ) + .unwrap() + .execute::(ctx) + .unwrap() + }) } const CODES_AND_VALUES_LENGTHS: &[(usize, usize)] = &[ @@ -117,13 +134,14 @@ fn bench_compare_sliced_dict_primitive( let value = primitive_arr.as_slice::()[0]; let session = VortexSession::empty(); - bencher.with_inputs(|| &dict).bench_refs(|dict| { - let mut ctx = session.create_execution_ctx(); - dict.apply(&eq(root(), lit(value))) - .unwrap() - .execute::(&mut ctx) - .unwrap() - }) + bencher + .with_inputs(|| (&dict, session.create_execution_ctx())) + .bench_refs(|(dict, ctx)| { + dict.apply(&eq(root(), lit(value))) + .unwrap() + .execute::(ctx) + .unwrap() + }) } #[divan::bench(args = CODES_AND_VALUES_LENGTHS)] @@ -138,11 +156,12 @@ fn bench_compare_sliced_dict_varbinview( let value = from_utf8(bytes.as_slice()).unwrap(); let session = VortexSession::empty(); - bencher.with_inputs(|| &dict).bench_refs(|dict| { - let mut ctx = session.create_execution_ctx(); - dict.apply(&eq(root(), lit(value))) - .unwrap() - .execute::(&mut ctx) - .unwrap() - }) + bencher + .with_inputs(|| (&dict, session.create_execution_ctx())) + .bench_refs(|(dict, ctx)| { + dict.apply(&eq(root(), lit(value))) + .unwrap() + .execute::(ctx) + .unwrap() + }) } diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 4d897dad959..950242c2e61 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -144,10 +144,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::take(array: &vortex_array::arrays::DictArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::DictVTable - -pub fn vortex_array::arrays::DictVTable::compare(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::is_constant(&self, array: &vortex_array::arrays::DictArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -166,6 +162,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::cast(array: &vortex_array::arrays::DictArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::fill_null(array: &vortex_array::arrays::DictArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -620,10 +620,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::take(array: &vortex_array::arrays::ChunkedArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::compare(&self, lhs: &vortex_array::arrays::ChunkedArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::ChunkedVTable pub fn vortex_array::arrays::ChunkedVTable::is_constant(&self, array: &vortex_array::arrays::ChunkedArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -790,10 +786,6 @@ impl vortex_array::arrays::TakeReduce for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::take(array: &vortex_array::arrays::ConstantArray, indices: &dyn vortex_array::Array) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::compare(&self, lhs: &vortex_array::arrays::ConstantArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::MaskKernel for vortex_array::arrays::ConstantVTable pub fn vortex_array::arrays::ConstantVTable::mask(&self, array: &vortex_array::arrays::ConstantArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult @@ -1206,10 +1198,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::take(array: &vortex_array::arrays::DictArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::DictVTable - -pub fn vortex_array::arrays::DictVTable::compare(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::is_constant(&self, array: &vortex_array::arrays::DictArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -1228,6 +1216,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::cast(array: &vortex_array::arrays::DictArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::expr::FillNullKernel for vortex_array::arrays::DictVTable pub fn vortex_array::arrays::DictVTable::fill_null(array: &vortex_array::arrays::DictArray, fill_value: &vortex_scalar::scalar::Scalar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> @@ -1396,10 +1388,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::take(array: &vortex_array::arrays::ExtensionArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::compare(&self, lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::is_constant(&self, array: &vortex_array::arrays::ExtensionArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -1426,6 +1414,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::cast(array: &vortex_array::arrays::ExtensionArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::CompareKernel for vortex_array::arrays::ExtensionVTable + +pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::ExtensionVTable pub fn vortex_array::arrays::ExtensionVTable::array_eq(array: &vortex_array::arrays::ExtensionArray, other: &vortex_array::arrays::ExtensionArray, precision: vortex_array::Precision) -> bool @@ -2236,10 +2228,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::MaskedVTable pub fn vortex_array::arrays::MaskedVTable::take(array: &vortex_array::arrays::MaskedArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::MaskedVTable - -pub fn vortex_array::arrays::MaskedVTable::compare(&self, lhs: &vortex_array::arrays::MaskedArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::MaskKernel for vortex_array::arrays::MaskedVTable pub fn vortex_array::arrays::MaskedVTable::mask(&self, array: &vortex_array::arrays::MaskedArray, mask_arg: &vortex_mask::Mask) -> vortex_error::VortexResult @@ -3616,10 +3604,6 @@ impl vortex_array::arrays::TakeExecute for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::take(array: &vortex_array::arrays::VarBinArray, indices: &dyn vortex_array::Array, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> -impl vortex_array::compute::CompareKernel for vortex_array::arrays::VarBinVTable - -pub fn vortex_array::arrays::VarBinVTable::compare(&self, lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - impl vortex_array::compute::IsConstantKernel for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::is_constant(&self, array: &vortex_array::arrays::VarBinArray, opts: &vortex_array::compute::IsConstantOpts) -> vortex_error::VortexResult> @@ -3642,6 +3626,10 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::cast(array: &vortex_array::arrays::VarBinArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +impl vortex_array::expr::CompareKernel for vortex_array::arrays::VarBinVTable + +pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::vtable::BaseArrayVTable for vortex_array::arrays::VarBinVTable pub fn vortex_array::arrays::VarBinVTable::array_eq(array: &vortex_array::arrays::VarBinArray, other: &vortex_array::arrays::VarBinArray, precision: vortex_array::Precision) -> bool @@ -5646,10 +5634,6 @@ impl core::marker::Copy for vortex_array::compute::Operator impl core::marker::StructuralPartialEq for vortex_array::compute::Operator -impl vortex_array::compute::Options for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::as_any(&self) -> &dyn core::any::Any - pub enum vortex_array::compute::Output pub vortex_array::compute::Output::Array(vortex_array::ArrayRef) @@ -5810,24 +5794,6 @@ pub type vortex_array::expr::CastReduceAdaptor::Parent = vortex_array::arrays pub fn vortex_array::expr::CastReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Cast>, _child_idx: usize) -> vortex_error::VortexResult> -pub struct vortex_array::compute::CompareKernelAdapter(pub V) - -impl vortex_array::compute::CompareKernelAdapter - -pub const fn vortex_array::compute::CompareKernelAdapter::lift(&'static self) -> vortex_array::compute::CompareKernelRef - -impl core::fmt::Debug for vortex_array::compute::CompareKernelAdapter - -pub fn vortex_array::compute::CompareKernelAdapter::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result - -impl vortex_array::compute::Kernel for vortex_array::compute::CompareKernelAdapter - -pub fn vortex_array::compute::CompareKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - -pub struct vortex_array::compute::CompareKernelRef(_) - -impl inventory::Collect for vortex_array::compute::CompareKernelRef - pub struct vortex_array::compute::ComputeFn impl vortex_array::compute::ComputeFn @@ -6240,34 +6206,6 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::VarBinViewVTable pub fn vortex_array::arrays::VarBinViewVTable::cast(array: &vortex_array::arrays::VarBinViewArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> -pub trait vortex_array::compute::CompareKernel: vortex_array::vtable::VTable - -pub fn vortex_array::compute::CompareKernel::compare(&self, lhs: &Self::Array, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ChunkedVTable - -pub fn vortex_array::arrays::ChunkedVTable::compare(&self, lhs: &vortex_array::arrays::ChunkedArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ConstantVTable - -pub fn vortex_array::arrays::ConstantVTable::compare(&self, lhs: &vortex_array::arrays::ConstantArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::DictVTable - -pub fn vortex_array::arrays::DictVTable::compare(&self, lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::ExtensionVTable - -pub fn vortex_array::arrays::ExtensionVTable::compare(&self, lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::MaskedVTable - -pub fn vortex_array::arrays::MaskedVTable::compare(&self, lhs: &vortex_array::arrays::MaskedArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - -impl vortex_array::compute::CompareKernel for vortex_array::arrays::VarBinVTable - -pub fn vortex_array::arrays::VarBinVTable::compare(&self, lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator) -> vortex_error::VortexResult> - pub trait vortex_array::compute::ComputeFnVTable: 'static + core::marker::Send + core::marker::Sync pub fn vortex_array::compute::ComputeFnVTable::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>, kernels: &[arcref::ArcRef]) -> vortex_error::VortexResult @@ -6468,10 +6406,6 @@ impl vor pub fn vortex_array::compute::BetweenKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> -impl vortex_array::compute::Kernel for vortex_array::compute::CompareKernelAdapter - -pub fn vortex_array::compute::CompareKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> - impl vortex_array::compute::Kernel for vortex_array::compute::IsConstantKernelAdapter pub fn vortex_array::compute::IsConstantKernelAdapter::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult> @@ -6664,10 +6598,6 @@ impl vortex_array::compute::Options for vortex_array::compute::IsConstantOpts pub fn vortex_array::compute::IsConstantOpts::as_any(&self) -> &dyn core::any::Any -impl vortex_array::compute::Options for vortex_array::compute::Operator - -pub fn vortex_array::compute::Operator::as_any(&self) -> &dyn core::any::Any - impl vortex_array::compute::Options for vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator pub fn vortex_scalar::typed_view::primitive::numeric_operator::NumericOperator::as_any(&self) -> &dyn core::any::Any @@ -7936,6 +7866,8 @@ pub fn vortex_array::expr::Operator::inverse(self) -> core::option::Option pub fn vortex_array::expr::Operator::is_arithmetic(&self) -> bool +pub fn vortex_array::expr::Operator::is_comparison(&self) -> bool + pub fn vortex_array::expr::Operator::logical_inverse(self) -> core::option::Option pub fn vortex_array::expr::Operator::maybe_cmp_operator(self) -> core::option::Option @@ -8126,6 +8058,22 @@ pub type vortex_array::expr::CastReduceAdaptor::Parent = vortex_array::arrays pub fn vortex_array::expr::CastReduceAdaptor::reduce_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Cast>, _child_idx: usize) -> vortex_error::VortexResult> +pub struct vortex_array::expr::CompareExecuteAdaptor(pub V) + +impl core::default::Default for vortex_array::expr::CompareExecuteAdaptor + +pub fn vortex_array::expr::CompareExecuteAdaptor::default() -> vortex_array::expr::CompareExecuteAdaptor + +impl core::fmt::Debug for vortex_array::expr::CompareExecuteAdaptor + +pub fn vortex_array::expr::CompareExecuteAdaptor::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::CompareExecuteAdaptor where V: vortex_array::expr::CompareKernel + +pub type vortex_array::expr::CompareExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::CompareExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Binary>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub struct vortex_array::expr::DynamicComparison impl vortex_array::expr::VTable for vortex_array::expr::DynamicComparison @@ -9136,6 +9084,22 @@ impl vortex_array::expr::CastReduce for vortex_array::arrays::VarBinViewVTable pub fn vortex_array::arrays::VarBinViewVTable::cast(array: &vortex_array::arrays::VarBinViewArray, dtype: &vortex_dtype::dtype::DType) -> vortex_error::VortexResult> +pub trait vortex_array::expr::CompareKernel: vortex_array::vtable::VTable + +pub fn vortex_array::expr::CompareKernel::compare(lhs: &Self::Array, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::expr::CompareKernel for vortex_array::arrays::DictVTable + +pub fn vortex_array::arrays::DictVTable::compare(lhs: &vortex_array::arrays::DictArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::expr::CompareKernel for vortex_array::arrays::ExtensionVTable + +pub fn vortex_array::arrays::ExtensionVTable::compare(lhs: &vortex_array::arrays::ExtensionArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::expr::CompareKernel for vortex_array::arrays::VarBinVTable + +pub fn vortex_array::arrays::VarBinVTable::compare(lhs: &vortex_array::arrays::VarBinArray, rhs: &dyn vortex_array::Array, operator: vortex_array::compute::Operator, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub trait vortex_array::expr::DynExprVTable: 'static + core::marker::Send + core::marker::Sync + vortex_array::expr::vtable::private::Sealed pub fn vortex_array::expr::DynExprVTable::arity(&self, options: &dyn core::any::Any) -> vortex_array::expr::Arity @@ -10072,6 +10036,12 @@ pub type vortex_array::expr::CastExecuteAdaptor::Parent = vortex_array::array pub fn vortex_array::expr::CastExecuteAdaptor::execute_parent(&self, array: &::Array, parent: ::Match, _child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::CompareExecuteAdaptor where V: vortex_array::expr::CompareKernel + +pub type vortex_array::expr::CompareExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn + +pub fn vortex_array::expr::CompareExecuteAdaptor::execute_parent(&self, array: &::Array, parent: vortex_array::arrays::ScalarFnArrayView<'_, vortex_array::expr::Binary>, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + impl vortex_array::kernel::ExecuteParentKernel for vortex_array::expr::FillNullExecuteAdaptor where V: vortex_array::expr::FillNullKernel pub type vortex_array::expr::FillNullExecuteAdaptor::Parent = vortex_array::arrays::ExactScalarFn @@ -10966,7 +10936,7 @@ pub fn vortex_array::validity::Validity::all_invalid(&self, len: usize) -> vorte pub fn vortex_array::validity::Validity::all_valid(&self, len: usize) -> vortex_error::VortexResult -pub fn vortex_array::validity::Validity::and(self, rhs: vortex_array::validity::Validity) -> vortex_array::validity::Validity +pub fn vortex_array::validity::Validity::and(self, rhs: vortex_array::validity::Validity) -> vortex_error::VortexResult pub fn vortex_array::validity::Validity::as_array(&self) -> core::option::Option<&vortex_array::ArrayRef> diff --git a/vortex-array/src/arrays/bool/compute/rules.rs b/vortex-array/src/arrays/bool/compute/rules.rs index d885804f89c..ab3f5ef50b9 100644 --- a/vortex-array/src/arrays/bool/compute/rules.rs +++ b/vortex-array/src/arrays/bool/compute/rules.rs @@ -46,7 +46,7 @@ impl ArrayParentReduceRule for BoolMaskedValidityRule { Ok(Some( BoolArray::new( array.to_bit_buffer(), - array.validity().clone().and(parent.validity().clone()), + array.validity().clone().and(parent.validity().clone())?, ) .into_array(), )) diff --git a/vortex-array/src/arrays/chunked/compute/compare.rs b/vortex-array/src/arrays/chunked/compute/compare.rs deleted file mode 100644 index d005527e5cb..00000000000 --- a/vortex-array/src/arrays/chunked/compute/compare.rs +++ /dev/null @@ -1,65 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::Array; -use crate::ArrayRef; -use crate::arrays::ChunkedArray; -use crate::arrays::ChunkedVTable; -use crate::builders::ArrayBuilder; -use crate::builders::BoolBuilder; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; -use crate::compute::Operator; -use crate::compute::compare; -use crate::register_kernel; - -impl CompareKernel for ChunkedVTable { - fn compare( - &self, - lhs: &ChunkedArray, - rhs: &dyn Array, - operator: Operator, - ) -> VortexResult> { - let mut idx = 0; - - let mut bool_builder = BoolBuilder::with_capacity( - // nullable <= non-nullable - (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(), - lhs.len(), - ); - - for chunk in lhs.non_empty_chunks() { - let sliced = rhs.slice(idx..idx + chunk.len())?; - let cmp_result = compare(chunk, &sliced, operator)?; - - bool_builder.extend_from_array(&cmp_result); - idx += chunk.len(); - } - - Ok(Some(bool_builder.finish())) - } -} - -register_kernel!(CompareKernelAdapter(ChunkedVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_buffer::Buffer; - - use super::*; - use crate::IntoArray; - - #[test] - fn empty_compare() { - let base = Buffer::::empty().into_array(); - let chunked = - ChunkedArray::try_new(vec![base.clone(), base.clone()], base.dtype().clone()).unwrap(); - let chunked_empty = ChunkedArray::try_new(vec![], base.dtype().clone()).unwrap(); - - let r = compare(chunked.as_ref(), chunked_empty.as_ref(), Operator::Eq).unwrap(); - - assert!(r.is_empty()); - } -} diff --git a/vortex-array/src/arrays/chunked/compute/mod.rs b/vortex-array/src/arrays/chunked/compute/mod.rs index 4212e0a2025..f4e954c88cc 100644 --- a/vortex-array/src/arrays/chunked/compute/mod.rs +++ b/vortex-array/src/arrays/chunked/compute/mod.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod cast; -mod compare; mod fill_null; mod filter; mod is_constant; diff --git a/vortex-array/src/arrays/constant/compute/compare.rs b/vortex-array/src/arrays/constant/compute/compare.rs deleted file mode 100644 index 078fb27ef75..00000000000 --- a/vortex-array/src/arrays/constant/compute/compare.rs +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; -use crate::compute::Operator; -use crate::compute::scalar_cmp; -use crate::register_kernel; - -impl CompareKernel for ConstantVTable { - fn compare( - &self, - lhs: &ConstantArray, - rhs: &dyn Array, - operator: Operator, - ) -> VortexResult> { - // We only support comparing a constant array to another constant array. - // For all other encodings, we assume the constant is on the RHS. - if let Some(const_scalar) = rhs.as_constant() { - let lhs_scalar = lhs.scalar(); - let scalar = scalar_cmp(lhs_scalar, &const_scalar, operator); - return Ok(Some(ConstantArray::new(scalar, lhs.len()).into_array())); - } - - Ok(None) - } -} - -register_kernel!(CompareKernelAdapter(ConstantVTable).lift()); diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 96e21b0efd1..2384df41a88 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod cast; -mod compare; mod fill_null; mod filter; mod mask; diff --git a/vortex-array/src/arrays/decimal/compute/rules.rs b/vortex-array/src/arrays/decimal/compute/rules.rs index f60e2f7519c..70130bf3e16 100644 --- a/vortex-array/src/arrays/decimal/compute/rules.rs +++ b/vortex-array/src/arrays/decimal/compute/rules.rs @@ -48,7 +48,7 @@ impl ArrayParentReduceRule for DecimalMaskedValidityRule { DecimalArray::new_unchecked( array.buffer::(), array.decimal_dtype(), - array.validity().clone().and(parent.validity().clone()), + array.validity().clone().and(parent.validity().clone())?, ) } .into_array() diff --git a/vortex-array/src/arrays/dict/compute/compare.rs b/vortex-array/src/arrays/dict/compute/compare.rs index ed5b3dd45f7..2fbba2c29e1 100644 --- a/vortex-array/src/arrays/dict/compute/compare.rs +++ b/vortex-array/src/arrays/dict/compute/compare.rs @@ -7,20 +7,19 @@ use super::DictArray; use super::DictVTable; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::ConstantArray; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; use crate::compute::Operator; use crate::compute::compare; -use crate::register_kernel; +use crate::expr::CompareKernel; impl CompareKernel for DictVTable { fn compare( - &self, lhs: &DictArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { // if we have more values than codes, it is faster to canonicalise first. if lhs.values().len() > lhs.codes().len() { @@ -51,108 +50,3 @@ impl CompareKernel for DictVTable { Ok(None) } } - -register_kernel!(CompareKernelAdapter(DictVTable).lift()); -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_dtype::Nullability; - use vortex_mask::Mask; - use vortex_scalar::Scalar; - - use crate::IntoArray; - use crate::arrays::BoolArray; - use crate::arrays::ConstantArray; - use crate::arrays::PrimitiveArray; - use crate::arrays::dict::DictArray; - use crate::assert_arrays_eq; - use crate::compute::Operator; - use crate::compute::compare; - use crate::validity::Validity; - - #[test] - fn test_compare_value() { - let dict = DictArray::try_new( - buffer![0u32, 1, 2].into_array(), - buffer![1i32, 2, 3].into_array(), - ) - .unwrap(); - - let res = compare( - dict.as_ref(), - ConstantArray::new(Scalar::from(1i32), 3).as_ref(), - Operator::Eq, - ) - .unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([true, false, false])); - } - - #[test] - fn test_compare_non_eq() { - let dict = DictArray::try_new( - buffer![0u32, 1, 2].into_array(), - buffer![1i32, 2, 3].into_array(), - ) - .unwrap(); - - let res = compare( - dict.as_ref(), - ConstantArray::new(Scalar::from(1i32), 3).as_ref(), - Operator::Gt, - ) - .unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([false, true, true])); - } - - #[test] - fn test_compare_nullable() { - let dict = DictArray::try_new( - PrimitiveArray::new( - buffer![0u32, 1, 2], - Validity::from_iter([false, true, false]), - ) - .into_array(), - PrimitiveArray::new(buffer![1i32, 2, 3], Validity::AllValid).into_array(), - ) - .unwrap(); - - let res = compare( - dict.as_ref(), - ConstantArray::new(Scalar::primitive(4i32, Nullability::Nullable), 3).as_ref(), - Operator::Eq, - ) - .unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([None, Some(false), None])); - assert_eq!(res.dtype().nullability(), Nullability::Nullable); - assert_eq!( - res.validity_mask().unwrap(), - Mask::from_iter([false, true, false]) - ); - } - - #[test] - fn test_compare_null_values() { - let dict = DictArray::try_new( - buffer![0u32, 1, 2].into_array(), - PrimitiveArray::new( - buffer![1i32, 2, 0], - Validity::from_iter([true, true, false]), - ) - .into_array(), - ) - .unwrap(); - - let res = compare( - dict.as_ref(), - ConstantArray::new(Scalar::primitive(4i32, Nullability::NonNullable), 3).as_ref(), - Operator::Eq, - ) - .unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([Some(false), Some(false), None])); - assert_eq!(res.dtype().nullability(), Nullability::Nullable); - assert_eq!( - res.validity_mask().unwrap(), - Mask::from_iter([true, true, false]) - ); - } -} diff --git a/vortex-array/src/arrays/dict/vtable/kernel.rs b/vortex-array/src/arrays/dict/vtable/kernel.rs index 32a9c328532..5ab633d0b21 100644 --- a/vortex-array/src/arrays/dict/vtable/kernel.rs +++ b/vortex-array/src/arrays/dict/vtable/kernel.rs @@ -3,10 +3,12 @@ use crate::arrays::DictVTable; use crate::arrays::TakeExecuteAdaptor; +use crate::expr::CompareExecuteAdaptor; use crate::expr::FillNullExecuteAdaptor; use crate::kernel::ParentKernelSet; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(DictVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(DictVTable)), ParentKernelSet::lift(&FillNullExecuteAdaptor(DictVTable)), ]); diff --git a/vortex-array/src/arrays/extension/compute/compare.rs b/vortex-array/src/arrays/extension/compute/compare.rs index fd3425937e0..aac0226bae2 100644 --- a/vortex-array/src/arrays/extension/compute/compare.rs +++ b/vortex-array/src/arrays/extension/compute/compare.rs @@ -5,21 +5,20 @@ use vortex_error::VortexResult; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::arrays::ConstantArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; +use crate::compute; use crate::compute::Operator; -use crate::compute::{self}; -use crate::register_kernel; +use crate::expr::CompareKernel; impl CompareKernel for ExtensionVTable { fn compare( - &self, lhs: &ExtensionArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { // If the RHS is a constant, we can extract the storage scalar. if let Some(const_ext) = rhs.as_constant() { @@ -41,5 +40,3 @@ impl CompareKernel for ExtensionVTable { Ok(None) } } - -register_kernel!(CompareKernelAdapter(ExtensionVTable).lift()); diff --git a/vortex-array/src/arrays/extension/vtable/kernel.rs b/vortex-array/src/arrays/extension/vtable/kernel.rs index 41f4f8a0f9a..a5c07cbf2f1 100644 --- a/vortex-array/src/arrays/extension/vtable/kernel.rs +++ b/vortex-array/src/arrays/extension/vtable/kernel.rs @@ -3,7 +3,10 @@ use crate::arrays::ExtensionVTable; use crate::arrays::TakeExecuteAdaptor; +use crate::expr::CompareExecuteAdaptor; use crate::kernel::ParentKernelSet; -pub(super) const PARENT_KERNELS: ParentKernelSet = - ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(ExtensionVTable))]); +pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(ExtensionVTable)), + ParentKernelSet::lift(&TakeExecuteAdaptor(ExtensionVTable)), +]); diff --git a/vortex-array/src/arrays/masked/compute/compare.rs b/vortex-array/src/arrays/masked/compute/compare.rs deleted file mode 100644 index 0a9f86aa245..00000000000 --- a/vortex-array/src/arrays/masked/compute/compare.rs +++ /dev/null @@ -1,143 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexResult; - -use crate::Array; -use crate::ArrayRef; -use crate::IntoArray; -use crate::arrays::BoolArray; -use crate::arrays::MaskedArray; -use crate::arrays::MaskedVTable; -use crate::canonical::ToCanonical; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; -use crate::compute::Operator; -use crate::compute::compare; -use crate::register_kernel; -use crate::vtable::ValidityHelper; - -impl CompareKernel for MaskedVTable { - fn compare( - &self, - lhs: &MaskedArray, - rhs: &dyn Array, - operator: Operator, - ) -> VortexResult> { - // Compare the child arrays - let compare_result = compare(&lhs.child, rhs, operator)?; - - // Get the boolean buffer from the comparison result - let bool_array = compare_result.to_bool(); - let combined_validity = bool_array.validity().clone().and(lhs.validity().clone()); - - // Return a plain BoolArray with the combined validity - Ok(Some( - BoolArray::new(bool_array.to_bit_buffer(), combined_validity).into_array(), - )) - } -} - -register_kernel!(CompareKernelAdapter(MaskedVTable).lift()); - -#[cfg(test)] -mod tests { - use vortex_dtype::Nullability; - use vortex_mask::Mask; - use vortex_scalar::Scalar; - - use crate::IntoArray; - use crate::arrays::BoolArray; - use crate::arrays::ConstantArray; - use crate::arrays::MaskedArray; - use crate::arrays::PrimitiveArray; - use crate::assert_arrays_eq; - use crate::compute::Operator; - use crate::compute::compare; - use crate::validity::Validity; - - #[test] - fn test_compare_value() { - let masked = MaskedArray::try_new( - PrimitiveArray::from_iter([1i32, 2, 3]).into_array(), - Validity::AllValid, - ) - .unwrap(); - - let res = compare( - masked.as_ref(), - ConstantArray::new(Scalar::from(2i32), 3).as_ref(), - Operator::Eq, - ) - .unwrap(); - assert_arrays_eq!( - res, - BoolArray::from_iter([Some(false), Some(true), Some(false)]) - ); - } - - #[test] - fn test_compare_non_eq() { - let masked = MaskedArray::try_new( - PrimitiveArray::from_iter([1i32, 2, 3]).into_array(), - Validity::AllValid, - ) - .unwrap(); - - let res = compare( - masked.as_ref(), - ConstantArray::new(Scalar::from(2i32), 3).as_ref(), - Operator::Gt, - ) - .unwrap(); - assert_arrays_eq!( - res, - BoolArray::from_iter([Some(false), Some(false), Some(true)]) - ); - } - - #[test] - fn test_compare_nullable() { - // MaskedArray with nulls - let masked = MaskedArray::try_new( - PrimitiveArray::from_iter([1i32, 2, 3]).into_array(), - Validity::from_iter([false, true, false]), - ) - .unwrap(); - - let res = compare( - masked.as_ref(), - ConstantArray::new(Scalar::primitive(2i32, Nullability::Nullable), 3).as_ref(), - Operator::Eq, - ) - .unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([None, Some(true), None])); - assert_eq!(res.dtype().nullability(), Nullability::Nullable); - assert_eq!( - res.validity_mask().unwrap(), - Mask::from_iter([false, true, false]) - ); - } - - #[test] - fn test_compare_with_null_rhs() { - // MaskedArray with some nulls - let masked = MaskedArray::try_new( - PrimitiveArray::from_iter([1i32, 2, 3]).into_array(), - Validity::from_iter([true, true, false]), - ) - .unwrap(); - - // RHS has a null value - let rhs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]); - - let res = compare(masked.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); - assert_arrays_eq!(res, BoolArray::from_iter([Some(true), None, None])); - assert_eq!(res.dtype().nullability(), Nullability::Nullable); - // Validity is union of both: lhs=[T,T,F], rhs=[T,F,T] => result=[T,F,F] - assert_eq!( - res.validity_mask().unwrap(), - Mask::from_iter([true, false, false]) - ); - } -} diff --git a/vortex-array/src/arrays/masked/compute/mod.rs b/vortex-array/src/arrays/masked/compute/mod.rs index 86840ca4e2a..404ced01ea3 100644 --- a/vortex-array/src/arrays/masked/compute/mod.rs +++ b/vortex-array/src/arrays/masked/compute/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod compare; mod filter; mod mask; pub(crate) mod rules; diff --git a/vortex-array/src/arrays/masked/vtable/kernel.rs b/vortex-array/src/arrays/masked/vtable/kernel.rs index 869e8601aba..cacfbd98cd4 100644 --- a/vortex-array/src/arrays/masked/vtable/kernel.rs +++ b/vortex-array/src/arrays/masked/vtable/kernel.rs @@ -5,5 +5,6 @@ use crate::arrays::MaskedVTable; use crate::arrays::TakeExecuteAdaptor; use crate::kernel::ParentKernelSet; +// TODO(joe): add CompareExecuteAdaptor to push comparisons through the mask without canonicalizing. pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(MaskedVTable))]); diff --git a/vortex-array/src/arrays/primitive/compute/rules.rs b/vortex-array/src/arrays/primitive/compute/rules.rs index e2bcc53327f..9bf20fd5cf6 100644 --- a/vortex-array/src/arrays/primitive/compute/rules.rs +++ b/vortex-array/src/arrays/primitive/compute/rules.rs @@ -44,7 +44,7 @@ impl ArrayParentReduceRule for PrimitiveMaskedValidityRule { PrimitiveArray::new_unchecked_from_handle( array.buffer_handle().clone(), array.ptype(), - array.validity().clone().and(parent.validity().clone()), + array.validity().clone().and(parent.validity().clone())?, ) } .into_array() diff --git a/vortex-array/src/arrays/varbin/compute/compare.rs b/vortex-array/src/arrays/varbin/compute/compare.rs index 7dd9b9d48e1..50491952580 100644 --- a/vortex-array/src/arrays/varbin/compute/compare.rs +++ b/vortex-array/src/arrays/varbin/compute/compare.rs @@ -16,6 +16,7 @@ use vortex_error::vortex_err; use crate::Array; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::ToCanonical; use crate::arrays::BoolArray; @@ -24,21 +25,19 @@ use crate::arrays::VarBinArray; use crate::arrays::VarBinVTable; use crate::arrow::Datum; use crate::arrow::from_arrow_array_with_len; -use crate::compute::CompareKernel; -use crate::compute::CompareKernelAdapter; use crate::compute::Operator; use crate::compute::compare; use crate::compute::compare_lengths_to_empty; -use crate::register_kernel; +use crate::expr::CompareKernel; use crate::vtable::ValidityHelper; // This implementation exists so we can have custom translation of RHS to arrow that's not the same as IntoCanonical impl CompareKernel for VarBinVTable { fn compare( - &self, lhs: &VarBinArray, rhs: &dyn Array, operator: Operator, + _ctx: &mut ExecutionCtx, ) -> VortexResult> { if let Some(rhs_const) = rhs.as_constant() { let nullable = lhs.dtype().is_nullable() || rhs_const.dtype().is_nullable(); @@ -122,8 +121,6 @@ impl CompareKernel for VarBinVTable { } } -register_kernel!(CompareKernelAdapter(VarBinVTable).lift()); - fn compare_offsets_to_empty( offsets: PrimitiveArray, operator: Operator, diff --git a/vortex-array/src/arrays/varbin/vtable/kernel.rs b/vortex-array/src/arrays/varbin/vtable/kernel.rs index 6a94dffb2f8..7c07e795740 100644 --- a/vortex-array/src/arrays/varbin/vtable/kernel.rs +++ b/vortex-array/src/arrays/varbin/vtable/kernel.rs @@ -4,9 +4,11 @@ use crate::arrays::TakeExecuteAdaptor; use crate::arrays::VarBinVTable; use crate::arrays::filter::FilterExecuteAdaptor; +use crate::expr::CompareExecuteAdaptor; use crate::kernel::ParentKernelSet; pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::new(&[ + ParentKernelSet::lift(&CompareExecuteAdaptor(VarBinVTable)), ParentKernelSet::lift(&FilterExecuteAdaptor(VarBinVTable)), ParentKernelSet::lift(&TakeExecuteAdaptor(VarBinVTable)), ]); diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 44c7282f7f2..c773a1234d2 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -2,68 +2,39 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use core::fmt; -use std::any::Any; use std::cmp::Ordering; use std::fmt::Display; use std::fmt::Formatter; -use std::sync::LazyLock; -use arcref::ArcRef; -use arrow_array::Array as ArrowArray; use arrow_array::BooleanArray; use arrow_buffer::NullBuffer; -use arrow_ord::cmp; use arrow_ord::ord::make_comparator; use arrow_schema::SortOptions; use vortex_buffer::BitBuffer; use vortex_dtype::DType; use vortex_dtype::IntegerPType; use vortex_dtype::Nullability; -use vortex_error::VortexError; -use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; -use crate::Canonical; use crate::IntoArray; -use crate::arrays::ConstantArray; -use crate::arrays::ConstantVTable; -use crate::arrow::Datum; -use crate::arrow::IntoArrowArray; -use crate::arrow::from_arrow_array_with_len; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Options; -use crate::compute::Output; -use crate::vtable::VTable; - -static COMPARE_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("compare".into(), ArcRef::new_ref(&Compare)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - COMPARE_FN.kernels().len() -} +use crate::arrays::ScalarFnArray; +use crate::expr::Binary; +use crate::expr::ScalarFn; /// Compares two arrays and returns a new boolean array with the result of the comparison. -/// Or, returns None if comparison is not supported for these arrays. +/// +/// The returned array is lazy (a [`ScalarFnArray`]) and will be evaluated on demand. pub fn compare(left: &dyn Array, right: &dyn Array, operator: Operator) -> VortexResult { - COMPARE_FN - .invoke(&InvocationArgs { - inputs: &[left.into(), right.into()], - options: &operator, - })? - .unwrap_array() + let expr_op: crate::expr::operators::Operator = operator.into(); + Ok(ScalarFnArray::try_new( + ScalarFn::new(Binary, expr_op), + vec![left.to_array(), right.to_array()], + left.len(), + )? + .into_array()) } #[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)] @@ -121,183 +92,6 @@ impl Operator { } } -pub struct CompareKernelRef(ArcRef); -inventory::collect!(CompareKernelRef); - -pub trait CompareKernel: VTable { - fn compare( - &self, - lhs: &Self::Array, - rhs: &dyn Array, - operator: Operator, - ) -> VortexResult>; -} - -#[derive(Debug)] -pub struct CompareKernelAdapter(pub V); - -impl CompareKernelAdapter { - pub const fn lift(&'static self) -> CompareKernelRef { - CompareKernelRef(ArcRef::new_ref(self)) - } -} - -impl Kernel for CompareKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = CompareArgs::try_from(args)?; - let Some(array) = inputs.lhs.as_opt::() else { - return Ok(None); - }; - Ok(V::compare(&self.0, array, inputs.rhs, inputs.operator)?.map(|array| array.into())) - } -} - -struct Compare; - -impl ComputeFnVTable for Compare { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let CompareArgs { lhs, rhs, operator } = CompareArgs::try_from(args)?; - - let return_dtype = self.return_dtype(args)?; - - if lhs.is_empty() { - return Ok(Canonical::empty(&return_dtype).into_array().into()); - } - - let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false); - let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false); - if left_constant_null || right_constant_null { - return Ok(ConstantArray::new(Scalar::null(return_dtype), lhs.len()) - .into_array() - .into()); - } - - let right_is_constant = rhs.is::(); - - // Always try to put constants on the right-hand side so encodings can optimise themselves. - if lhs.is::() && !right_is_constant { - return Ok(compare(rhs, lhs, operator.swap())?.into()); - } - - // First try lhs op rhs, then invert and try again. - for kernel in kernels { - if let Some(output) = kernel.invoke(args)? { - return Ok(output); - } - } - - // Try inverting the operator and swapping the arguments - let inverted_args = InvocationArgs { - inputs: &[rhs.into(), lhs.into()], - options: &operator.swap(), - }; - for kernel in kernels { - if let Some(output) = kernel.invoke(&inverted_args)? { - return Ok(output); - } - } - - // Only log missing compare implementation if there's possibly better one than arrow, - // i.e. lhs isn't arrow or rhs isn't arrow or constant - if !(lhs.is_arrow() && (rhs.is_arrow() || right_is_constant)) { - tracing::debug!( - "No compare implementation found for LHS {}, RHS {}, and operator {} (or inverse)", - lhs.encoding_id(), - rhs.encoding_id(), - operator, - ); - } - - // Fallback to arrow on canonical types - Ok(arrow_compare(lhs, rhs, operator)?.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?; - - if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) { - if lhs.dtype().is_float() && rhs.dtype().is_float() { - vortex_bail!( - "Cannot compare different floating-point types ({}, {}). Consider using cast.", - lhs.dtype(), - rhs.dtype(), - ); - } - if lhs.dtype().is_int() && rhs.dtype().is_int() { - vortex_bail!( - "Cannot compare different fixed-width types ({}, {}). Consider using cast.", - lhs.dtype(), - rhs.dtype() - ); - } - vortex_bail!( - "Cannot compare different DTypes {} and {}", - lhs.dtype(), - rhs.dtype() - ); - } - - Ok(DType::Bool( - lhs.dtype().nullability() | rhs.dtype().nullability(), - )) - } - - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let CompareArgs { lhs, rhs, .. } = CompareArgs::try_from(args)?; - if lhs.len() != rhs.len() { - vortex_bail!( - "Compare operations only support arrays of the same length, got {} and {}", - lhs.len(), - rhs.len() - ); - } - Ok(lhs.len()) - } - - fn is_elementwise(&self) -> bool { - true - } -} - -struct CompareArgs<'a> { - lhs: &'a dyn Array, - rhs: &'a dyn Array, - operator: Operator, -} - -impl Options for Operator { - fn as_any(&self) -> &dyn Any { - self - } -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for CompareArgs<'a> { - type Error = VortexError; - - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Expected 2 inputs, found {}", value.inputs.len()); - } - let lhs = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected first input to be an array"))?; - let rhs = value.inputs[1] - .array() - .ok_or_else(|| vortex_err!("Expected second input to be an array"))?; - let operator = *value - .options - .as_any() - .downcast_ref::() - .vortex_expect("Expected options to be an operator"); - - Ok(CompareArgs { lhs, rhs, operator }) - } -} - /// Helper function to compare empty values with arrays that have external value length information /// like `VarBin`. pub fn compare_lengths_to_empty(lengths: I, op: Operator) -> BitBuffer @@ -347,48 +141,6 @@ pub(crate) fn compare_nested_arrow_arrays( Ok(BooleanArray::new(values, nulls)) } -/// Implementation of `CompareFn` using the Arrow crate. -fn arrow_compare( - left: &dyn Array, - right: &dyn Array, - operator: Operator, -) -> VortexResult { - assert_eq!(left.len(), right.len()); - - let nullable = left.dtype().is_nullable() || right.dtype().is_nullable(); - - // Arrow's vectorized comparison kernels (`cmp::eq`, etc.) are faster but don't support nested - // types. For nested types, we fall back to `make_comparator` which does element-wise - // comparison. - let array = if left.dtype().is_nested() || right.dtype().is_nested() { - let rhs = right.to_array().into_arrow_preferred()?; - let lhs = left.to_array().into_arrow(rhs.data_type())?; - - assert!( - lhs.data_type().equals_datatype(rhs.data_type()), - "lhs data_type: {}, rhs data_type: {}", - lhs.data_type(), - rhs.data_type() - ); - - compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)? - } else { - // Fast path: use vectorized kernels for primitive types. - let lhs = Datum::try_new(left)?; - let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?; - - match operator { - Operator::Eq => cmp::eq(&lhs, &rhs)?, - Operator::NotEq => cmp::neq(&lhs, &rhs)?, - Operator::Gt => cmp::gt(&lhs, &rhs)?, - Operator::Gte => cmp::gt_eq(&lhs, &rhs)?, - Operator::Lt => cmp::lt(&lhs, &rhs)?, - Operator::Lte => cmp::lt_eq(&lhs, &rhs)?, - } - }; - from_arrow_array_with_len(&array, left.len(), nullable) -} - pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { if lhs.is_null() | rhs.is_null() { Scalar::null(DType::Bool(Nullability::Nullable)) @@ -424,9 +176,6 @@ mod tests { use crate::arrays::VarBinArray; use crate::arrays::VarBinViewArray; use crate::assert_arrays_eq; - use crate::expr::get_item; - use crate::expr::lt; - use crate::expr::root; use crate::test_harness::to_int_indices; use crate::validity::Validity; @@ -480,15 +229,10 @@ mod tests { let left = ConstantArray::new(Scalar::from(2u32), 10); let right = ConstantArray::new(Scalar::from(10u32), 10); - let compare = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap(); - let res = compare.as_constant().unwrap(); - assert_eq!(res.as_bool().value(), Some(false)); - assert_eq!(compare.len(), 10); - - let compare = arrow_compare(&left.into_array(), &right.into_array(), Operator::Gt).unwrap(); - let res = compare.as_constant().unwrap(); - assert_eq!(res.as_bool().value(), Some(false)); - assert_eq!(compare.len(), 10); + let result = compare(left.as_ref(), right.as_ref(), Operator::Gt).unwrap(); + assert_eq!(result.len(), 10); + let scalar = result.scalar_at(0).unwrap(); + assert_eq!(scalar.as_bool().value(), Some(false)); } #[rstest] @@ -656,58 +400,4 @@ mod tests { assert!(result.scalar_at(1).unwrap().is_valid()); assert!(result.scalar_at(2).unwrap().is_valid()); } - - #[test] - fn test_different_floats_error_messages() { - let result = compare( - &buffer![0.0f32].into_array(), - &buffer![0.0f64].into_array(), - Operator::Lt, - ); - assert!(result.as_ref().is_err_and(|err| { - err.to_string() - .contains("Cannot compare different floating-point types") - })); - - let expr = lt(get_item("l", root()), get_item("r", root())); - let array = StructArray::from_fields(&[ - ("l", buffer![0.0f32].into_array()), - ("r", buffer![0.0f64].into_array()), - ]) - .unwrap() - .into_array(); - // Force evaluation by calling scalar_at - let result = array.apply(&expr).and_then(|arr| arr.scalar_at(0)); - assert!(result.as_ref().is_err_and(|err| { - err.to_string() - .contains("Cannot compare different floating-point types") - })); - } - - #[test] - fn test_different_ints_error_messages() { - let result = compare( - &buffer![0u8].into_array(), - &buffer![0u16].into_array(), - Operator::Lt, - ); - assert!(result.as_ref().is_err_and(|err| { - err.to_string() - .contains("Cannot compare different fixed-width types") - })); - - let expr = lt(get_item("l", root()), get_item("r", root())); - let array = StructArray::from_fields(&[ - ("l", buffer![0u8].into_array()), - ("r", buffer![0u16].into_array()), - ]) - .unwrap() - .into_array(); - // Force evaluation by calling scalar_at - let result = array.apply(&expr).and_then(|arr| arr.scalar_at(0)); - assert!(result.as_ref().is_err_and(|err| { - err.to_string() - .contains("Cannot compare different fixed-width types") - })); - } } diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 2aed159bbc1..87298c1e6e5 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -94,7 +94,6 @@ pub struct ComputeFn { pub fn warm_up_vtables() { #[allow(unused_qualifications)] between::warm_up_vtable(); - compare::warm_up_vtable(); is_constant::warm_up_vtable(); is_sorted::warm_up_vtable(); list_contains::warm_up_vtable(); diff --git a/vortex-array/src/expr/exprs/binary/compare.rs b/vortex-array/src/expr/exprs/binary/compare.rs new file mode 100644 index 00000000000..cd44365b2cf --- /dev/null +++ b/vortex-array/src/expr/exprs/binary/compare.rs @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use arrow_array::BooleanArray; +use arrow_ord::cmp; +use vortex_error::VortexResult; +use vortex_scalar::Scalar; + +use crate::Array; +use crate::ArrayRef; +use crate::Canonical; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::arrays::ExactScalarFn; +use crate::arrays::ScalarFnArrayView; +use crate::arrays::ScalarFnVTable; +use crate::arrow::Datum; +use crate::arrow::IntoArrowArray; +use crate::arrow::from_arrow_array_with_len; +use crate::compute::Operator; +use crate::compute::compare_nested_arrow_arrays; +use crate::compute::scalar_cmp; +use crate::expr::Binary; +use crate::kernel::ExecuteParentKernel; +use crate::vtable::VTable; + +/// Trait for encoding-specific comparison kernels that operate in encoded space. +/// +/// Implementations can compare an encoded array against another array (typically a constant) +/// without first decompressing. The adaptor normalizes operand order so `array` is always +/// the left-hand side, swapping the operator when necessary. +pub trait CompareKernel: VTable { + fn compare( + lhs: &Self::Array, + rhs: &dyn Array, + operator: Operator, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; +} + +/// Adaptor that bridges [`CompareKernel`] implementations to [`ExecuteParentKernel`]. +/// +/// When a `ScalarFnArray(Binary, cmp_op)` wraps a child that implements `CompareKernel`, +/// this adaptor extracts the comparison operator and other operand, normalizes operand order +/// (swapping the operator if the encoded array is on the RHS), and delegates to the kernel. +#[derive(Default, Debug)] +pub struct CompareExecuteAdaptor(pub V); + +impl ExecuteParentKernel for CompareExecuteAdaptor +where + V: CompareKernel, +{ + type Parent = ExactScalarFn; + + fn execute_parent( + &self, + array: &V::Array, + parent: ScalarFnArrayView<'_, Binary>, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Only handle comparison operators + let Some(cmp_op) = parent.options.maybe_cmp_operator() else { + return Ok(None); + }; + + // Get the ScalarFnArray to access children + let Some(scalar_fn_array) = parent.as_opt::() else { + return Ok(None); + }; + let children = scalar_fn_array.children(); + + // Normalize so `array` is always LHS, swapping the operator if needed + // TODO(joe): should be go this here or in the Rule/Kernel + let (cmp_op, other) = match child_idx { + 0 => (cmp_op, &children[1]), + 1 => (cmp_op.swap(), &children[0]), + _ => return Ok(None), + }; + + let len = array.len(); + let nullable = array.dtype().is_nullable() || other.dtype().is_nullable(); + + // Empty array → empty bool result + if len == 0 { + return Ok(Some( + Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array(), + )); + } + + // Null constant on either side → all-null bool result + if other.as_constant().is_some_and(|s| s.is_null()) { + return Ok(Some( + ConstantArray::new( + Scalar::null(vortex_dtype::DType::Bool(nullable.into())), + len, + ) + .into_array(), + )); + } + + V::compare(array, other.as_ref(), cmp_op, ctx) + } +} + +/// Execute a compare operation between two arrays. +/// +/// This is the entry point for compare operations from the binary expression. +/// Handles empty, constant-null, and constant-constant directly, otherwise falls back to Arrow. +pub(crate) fn execute_compare( + lhs: &dyn Array, + rhs: &dyn Array, + op: Operator, +) -> VortexResult { + let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); + + if lhs.is_empty() { + return Ok(Canonical::empty(&vortex_dtype::DType::Bool(nullable.into())).into_array()); + } + + let left_constant_null = lhs.as_constant().map(|l| l.is_null()).unwrap_or(false); + let right_constant_null = rhs.as_constant().map(|r| r.is_null()).unwrap_or(false); + if left_constant_null || right_constant_null { + return Ok(ConstantArray::new( + Scalar::null(vortex_dtype::DType::Bool(nullable.into())), + lhs.len(), + ) + .into_array()); + } + + // Constant-constant fast path + if let (Some(lhs_const), Some(rhs_const)) = ( + lhs.as_opt::(), + rhs.as_opt::(), + ) { + let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op); + return Ok(ConstantArray::new(result, lhs.len()).into_array()); + } + + arrow_compare_arrays(lhs, rhs, op) +} + +/// Fall back to Arrow for comparison. +fn arrow_compare_arrays( + left: &dyn Array, + right: &dyn Array, + operator: Operator, +) -> VortexResult { + assert_eq!(left.len(), right.len()); + + let nullable = left.dtype().is_nullable() || right.dtype().is_nullable(); + + // Arrow's vectorized comparison kernels don't support nested types. + // For nested types, fall back to `make_comparator` which does element-wise comparison. + let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() { + let rhs = right.to_array().into_arrow_preferred()?; + let lhs = left.to_array().into_arrow(rhs.data_type())?; + + assert!( + lhs.data_type().equals_datatype(rhs.data_type()), + "lhs data_type: {}, rhs data_type: {}", + lhs.data_type(), + rhs.data_type() + ); + + compare_nested_arrow_arrays(lhs.as_ref(), rhs.as_ref(), operator)? + } else { + // Fast path: use vectorized kernels for primitive types. + let lhs = Datum::try_new(left)?; + let rhs = Datum::try_new_with_target_datatype(right, lhs.data_type())?; + + match operator { + Operator::Eq => cmp::eq(&lhs, &rhs)?, + Operator::NotEq => cmp::neq(&lhs, &rhs)?, + Operator::Gt => cmp::gt(&lhs, &rhs)?, + Operator::Gte => cmp::gt_eq(&lhs, &rhs)?, + Operator::Lt => cmp::lt(&lhs, &rhs)?, + Operator::Lte => cmp::lt_eq(&lhs, &rhs)?, + } + }; + from_arrow_array_with_len(&array, left.len(), nullable) +} diff --git a/vortex-array/src/expr/exprs/binary/mod.rs b/vortex-array/src/expr/exprs/binary/mod.rs index c66eb1a3548..16166e088f9 100644 --- a/vortex-array/src/expr/exprs/binary/mod.rs +++ b/vortex-array/src/expr/exprs/binary/mod.rs @@ -14,7 +14,6 @@ use vortex_session::VortexSession; use crate::ArrayRef; use crate::compute; use crate::compute::BooleanOperator; -use crate::compute::compare; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::ExecutionArgs; @@ -29,6 +28,8 @@ use crate::expr::stats::Stat; mod boolean; pub(crate) use boolean::*; +mod compare; +pub use compare::*; mod numeric; pub(crate) use numeric::*; @@ -99,6 +100,14 @@ impl VTable for Binary { ); } + if operator.is_comparison() + && !lhs.eq_ignore_nullability(rhs) + && !lhs.is_extension() + && !rhs.is_extension() + { + vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs); + } + Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into())) } @@ -108,12 +117,12 @@ impl VTable for Binary { }; match op { - Operator::Eq => compare(lhs, rhs, compute::Operator::Eq), - Operator::NotEq => compare(lhs, rhs, compute::Operator::NotEq), - Operator::Lt => compare(lhs, rhs, compute::Operator::Lt), - Operator::Lte => compare(lhs, rhs, compute::Operator::Lte), - Operator::Gt => compare(lhs, rhs, compute::Operator::Gt), - Operator::Gte => compare(lhs, rhs, compute::Operator::Gte), + Operator::Eq => execute_compare(lhs, rhs, compute::Operator::Eq), + Operator::NotEq => execute_compare(lhs, rhs, compute::Operator::NotEq), + Operator::Lt => execute_compare(lhs, rhs, compute::Operator::Lt), + Operator::Lte => execute_compare(lhs, rhs, compute::Operator::Lte), + Operator::Gt => execute_compare(lhs, rhs, compute::Operator::Gt), + Operator::Gte => execute_compare(lhs, rhs, compute::Operator::Gte), Operator::And => execute_boolean(lhs, rhs, BooleanOperator::AndKleene), Operator::Or => execute_boolean(lhs, rhs, BooleanOperator::OrKleene), Operator::Add => execute_numeric(lhs, rhs, vortex_scalar::NumericOperator::Add), @@ -554,6 +563,7 @@ mod tests { use super::*; use crate::assert_arrays_eq; + use crate::compute::compare; use crate::expr::Expression; use crate::expr::exprs::get_item::col; use crate::expr::exprs::literal::lit; diff --git a/vortex-array/src/expr/exprs/operators.rs b/vortex-array/src/expr/exprs/operators.rs index c2084c08f6b..5121c506d2c 100644 --- a/vortex-array/src/expr/exprs/operators.rs +++ b/vortex-array/src/expr/exprs/operators.rs @@ -185,6 +185,13 @@ impl Operator { pub fn is_arithmetic(&self) -> bool { matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Div) } + + pub fn is_comparison(&self) -> bool { + matches!( + self, + Self::Eq | Self::NotEq | Self::Gt | Self::Gte | Self::Lt | Self::Lte + ) + } } impl From for Operator { diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 91060f3f8e8..09f148166c8 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -4,7 +4,6 @@ //! Array validity and nullability behavior, used by arrays and compute functions. use std::fmt::Debug; -use std::ops::BitAnd; use std::ops::Range; use vortex_buffer::BitBuffer; @@ -28,8 +27,11 @@ use crate::IntoArray; use crate::ToCanonical; use crate::arrays::BoolArray; use crate::arrays::ConstantArray; +use crate::arrays::ScalarFnArrayExt; use crate::builtins::ArrayBuiltins; use crate::compute::sum; +use crate::expr::Binary; +use crate::expr::Operator; use crate::patches::Patches; /// Validity information for an array @@ -263,8 +265,8 @@ impl Validity { /// Logically & two Validity values of the same length #[inline] - pub fn and(self, rhs: Validity) -> Validity { - match (self, rhs) { + pub fn and(self, rhs: Validity) -> VortexResult { + Ok(match (self, rhs) { // Should be pretty clear (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable, // Any `AllInvalid` makes the output all invalid values @@ -280,15 +282,9 @@ impl Validity { | (Validity::AllValid, Validity::AllValid) => Validity::AllValid, // Here we actually have to do some work (Validity::Array(lhs), Validity::Array(rhs)) => { - let lhs = lhs.to_bool(); - let rhs = rhs.to_bool(); - - let lhs = lhs.to_bit_buffer(); - let rhs = rhs.to_bit_buffer(); - - Validity::from(lhs.bitand(rhs)) + Validity::Array(Binary.try_new_array(lhs.len(), Operator::And, [lhs, rhs])?) } - } + }) } pub fn patch(