diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index ed872aa095a..b3bcf44808a 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -1688,6 +1688,34 @@ pub fn vortex_array::arrays::dict::take_canonical(values: vortex_array::Canonica pub mod vortex_array::arrays::extension +pub struct vortex_array::arrays::extension::ExactExtArray(_) + +impl core::default::Default for vortex_array::arrays::extension::ExactExtArray + +pub fn vortex_array::arrays::extension::ExactExtArray::default() -> vortex_array::arrays::extension::ExactExtArray + +impl core::fmt::Debug for vortex_array::arrays::extension::ExactExtArray + +pub fn vortex_array::arrays::extension::ExactExtArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::matcher::Matcher for vortex_array::arrays::extension::ExactExtArray + +pub type vortex_array::arrays::extension::ExactExtArray::Match<'a> = vortex_array::arrays::extension::ExtArray<'a, V> + +pub fn vortex_array::arrays::extension::ExactExtArray::matches(array: &dyn vortex_array::DynArray) -> bool + +pub fn vortex_array::arrays::extension::ExactExtArray::try_match(array: &dyn vortex_array::DynArray) -> core::option::Option + +pub struct vortex_array::arrays::extension::ExtArray<'a, V: vortex_array::dtype::extension::ExtVTable> + +impl<'a, V: vortex_array::dtype::extension::ExtVTable> vortex_array::arrays::extension::ExtArray<'a, V> + +pub fn vortex_array::arrays::extension::ExtArray<'a, V>::ext_dtype(&self) -> &vortex_array::dtype::extension::ExtDType + +pub fn vortex_array::arrays::extension::ExtArray<'a, V>::storage_array(&self) -> &vortex_array::ArrayRef + +pub fn vortex_array::arrays::extension::ExtArray<'a, V>::try_new(array: &'a vortex_array::arrays::ExtensionArray) -> core::option::Option + pub struct vortex_array::arrays::extension::Extension impl vortex_array::arrays::Extension @@ -1806,6 +1834,8 @@ pub struct vortex_array::arrays::extension::ExtensionArray impl vortex_array::arrays::ExtensionArray +pub fn vortex_array::arrays::ExtensionArray::downcast_ref(&self) -> core::option::Option> + pub fn vortex_array::arrays::ExtensionArray::ext_dtype(&self) -> &vortex_array::dtype::extension::ExtDTypeRef pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self @@ -5620,6 +5650,8 @@ pub struct vortex_array::arrays::ExtensionArray impl vortex_array::arrays::ExtensionArray +pub fn vortex_array::arrays::ExtensionArray::downcast_ref(&self) -> core::option::Option> + pub fn vortex_array::arrays::ExtensionArray::ext_dtype(&self) -> &vortex_array::dtype::extension::ExtDTypeRef pub fn vortex_array::arrays::ExtensionArray::new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> Self @@ -10196,6 +10228,8 @@ impl vortex_array::dtype::extension::ExtDTypeRef pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast(self) -> alloc::sync::Arc> +pub fn vortex_array::dtype::extension::ExtDTypeRef::downcast_ref(&self) -> core::option::Option<&vortex_array::dtype::extension::ExtDType> + pub fn vortex_array::dtype::extension::ExtDTypeRef::is(&self) -> bool pub fn vortex_array::dtype::extension::ExtDTypeRef::metadata(&self) -> ::Match @@ -10250,10 +10284,14 @@ pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_to(&self, ext_dtype pub fn vortex_array::dtype::extension::ExtVTable::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::dtype::extension::ExtVTable::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::dtype::extension::ExtVTable::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::dtype::extension::ExtVTable::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::dtype::extension::ExtVTable::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10274,10 +10312,14 @@ pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, ext_dtype: pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Date::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Date::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Date::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10298,10 +10340,14 @@ pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, ext_dtype: pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Time::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Time::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Time::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10322,10 +10368,14 @@ pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, ext_dt pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Timestamp::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Timestamp::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10346,10 +10396,14 @@ pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, ext_dtype: &vor pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::uuid::Uuid::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::uuid::Uuid::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14070,6 +14124,8 @@ pub vortex_array::extension::datetime::TimeUnit::Seconds = 3 impl vortex_array::extension::datetime::TimeUnit +pub fn vortex_array::extension::datetime::TimeUnit::nanos_per_unit(&self) -> i64 + pub fn vortex_array::extension::datetime::TimeUnit::to_jiff_span(&self, v: i64) -> vortex_error::VortexResult impl core::clone::Clone for vortex_array::extension::datetime::TimeUnit @@ -14212,10 +14268,14 @@ pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, ext_dtype: pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Date::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Date::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Date::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14268,10 +14328,14 @@ pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, ext_dtype: pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Time::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Time::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Time::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14326,10 +14390,14 @@ pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, ext_dt pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::datetime::Timestamp::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Timestamp::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14408,10 +14476,14 @@ pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, ext_dtype: &vor pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult +pub fn vortex_array::extension::uuid::Uuid::execute_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> + pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::uuid::Uuid::reduce_parent_array(&self, array: &vortex_array::arrays::extension::ExtArray<'_, Self>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14706,6 +14778,14 @@ pub fn vortex_array::arrays::scalar_fn::ExactScalarFn::matches(array: &dyn vo pub fn vortex_array::arrays::scalar_fn::ExactScalarFn::try_match(array: &dyn vortex_array::DynArray) -> core::option::Option +impl vortex_array::matcher::Matcher for vortex_array::arrays::extension::ExactExtArray + +pub type vortex_array::arrays::extension::ExactExtArray::Match<'a> = vortex_array::arrays::extension::ExtArray<'a, V> + +pub fn vortex_array::arrays::extension::ExactExtArray::matches(array: &dyn vortex_array::DynArray) -> bool + +pub fn vortex_array::arrays::extension::ExactExtArray::try_match(array: &dyn vortex_array::DynArray) -> core::option::Option + impl vortex_array::matcher::Matcher for V pub type V::Match<'a> = &'a ::Array diff --git a/vortex-array/src/arrays/extension/array.rs b/vortex-array/src/arrays/extension/array.rs index 799646817a9..75f677b4f40 100644 --- a/vortex-array/src/arrays/extension/array.rs +++ b/vortex-array/src/arrays/extension/array.rs @@ -5,8 +5,10 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::ArrayRef; +use crate::arrays::extension::view::ExtArray; use crate::dtype::DType; use crate::dtype::extension::ExtDTypeRef; +use crate::dtype::extension::ExtVTable; use crate::stats::ArrayStats; /// An extension array that wraps another array with additional type information. @@ -125,4 +127,8 @@ impl ExtensionArray { pub fn storage_array(&self) -> &ArrayRef { &self.storage_array } + + pub fn downcast_ref(&self) -> Option> { + ExtArray::try_new(self) + } } diff --git a/vortex-array/src/arrays/extension/compute/cast.rs b/vortex-array/src/arrays/extension/compute/cast.rs index 1b02b479438..383c2a64222 100644 --- a/vortex-array/src/arrays/extension/compute/cast.rs +++ b/vortex-array/src/arrays/extension/compute/cast.rs @@ -11,28 +11,24 @@ use crate::scalar_fn::fns::cast::CastReduce; impl CastReduce for Extension { fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult> { - if !array.dtype().eq_ignore_nullability(dtype) { - return Ok(None); + // Same extension type (ignoring nullability): just cast the storage nullability. + if array.dtype().eq_ignore_nullability(dtype) { + let DType::Extension(ext_dtype) = dtype else { + unreachable!("Already verified we have an extension dtype"); + }; + + let new_storage = array + .storage_array() + .cast(ext_dtype.storage_dtype().clone())?; + + return Ok(Some( + ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), + )); } - let DType::Extension(ext_dtype) = dtype else { - unreachable!("Already verified we have an extension dtype"); - }; - - let new_storage = match array - .storage_array() - .cast(ext_dtype.storage_dtype().clone()) - { - Ok(arr) => arr, - Err(e) => { - tracing::warn!("Failed to cast storage array: {e}"); - return Ok(None); - } - }; - - Ok(Some( - ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), - )) + // Type-specific casting (e.g. Timestamp(s) → Timestamp(ns)) is handled by + // ExtVTable::reduce_parent_array, which runs before this CastReduce fallback. + Ok(None) } } @@ -86,21 +82,68 @@ mod tests { } #[test] - fn cast_different_ext_dtype() { - let original_dtype = + fn cast_timestamp_ms_to_ns() { + let source_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased(); - // Note NS here instead of MS let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased(); + let storage = buffer![1i64, 2, 3].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype.clone())).unwrap(); + assert_eq!(result.dtype(), &DType::Extension(target_dtype)); + + // Verify values were scaled: ms → ns is ×1_000_000 + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + assert_eq!(prim.as_slice::(), &[1_000_000, 2_000_000, 3_000_000]); + } + + #[test] + fn cast_timestamp_s_to_us() { + let source_dtype = Timestamp::new(TimeUnit::Seconds, Nullability::NonNullable).erased(); + let target_dtype = + Timestamp::new(TimeUnit::Microseconds, Nullability::NonNullable).erased(); + + let storage = buffer![10i64, 20].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype)).unwrap(); + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + assert_eq!(prim.as_slice::(), &[10_000_000, 20_000_000]); + } + + #[test] + fn cast_timestamp_tz_mismatch_fails() { + use std::sync::Arc; + + let utc_dtype = Timestamp::new_with_tz( + TimeUnit::Seconds, + Some(Arc::from("UTC")), + Nullability::NonNullable, + ) + .erased(); + let no_tz_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased(); + let storage = buffer![1i64].into_array(); - let arr = ExtensionArray::new(original_dtype, storage); - - assert!( - arr.into_array() - .cast(DType::Extension(target_dtype)) - .and_then(|a| a.to_canonical().map(|c| c.into_array())) - .is_err() - ); + let arr = ExtensionArray::new(utc_dtype, storage).into_array(); + + // Timezone mismatch: cast creates a lazy expression, error surfaces on evaluation. + let result = arr + .cast(DType::Extension(no_tz_dtype)) + .and_then(|a| a.to_canonical().map(|c| c.into_array())); + assert!(result.is_err()); } #[rstest] diff --git a/vortex-array/src/arrays/extension/mod.rs b/vortex-array/src/arrays/extension/mod.rs index 703871bf311..21f02159e9a 100644 --- a/vortex-array/src/arrays/extension/mod.rs +++ b/vortex-array/src/arrays/extension/mod.rs @@ -4,7 +4,12 @@ mod array; pub use array::ExtensionArray; +mod view; +pub use view::ExactExtArray; +pub use view::ExtArray; + pub(crate) mod compute; mod vtable; + pub use vtable::Extension; diff --git a/vortex-array/src/arrays/extension/view.rs b/vortex-array/src/arrays/extension/view.rs new file mode 100644 index 00000000000..5ddcbc1b4b4 --- /dev/null +++ b/vortex-array/src/arrays/extension/view.rs @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::marker::PhantomData; + +use crate::ArrayRef; +use crate::DynArray; +use crate::arrays::Extension; +use crate::arrays::ExtensionArray; +use crate::dtype::extension::ExtDType; +use crate::dtype::extension::ExtVTable; +use crate::matcher::Matcher; + +/// A typed view of an extension array. +pub struct ExtArray<'a, V: ExtVTable> { + ext_dtype: &'a ExtDType, + array: &'a ExtensionArray, +} + +impl<'a, V: ExtVTable> ExtArray<'a, V> { + pub fn try_new(array: &'a ExtensionArray) -> Option { + let ext_dtype = array.ext_dtype().downcast_ref::()?; + Some(Self { ext_dtype, array }) + } + + pub fn ext_dtype(&self) -> &ExtDType { + self.ext_dtype + } + + pub fn storage_array(&self) -> &ArrayRef { + self.array.storage_array() + } +} + +/// A matcher that matches an [`ExtensionArray`] with a specific [`ExtVTable`] type. +/// +/// Similar to [`ExactScalarFn`](crate::arrays::scalar_fn::ExactScalarFn) for scalar functions, +/// this provides typed access to the extension array view ([`ExtArray`]). +#[derive(Debug, Default)] +pub struct ExactExtArray(PhantomData); + +impl Matcher for ExactExtArray { + type Match<'a> = ExtArray<'a, V>; + + fn matches(array: &dyn DynArray) -> bool { + if let Some(ext_array) = array.as_opt::() { + ext_array.downcast_ref::().is_some() + } else { + false + } + } + + fn try_match(array: &dyn DynArray) -> Option> { + let ext_array = array.as_opt::()?; + ext_array.downcast_ref::() + } +} diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index 243a0ef35d8..903e50d2ce1 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -157,21 +157,30 @@ impl VTable for Extension { Ok(ExecutionStep::Done(array.clone().into_array())) } - fn reduce_parent( + fn execute_parent( array: &Self::Array, parent: &ArrayRef, child_idx: usize, + ctx: &mut ExecutionCtx, ) -> VortexResult> { - PARENT_RULES.evaluate(array, parent, child_idx) + if let Some(result) = array + .ext_dtype() + .execute_parent(array, parent, child_idx, ctx)? + { + return Ok(Some(result)); + } + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } - fn execute_parent( + fn reduce_parent( array: &Self::Array, parent: &ArrayRef, child_idx: usize, - ctx: &mut ExecutionCtx, ) -> VortexResult> { - PARENT_KERNELS.execute(array, parent, child_idx, ctx) + if let Some(result) = array.ext_dtype().reduce_parent(array, parent, child_idx)? { + return Ok(Some(result)); + } + PARENT_RULES.evaluate(array, parent, child_idx) } } diff --git a/vortex-array/src/dtype/coercion.rs b/vortex-array/src/dtype/coercion.rs index 528d9b673b2..fee93900a0e 100644 --- a/vortex-array/src/dtype/coercion.rs +++ b/vortex-array/src/dtype/coercion.rs @@ -205,6 +205,13 @@ impl DType { /// Convenience — is there a path from `self` to `other`? pub fn can_coerce_to(&self, other: &DType) -> bool { + if let DType::Extension(ext) = self { + // Extension types can define coercions in either direction, so check both. + if ext.can_coerce_to(other) { + return true; + } + }; + other.can_coerce_from(self) } diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index e624c5c7ed2..c8726ff2f44 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -13,6 +13,9 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::ExtensionArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::extension::ExtDType; @@ -114,6 +117,25 @@ impl ExtDTypeRef { pub fn least_supertype(&self, other: &DType) -> Option { self.0.coercion_least_supertype(other) } + + pub(crate) fn execute_parent( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + self.0.execute_parent_array(array, parent, child_idx, ctx) + } + + pub(crate) fn reduce_parent( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + self.0.reduce_parent_array(array, parent, child_idx) + } } /// Methods for downcasting type-erased extension dtypes. @@ -166,6 +188,11 @@ impl ExtDTypeRef { }) .vortex_expect("Failed to downcast ExtDTypeRef") } + + /// Downcast to the concrete [`ExtDType`]. + pub fn downcast_ref(&self) -> Option<&ExtDType> { + self.0.as_any().downcast_ref::>() + } } impl PartialEq for ExtDTypeRef { diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index baa3706b56b..05d4b2879e1 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -15,7 +15,11 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::ExtensionArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::extension::ExtDTypeRef; @@ -129,6 +133,19 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { fn coercion_can_coerce_to(&self, other: &DType) -> bool; /// Compute the least supertype of this extension type and another type. fn coercion_least_supertype(&self, other: &DType) -> Option; + fn execute_parent_array( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; + fn reduce_parent_array( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult>; } impl DynExtDType for ExtDType { @@ -212,4 +229,36 @@ impl DynExtDType for ExtDType { fn coercion_least_supertype(&self, other: &DType) -> Option { self.vtable.least_supertype(self, other) } + + fn execute_parent_array( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + self.vtable.execute_parent_array( + &array + .downcast_ref() + .ok_or_else(|| vortex_err!("Array is not an extension array of this type"))?, + parent, + child_idx, + ctx, + ) + } + + fn reduce_parent_array( + &self, + array: &ExtensionArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + self.vtable.reduce_parent_array( + &array + .downcast_ref() + .ok_or_else(|| vortex_err!("Array is not an extension array of this type"))?, + parent, + child_idx, + ) + } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 120cfa02dc3..83203b901ad 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -7,6 +7,9 @@ use std::hash::Hash; use vortex_error::VortexResult; +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::arrays::extension::ExtArray; use crate::dtype::DType; use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; @@ -65,8 +68,6 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { None } - // Methods related to the extension scalar values. - /// Validate the given storage value is compatible with the extension type. /// /// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the result. @@ -96,4 +97,31 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { ext_dtype: &'a ExtDType, storage_value: &'a ScalarValue, ) -> VortexResult>; + + /// Execute the extension array's parent. + /// + /// See [`crate::vtable::VTable::execute_parent`] + fn execute_parent_array( + &self, + array: &ExtArray, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + let _ = (array, parent, child_idx, ctx); + Ok(None) + } + + /// Reduce the extension array's parent. + /// + /// See [`crate::vtable::VTable::reduce_parent`] + fn reduce_parent_array( + &self, + array: &ExtArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + let _ = (array, parent, child_idx); + Ok(None) + } } diff --git a/vortex-array/src/extension/datetime/date.rs b/vortex-array/src/extension/datetime/date.rs index 57fe7ad87e0..cf15fa87310 100644 --- a/vortex-array/src/extension/datetime/date.rs +++ b/vortex-array/src/extension/datetime/date.rs @@ -10,6 +10,13 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ExtensionArray; +use crate::arrays::extension::ExtArray; +use crate::arrays::scalar_fn::ExactScalarFn; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -17,7 +24,10 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; +use crate::matcher::Matcher; use crate::scalar::ScalarValue; +use crate::scalar_fn::fns::cast::Cast; +use crate::scalar_fn::fns::operators::Operator; /// The Unix epoch date (1970-01-01). const EPOCH: jiff::civil::Date = jiff::civil::Date::constant(1970, 1, 1); @@ -129,6 +139,55 @@ impl ExtVTable for Date { Some(DType::Extension(Date::new(finest, union_null).erased())) } + fn reduce_parent_array( + &self, + array: &ExtArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + let _ = child_idx; + let Some(cast_view) = ExactScalarFn::::try_match(parent.as_ref()) else { + return Ok(None); + }; + let target = cast_view.options; + + let DType::Extension(target_ext) = target else { + return Ok(None); + }; + let Some(target_unit) = target_ext.metadata_opt::() else { + return Ok(None); + }; + let source_unit = array.ext_dtype().metadata(); + + let source_nanos = source_unit.nanos_per_unit(); + let target_nanos = target_unit.nanos_per_unit(); + + // Cast storage to target ptype first (e.g. i32 → i64 for Days → Ms). + let storage = array + .storage_array() + .cast(target_ext.storage_dtype().clone())?; + + let storage = if source_nanos == target_nanos { + storage + } else if source_nanos > target_nanos { + let factor = source_nanos / target_nanos; + storage.binary( + ConstantArray::new(factor, storage.len()).into_array(), + Operator::Mul, + )? + } else { + let factor = target_nanos / source_nanos; + storage.binary( + ConstantArray::new(factor, storage.len()).into_array(), + Operator::Div, + )? + }; + + Ok(Some( + ExtensionArray::new(target_ext.clone(), storage).into_array(), + )) + } + fn unpack_native( &self, ext_dtype: &ExtDType, @@ -210,4 +269,32 @@ mod tests { assert!(ms.can_coerce_from(&days)); assert!(!days.can_coerce_from(&ms)); } + + #[test] + fn cast_date_days_to_ms() { + use vortex_buffer::buffer; + + use crate::IntoArray; + use crate::arrays::ExtensionArray; + use crate::builtins::ArrayBuiltins; + use crate::dtype::Nullability::NonNullable; + + let source_dtype = Date::new(TimeUnit::Days, NonNullable).erased(); + let target_dtype = Date::new(TimeUnit::Milliseconds, NonNullable).erased(); + + // 1 day and 2 days since epoch + let storage = buffer![1i32, 2].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype)).unwrap(); + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + // Days → Ms: ×86_400_000 + assert_eq!(prim.as_slice::(), &[86_400_000i64, 172_800_000]); + } } diff --git a/vortex-array/src/extension/datetime/time.rs b/vortex-array/src/extension/datetime/time.rs index eb53a0e0fe3..63478902be8 100644 --- a/vortex-array/src/extension/datetime/time.rs +++ b/vortex-array/src/extension/datetime/time.rs @@ -10,6 +10,13 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ExtensionArray; +use crate::arrays::extension::ExtArray; +use crate::arrays::scalar_fn::ExactScalarFn; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -17,7 +24,10 @@ use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; +use crate::matcher::Matcher; use crate::scalar::ScalarValue; +use crate::scalar_fn::fns::cast::Cast; +use crate::scalar_fn::fns::operators::Operator; /// Time DType. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] @@ -129,6 +139,55 @@ impl ExtVTable for Time { Some(DType::Extension(Time::new(finest, union_null).erased())) } + fn reduce_parent_array( + &self, + array: &ExtArray, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + let _ = child_idx; + let Some(cast_view) = ExactScalarFn::::try_match(parent.as_ref()) else { + return Ok(None); + }; + let target = cast_view.options; + + let DType::Extension(target_ext) = target else { + return Ok(None); + }; + let Some(target_unit) = target_ext.metadata_opt::