diff --git a/encodings/alp/src/alp/compute/mask.rs b/encodings/alp/src/alp/compute/mask.rs index f08b6499a20..cde1347b626 100644 --- a/encodings/alp/src/alp/compute/mask.rs +++ b/encodings/alp/src/alp/compute/mask.rs @@ -2,38 +2,25 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::IntoArray; +use vortex_array::arrays::MaskedArray; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ALPArray; use crate::ALPVTable; -impl MaskKernel for ALPVTable { - fn mask(&self, array: &ALPArray, filter_mask: &Mask) -> VortexResult { - let masked_encoded = mask(array.encoded(), filter_mask)?; - let masked_patches = array - .patches() - .map(|p| p.mask(filter_mask)) - .transpose()? - .flatten() - .map(|patches| { - patches.cast_values( - &array - .dtype() - .with_nullability(masked_encoded.dtype().nullability()), - ) - }) - .transpose()?; - Ok(ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array()) +impl MaskReduce for ALPVTable { + fn mask(array: &ALPArray, validity: &Validity) -> VortexResult> { + let masked_encoded = + MaskedArray::try_new(array.encoded().clone(), validity.clone())?.into_array(); + Ok(Some( + ALPArray::new(masked_encoded, array.exponents(), array.patches().cloned()).to_array(), + )) } } -register_kernel!(MaskKernelAdapter(ALPVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/encodings/alp/src/alp/rules.rs b/encodings/alp/src/alp/rules.rs index f90739c1733..df258535791 100644 --- a/encodings/alp/src/alp/rules.rs +++ b/encodings/alp/src/alp/rules.rs @@ -5,6 +5,7 @@ use vortex_array::arrays::FilterExecuteAdaptor; use vortex_array::arrays::SliceExecuteAdaptor; use vortex_array::arrays::TakeExecuteAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::kernel::ParentKernelSet; use vortex_array::optimizer::rules::ParentRuleSet; @@ -16,5 +17,7 @@ pub(super) const PARENT_KERNELS: ParentKernelSet = ParentKernelSet::n ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)), ]); -pub(super) const RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable))]); +pub(super) const RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ALPVTable)), +]); diff --git a/encodings/alp/src/alp_rd/compute/mask.rs b/encodings/alp/src/alp_rd/compute/mask.rs index dcda1e1e2e2..14ba5d9547b 100644 --- a/encodings/alp/src/alp_rd/compute/mask.rs +++ b/encodings/alp/src/alp_rd/compute/mask.rs @@ -3,32 +3,32 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::arrays::MaskedArray; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ALPRDArray; use crate::ALPRDVTable; -impl MaskKernel for ALPRDVTable { - fn mask(&self, array: &ALPRDArray, filter_mask: &Mask) -> VortexResult { - Ok(ALPRDArray::try_new( - array.dtype().as_nullable(), - mask(array.left_parts(), filter_mask)?, - array.left_parts_dictionary().clone(), - array.right_parts().clone(), - array.right_bit_width(), - array.left_parts_patches().cloned(), - )? - .into_array()) +impl MaskReduce for ALPRDVTable { + fn mask(array: &ALPRDArray, validity: &Validity) -> VortexResult> { + let masked_left_parts = + MaskedArray::try_new(array.left_parts().clone(), validity.clone())?.into_array(); + Ok(Some( + ALPRDArray::try_new( + array.dtype().as_nullable(), + masked_left_parts, + array.left_parts_dictionary().clone(), + array.right_parts().clone(), + array.right_bit_width(), + array.left_parts_patches().cloned(), + )? + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ALPRDVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/encodings/alp/src/alp_rd/rules.rs b/encodings/alp/src/alp_rd/rules.rs index a7280e0bf8f..ed048e10829 100644 --- a/encodings/alp/src/alp_rd/rules.rs +++ b/encodings/alp/src/alp_rd/rules.rs @@ -2,9 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::alp_rd::ALPRDVTable; -pub(crate) static RULES: ParentRuleSet = - ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable))]); +pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ALPRDVTable)), +]); diff --git a/encodings/bytebool/src/compute.rs b/encodings/bytebool/src/compute.rs index b40e3722ee0..f966a151016 100644 --- a/encodings/bytebool/src/compute.rs +++ b/encodings/bytebool/src/compute.rs @@ -9,14 +9,12 @@ use vortex_array::IntoArray; use vortex_array::ToCanonical; use vortex_array::arrays::TakeExecute; use vortex_array::compute::CastReduce; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::register_kernel; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_array::vtable::ValidityHelper; use vortex_dtype::DType; use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; -use vortex_mask::Mask; use super::ByteBoolArray; use super::ByteBoolVTable; @@ -44,14 +42,18 @@ impl CastReduce for ByteBoolVTable { } } -impl MaskKernel for ByteBoolVTable { - fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult { - Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array()) +impl MaskReduce for ByteBoolVTable { + fn mask(array: &ByteBoolArray, validity: &Validity) -> VortexResult> { + Ok(Some( + ByteBoolArray::new( + array.buffer().clone(), + array.validity().clone().and(validity.clone()), + ) + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift()); - impl TakeExecute for ByteBoolVTable { fn take( array: &ByteBoolArray, diff --git a/encodings/bytebool/src/rules.rs b/encodings/bytebool/src/rules.rs index 52e9e32cefa..284989a3621 100644 --- a/encodings/bytebool/src/rules.rs +++ b/encodings/bytebool/src/rules.rs @@ -3,11 +3,13 @@ use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::ByteBoolVTable; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)), ParentRuleSet::lift(&CastReduceAdaptor(ByteBoolVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ByteBoolVTable)), + ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)), ]); diff --git a/encodings/datetime-parts/src/compute/mask.rs b/encodings/datetime-parts/src/compute/mask.rs index ba9ba4b9acf..b820bbf3f31 100644 --- a/encodings/datetime-parts/src/compute/mask.rs +++ b/encodings/datetime-parts/src/compute/mask.rs @@ -1,41 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::Array; use vortex_array::ArrayRef; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::IntoArray; +use vortex_array::arrays::MaskedArray; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::DateTimePartsArray; use crate::DateTimePartsVTable; -impl MaskKernel for DateTimePartsVTable { - fn mask(&self, array: &DateTimePartsArray, mask_array: &Mask) -> VortexResult { - // DateTimePartsArray has specific constraints: - // - days nullability must match the dtype - // - seconds and subseconds must always be non-nullable - // - // When masking, we can't make seconds/subseconds nullable. - // Instead, we'll keep the same values but the overall array becomes nullable - // through the days component. - - let masked_days = mask(array.days(), mask_array)?; - assert!(masked_days.dtype().is_nullable()); - - // Keep seconds and subseconds unchanged since they must remain non-nullable - let seconds = array.seconds().clone(); - let subseconds = array.subseconds().clone(); - - // Update the dtype to reflect the new nullability of days - let new_dtype = array.dtype().as_nullable(); - - DateTimePartsArray::try_new(new_dtype, masked_days, seconds, subseconds) - .map(|a| a.to_array()) +impl MaskReduce for DateTimePartsVTable { + fn mask(array: &DateTimePartsArray, validity: &Validity) -> VortexResult> { + let masked_days = + MaskedArray::try_new(array.days().clone(), validity.clone())?.into_array(); + Ok(Some( + DateTimePartsArray::try_new( + array.dtype().as_nullable(), + masked_days, + array.seconds().clone(), + array.subseconds().clone(), + )? + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(DateTimePartsVTable).lift()); diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 436636872e2..e1d5e855236 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -14,6 +14,7 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::expr::Between; use vortex_array::expr::Binary; use vortex_array::optimizer::ArrayOptimizer; @@ -33,6 +34,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSe ParentRuleSet::lift(&DTPComparisonPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(DateTimePartsVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(DateTimePartsVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DateTimePartsVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DateTimePartsVTable)), ]); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs index 04a91c5de17..9203e2a66b0 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs @@ -2,21 +2,20 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_array::ArrayRef; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::IntoArray; +use vortex_array::arrays::MaskedArray; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::DecimalBytePartsArray; use crate::DecimalBytePartsVTable; -impl MaskKernel for DecimalBytePartsVTable { - fn mask(&self, array: &DecimalBytePartsArray, mask_array: &Mask) -> VortexResult { - let masked = mask(&array.msp, mask_array)?; - DecimalBytePartsArray::try_new(masked, *array.decimal_dtype()).map(|a| a.to_array()) +impl MaskReduce for DecimalBytePartsVTable { + fn mask(array: &DecimalBytePartsArray, validity: &Validity) -> VortexResult> { + let masked_msp = MaskedArray::try_new(array.msp.clone(), validity.clone())?.into_array(); + Ok(Some( + DecimalBytePartsArray::try_new(masked_msp, *array.decimal_dtype())?.into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(DecimalBytePartsVTable).lift()); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs index 7c851149323..7762b50b0ec 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs @@ -9,6 +9,7 @@ use vortex_array::arrays::FilterReduceAdaptor; use vortex_array::arrays::FilterVTable; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ArrayParentReduceRule; use vortex_array::optimizer::rules::ParentRuleSet; use vortex_error::VortexResult; @@ -20,6 +21,7 @@ pub(super) const PARENT_RULES: ParentRuleSet = ParentRul ParentRuleSet::lift(&DecimalBytePartsFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(DecimalBytePartsVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(DecimalBytePartsVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DecimalBytePartsVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DecimalBytePartsVTable)), ]); diff --git a/encodings/zigzag/src/compute/mod.rs b/encodings/zigzag/src/compute/mod.rs index 562af093215..22082e10f6a 100644 --- a/encodings/zigzag/src/compute/mod.rs +++ b/encodings/zigzag/src/compute/mod.rs @@ -8,11 +8,10 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::FilterReduce; +use vortex_array::arrays::MaskedArray; use vortex_array::arrays::TakeExecute; -use vortex_array::compute::MaskKernel; -use vortex_array::compute::MaskKernelAdapter; -use vortex_array::compute::mask; -use vortex_array::register_kernel; +use vortex_array::compute::MaskReduce; +use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_mask::Mask; @@ -37,15 +36,14 @@ impl TakeExecute for ZigZagVTable { } } -impl MaskKernel for ZigZagVTable { - fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult { - let encoded = mask(array.encoded(), filter_mask)?; - Ok(ZigZagArray::try_new(encoded)?.into_array()) +impl MaskReduce for ZigZagVTable { + fn mask(array: &ZigZagArray, validity: &Validity) -> VortexResult> { + let masked_encoded = + MaskedArray::try_new(array.encoded().clone(), validity.clone())?.into_array(); + Ok(Some(ZigZagArray::try_new(masked_encoded)?.into_array())) } } -register_kernel!(MaskKernelAdapter(ZigZagVTable).lift()); - pub(crate) trait ZigZagEncoded { type Int: zigzag::ZigZag; } diff --git a/encodings/zigzag/src/rules.rs b/encodings/zigzag/src/rules.rs index 0b684b0c01d..0cdc6962976 100644 --- a/encodings/zigzag/src/rules.rs +++ b/encodings/zigzag/src/rules.rs @@ -4,6 +4,7 @@ use vortex_array::arrays::FilterReduceAdaptor; use vortex_array::arrays::SliceReduceAdaptor; use vortex_array::compute::CastReduceAdaptor; +use vortex_array::compute::MaskReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use crate::ZigZagVTable; @@ -11,5 +12,6 @@ use crate::ZigZagVTable; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(ZigZagVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(ZigZagVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ZigZagVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ZigZagVTable)), ]); diff --git a/fuzz/src/array/mod.rs b/fuzz/src/array/mod.rs index aada877fdca..e9a0312bf23 100644 --- a/fuzz/src/array/mod.rs +++ b/fuzz/src/array/mod.rs @@ -553,7 +553,6 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::compute::compare; - use vortex_array::compute::mask; use vortex_array::compute::min_max; use vortex_array::compute::sum; let FuzzArrayAction { array, actions } = fuzz_action; @@ -642,7 +641,9 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz assert_array_eq(&expected.array(), ¤t_array, i)?; } Action::Mask(mask_val) => { - current_array = mask(¤t_array, &mask_val) + current_array = current_array + .as_ref() + .mask(&mask_val) .vortex_expect("mask operation should succeed in fuzz test"); assert_array_eq(&expected.array(), ¤t_array, i)?; } diff --git a/vortex-array/benches/dict_mask.rs b/vortex-array/benches/dict_mask.rs index 2dff5768431..a2b68c77601 100644 --- a/vortex-array/benches/dict_mask.rs +++ b/vortex-array/benches/dict_mask.rs @@ -7,12 +7,15 @@ use divan::Bencher; use rand::Rng; use rand::SeedableRng; use rand::rngs::StdRng; +use vortex_array::Canonical; use vortex_array::IntoArray; +use vortex_array::RecursiveCanonical; +use vortex_array::VortexSessionExecute; use vortex_array::arrays::DictArray; use vortex_array::arrays::PrimitiveArray; -use vortex_array::compute::mask; use vortex_array::compute::warm_up_vtables; use vortex_mask::Mask; +use vortex_session::VortexSession; fn main() { warm_up_vtables(); @@ -59,7 +62,15 @@ fn bench_dict_mask(bencher: Bencher, (fraction_valid, fraction_masked): (f64, f6 let values = PrimitiveArray::from_option_iter([None, Some(42i32)]).into_array(); let array = DictArray::try_new(codes, values).unwrap().into_array(); let filter_mask = filter_mask(len, fraction_masked, &mut rng); + let session = VortexSession::empty(); bencher .with_inputs(|| (&array, &filter_mask)) - .bench_refs(|(array, filter_mask)| mask(array.as_ref(), filter_mask).unwrap()); + .bench_refs(|(array, filter_mask)| { + let mut ctx = session.create_execution_ctx(); + array + .mask(filter_mask) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); } diff --git a/vortex-array/src/arrays/bool/compute/mask.rs b/vortex-array/src/arrays/bool/compute/mask.rs index b013cee08fe..53ae787266c 100644 --- a/vortex-array/src/arrays/bool/compute/mask.rs +++ b/vortex-array/src/arrays/bool/compute/mask.rs @@ -2,25 +2,27 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::BoolArray; use crate::arrays::BoolVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for BoolVTable { - fn mask(&self, array: &BoolArray, mask: &Mask) -> VortexResult { - Ok(BoolArray::new(array.to_bit_buffer(), array.validity().mask(mask)).into_array()) +impl MaskReduce for BoolVTable { + fn mask(array: &BoolArray, validity: &Validity) -> VortexResult> { + Ok(Some( + BoolArray::new( + array.to_bit_buffer(), + array.validity().clone().and(validity.clone()), + ) + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(BoolVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/bool/compute/rules.rs b/vortex-array/src/arrays/bool/compute/rules.rs index d885804f89c..09b7d80da86 100644 --- a/vortex-array/src/arrays/bool/compute/rules.rs +++ b/vortex-array/src/arrays/bool/compute/rules.rs @@ -11,6 +11,7 @@ use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; @@ -18,6 +19,7 @@ use crate::vtable::ValidityHelper; pub(crate) const RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&BoolMaskedValidityRule), ParentRuleSet::lift(&CastReduceAdaptor(BoolVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(BoolVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(BoolVTable)), ]); diff --git a/vortex-array/src/arrays/chunked/compute/mask.rs b/vortex-array/src/arrays/chunked/compute/mask.rs index 7c6aa7b87ea..eea4259de46 100644 --- a/vortex-array/src/arrays/chunked/compute/mask.rs +++ b/vortex-array/src/arrays/chunked/compute/mask.rs @@ -1,153 +1,44 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use itertools::Itertools as _; -use vortex_buffer::BitBuffer; -use vortex_buffer::BitBufferMut; -use vortex_dtype::DType; use vortex_error::VortexResult; -use vortex_mask::AllOr; -use vortex_mask::Mask; -use vortex_mask::MaskIter; -use vortex_scalar::Scalar; -use super::filter::ChunkFilter; -use super::filter::chunk_filters; -use super::filter::find_chunk_idx; use crate::ArrayRef; use crate::IntoArray; -use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; use crate::arrays::ChunkedVTable; -use crate::arrays::ConstantArray; -use crate::arrays::chunked::compute::filter::FILTER_SLICES_SELECTIVITY_THRESHOLD; -use crate::builtins::ArrayBuiltins; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::compute::mask; -use crate::register_kernel; +use crate::arrays::ScalarFnArrayExt; +use crate::compute::MaskReduce; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; use crate::validity::Validity; -impl MaskKernel for ChunkedVTable { - fn mask(&self, array: &ChunkedArray, mask: &Mask) -> VortexResult { - let new_dtype = array.dtype().as_nullable(); - let new_chunks = match mask.threshold_iter(FILTER_SLICES_SELECTIVITY_THRESHOLD) { - AllOr::All => unreachable!("handled in top-level mask"), - AllOr::None => unreachable!("handled in top-level mask"), - AllOr::Some(MaskIter::Indices(indices)) => mask_indices(array, indices, &new_dtype), - AllOr::Some(MaskIter::Slices(slices)) => { - mask_slices(array, slices.iter().cloned(), &new_dtype) - } - }?; - debug_assert_eq!(new_chunks.len(), array.nchunks()); - debug_assert_eq!( - new_chunks.iter().map(|x| x.len()).sum::(), - array.len() - ); - ChunkedArray::try_new(new_chunks, new_dtype).map(|c| c.into_array()) +impl MaskReduce for ChunkedVTable { + fn mask(array: &ChunkedArray, validity: &Validity) -> VortexResult> { + let Validity::Array(validity_array) = validity else { + // Precondition guarantees Array variant, but handle gracefully + return Ok(None); + }; + + let chunk_offsets = array.chunk_offsets(); + let new_chunks: Vec = array + .chunks() + .iter() + .enumerate() + .map(|(i, chunk)| { + let start: usize = chunk_offsets[i].try_into()?; + let end: usize = chunk_offsets[i + 1].try_into()?; + let chunk_mask = validity_array.slice(start..end)?; + MaskExpr.try_new_array(chunk.len(), EmptyOptions, [chunk.clone(), chunk_mask]) + }) + .collect::>()?; + + Ok(Some( + ChunkedArray::try_new(new_chunks, array.dtype().as_nullable())?.into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ChunkedVTable).lift()); - -fn mask_indices( - array: &ChunkedArray, - indices: &[usize], - new_dtype: &DType, -) -> VortexResult> { - let mut new_chunks = Vec::with_capacity(array.nchunks()); - let mut current_chunk_id = 0; - let mut chunk_indices = Vec::::new(); - - let chunk_offsets = array.chunk_offsets(); - - for &set_index in indices { - let (chunk_id, index) = find_chunk_idx(set_index, &chunk_offsets)?; - if chunk_id != current_chunk_id { - let chunk = array.chunk(current_chunk_id).clone(); - let chunk_len = chunk.len(); - // chunk_indices contains indices to null out, but chunk.mask() expects - // mask=true to mean "retain". So we create a mask with bits set at indices - // to null, then invert it to get mask=true at indices to retain. - let mask = BoolArray::new( - !BitBuffer::from_indices(chunk_len, &chunk_indices), - Validity::NonNullable, - ) - .into_array(); - let masked_chunk = chunk.mask(mask)?; - // Advance the chunk forward, reset the chunk indices buffer. - chunk_indices = Vec::new(); - new_chunks.push(masked_chunk); - current_chunk_id += 1; - - while current_chunk_id < chunk_id { - // Chunks that are not affected by the mask, must still be casted to the correct dtype. - let chunk = array.chunk(current_chunk_id).cast(new_dtype.clone())?; - new_chunks.push(chunk); - current_chunk_id += 1; - } - } - - chunk_indices.push(index); - } - - if !chunk_indices.is_empty() { - let chunk = array.chunk(current_chunk_id).clone(); - let chunk_len = chunk.len(); - // Same inversion as above: invert the mask so mask=true means "retain" - let masked_chunk = chunk.mask( - BoolArray::new( - !BitBufferMut::from_indices(chunk_len, &chunk_indices).freeze(), - Validity::NonNullable, - ) - .into_array(), - )?; - new_chunks.push(masked_chunk); - current_chunk_id += 1; - } - - while current_chunk_id < array.nchunks() { - let chunk = array.chunk(current_chunk_id); - new_chunks.push(chunk.cast(new_dtype.clone())?); - current_chunk_id += 1; - } - - Ok(new_chunks) -} - -fn mask_slices( - array: &ChunkedArray, - slices: impl Iterator, - new_dtype: &DType, -) -> VortexResult> { - let chunked_filters = chunk_filters(array, slices)?; - - array - .chunks() - .iter() - .zip_eq(chunked_filters) - .map(|(chunk, chunk_filter)| -> VortexResult { - match chunk_filter { - ChunkFilter::All => { - // entire chunk is masked out - Ok( - ConstantArray::new(Scalar::null(new_dtype.clone()), chunk.len()) - .into_array(), - ) - } - ChunkFilter::None => { - // entire chunk is not affected by mask - chunk.cast(new_dtype.clone()) - } - ChunkFilter::Slices(slices) => { - // Slices of indices that must be set to null - mask(chunk, &Mask::from_slices(chunk.len(), slices)) - } - } - }) - .process_results(|iter| iter.collect::>()) -} - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/chunked/compute/rules.rs b/vortex-array/src/arrays/chunked/compute/rules.rs index 4f453a1cd05..782273798d5 100644 --- a/vortex-array/src/arrays/chunked/compute/rules.rs +++ b/vortex-array/src/arrays/chunked/compute/rules.rs @@ -14,6 +14,7 @@ use crate::arrays::ConstantVTable; use crate::arrays::ScalarFnArray; use crate::compute::CastReduceAdaptor; use crate::expr::FillNullReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::expr::ZipReduceAdaptor; use crate::optimizer::ArrayOptimizer; use crate::optimizer::rules::ArrayParentReduceRule; @@ -24,6 +25,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new ParentRuleSet::lift(&ChunkedUnaryScalarFnPushDownRule), ParentRuleSet::lift(&ChunkedConstantScalarFnPushDownRule), ParentRuleSet::lift(&FillNullReduceAdaptor(ChunkedVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ChunkedVTable)), ParentRuleSet::lift(&ZipReduceAdaptor(ChunkedVTable)), ]); diff --git a/vortex-array/src/arrays/constant/compute/mask.rs b/vortex-array/src/arrays/constant/compute/mask.rs index a1ad217e2b7..f97160c7a06 100644 --- a/vortex-array/src/arrays/constant/compute/mask.rs +++ b/vortex-array/src/arrays/constant/compute/mask.rs @@ -1,32 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_dtype::Nullability; use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; use crate::arrays::ConstantVTable; use crate::arrays::MaskedArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; use crate::validity::Validity; -impl MaskKernel for ConstantVTable { - fn mask(&self, array: &ConstantArray, mask: &Mask) -> VortexResult { - MaskedArray::try_new( - array.to_array(), - Validity::from_mask(!mask, Nullability::Nullable), - ) - .map(|a| a.into_array()) +impl MaskReduce for ConstantVTable { + fn mask(array: &ConstantArray, validity: &Validity) -> VortexResult> { + if array.scalar.is_null() { + // Already all nulls, masking has no effect. + return Ok(Some(array.to_array())); + } + Ok(Some( + MaskedArray::try_new(array.to_array(), validity.clone())?.into_array(), + )) } } -register_kernel!(MaskKernelAdapter(ConstantVTable).lift()); - #[cfg(test)] mod test { use crate::arrays::ConstantArray; diff --git a/vortex-array/src/arrays/constant/compute/rules.rs b/vortex-array/src/arrays/constant/compute/rules.rs index 63d5344d54d..66de2d15fc8 100644 --- a/vortex-array/src/arrays/constant/compute/rules.rs +++ b/vortex-array/src/arrays/constant/compute/rules.rs @@ -14,6 +14,7 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::TakeReduceAdaptor; use crate::compute::CastReduceAdaptor; use crate::expr::FillNullReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::expr::NotReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -24,6 +25,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::ne ParentRuleSet::lift(&NotReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&FillNullReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(ConstantVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ConstantVTable)), ParentRuleSet::lift(&TakeReduceAdaptor(ConstantVTable)), ]); diff --git a/vortex-array/src/arrays/decimal/compute/mask.rs b/vortex-array/src/arrays/decimal/compute/mask.rs new file mode 100644 index 00000000000..0f1707d49ee --- /dev/null +++ b/vortex-array/src/arrays/decimal/compute/mask.rs @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_dtype::match_each_decimal_value_type; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::DecimalArray; +use crate::arrays::DecimalVTable; +use crate::compute::MaskReduce; +use crate::validity::Validity; +use crate::vtable::ValidityHelper; + +impl MaskReduce for DecimalVTable { + fn mask(array: &DecimalArray, validity: &Validity) -> VortexResult> { + Ok(Some(match_each_decimal_value_type!( + array.values_type(), + |D| { + // SAFETY: only changing validity, not data structure + unsafe { + DecimalArray::new_unchecked( + array.buffer::(), + array.decimal_dtype(), + array.validity().clone().and(validity.clone()), + ) + } + .into_array() + } + ))) + } +} diff --git a/vortex-array/src/arrays/decimal/compute/mod.rs b/vortex-array/src/arrays/decimal/compute/mod.rs index 2af9e6bbdfc..11a509b240d 100644 --- a/vortex-array/src/arrays/decimal/compute/mod.rs +++ b/vortex-array/src/arrays/decimal/compute/mod.rs @@ -6,6 +6,7 @@ mod cast; mod fill_null; mod is_constant; mod is_sorted; +mod mask; mod min_max; pub mod rules; mod sum; diff --git a/vortex-array/src/arrays/decimal/compute/rules.rs b/vortex-array/src/arrays/decimal/compute/rules.rs index f60e2f7519c..d1103233d12 100644 --- a/vortex-array/src/arrays/decimal/compute/rules.rs +++ b/vortex-array/src/arrays/decimal/compute/rules.rs @@ -14,12 +14,14 @@ use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduce; use crate::arrays::SliceReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&DecimalMaskedValidityRule), + ParentRuleSet::lift(&MaskReduceAdaptor(DecimalVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(DecimalVTable)), ]); diff --git a/vortex-array/src/arrays/dict/compute/mask.rs b/vortex-array/src/arrays/dict/compute/mask.rs new file mode 100644 index 00000000000..fe94791e14e --- /dev/null +++ b/vortex-array/src/arrays/dict/compute/mask.rs @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::DictArray; +use crate::arrays::DictVTable; +use crate::arrays::MaskedArray; +use crate::compute::MaskReduce; +use crate::validity::Validity; + +impl MaskReduce for DictVTable { + fn mask(array: &DictArray, validity: &Validity) -> VortexResult> { + let masked_codes = + MaskedArray::try_new(array.codes().clone(), validity.clone())?.into_array(); + // SAFETY: masking codes doesn't change dict invariants + Ok(Some(unsafe { + DictArray::new_unchecked(masked_codes, array.values().clone()).into_array() + })) + } +} diff --git a/vortex-array/src/arrays/dict/compute/min_max.rs b/vortex-array/src/arrays/dict/compute/min_max.rs index 32a041784fc..c9161c35e70 100644 --- a/vortex-array/src/arrays/dict/compute/min_max.rs +++ b/vortex-array/src/arrays/dict/compute/min_max.rs @@ -10,7 +10,6 @@ use crate::Array as _; use crate::compute::MinMaxKernel; use crate::compute::MinMaxKernelAdapter; use crate::compute::MinMaxResult; -use crate::compute::mask; use crate::compute::min_max; use crate::register_kernel; @@ -28,7 +27,7 @@ impl MinMaxKernel for DictVTable { // Slow path: compute which values are unreferenced and mask them out let unreferenced_mask = Mask::from_buffer(array.compute_referenced_values_mask(false)?); - min_max(&mask(array.values(), &unreferenced_mask)?) + min_max(&array.values().mask(&unreferenced_mask)?) } } diff --git a/vortex-array/src/arrays/dict/compute/mod.rs b/vortex-array/src/arrays/dict/compute/mod.rs index 056b151ec06..97cd46fced7 100644 --- a/vortex-array/src/arrays/dict/compute/mod.rs +++ b/vortex-array/src/arrays/dict/compute/mod.rs @@ -7,6 +7,7 @@ mod fill_null; mod is_constant; mod is_sorted; mod like; +mod mask; mod min_max; pub(crate) mod rules; mod slice; diff --git a/vortex-array/src/arrays/dict/compute/rules.rs b/vortex-array/src/arrays/dict/compute/rules.rs index cc6df43b650..53c818560e8 100644 --- a/vortex-array/src/arrays/dict/compute/rules.rs +++ b/vortex-array/src/arrays/dict/compute/rules.rs @@ -19,6 +19,7 @@ use crate::arrays::SliceReduceAdaptor; use crate::builtins::ArrayBuiltins; use crate::compute::CastReduceAdaptor; use crate::expr::Cast; +use crate::expr::MaskReduceAdaptor; use crate::expr::Pack; use crate::optimizer::ArrayOptimizer; use crate::optimizer::rules::ArrayParentReduceRule; @@ -27,6 +28,7 @@ use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(DictVTable)), ParentRuleSet::lift(&CastReduceAdaptor(DictVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(DictVTable)), ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule), ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule), ParentRuleSet::lift(&SliceReduceAdaptor(DictVTable)), diff --git a/vortex-array/src/arrays/extension/compute/mask.rs b/vortex-array/src/arrays/extension/compute/mask.rs index 1b44447b9dc..712d91bf01a 100644 --- a/vortex-array/src/arrays/extension/compute/mask.rs +++ b/vortex-array/src/arrays/extension/compute/mask.rs @@ -2,32 +2,35 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; -use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ExtensionArray; use crate::arrays::ExtensionVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::compute::mask; -use crate::register_kernel; +use crate::arrays::ScalarFnArrayExt; +use crate::compute::MaskReduce; +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; +use crate::validity::Validity; -impl MaskKernel for ExtensionVTable { - fn mask(&self, array: &ExtensionArray, mask_array: &Mask) -> VortexResult { - // Use compute::mask directly since mask_array has compute::mask semantics (true=null) - let masked_storage = mask(array.storage(), mask_array)?; - assert!(masked_storage.dtype().is_nullable()); - - Ok(ExtensionArray::new( - array - .ext_dtype() - .with_nullability(masked_storage.dtype().nullability()), - masked_storage, - ) - .into_array()) +impl MaskReduce for ExtensionVTable { + fn mask(array: &ExtensionArray, validity: &Validity) -> VortexResult> { + let Validity::Array(keep_mask) = validity else { + return Ok(None); + }; + let masked_storage = MaskExpr.try_new_array( + array.storage().len(), + EmptyOptions, + [array.storage().clone(), keep_mask.clone()], + )?; + Ok(Some( + ExtensionArray::new( + array + .ext_dtype() + .with_nullability(masked_storage.dtype().nullability()), + masked_storage, + ) + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(ExtensionVTable).lift()); diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index f18687d570a..ee3b883a9f7 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -12,6 +12,7 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::FilterVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -19,6 +20,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::n ParentRuleSet::lift(&ExtensionFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(ExtensionVTable)), ParentRuleSet::lift(&FilterReduceAdaptor(ExtensionVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ExtensionVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ExtensionVTable)), ]); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/mask.rs b/vortex-array/src/arrays/fixed_size_list/compute/mask.rs index 26232addc37..3ca57bea4ea 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/mask.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/mask.rs @@ -2,34 +2,28 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::FixedSizeListArray; use crate::arrays::FixedSizeListVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -/// Mask implementation for [`FixedSizeListArray`]. -/// -/// Applies a validity mask to the array without modifying the underlying element data. -impl MaskKernel for FixedSizeListVTable { - fn mask(&self, array: &FixedSizeListArray, mask: &Mask) -> VortexResult { - // SAFETY: The only thing that changes here is the validity mask, which will have the same - // length. So as long as the original array is valid, this is also valid. - Ok(unsafe { - FixedSizeListArray::new_unchecked( - array.elements().clone(), - array.list_size(), - array.validity().mask(mask), - array.len(), - ) - } - .into_array()) +impl MaskReduce for FixedSizeListVTable { + fn mask(array: &FixedSizeListArray, validity: &Validity) -> VortexResult> { + // SAFETY: only changing validity, not data structure + Ok(Some( + unsafe { + FixedSizeListArray::new_unchecked( + array.elements().clone(), + array.list_size(), + array.validity().clone().and(validity.clone()), + array.len(), + ) + } + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(FixedSizeListVTable).lift()); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs index 5ba1e777ec2..af3bc242a14 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FixedSizeListVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(FixedSizeListVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(FixedSizeListVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(FixedSizeListVTable)), ]); diff --git a/vortex-array/src/arrays/list/compute/mask.rs b/vortex-array/src/arrays/list/compute/mask.rs index 877386133c1..7cb8355f404 100644 --- a/vortex-array/src/arrays/list/compute/mask.rs +++ b/vortex-array/src/arrays/list/compute/mask.rs @@ -2,26 +2,22 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ListArray; use crate::arrays::ListVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for ListVTable { - fn mask(&self, array: &ListArray, mask: &Mask) -> VortexResult { +impl MaskReduce for ListVTable { + fn mask(array: &ListArray, validity: &Validity) -> VortexResult> { ListArray::try_new( array.elements().clone(), array.offsets().clone(), - array.validity().mask(mask), + array.validity().clone().and(validity.clone()), ) - .map(|a| a.into_array()) + .map(|a| Some(a.into_array())) } } - -register_kernel!(MaskKernelAdapter(ListVTable).lift()); diff --git a/vortex-array/src/arrays/list/compute/rules.rs b/vortex-array/src/arrays/list/compute/rules.rs index 70900ea11e8..8ea7043fd1f 100644 --- a/vortex-array/src/arrays/list/compute/rules.rs +++ b/vortex-array/src/arrays/list/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::ListVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(ListVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ListVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ListVTable)), ]); diff --git a/vortex-array/src/arrays/listview/compute/mask.rs b/vortex-array/src/arrays/listview/compute/mask.rs index e34ec88527f..16fa41f89fb 100644 --- a/vortex-array/src/arrays/listview/compute/mask.rs +++ b/vortex-array/src/arrays/listview/compute/mask.rs @@ -2,32 +2,29 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for ListViewVTable { - fn mask(&self, array: &ListViewArray, mask: &Mask) -> VortexResult { - // SAFETY: Since we are only masking the validity and everything else comes from an already - // valid `ListViewArray`, all of the invariants are still upheld. - Ok(unsafe { - ListViewArray::new_unchecked( - array.elements().clone(), - array.offsets().clone(), - array.sizes().clone(), - array.validity().mask(mask), - ) - .with_zero_copy_to_list(array.is_zero_copy_to_list()) - } - .into_array()) +impl MaskReduce for ListViewVTable { + fn mask(array: &ListViewArray, validity: &Validity) -> VortexResult> { + // SAFETY: only changing validity, not data structure + Ok(Some( + unsafe { + ListViewArray::new_unchecked( + array.elements().clone(), + array.offsets().clone(), + array.sizes().clone(), + array.validity().clone().and(validity.clone()), + ) + .with_zero_copy_to_list(array.is_zero_copy_to_list()) + } + .into_array(), + )) } } - -register_kernel!(MaskKernelAdapter(ListViewVTable).lift()); diff --git a/vortex-array/src/arrays/listview/compute/rules.rs b/vortex-array/src/arrays/listview/compute/rules.rs index 24c25fa1b3e..a293ae4655f 100644 --- a/vortex-array/src/arrays/listview/compute/rules.rs +++ b/vortex-array/src/arrays/listview/compute/rules.rs @@ -12,6 +12,7 @@ use crate::arrays::ListViewArray; use crate::arrays::ListViewVTable; use crate::arrays::SliceReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; @@ -19,6 +20,7 @@ use crate::vtable::ValidityHelper; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&ListViewFilterPushDown), ParentRuleSet::lift(&CastReduceAdaptor(ListViewVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(ListViewVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(ListViewVTable)), ]); diff --git a/vortex-array/src/arrays/masked/compute/mask.rs b/vortex-array/src/arrays/masked/compute/mask.rs index a8d3c3eacd2..2220d80f34a 100644 --- a/vortex-array/src/arrays/masked/compute/mask.rs +++ b/vortex-array/src/arrays/masked/compute/mask.rs @@ -2,29 +2,24 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask as MaskType; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::MaskedArray; use crate::arrays::MaskedVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for MaskedVTable { - fn mask(&self, array: &MaskedArray, mask_arg: &MaskType) -> VortexResult { - // Combine the mask with the existing validity - // The child remains unchanged (no nulls), only validity is updated - let combined_validity = array.validity().mask(mask_arg); - - Ok(MaskedArray::try_new(array.child.clone(), combined_validity)?.into_array()) +impl MaskReduce for MaskedVTable { + fn mask(array: &MaskedArray, validity: &Validity) -> VortexResult> { + let combined_validity = array.validity().clone().and(validity.clone()); + Ok(Some( + MaskedArray::try_new(array.child.clone(), combined_validity)?.into_array(), + )) } } -register_kernel!(MaskKernelAdapter(MaskedVTable).lift()); - #[cfg(test)] mod tests { use rstest::rstest; diff --git a/vortex-array/src/arrays/masked/compute/rules.rs b/vortex-array/src/arrays/masked/compute/rules.rs index 597f1844bb4..b4b7a527259 100644 --- a/vortex-array/src/arrays/masked/compute/rules.rs +++ b/vortex-array/src/arrays/masked/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::FilterReduceAdaptor; use crate::arrays::MaskedVTable; use crate::arrays::SliceReduceAdaptor; +use crate::compute::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(MaskedVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(MaskedVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(MaskedVTable)), ]); diff --git a/vortex-array/src/arrays/masked/execute.rs b/vortex-array/src/arrays/masked/execute.rs index 00e803de64e..496aad3a58a 100644 --- a/vortex-array/src/arrays/masked/execute.rs +++ b/vortex-array/src/arrays/masked/execute.rs @@ -153,8 +153,11 @@ fn mask_validity_extension( // For extension arrays, we need to mask the underlying storage let storage = array.storage().clone().execute::(ctx)?; let masked_storage = mask_validity_canonical(storage, mask, ctx)?; + let masked_storage = masked_storage.into_array(); Ok(ExtensionArray::new( - array.ext_dtype().clone(), - masked_storage.into_array(), + array + .ext_dtype() + .with_nullability(masked_storage.dtype().nullability()), + masked_storage, )) } diff --git a/vortex-array/src/arrays/null/compute/mask.rs b/vortex-array/src/arrays/null/compute/mask.rs index c513efc3bd0..d578bfef6e9 100644 --- a/vortex-array/src/arrays/null/compute/mask.rs +++ b/vortex-array/src/arrays/null/compute/mask.rs @@ -2,19 +2,16 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::arrays::NullArray; use crate::arrays::NullVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; -impl MaskKernel for NullVTable { - fn mask(&self, array: &NullArray, _mask: &Mask) -> VortexResult { - Ok(array.to_array()) +impl MaskReduce for NullVTable { + fn mask(array: &NullArray, _validity: &Validity) -> VortexResult> { + // Null array is already all nulls, masking has no effect. + Ok(Some(array.to_array())) } } - -register_kernel!(MaskKernelAdapter(NullVTable).lift()); diff --git a/vortex-array/src/arrays/null/compute/rules.rs b/vortex-array/src/arrays/null/compute/rules.rs index 7fd25840dfd..645422884c2 100644 --- a/vortex-array/src/arrays/null/compute/rules.rs +++ b/vortex-array/src/arrays/null/compute/rules.rs @@ -6,11 +6,13 @@ use crate::arrays::NullVTable; use crate::arrays::SliceReduceAdaptor; use crate::arrays::TakeReduceAdaptor; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&FilterReduceAdaptor(NullVTable)), ParentRuleSet::lift(&CastReduceAdaptor(NullVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(NullVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(NullVTable)), ParentRuleSet::lift(&TakeReduceAdaptor(NullVTable)), ]); diff --git a/vortex-array/src/arrays/primitive/compute/mask.rs b/vortex-array/src/arrays/primitive/compute/mask.rs index 545d0847618..d4a58b775d4 100644 --- a/vortex-array/src/arrays/primitive/compute/mask.rs +++ b/vortex-array/src/arrays/primitive/compute/mask.rs @@ -2,35 +2,29 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::PrimitiveVTable; use crate::arrays::primitive::PrimitiveArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for PrimitiveVTable { - fn mask(&self, array: &PrimitiveArray, mask: &Mask) -> VortexResult { - let validity = array.validity().mask(mask); - +impl MaskReduce for PrimitiveVTable { + fn mask(array: &PrimitiveArray, validity: &Validity) -> VortexResult> { // SAFETY: validity and data buffer still have same length - Ok(unsafe { + Ok(Some(unsafe { PrimitiveArray::new_unchecked_from_handle( array.buffer_handle().clone(), array.ptype(), - validity, + array.validity().clone().and(validity.clone()), ) .into_array() - }) + })) } } -register_kernel!(MaskKernelAdapter(PrimitiveVTable).lift()); - #[cfg(test)] mod test { use rstest::rstest; diff --git a/vortex-array/src/arrays/primitive/compute/rules.rs b/vortex-array/src/arrays/primitive/compute/rules.rs index e2bcc53327f..edaa7593c13 100644 --- a/vortex-array/src/arrays/primitive/compute/rules.rs +++ b/vortex-array/src/arrays/primitive/compute/rules.rs @@ -11,12 +11,14 @@ use crate::arrays::MaskedVTable; use crate::arrays::PrimitiveArray; use crate::arrays::PrimitiveVTable; use crate::arrays::SliceReduceAdaptor; +use crate::compute::MaskReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; use crate::vtable::ValidityHelper; pub(crate) const RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&PrimitiveMaskedValidityRule), + ParentRuleSet::lift(&MaskReduceAdaptor(PrimitiveVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(PrimitiveVTable)), ]); diff --git a/vortex-array/src/arrays/struct_/compute/mask.rs b/vortex-array/src/arrays/struct_/compute/mask.rs index 75f40f81b17..1b3e6cff588 100644 --- a/vortex-array/src/arrays/struct_/compute/mask.rs +++ b/vortex-array/src/arrays/struct_/compute/mask.rs @@ -2,28 +2,23 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::StructArray; use crate::arrays::StructVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for StructVTable { - fn mask(&self, array: &StructArray, filter_mask: &Mask) -> VortexResult { - let validity = array.validity().mask(filter_mask); - +impl MaskReduce for StructVTable { + fn mask(array: &StructArray, validity: &Validity) -> VortexResult> { StructArray::try_new_with_dtype( array.unmasked_fields().clone(), array.struct_fields().clone(), array.len(), - validity, + array.validity().clone().and(validity.clone()), ) - .map(|a| a.into_array()) + .map(|a| Some(a.into_array())) } } -register_kernel!(MaskKernelAdapter(StructVTable).lift()); diff --git a/vortex-array/src/arrays/struct_/compute/rules.rs b/vortex-array/src/arrays/struct_/compute/rules.rs index ff2df148c23..940d5a54c78 100644 --- a/vortex-array/src/arrays/struct_/compute/rules.rs +++ b/vortex-array/src/arrays/struct_/compute/rules.rs @@ -15,6 +15,7 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::builtins::ArrayBuiltins; +use crate::compute::MaskReduceAdaptor; use crate::expr::Cast; use crate::expr::EmptyOptions; use crate::expr::GetItem; @@ -27,6 +28,7 @@ use crate::vtable::ValidityHelper; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&StructCastPushDownRule), ParentRuleSet::lift(&StructGetItemRule), + ParentRuleSet::lift(&MaskReduceAdaptor(StructVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(StructVTable)), ]); diff --git a/vortex-array/src/arrays/varbin/compute/mask.rs b/vortex-array/src/arrays/varbin/compute/mask.rs index 5119ebbeb82..991f6d3c67a 100644 --- a/vortex-array/src/arrays/varbin/compute/mask.rs +++ b/vortex-array/src/arrays/varbin/compute/mask.rs @@ -2,31 +2,29 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::VarBinVTable; use crate::arrays::varbin::VarBinArray; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for VarBinVTable { - fn mask(&self, array: &VarBinArray, mask: &Mask) -> VortexResult { - Ok(VarBinArray::try_new( - array.offsets().clone(), - array.bytes().clone(), - array.dtype().as_nullable(), - array.validity().mask(mask), - )? - .into_array()) +impl MaskReduce for VarBinVTable { + fn mask(array: &VarBinArray, validity: &Validity) -> VortexResult> { + Ok(Some( + VarBinArray::try_new( + array.offsets().clone(), + array.bytes().clone(), + array.dtype().as_nullable(), + array.validity().clone().and(validity.clone()), + )? + .into_array(), + )) } } -register_kernel!(MaskKernelAdapter(VarBinVTable).lift()); - #[cfg(test)] mod test { use vortex_dtype::DType; diff --git a/vortex-array/src/arrays/varbin/compute/rules.rs b/vortex-array/src/arrays/varbin/compute/rules.rs index df9f7f7913e..478d97e4746 100644 --- a/vortex-array/src/arrays/varbin/compute/rules.rs +++ b/vortex-array/src/arrays/varbin/compute/rules.rs @@ -4,9 +4,11 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::VarBinVTable; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(VarBinVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(VarBinVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(VarBinVTable)), ]); diff --git a/vortex-array/src/arrays/varbinview/compute/mask.rs b/vortex-array/src/arrays/varbinview/compute/mask.rs index dd32dac5c8b..e2001605f4a 100644 --- a/vortex-array/src/arrays/varbinview/compute/mask.rs +++ b/vortex-array/src/arrays/varbinview/compute/mask.rs @@ -2,34 +2,32 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_error::VortexResult; -use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::VarBinViewArray; use crate::arrays::VarBinViewVTable; -use crate::compute::MaskKernel; -use crate::compute::MaskKernelAdapter; -use crate::register_kernel; +use crate::compute::MaskReduce; +use crate::validity::Validity; use crate::vtable::ValidityHelper; -impl MaskKernel for VarBinViewVTable { - fn mask(&self, array: &VarBinViewArray, mask: &Mask) -> VortexResult { - // SAFETY: masking the validity does not affect the invariants +impl MaskReduce for VarBinViewVTable { + fn mask(array: &VarBinViewArray, validity: &Validity) -> VortexResult> { + // SAFETY: only changing validity, not data structure unsafe { - Ok(VarBinViewArray::new_handle_unchecked( - array.views_handle().clone(), - array.buffers().clone(), - array.dtype().as_nullable(), - array.validity().mask(mask), - ) - .into_array()) + Ok(Some( + VarBinViewArray::new_handle_unchecked( + array.views_handle().clone(), + array.buffers().clone(), + array.dtype().as_nullable(), + array.validity().clone().and(validity.clone()), + ) + .into_array(), + )) } } } -register_kernel!(MaskKernelAdapter(VarBinViewVTable).lift()); - #[cfg(test)] mod tests { use crate::arrays::VarBinViewArray; diff --git a/vortex-array/src/arrays/varbinview/compute/rules.rs b/vortex-array/src/arrays/varbinview/compute/rules.rs index 9b1900ef277..0bbd12d1540 100644 --- a/vortex-array/src/arrays/varbinview/compute/rules.rs +++ b/vortex-array/src/arrays/varbinview/compute/rules.rs @@ -3,9 +3,11 @@ use crate::arrays::SliceReduceAdaptor; use crate::arrays::VarBinViewVTable; use crate::compute::CastReduceAdaptor; +use crate::expr::MaskReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(VarBinViewVTable)), + ParentRuleSet::lift(&MaskReduceAdaptor(VarBinViewVTable)), ParentRuleSet::lift(&SliceReduceAdaptor(VarBinViewVTable)), ]); diff --git a/vortex-array/src/compute/mask.rs b/vortex-array/src/compute/mask.rs index b92aa0ab3b8..af7808ca5c9 100644 --- a/vortex-array/src/compute/mask.rs +++ b/vortex-array/src/compute/mask.rs @@ -1,48 +1,31 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::LazyLock; +use std::ops::Not; -use arcref::ArcRef; -use arrow_array::BooleanArray; -use vortex_dtype::DType; -use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::Array; use crate::ArrayRef; use crate::IntoArray; +use crate::arrays::BoolArray; use crate::arrays::ConstantArray; -use crate::arrow::FromArrowArray; -use crate::arrow::IntoArrowArray; +use crate::arrays::ScalarFnArrayExt; use crate::builtins::ArrayBuiltins; -use crate::compute::ComputeFn; -use crate::compute::ComputeFnVTable; -use crate::compute::InvocationArgs; -use crate::compute::Kernel; -use crate::compute::Output; -use crate::vtable::VTable; - -static MASK_FN: LazyLock = LazyLock::new(|| { - let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn)); - for kernel in inventory::iter:: { - compute.register_kernel(kernel.0.clone()); - } - compute -}); - -pub(crate) fn warm_up_vtable() -> usize { - MASK_FN.kernels().len() -} +use crate::expr::EmptyOptions; +use crate::expr::mask::Mask as MaskExpr; +use crate::validity::Validity; /// Replace values with null where the mask is true. /// /// The returned array is nullable but otherwise has the same dtype and length as `array`. /// +/// This function returns a lazy `ScalarFnArray` wrapping the [`Mask`](crate::expr::mask::Mask) +/// expression that defers the actual masking operation until execution time. The mask is inverted +/// (true=mask-out becomes true=keep) and passed as a boolean child to the expression. +/// /// # Examples /// /// ``` @@ -69,132 +52,45 @@ pub(crate) fn warm_up_vtable() -> usize { /// # } /// ``` /// -pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult { - MASK_FN - .invoke(&InvocationArgs { - inputs: &[array.into(), mask.into()], - options: &(), - })? - .unwrap_array() -} - -pub struct MaskKernelRef(ArcRef); -inventory::collect!(MaskKernelRef); - -pub trait MaskKernel: VTable { - /// Replace masked values with null in array. - fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult; -} - -#[derive(Debug)] -pub struct MaskKernelAdapter(pub V); - -impl MaskKernelAdapter { - pub const fn lift(&'static self) -> MaskKernelRef { - MaskKernelRef(ArcRef::new_ref(self)) +impl dyn Array + '_ { + /// Replace values with null where the mask is true. + /// + /// See the free function [`mask`] for full documentation. + pub fn mask(&self, mask: &Mask) -> VortexResult { + compute_mask(self, mask) } } -impl Kernel for MaskKernelAdapter { - fn invoke(&self, args: &InvocationArgs) -> VortexResult> { - let inputs = MaskArgs::try_from(args)?; - let Some(array) = inputs.array.as_opt::() else { - return Ok(None); - }; - Ok(Some(V::mask(&self.0, array, inputs.mask)?.into())) - } +pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult { + compute_mask(array, mask) } -struct MaskFn; - -impl ComputeFnVTable for MaskFn { - fn invoke( - &self, - args: &InvocationArgs, - kernels: &[ArcRef], - ) -> VortexResult { - let MaskArgs { array, mask } = MaskArgs::try_from(args)?; - - let mask_true_count = mask.true_count(); - if mask_true_count == 0 { - // Fast-path for empty mask - return Ok(array.to_array().cast(array.dtype().as_nullable())?.into()); - } +fn compute_mask(array: &dyn Array, mask: &Mask) -> VortexResult { + let mask_true_count = mask.true_count(); - if mask_true_count == mask.len() { - // Fast-path for full mask. - return Ok( - ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len()) - .into_array() - .into(), - ); - } - - // Do nothing if the array is already all nulls. - if array.all_invalid()? { - return Ok(array.to_array().into()); - } - - for kernel in kernels { - if let Some(output) = kernel.invoke(args)? { - return Ok(output); - } - } - - // Fallback: implement using Arrow kernels. - tracing::debug!("No mask implementation found for {}", array.encoding_id()); - - let array_ref = array.to_array().into_arrow_preferred()?; - let mask = BooleanArray::new(mask.to_bit_buffer().into(), None); - - let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?; - - Ok(ArrayRef::from_arrow(masked.as_ref(), true)?.into()) - } - - fn return_dtype(&self, args: &InvocationArgs) -> VortexResult { - let MaskArgs { array, .. } = MaskArgs::try_from(args)?; - Ok(array.dtype().as_nullable()) + if mask_true_count == 0 { + // Fast-path for empty mask: nothing to mask out. + return array.to_array().cast(array.dtype().as_nullable()); } - fn return_len(&self, args: &InvocationArgs) -> VortexResult { - let MaskArgs { array, mask } = MaskArgs::try_from(args)?; - - if mask.len() != array.len() { - vortex_bail!( - "mask.len() is {}, does not equal array.len() of {}", - mask.len(), - array.len() - ); - } - - Ok(mask.len()) + if mask_true_count == mask.len() { + // Fast-path for full mask: everything is masked out. + return Ok( + ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len()).into_array(), + ); } - fn is_elementwise(&self) -> bool { - true + // Do nothing if the array is already all nulls. + if array.all_invalid()? { + return Ok(array.to_array()); } -} - -struct MaskArgs<'a> { - array: &'a dyn Array, - mask: &'a Mask, -} - -impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> { - type Error = VortexError; - fn try_from(value: &InvocationArgs<'a>) -> Result { - if value.inputs.len() != 2 { - vortex_bail!("Mask function requires 2 arguments"); - } - let array = value.inputs[0] - .array() - .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?; - let mask = value.inputs[1] - .mask() - .ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?; - - Ok(MaskArgs { array, mask }) - } + // Lazy wrap: invert the mask (true=mask_out → true=keep) and create a ScalarFnArray + // wrapping the Mask expression. + let keep_mask = BoolArray::new(mask.to_bit_buffer().not(), Validity::NonNullable); + MaskExpr.try_new_array( + array.len(), + EmptyOptions, + [array.to_array(), keep_mask.into_array()], + ) } diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 8d603de9a24..eb7977881e7 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -55,6 +55,10 @@ pub use crate::expr::FillNullExecuteAdaptor; pub use crate::expr::FillNullKernel; pub use crate::expr::FillNullReduce; pub use crate::expr::FillNullReduceAdaptor; +pub use crate::expr::MaskExecuteAdaptor; +pub use crate::expr::MaskKernel; +pub use crate::expr::MaskReduce; +pub use crate::expr::MaskReduceAdaptor; pub use crate::expr::NotExecuteAdaptor; pub use crate::expr::NotKernel; pub use crate::expr::NotReduce; @@ -101,7 +105,6 @@ pub fn warm_up_vtables() { is_sorted::warm_up_vtable(); like::warm_up_vtable(); list_contains::warm_up_vtable(); - mask::warm_up_vtable(); min_max::warm_up_vtable(); nan_count::warm_up_vtable(); sum::warm_up_vtable(); diff --git a/vortex-array/src/expr/exprs/get_item.rs b/vortex-array/src/expr/exprs/get_item.rs index ca0d93740f5..7f9bf4f4065 100644 --- a/vortex-array/src/expr/exprs/get_item.rs +++ b/vortex-array/src/expr/exprs/get_item.rs @@ -18,7 +18,6 @@ use vortex_session::VortexSession; use crate::ArrayRef; use crate::arrays::StructArray; use crate::builtins::ExprBuiltins; -use crate::compute::mask; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -116,7 +115,7 @@ impl VTable for GetItem { match input.dtype().nullability() { Nullability::NonNullable => Ok(field), - Nullability::Nullable => mask(&field, &input.validity_mask()?.not()), + Nullability::Nullable => field.mask(&input.validity_mask()?.not()), }? .execute(args.ctx) } diff --git a/vortex-array/src/expr/exprs/mask/kernel.rs b/vortex-array/src/expr/exprs/mask/kernel.rs new file mode 100644 index 00000000000..7b21ad5085a --- /dev/null +++ b/vortex-array/src/expr/exprs/mask/kernel.rs @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::BoolVTable; +use crate::arrays::ExactScalarFn; +use crate::arrays::ScalarFnArrayView; +use crate::expr::Mask as MaskExpr; +use crate::kernel::ExecuteParentKernel; +use crate::optimizer::rules::ArrayParentReduceRule; +use crate::validity::Validity; +use crate::vtable::VTable; + +/// Mask an array without reading buffers. +/// +/// This trait is for mask implementations that can operate purely on array metadata and +/// structure without needing to read or execute on the underlying buffers. Implementations +/// should return `None` if masking requires buffer access. +/// +/// The validity represents the mask to intersect with the array's existing validity +/// (true=valid/keep, false=invalid/null-out). +/// +/// # Preconditions +/// +/// The validity is guaranteed to have the same length as the array and is guaranteed to +/// be a `Validity::Array` variant (i.e., neither `AllValid`, `AllInvalid`, nor `NonNullable`). +pub trait MaskReduce: VTable { + fn mask(array: &Self::Array, validity: &Validity) -> VortexResult>; +} + +/// Mask an array, potentially reading buffers. +/// +/// Unlike [`MaskReduce`], this trait is for mask implementations that may need to read +/// and execute on the underlying buffers to produce the masked result. +/// +/// # Preconditions +/// +/// The validity is guaranteed to have the same length as the array and is guaranteed to +/// be a `Validity::Array` variant (i.e., neither `AllValid`, `AllInvalid`, nor `NonNullable`). +pub trait MaskKernel: VTable { + fn mask( + array: &Self::Array, + validity: &Validity, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; +} + +/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`]. +#[derive(Default, Debug)] +pub struct MaskReduceAdaptor(pub V); + +impl ArrayParentReduceRule for MaskReduceAdaptor +where + V: MaskReduce, +{ + type Parent = ExactScalarFn; + + fn reduce_parent( + &self, + array: &V::Array, + parent: ScalarFnArrayView<'_, MaskExpr>, + child_idx: usize, + ) -> VortexResult> { + // Only reduce the input child (index 0), not the mask child (index 1). + if child_idx != 0 { + return Ok(None); + } + // The mask child (child 1) is a non-nullable BoolArray where true=keep. + // If it's not yet a BoolArray, we can't reduce without execution. + let mask_child = parent + .nth_child(1) + .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?; + let Some(mask_bool) = mask_child.as_opt::() else { + return Ok(None); + }; + let validity = Validity::Array(mask_bool.to_array()); + ::mask(array, &validity) + } +} + +/// Adaptor that wraps a [`MaskKernel`] impl as an [`ExecuteParentKernel`]. +#[derive(Default, Debug)] +pub struct MaskExecuteAdaptor(pub V); + +impl ExecuteParentKernel for MaskExecuteAdaptor +where + V: MaskKernel, +{ + type Parent = ExactScalarFn; + + fn execute_parent( + &self, + array: &V::Array, + parent: ScalarFnArrayView<'_, MaskExpr>, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // Only execute the input child (index 0), not the mask child (index 1). + if child_idx != 0 { + return Ok(None); + } + let mask_child = parent + .nth_child(1) + .ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?; + let validity = Validity::Array(mask_child); + ::mask(array, &validity, ctx) + } +} diff --git a/vortex-array/src/expr/exprs/mask.rs b/vortex-array/src/expr/exprs/mask/mod.rs similarity index 69% rename from vortex-array/src/expr/exprs/mask.rs rename to vortex-array/src/expr/exprs/mask/mod.rs index d93618829d1..882da1ec736 100644 --- a/vortex-array/src/expr/exprs/mask.rs +++ b/vortex-array/src/expr/exprs/mask/mod.rs @@ -1,9 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod kernel; use std::fmt::Formatter; -use std::ops::Not; +pub use kernel::*; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_error::VortexExpect; @@ -14,8 +15,13 @@ use vortex_scalar::Scalar; use vortex_session::VortexSession; use crate::ArrayRef; +use crate::Canonical; +use crate::IntoArray; use crate::arrays::BoolArray; -use crate::compute; +use crate::arrays::ConstantArray; +use crate::arrays::ConstantVTable; +use crate::arrays::mask_validity_canonical; +use crate::builtins::ArrayBuiltins; use crate::expr::Arity; use crate::expr::ChildName; use crate::expr::EmptyOptions; @@ -94,9 +100,11 @@ impl VTable for Mask { .try_into() .map_err(|_| vortex_err!("Wrong arg count"))?; - let mask_bool = mask_array.execute::(args.ctx)?; - let inverted = mask_bool.to_bit_buffer().not(); - compute::mask(&input, &vortex_mask::Mask::from(inverted))?.execute(args.ctx) + if let Some(result) = execute_constant(&input, &mask_array)? { + return Ok(result); + } + + execute_canonical(input, mask_array, args.ctx) } fn simplify( @@ -136,6 +144,51 @@ impl VTable for Mask { } } +/// Try to handle masking when at least one of the input or mask is a constant array. +/// +/// Returns `Ok(Some(result))` if the constant case was handled, `Ok(None)` if not. +fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult> { + let len = input.len(); + + // Constant mask: avoid materializing the bool array entirely. + if let Some(constant_mask) = mask_array.as_opt::() { + let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false); + return if mask_value { + // Mask is all true (keep everything), just cast input to nullable. + input.cast(input.dtype().as_nullable()).map(Some) + } else { + // Mask is all false (mask everything out), return all nulls. + Ok(Some( + ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(), + )) + }; + } + + // Constant null input: masking changes nothing, result is still all-null. + if let Some(constant_input) = input.as_opt::() + && constant_input.scalar().is_null() + { + return Ok(Some( + ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(), + )); + } + + Ok(None) +} + +/// Execute the mask by materializing both inputs to their canonical forms. +fn execute_canonical( + input: ArrayRef, + mask_array: ArrayRef, + ctx: &mut crate::executor::ExecutionCtx, +) -> VortexResult { + let mask_bool = mask_array.execute::(ctx)?; + let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer()); + + let canonical = input.execute::(ctx)?; + Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array()) +} + /// Creates a mask expression that applies the given boolean mask to the input array. pub fn mask(array: Expression, mask: Expression) -> Expression { Mask.new_expr(EmptyOptions, [array, mask]) diff --git a/vortex-duckdb/src/exporter/struct_.rs b/vortex-duckdb/src/exporter/struct_.rs index 140a4ded143..03f67296153 100644 --- a/vortex-duckdb/src/exporter/struct_.rs +++ b/vortex-duckdb/src/exporter/struct_.rs @@ -8,7 +8,6 @@ use vortex::array::IntoArray; use vortex::array::arrays::StructArray; use vortex::array::arrays::StructArrayParts; use vortex::array::optimizer::ArrayOptimizer; -use vortex::compute::mask; use vortex::error::VortexResult; use vortex::mask::Mask; @@ -50,11 +49,7 @@ pub(crate) fn new_exporter( .map(|child| { if matches!(validity, Mask::Values(_)) { // TODO(joe): use new mask. - new_array_exporter( - mask(child, &validity.clone().not())?.optimize()?, - cache, - ctx, - ) + new_array_exporter(child.mask(&validity.clone().not())?.optimize()?, cache, ctx) } else { new_array_exporter(child.clone().into_array(), cache, ctx) } diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index acc9f07f1b8..770bfcc9400 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -351,7 +351,7 @@ impl LayoutReader for StructReader { let masked_fields: Vec = struct_array .unmasked_fields() .iter() - .map(|a| vortex_array::compute::mask(a.as_ref(), &mask)) + .map(|a| a.mask(&mask)) .try_collect()?; Ok(StructArray::try_new( @@ -364,7 +364,7 @@ impl LayoutReader for StructReader { } else { // If the root expression was not a pack or merge, e.g. if it's something like // a get_item, then we apply the validity directly to the result - vortex_array::compute::mask(array.as_ref(), &mask) + array.mask(&mask) } } else { projected.await diff --git a/vortex-layout/src/reader.rs b/vortex-layout/src/reader.rs index e26e33372e5..2936ec188bd 100644 --- a/vortex-layout/src/reader.rs +++ b/vortex-layout/src/reader.rs @@ -98,7 +98,7 @@ impl ArrayFutureExt for ArrayFuture { fn masked(self, mask: MaskFuture) -> Self { Box::pin(async move { let (array, mask) = try_join!(self, mask)?; - vortex_array::compute::mask(array.as_ref(), &mask) + array.mask(&mask) }) } } diff --git a/vortex-test/e2e/src/lib.rs b/vortex-test/e2e/src/lib.rs index 2e8d3c77be0..765d4de32bd 100644 --- a/vortex-test/e2e/src/lib.rs +++ b/vortex-test/e2e/src/lib.rs @@ -28,7 +28,7 @@ mod tests { #[cfg(feature = "unstable_encodings")] const EXPECTED_SIZE: usize = 216188; #[cfg(not(feature = "unstable_encodings"))] - const EXPECTED_SIZE: usize = 216156; + const EXPECTED_SIZE: usize = 216188; let futures: Vec<_> = (0..5) .map(|_| { let array = array.clone();