diff --git a/Cargo.lock b/Cargo.lock index d4d5236b130..0a22c17749f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4860,11 +4860,11 @@ dependencies = [ [[package]] name = "insta" -version = "1.46.3" +version = "1.47.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" +checksum = "7b4a6248eb93a4401ed2f37dfe8ea592d3cf05b7cf4f8efa867b6895af7e094e" dependencies = [ - "console 0.15.11", + "console 0.16.3", "once_cell", "similar", "tempfile", @@ -9325,12 +9325,12 @@ dependencies = [ [[package]] name = "terminal_size" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0" +checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix 1.1.4", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 84b3fb9513d..16bc0216f1b 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -13086,15 +13086,17 @@ pub fn V::try_match<'a>(array: &'a vortex_array::ArrayRef) -> core::option::Opti pub mod vortex_array::normalize -pub enum vortex_array::normalize::Operation +pub enum vortex_array::normalize::Operation<'a> pub vortex_array::normalize::Operation::Error +pub vortex_array::normalize::Operation::Execute(&'a mut vortex_array::ExecutionCtx) + pub struct vortex_array::normalize::NormalizeOptions<'a> -pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_array::session::ArrayRegistry +pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_utils::aliases::hash_set::HashSet -pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation +pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation<'a> pub mod vortex_array::optimizer diff --git a/vortex-array/src/normalize.rs b/vortex-array/src/normalize.rs index 9a796e5485c..3235a480809 100644 --- a/vortex-array/src/normalize.rs +++ b/vortex-array/src/normalize.rs @@ -1,27 +1,27 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use itertools::Itertools; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_session::registry::Id; +use vortex_utils::aliases::hash_set::HashSet; use crate::ArrayRef; -use crate::session::ArrayRegistry; +use crate::ExecutionCtx; /// Options for normalizing an array. pub struct NormalizeOptions<'a> { /// The set of allowed array encodings (in addition to the canonical ones) that are permitted /// in the normalized array. - pub allowed: &'a ArrayRegistry, + pub allowed: &'a HashSet, /// The operation to perform when a non-allowed encoding is encountered. - pub operation: Operation, + pub operation: Operation<'a>, } /// The operation to perform when a non-allowed encoding is encountered. -pub enum Operation { +pub enum Operation<'a> { Error, - // TODO(joe): add into canonical variant + Execute(&'a mut ExecutionCtx), } impl ArrayRef { @@ -30,14 +30,18 @@ impl ArrayRef { /// This operation performs a recursive traversal of the array. Any non-allowed encoding is /// normalized per the configured operation. pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult { - let array_ids = options.allowed.ids().collect_vec(); - self.normalize_with_error(&array_ids)?; - // Note this takes ownership so we can at a later date remove non-allowed encodings. - Ok(self) + match &mut options.operation { + Operation::Error => { + self.normalize_with_error(options.allowed)?; + // Note this takes ownership so we can at a later date remove non-allowed encodings. + Ok(self) + } + Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx), + } } - fn normalize_with_error(&self, allowed: &[Id]) -> VortexResult<()> { - if !allowed.contains(&self.encoding_id()) { + fn normalize_with_error(&self, allowed: &HashSet) -> VortexResult<()> { + if !self.is_allowed_encoding(allowed) { vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id()) } @@ -46,4 +50,183 @@ impl ArrayRef { } Ok(()) } + + fn normalize_with_execution( + self, + allowed: &HashSet, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mut normalized = self; + + // Top-first execute the array tree while we hit non-allowed encodings. + while !normalized.is_allowed_encoding(allowed) { + normalized = normalized.execute(ctx)?; + } + + // Now we've normalized the root, we need to ensure the children are normalized also. + let slots = normalized.slots(); + let mut normalized_slots = Vec::with_capacity(slots.len()); + let mut any_slot_changed = false; + + for slot in slots { + match slot { + Some(child) => { + let normalized_child = child.clone().normalize(&mut NormalizeOptions { + allowed, + operation: Operation::Execute(ctx), + })?; + any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child); + normalized_slots.push(Some(normalized_child)); + } + None => normalized_slots.push(None), + } + } + + if any_slot_changed { + normalized = normalized.with_slots(normalized_slots)?; + } + + Ok(normalized) + } + + fn is_allowed_encoding(&self, allowed: &HashSet) -> bool { + allowed.contains(&self.encoding_id()) || self.is_canonical() + } +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + use vortex_session::VortexSession; + use vortex_utils::aliases::hash_set::HashSet; + + use super::NormalizeOptions; + use super::Operation; + use crate::ArrayRef; + use crate::ExecutionCtx; + use crate::IntoArray; + use crate::arrays::Dict; + use crate::arrays::DictArray; + use crate::arrays::Primitive; + use crate::arrays::PrimitiveArray; + use crate::arrays::Slice; + use crate::arrays::SliceArray; + use crate::arrays::StructArray; + use crate::assert_arrays_eq; + use crate::validity::Validity; + + #[test] + fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> { + let field = PrimitiveArray::from_iter(0i32..4).into_array(); + let array = StructArray::try_new( + ["field"].into(), + vec![field.clone()], + field.len(), + Validity::NonNullable, + )? + .into_array(); + let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let normalized = array.clone().normalize(&mut NormalizeOptions { + allowed: &allowed, + operation: Operation::Execute(&mut ctx), + })?; + + assert!(ArrayRef::ptr_eq(&array, &normalized)); + Ok(()) + } + + #[test] + fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> { + let field = PrimitiveArray::from_iter(0i32..4).into_array(); + let array = StructArray::try_new( + ["field"].into(), + vec![field.clone()], + field.len(), + Validity::NonNullable, + )? + .into_array(); + let allowed = HashSet::default(); + + let normalized = array.clone().normalize(&mut NormalizeOptions { + allowed: &allowed, + operation: Operation::Error, + })?; + + assert!(ArrayRef::ptr_eq(&array, &normalized)); + Ok(()) + } + + #[test] + fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> { + let unchanged = PrimitiveArray::from_iter(0i32..4).into_array(); + let sliced = + SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array(); + let array = StructArray::try_new( + ["lhs", "rhs"].into(), + vec![unchanged.clone(), sliced], + unchanged.len(), + Validity::NonNullable, + )? + .into_array(); + let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + let normalized = array.clone().normalize(&mut NormalizeOptions { + allowed: &allowed, + operation: Operation::Execute(&mut ctx), + })?; + + assert!(!ArrayRef::ptr_eq(&array, &normalized)); + + let original_children = array.children(); + let normalized_children = normalized.children(); + assert!(ArrayRef::ptr_eq( + &original_children[0], + &normalized_children[0] + )); + assert!(!ArrayRef::ptr_eq( + &original_children[1], + &normalized_children[1] + )); + assert_arrays_eq!(normalized_children[1], PrimitiveArray::from_iter(12i32..16)); + + Ok(()) + } + + #[test] + fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> { + let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array(); + let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array(); + let dict = DictArray::try_new(codes, values)?.into_array(); + + // Slice the dict array to get a SliceArray wrapping a DictArray. + let sliced = SliceArray::new(dict, 1..4).into_array(); + assert_eq!(sliced.encoding_id(), Slice::ID); + + let allowed = HashSet::from_iter([Dict::ID, Primitive::ID]); + let mut ctx = ExecutionCtx::new(VortexSession::empty()); + + println!("sliced {}", sliced.display_tree()); + + let normalized = sliced.normalize(&mut NormalizeOptions { + allowed: &allowed, + operation: Operation::Execute(&mut ctx), + })?; + + println!("after {}", normalized.display_tree()); + + // The normalized result should be a DictArray, not a SliceArray. + assert_eq!(normalized.encoding_id(), Dict::ID); + assert_eq!(normalized.len(), 3); + + // Verify the data: codes [1,0,1] -> values [20, 10, 20] + assert_arrays_eq!( + normalized.to_canonical()?, + PrimitiveArray::from_iter(vec![20i32, 10, 20]) + ); + + Ok(()) + } } diff --git a/vortex-cuda/src/layout.rs b/vortex-cuda/src/layout.rs index 0a8efcaac5f..6ab2c08681b 100644 --- a/vortex-cuda/src/layout.rs +++ b/vortex-cuda/src/layout.rs @@ -14,6 +14,7 @@ use futures::FutureExt; use futures::StreamExt; use futures::future::BoxFuture; use vortex::array::ArrayContext; +use vortex::array::ArrayId; use vortex::array::ArrayRef; use vortex::array::DeserializeMetadata; use vortex::array::MaskFuture; @@ -28,7 +29,6 @@ use vortex::array::normalize::NormalizeOptions; use vortex::array::normalize::Operation; use vortex::array::serde::SerializeOptions; use vortex::array::serde::SerializedArray; -use vortex::array::session::ArrayRegistry; use vortex::array::stats::StatsSetRef; use vortex::buffer::BufferString; use vortex::buffer::ByteBuffer; @@ -63,6 +63,7 @@ use vortex::scalar::upper_bound; use vortex::session::VortexSession; use vortex::session::registry::ReadContext; use vortex::utils::aliases::hash_map::HashMap; +use vortex::utils::aliases::hash_set::HashSet; /// A buffer inlined into layout metadata for host-side access. #[derive(Clone, prost::Message)] @@ -390,7 +391,7 @@ pub struct CudaFlatLayoutStrategy { /// Maximum length of variable length statistics. pub max_variable_length_statistics_size: usize, /// Optional set of allowed array encodings for normalization. - pub allowed_encodings: Option, + pub allowed_encodings: Option>, } impl Default for CudaFlatLayoutStrategy { @@ -414,7 +415,7 @@ impl CudaFlatLayoutStrategy { self } - pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self { + pub fn with_allow_encodings(mut self, allow_encodings: HashSet) -> Self { self.allowed_encodings = Some(allow_encodings); self } diff --git a/vortex-file/public-api.lock b/vortex-file/public-api.lock index 4d896437a3b..fe545cfb26a 100644 --- a/vortex-file/public-api.lock +++ b/vortex-file/public-api.lock @@ -344,7 +344,7 @@ impl vortex_file::WriteStrategyBuilder pub fn vortex_file::WriteStrategyBuilder::build(self) -> alloc::sync::Arc -pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_array::session::ArrayRegistry) -> Self +pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_utils::aliases::hash_set::HashSet) -> Self pub fn vortex_file::WriteStrategyBuilder::with_btrblocks_builder(self, builder: vortex_btrblocks::builder::BtrBlocksCompressorBuilder) -> Self @@ -396,7 +396,7 @@ pub const vortex_file::VERSION: u16 pub const vortex_file::VORTEX_FILE_EXTENSION: &str -pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock +pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock> pub trait vortex_file::OpenOptionsSessionExt: vortex_array::session::ArraySessionExt + vortex_layout::session::LayoutSessionExt + vortex_io::session::RuntimeSessionExt diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index 9af6c1e9402..665e34684bd 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -10,6 +10,8 @@ use vortex_alp::ALP; // Compressed encodings from encoding crates // Canonical array encodings from vortex-array use vortex_alp::ALPRD; +use vortex_array::ArrayId; +use vortex_array::VTable; use vortex_array::arrays::Bool; use vortex_array::arrays::Chunked; use vortex_array::arrays::Constant; @@ -26,8 +28,6 @@ use vortex_array::arrays::Struct; use vortex_array::arrays::VarBin; use vortex_array::arrays::VarBinView; use vortex_array::dtype::FieldPath; -use vortex_array::session::ArrayRegistry; -use vortex_array::session::ArraySession; use vortex_btrblocks::BtrBlocksCompressorBuilder; use vortex_btrblocks::SchemeExt; use vortex_btrblocks::schemes::integer::IntDictScheme; @@ -59,6 +59,7 @@ use vortex_sparse::Sparse; #[cfg(feature = "unstable_encodings")] use vortex_tensor::encodings::turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; +use vortex_utils::aliases::hash_set::HashSet; use vortex_zigzag::ZigZag; #[cfg(feature = "zstd")] use vortex_zstd::Zstd; @@ -71,51 +72,51 @@ const ONE_MEG: u64 = 1 << 20; /// /// This includes all canonical encodings from vortex-array plus all compressed /// encodings from the various encoding crates. -pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { - let session = ArraySession::empty(); +pub static ALLOWED_ENCODINGS: LazyLock> = LazyLock::new(|| { + let mut allowed = HashSet::new(); // Canonical encodings from vortex-array - session.register(Null); - session.register(Bool); - session.register(Primitive); - session.register(Decimal); - session.register(VarBin); - session.register(VarBinView); - session.register(List); - session.register(ListView); - session.register(FixedSizeList); - session.register(Struct); - session.register(Extension); - session.register(Chunked); - session.register(Constant); - session.register(Masked); - session.register(Dict); + allowed.insert(Null.id()); + allowed.insert(Bool.id()); + allowed.insert(Primitive.id()); + allowed.insert(Decimal.id()); + allowed.insert(VarBin.id()); + allowed.insert(VarBinView.id()); + allowed.insert(List.id()); + allowed.insert(ListView.id()); + allowed.insert(FixedSizeList.id()); + allowed.insert(Struct.id()); + allowed.insert(Extension.id()); + allowed.insert(Chunked.id()); + allowed.insert(Constant.id()); + allowed.insert(Masked.id()); + allowed.insert(Dict.id()); // Compressed encodings from encoding crates - session.register(ALP); - session.register(ALPRD); - session.register(BitPacked); - session.register(ByteBool); - session.register(DateTimeParts); - session.register(DecimalByteParts); - session.register(Delta); - session.register(FoR); - session.register(FSST); - session.register(Pco); - session.register(RLE); - session.register(RunEnd); - session.register(Sequence); - session.register(Sparse); + allowed.insert(ALP.id()); + allowed.insert(ALPRD.id()); + allowed.insert(BitPacked.id()); + allowed.insert(ByteBool.id()); + allowed.insert(DateTimeParts.id()); + allowed.insert(DecimalByteParts.id()); + allowed.insert(Delta.id()); + allowed.insert(FoR.id()); + allowed.insert(FSST.id()); + allowed.insert(Pco.id()); + allowed.insert(RLE.id()); + allowed.insert(RunEnd.id()); + allowed.insert(Sequence.id()); + allowed.insert(Sparse.id()); #[cfg(feature = "unstable_encodings")] - session.register(TurboQuant); - session.register(ZigZag); + allowed.insert(TurboQuant.id()); + allowed.insert(ZigZag.id()); #[cfg(feature = "zstd")] - session.register(Zstd); + allowed.insert(Zstd.id()); #[cfg(all(feature = "zstd", feature = "unstable_encodings"))] - session.register(ZstdBuffers); + allowed.insert(ZstdBuffers.id()); - session.registry().clone() + allowed }); /// How the compressor was configured on [`WriteStrategyBuilder`]. @@ -138,7 +139,7 @@ pub struct WriteStrategyBuilder { compressor: CompressorConfig, row_block_size: usize, field_writers: HashMap>, - allow_encodings: Option, + allow_encodings: Option>, flat_strategy: Option>, } @@ -175,7 +176,7 @@ impl WriteStrategyBuilder { } /// Override the allowed array encodings for normalization. - pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self { + pub fn with_allow_encodings(mut self, allow_encodings: HashSet) -> Self { self.allow_encodings = Some(allow_encodings); self } diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 1b31f837935..cc11334cb0c 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -414,7 +414,7 @@ pub mod vortex_layout::layouts::flat::writer pub struct vortex_layout::layouts::flat::writer::FlatLayoutStrategy -pub vortex_layout::layouts::flat::writer::FlatLayoutStrategy::allowed_encodings: core::option::Option +pub vortex_layout::layouts::flat::writer::FlatLayoutStrategy::allowed_encodings: core::option::Option> pub vortex_layout::layouts::flat::writer::FlatLayoutStrategy::include_padding: bool @@ -422,7 +422,7 @@ pub vortex_layout::layouts::flat::writer::FlatLayoutStrategy::max_variable_lengt impl vortex_layout::layouts::flat::writer::FlatLayoutStrategy -pub fn vortex_layout::layouts::flat::writer::FlatLayoutStrategy::with_allow_encodings(self, allow_encodings: vortex_array::session::ArrayRegistry) -> Self +pub fn vortex_layout::layouts::flat::writer::FlatLayoutStrategy::with_allow_encodings(self, allow_encodings: vortex_utils::aliases::hash_set::HashSet) -> Self pub fn vortex_layout::layouts::flat::writer::FlatLayoutStrategy::with_include_padding(self, include_padding: bool) -> Self diff --git a/vortex-layout/src/layouts/flat/writer.rs b/vortex-layout/src/layouts/flat/writer.rs index 1af1a00cdc6..01e0d7760ab 100644 --- a/vortex-layout/src/layouts/flat/writer.rs +++ b/vortex-layout/src/layouts/flat/writer.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use futures::StreamExt; use vortex_array::ArrayContext; +use vortex_array::ArrayId; use vortex_array::dtype::DType; use vortex_array::expr::stats::Precision; use vortex_array::expr::stats::Stat; @@ -15,7 +16,6 @@ use vortex_array::scalar::ScalarTruncation; use vortex_array::scalar::lower_bound; use vortex_array::scalar::upper_bound; use vortex_array::serde::SerializeOptions; -use vortex_array::session::ArrayRegistry; use vortex_array::stats::StatsSetRef; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; @@ -24,6 +24,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_io::runtime::Handle; use vortex_session::registry::ReadContext; +use vortex_utils::aliases::hash_set::HashSet; use crate::IntoLayout; use crate::LayoutRef; @@ -42,7 +43,7 @@ pub struct FlatLayoutStrategy { pub max_variable_length_statistics_size: usize, /// Optional set of allowed array encodings for normalization. /// If None, then all are allowed. - pub allowed_encodings: Option, + pub allowed_encodings: Option>, } impl Default for FlatLayoutStrategy { @@ -69,7 +70,7 @@ impl FlatLayoutStrategy { } /// Set the allowed array encodings for normalization. - pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self { + pub fn with_allow_encodings(mut self, allow_encodings: HashSet) -> Self { self.allowed_encodings = Some(allow_encodings); self } @@ -202,7 +203,6 @@ mod tests { use vortex_array::arrays::BoolArray; use vortex_array::arrays::Dict; use vortex_array::arrays::DictArray; - use vortex_array::arrays::Primitive; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; @@ -216,9 +216,7 @@ mod tests { use vortex_array::expr::stats::Precision; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProviderExt; - use vortex_array::session::ArrayRegistry; use vortex_array::validity::Validity; - use vortex_array::vtable::ArrayPluginRef; use vortex_array::vtable::VTable; use vortex_buffer::BitBufferMut; use vortex_buffer::buffer; @@ -227,6 +225,7 @@ mod tests { use vortex_io::runtime::single::block_on; use vortex_mask::AllOr; use vortex_mask::Mask; + use vortex_utils::aliases::hash_set::HashSet; use crate::LayoutStrategy; use crate::layouts::flat::writer::FlatLayoutStrategy; @@ -425,9 +424,8 @@ mod tests { let (layout, _segments) = { let segments = Arc::new(TestSegments::default()); let (ptr, eof) = SequenceId::root().split(); - // Only allow primitive encodings - filter arrays should fail. - let allowed = ArrayRegistry::default(); - allowed.register(Primitive::ID, Arc::new(Primitive) as ArrayPluginRef); + // Disallow all encodings so filter arrays fail normalization immediately. + let allowed = HashSet::default(); let layout = FlatLayoutStrategy::default() .with_allow_encodings(allowed) .write_stream( @@ -466,10 +464,9 @@ mod tests { let (layout, _segments) = { let segments = Arc::new(TestSegments::default()); let (ptr, eof) = SequenceId::root().split(); - // Only allow primitive encodings - filter arrays should fail. - let allowed = ArrayRegistry::default(); - allowed.register(Primitive.id(), Arc::new(Primitive) as ArrayPluginRef); - allowed.register(Dict.id(), Arc::new(Dict) as ArrayPluginRef); + // Only allow the dict encoding; canonical primitive children remain permitted. + let mut allowed = HashSet::default(); + allowed.insert(Dict.id()); let layout = FlatLayoutStrategy::default() .with_allow_encodings(allowed) .write_stream(