Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions encodings/alp/src/alp/compute/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,25 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::compute::mask;
use vortex_array::register_kernel;
use vortex_array::IntoArray;
use vortex_array::arrays::MaskedArray;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::ALPArray;
use crate::ALPVTable;

impl MaskKernel for ALPVTable {
fn mask(&self, array: &ALPArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
let masked_encoded = mask(array.encoded(), filter_mask)?;
let masked_patches = array
.patches()
.map(|p| p.mask(filter_mask))
.transpose()?
.flatten()
.map(|patches| {
patches.cast_values(
&array
.dtype()
.with_nullability(masked_encoded.dtype().nullability()),
)
})
.transpose()?;
Ok(ALPArray::new(masked_encoded, array.exponents(), masked_patches).to_array())
impl MaskReduce for ALPVTable {
fn mask(array: &ALPArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
let masked_encoded =
MaskedArray::try_new(array.encoded().clone(), validity.clone())?.into_array();
Ok(Some(
ALPArray::new(masked_encoded, array.exponents(), array.patches().cloned()).to_array(),
))
}
}

register_kernel!(MaskKernelAdapter(ALPVTable).lift());

#[cfg(test)]
mod test {
use rstest::rstest;
Expand Down
7 changes: 5 additions & 2 deletions encodings/alp/src/alp/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use vortex_array::arrays::FilterExecuteAdaptor;
use vortex_array::arrays::SliceExecuteAdaptor;
use vortex_array::arrays::TakeExecuteAdaptor;
use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::kernel::ParentKernelSet;
use vortex_array::optimizer::rules::ParentRuleSet;

Expand All @@ -16,5 +17,7 @@ pub(super) const PARENT_KERNELS: ParentKernelSet<ALPVTable> = ParentKernelSet::n
ParentKernelSet::lift(&TakeExecuteAdaptor(ALPVTable)),
]);

pub(super) const RULES: ParentRuleSet<ALPVTable> =
ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable))]);
pub(super) const RULES: ParentRuleSet<ALPVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&CastReduceAdaptor(ALPVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(ALPVTable)),
]);
36 changes: 18 additions & 18 deletions encodings/alp/src/alp_rd/compute/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@

use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::compute::mask;
use vortex_array::register_kernel;
use vortex_array::arrays::MaskedArray;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::ALPRDArray;
use crate::ALPRDVTable;

impl MaskKernel for ALPRDVTable {
fn mask(&self, array: &ALPRDArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
Ok(ALPRDArray::try_new(
array.dtype().as_nullable(),
mask(array.left_parts(), filter_mask)?,
array.left_parts_dictionary().clone(),
array.right_parts().clone(),
array.right_bit_width(),
array.left_parts_patches().cloned(),
)?
.into_array())
impl MaskReduce for ALPRDVTable {
fn mask(array: &ALPRDArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
let masked_left_parts =
MaskedArray::try_new(array.left_parts().clone(), validity.clone())?.into_array();
Ok(Some(
ALPRDArray::try_new(
array.dtype().as_nullable(),
masked_left_parts,
array.left_parts_dictionary().clone(),
array.right_parts().clone(),
array.right_bit_width(),
array.left_parts_patches().cloned(),
)?
.into_array(),
))
}
}

register_kernel!(MaskKernelAdapter(ALPRDVTable).lift());

#[cfg(test)]
mod tests {
use rstest::rstest;
Expand Down
7 changes: 5 additions & 2 deletions encodings/alp/src/alp_rd/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::optimizer::rules::ParentRuleSet;

use crate::alp_rd::ALPRDVTable;

pub(crate) static RULES: ParentRuleSet<ALPRDVTable> =
ParentRuleSet::new(&[ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable))]);
pub(crate) static RULES: ParentRuleSet<ALPRDVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&CastReduceAdaptor(ALPRDVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(ALPRDVTable)),
]);
20 changes: 11 additions & 9 deletions encodings/bytebool/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ use vortex_array::IntoArray;
use vortex_array::ToCanonical;
use vortex_array::arrays::TakeExecute;
use vortex_array::compute::CastReduce;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::register_kernel;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_array::vtable::ValidityHelper;
use vortex_dtype::DType;
use vortex_dtype::match_each_integer_ptype;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use super::ByteBoolArray;
use super::ByteBoolVTable;
Expand Down Expand Up @@ -44,14 +42,18 @@ impl CastReduce for ByteBoolVTable {
}
}

impl MaskKernel for ByteBoolVTable {
fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array())
impl MaskReduce for ByteBoolVTable {
fn mask(array: &ByteBoolArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
Ok(Some(
ByteBoolArray::new(
array.buffer().clone(),
array.validity().clone().and(validity.clone()),
)
.into_array(),
))
}
}

register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());

impl TakeExecute for ByteBoolVTable {
fn take(
array: &ByteBoolArray,
Expand Down
4 changes: 3 additions & 1 deletion encodings/bytebool/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

use vortex_array::arrays::SliceReduceAdaptor;
use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::optimizer::rules::ParentRuleSet;

use crate::ByteBoolVTable;

pub(crate) static RULES: ParentRuleSet<ByteBoolVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)),
ParentRuleSet::lift(&CastReduceAdaptor(ByteBoolVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(ByteBoolVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(ByteBoolVTable)),
]);
47 changes: 17 additions & 30 deletions encodings/datetime-parts/src/compute/mask.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,28 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::Array;
use vortex_array::ArrayRef;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::compute::mask;
use vortex_array::register_kernel;
use vortex_array::IntoArray;
use vortex_array::arrays::MaskedArray;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::DateTimePartsArray;
use crate::DateTimePartsVTable;

impl MaskKernel for DateTimePartsVTable {
fn mask(&self, array: &DateTimePartsArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
// DateTimePartsArray has specific constraints:
// - days nullability must match the dtype
// - seconds and subseconds must always be non-nullable
//
// When masking, we can't make seconds/subseconds nullable.
// Instead, we'll keep the same values but the overall array becomes nullable
// through the days component.

let masked_days = mask(array.days(), mask_array)?;
assert!(masked_days.dtype().is_nullable());

// Keep seconds and subseconds unchanged since they must remain non-nullable
let seconds = array.seconds().clone();
let subseconds = array.subseconds().clone();

// Update the dtype to reflect the new nullability of days
let new_dtype = array.dtype().as_nullable();

DateTimePartsArray::try_new(new_dtype, masked_days, seconds, subseconds)
.map(|a| a.to_array())
impl MaskReduce for DateTimePartsVTable {
fn mask(array: &DateTimePartsArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
let masked_days =
MaskedArray::try_new(array.days().clone(), validity.clone())?.into_array();
Ok(Some(
DateTimePartsArray::try_new(
array.dtype().as_nullable(),
masked_days,
array.seconds().clone(),
array.subseconds().clone(),
)?
.into_array(),
))
}
}

register_kernel!(MaskKernelAdapter(DateTimePartsVTable).lift());
2 changes: 2 additions & 0 deletions encodings/datetime-parts/src/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::SliceReduceAdaptor;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::expr::Between;
use vortex_array::expr::Binary;
use vortex_array::optimizer::ArrayOptimizer;
Expand All @@ -33,6 +34,7 @@ pub(crate) const PARENT_RULES: ParentRuleSet<DateTimePartsVTable> = ParentRuleSe
ParentRuleSet::lift(&DTPComparisonPushDownRule),
ParentRuleSet::lift(&CastReduceAdaptor(DateTimePartsVTable)),
ParentRuleSet::lift(&FilterReduceAdaptor(DateTimePartsVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(DateTimePartsVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(DateTimePartsVTable)),
]);

Expand Down
21 changes: 10 additions & 11 deletions encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_array::ArrayRef;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::compute::mask;
use vortex_array::register_kernel;
use vortex_array::IntoArray;
use vortex_array::arrays::MaskedArray;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_mask::Mask;

use crate::DecimalBytePartsArray;
use crate::DecimalBytePartsVTable;

impl MaskKernel for DecimalBytePartsVTable {
fn mask(&self, array: &DecimalBytePartsArray, mask_array: &Mask) -> VortexResult<ArrayRef> {
let masked = mask(&array.msp, mask_array)?;
DecimalBytePartsArray::try_new(masked, *array.decimal_dtype()).map(|a| a.to_array())
impl MaskReduce for DecimalBytePartsVTable {
fn mask(array: &DecimalBytePartsArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
let masked_msp = MaskedArray::try_new(array.msp.clone(), validity.clone())?.into_array();
Ok(Some(
DecimalBytePartsArray::try_new(masked_msp, *array.decimal_dtype())?.into_array(),
))
}
}

register_kernel!(MaskKernelAdapter(DecimalBytePartsVTable).lift());
2 changes: 2 additions & 0 deletions encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use vortex_array::arrays::FilterReduceAdaptor;
use vortex_array::arrays::FilterVTable;
use vortex_array::arrays::SliceReduceAdaptor;
use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::optimizer::rules::ArrayParentReduceRule;
use vortex_array::optimizer::rules::ParentRuleSet;
use vortex_error::VortexResult;
Expand All @@ -20,6 +21,7 @@ pub(super) const PARENT_RULES: ParentRuleSet<DecimalBytePartsVTable> = ParentRul
ParentRuleSet::lift(&DecimalBytePartsFilterPushDownRule),
ParentRuleSet::lift(&CastReduceAdaptor(DecimalBytePartsVTable)),
ParentRuleSet::lift(&FilterReduceAdaptor(DecimalBytePartsVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(DecimalBytePartsVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(DecimalBytePartsVTable)),
]);

Expand Down
18 changes: 8 additions & 10 deletions encodings/zigzag/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::FilterReduce;
use vortex_array::arrays::MaskedArray;
use vortex_array::arrays::TakeExecute;
use vortex_array::compute::MaskKernel;
use vortex_array::compute::MaskKernelAdapter;
use vortex_array::compute::mask;
use vortex_array::register_kernel;
use vortex_array::compute::MaskReduce;
use vortex_array::validity::Validity;
use vortex_error::VortexResult;
use vortex_mask::Mask;

Expand All @@ -37,15 +36,14 @@ impl TakeExecute for ZigZagVTable {
}
}

impl MaskKernel for ZigZagVTable {
fn mask(&self, array: &ZigZagArray, filter_mask: &Mask) -> VortexResult<ArrayRef> {
let encoded = mask(array.encoded(), filter_mask)?;
Ok(ZigZagArray::try_new(encoded)?.into_array())
impl MaskReduce for ZigZagVTable {
fn mask(array: &ZigZagArray, validity: &Validity) -> VortexResult<Option<ArrayRef>> {
let masked_encoded =
MaskedArray::try_new(array.encoded().clone(), validity.clone())?.into_array();
Ok(Some(ZigZagArray::try_new(masked_encoded)?.into_array()))
}
}

register_kernel!(MaskKernelAdapter(ZigZagVTable).lift());

pub(crate) trait ZigZagEncoded {
type Int: zigzag::ZigZag;
}
Expand Down
2 changes: 2 additions & 0 deletions encodings/zigzag/src/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
use vortex_array::arrays::FilterReduceAdaptor;
use vortex_array::arrays::SliceReduceAdaptor;
use vortex_array::compute::CastReduceAdaptor;
use vortex_array::compute::MaskReduceAdaptor;
use vortex_array::optimizer::rules::ParentRuleSet;

use crate::ZigZagVTable;

pub(crate) static RULES: ParentRuleSet<ZigZagVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&CastReduceAdaptor(ZigZagVTable)),
ParentRuleSet::lift(&FilterReduceAdaptor(ZigZagVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(ZigZagVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(ZigZagVTable)),
]);
5 changes: 3 additions & 2 deletions fuzz/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz
use vortex_array::arrays::ConstantArray;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::compute::compare;
use vortex_array::compute::mask;
use vortex_array::compute::min_max;
use vortex_array::compute::sum;
let FuzzArrayAction { array, actions } = fuzz_action;
Expand Down Expand Up @@ -642,7 +641,9 @@ pub fn run_fuzz_action(fuzz_action: FuzzArrayAction) -> crate::error::VortexFuzz
assert_array_eq(&expected.array(), &current_array, i)?;
}
Action::Mask(mask_val) => {
current_array = mask(&current_array, &mask_val)
current_array = current_array
.as_ref()
.mask(&mask_val)
.vortex_expect("mask operation should succeed in fuzz test");
assert_array_eq(&expected.array(), &current_array, i)?;
}
Expand Down
Loading
Loading