Skip to content
Merged
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13086,15 +13086,17 @@ pub fn V::try_match<'a>(array: &'a vortex_array::ArrayRef) -> core::option::Opti

pub mod vortex_array::normalize

pub enum vortex_array::normalize::Operation
pub enum vortex_array::normalize::Operation<'a>

pub vortex_array::normalize::Operation::Error

pub vortex_array::normalize::Operation::Execute(&'a mut vortex_array::ExecutionCtx)

pub struct vortex_array::normalize::NormalizeOptions<'a>

pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_array::session::ArrayRegistry
pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_utils::aliases::hash_set::HashSet<vortex_session::registry::Id>

pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation
pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation<'a>

pub mod vortex_array::optimizer

Expand Down
207 changes: 195 additions & 12 deletions vortex-array/src/normalize.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_session::registry::Id;
use vortex_utils::aliases::hash_set::HashSet;

use crate::ArrayRef;
use crate::session::ArrayRegistry;
use crate::ExecutionCtx;

/// Options for normalizing an array.
pub struct NormalizeOptions<'a> {
/// The set of allowed array encodings (in addition to the canonical ones) that are permitted
/// in the normalized array.
pub allowed: &'a ArrayRegistry,
pub allowed: &'a HashSet<Id>,
/// The operation to perform when a non-allowed encoding is encountered.
pub operation: Operation,
pub operation: Operation<'a>,
}

/// The operation to perform when a non-allowed encoding is encountered.
pub enum Operation {
pub enum Operation<'a> {
Error,
// TODO(joe): add into canonical variant
Execute(&'a mut ExecutionCtx),
}

impl ArrayRef {
Expand All @@ -30,14 +30,18 @@ impl ArrayRef {
/// This operation performs a recursive traversal of the array. Any non-allowed encoding is
/// normalized per the configured operation.
pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult<ArrayRef> {
let array_ids = options.allowed.ids().collect_vec();
self.normalize_with_error(&array_ids)?;
// Note this takes ownership so we can at a later date remove non-allowed encodings.
Ok(self)
match &mut options.operation {
Operation::Error => {
self.normalize_with_error(options.allowed)?;
// Note this takes ownership so we can at a later date remove non-allowed encodings.
Ok(self)
}
Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx),
}
}

fn normalize_with_error(&self, allowed: &[Id]) -> VortexResult<()> {
if !allowed.contains(&self.encoding_id()) {
fn normalize_with_error(&self, allowed: &HashSet<Id>) -> VortexResult<()> {
if !self.is_allowed_encoding(allowed) {
vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id())
}

Expand All @@ -46,4 +50,183 @@ impl ArrayRef {
}
Ok(())
}

fn normalize_with_execution(
self,
allowed: &HashSet<Id>,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let mut normalized = self;

// Top-first execute the array tree while we hit non-allowed encodings.
while !normalized.is_allowed_encoding(allowed) {
normalized = normalized.execute(ctx)?;
}

// Now we've normalized the root, we need to ensure the children are normalized also.
let slots = normalized.slots();
let mut normalized_slots = Vec::with_capacity(slots.len());
let mut any_slot_changed = false;

for slot in slots {
match slot {
Some(child) => {
let normalized_child = child.clone().normalize(&mut NormalizeOptions {
allowed,
operation: Operation::Execute(ctx),
})?;
any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child);
normalized_slots.push(Some(normalized_child));
}
None => normalized_slots.push(None),
}
}

if any_slot_changed {
normalized = normalized.with_slots(normalized_slots)?;
}

Ok(normalized)
}

fn is_allowed_encoding(&self, allowed: &HashSet<Id>) -> bool {
allowed.contains(&self.encoding_id()) || self.is_canonical()
}
}

#[cfg(test)]
mod tests {
use vortex_error::VortexResult;
use vortex_session::VortexSession;
use vortex_utils::aliases::hash_set::HashSet;

use super::NormalizeOptions;
use super::Operation;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::Dict;
use crate::arrays::DictArray;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::arrays::Slice;
use crate::arrays::SliceArray;
use crate::arrays::StructArray;
use crate::assert_arrays_eq;
use crate::validity::Validity;

#[test]
fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> {
let field = PrimitiveArray::from_iter(0i32..4).into_array();
let array = StructArray::try_new(
["field"].into(),
vec![field.clone()],
field.len(),
Validity::NonNullable,
)?
.into_array();
let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]);
let mut ctx = ExecutionCtx::new(VortexSession::empty());

let normalized = array.clone().normalize(&mut NormalizeOptions {
allowed: &allowed,
operation: Operation::Execute(&mut ctx),
})?;

assert!(ArrayRef::ptr_eq(&array, &normalized));
Ok(())
}

#[test]
fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> {
let field = PrimitiveArray::from_iter(0i32..4).into_array();
let array = StructArray::try_new(
["field"].into(),
vec![field.clone()],
field.len(),
Validity::NonNullable,
)?
.into_array();
let allowed = HashSet::default();

let normalized = array.clone().normalize(&mut NormalizeOptions {
allowed: &allowed,
operation: Operation::Error,
})?;

assert!(ArrayRef::ptr_eq(&array, &normalized));
Ok(())
}

#[test]
fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> {
let unchanged = PrimitiveArray::from_iter(0i32..4).into_array();
let sliced =
SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array();
let array = StructArray::try_new(
["lhs", "rhs"].into(),
vec![unchanged.clone(), sliced],
unchanged.len(),
Validity::NonNullable,
)?
.into_array();
let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]);
let mut ctx = ExecutionCtx::new(VortexSession::empty());

let normalized = array.clone().normalize(&mut NormalizeOptions {
allowed: &allowed,
operation: Operation::Execute(&mut ctx),
})?;

assert!(!ArrayRef::ptr_eq(&array, &normalized));

let original_children = array.children();
let normalized_children = normalized.children();
assert!(ArrayRef::ptr_eq(
&original_children[0],
&normalized_children[0]
));
assert!(!ArrayRef::ptr_eq(
&original_children[1],
&normalized_children[1]
));
assert_arrays_eq!(normalized_children[1], PrimitiveArray::from_iter(12i32..16));

Ok(())
}

#[test]
fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> {
let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array();
let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array();
let dict = DictArray::try_new(codes, values)?.into_array();

// Slice the dict array to get a SliceArray wrapping a DictArray.
let sliced = SliceArray::new(dict, 1..4).into_array();
assert_eq!(sliced.encoding_id(), Slice::ID);

let allowed = HashSet::from_iter([Dict::ID, Primitive::ID]);
let mut ctx = ExecutionCtx::new(VortexSession::empty());

println!("sliced {}", sliced.display_tree());

let normalized = sliced.normalize(&mut NormalizeOptions {
allowed: &allowed,
operation: Operation::Execute(&mut ctx),
})?;

println!("after {}", normalized.display_tree());

// The normalized result should be a DictArray, not a SliceArray.
assert_eq!(normalized.encoding_id(), Dict::ID);
assert_eq!(normalized.len(), 3);

// Verify the data: codes [1,0,1] -> values [20, 10, 20]
assert_arrays_eq!(
normalized.to_canonical()?,
PrimitiveArray::from_iter(vec![20i32, 10, 20])
);

Ok(())
}
}
7 changes: 4 additions & 3 deletions vortex-cuda/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
use vortex::array::ArrayContext;
use vortex::array::ArrayId;
use vortex::array::ArrayRef;
use vortex::array::DeserializeMetadata;
use vortex::array::MaskFuture;
Expand All @@ -28,7 +29,6 @@ use vortex::array::normalize::NormalizeOptions;
use vortex::array::normalize::Operation;
use vortex::array::serde::SerializeOptions;
use vortex::array::serde::SerializedArray;
use vortex::array::session::ArrayRegistry;
use vortex::array::stats::StatsSetRef;
use vortex::buffer::BufferString;
use vortex::buffer::ByteBuffer;
Expand Down Expand Up @@ -63,6 +63,7 @@ use vortex::scalar::upper_bound;
use vortex::session::VortexSession;
use vortex::session::registry::ReadContext;
use vortex::utils::aliases::hash_map::HashMap;
use vortex::utils::aliases::hash_set::HashSet;

/// A buffer inlined into layout metadata for host-side access.
#[derive(Clone, prost::Message)]
Expand Down Expand Up @@ -390,7 +391,7 @@ pub struct CudaFlatLayoutStrategy {
/// Maximum length of variable length statistics.
pub max_variable_length_statistics_size: usize,
/// Optional set of allowed array encodings for normalization.
pub allowed_encodings: Option<ArrayRegistry>,
pub allowed_encodings: Option<HashSet<ArrayId>>,
}

impl Default for CudaFlatLayoutStrategy {
Expand All @@ -414,7 +415,7 @@ impl CudaFlatLayoutStrategy {
self
}

pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self {
pub fn with_allow_encodings(mut self, allow_encodings: HashSet<ArrayId>) -> Self {
self.allowed_encodings = Some(allow_encodings);
self
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-file/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ impl vortex_file::WriteStrategyBuilder

pub fn vortex_file::WriteStrategyBuilder::build(self) -> alloc::sync::Arc<dyn vortex_layout::strategy::LayoutStrategy>

pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_array::session::ArrayRegistry) -> Self
pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_utils::aliases::hash_set::HashSet<vortex_array::array::ArrayId>) -> Self

pub fn vortex_file::WriteStrategyBuilder::with_btrblocks_builder(self, builder: vortex_btrblocks::builder::BtrBlocksCompressorBuilder) -> Self

Expand Down Expand Up @@ -396,7 +396,7 @@ pub const vortex_file::VERSION: u16

pub const vortex_file::VORTEX_FILE_EXTENSION: &str

pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock<vortex_array::session::ArrayRegistry>
pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock<vortex_utils::aliases::hash_set::HashSet<vortex_array::array::ArrayId>>

pub trait vortex_file::OpenOptionsSessionExt: vortex_array::session::ArraySessionExt + vortex_layout::session::LayoutSessionExt + vortex_io::session::RuntimeSessionExt

Expand Down
Loading
Loading