From 3b0b8e51c2bfd6cded2aa7005d766018c9cd987f Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 16:18:35 -0400 Subject: [PATCH 1/3] Support casting for extension arrays Signed-off-by: Nicholas Gates --- vortex-array/src/arrays/extension/array.rs | 6 +++ .../src/arrays/extension/compute/cast.rs | 36 ++++++++-------- vortex-array/src/arrays/extension/mod.rs | 4 ++ vortex-array/src/arrays/extension/view.rs | 28 +++++++++++++ vortex-array/src/dtype/extension/erased.rs | 26 ++++++++++++ vortex-array/src/dtype/extension/typed.rs | 41 +++++++++++++++++++ vortex-array/src/dtype/extension/vtable.rs | 24 +++++++++++ 7 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 vortex-array/src/arrays/extension/view.rs diff --git a/vortex-array/src/arrays/extension/array.rs b/vortex-array/src/arrays/extension/array.rs index 799646817a9..75f677b4f40 100644 --- a/vortex-array/src/arrays/extension/array.rs +++ b/vortex-array/src/arrays/extension/array.rs @@ -5,8 +5,10 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::ArrayRef; +use crate::arrays::extension::view::ExtArray; use crate::dtype::DType; use crate::dtype::extension::ExtDTypeRef; +use crate::dtype::extension::ExtVTable; use crate::stats::ArrayStats; /// An extension array that wraps another array with additional type information. @@ -125,4 +127,8 @@ impl ExtensionArray { pub fn storage_array(&self) -> &ArrayRef { &self.storage_array } + + pub fn downcast_ref(&self) -> Option> { + ExtArray::try_new(self) + } } diff --git a/vortex-array/src/arrays/extension/compute/cast.rs b/vortex-array/src/arrays/extension/compute/cast.rs index 1b02b479438..f666269f05a 100644 --- a/vortex-array/src/arrays/extension/compute/cast.rs +++ b/vortex-array/src/arrays/extension/compute/cast.rs @@ -11,28 +11,24 @@ use crate::scalar_fn::fns::cast::CastReduce; impl CastReduce for Extension { fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult> { - if !array.dtype().eq_ignore_nullability(dtype) { - return Ok(None); + // Try casting the internal storage of the extension array. + if array.dtype().eq_ignore_nullability(dtype) { + let DType::Extension(ext_dtype) = dtype else { + unreachable!("Already verified we have an extension dtype"); + }; + + if let Some(new_storage) = array + .storage_array() + .cast(ext_dtype.storage_dtype().clone())? + { + return Ok(Some( + ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), + )); + }; } - let DType::Extension(ext_dtype) = dtype else { - unreachable!("Already verified we have an extension dtype"); - }; - - let new_storage = match array - .storage_array() - .cast(ext_dtype.storage_dtype().clone()) - { - Ok(arr) => arr, - Err(e) => { - tracing::warn!("Failed to cast storage array: {e}"); - return Ok(None); - } - }; - - Ok(Some( - ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), - )) + // Otherwise we defer to the extension vtable. + array.ext_dtype().cast_from_ext(array, dtype) } } diff --git a/vortex-array/src/arrays/extension/mod.rs b/vortex-array/src/arrays/extension/mod.rs index 703871bf311..efb0e124f92 100644 --- a/vortex-array/src/arrays/extension/mod.rs +++ b/vortex-array/src/arrays/extension/mod.rs @@ -4,7 +4,11 @@ mod array; pub use array::ExtensionArray; +mod view; +pub use view::ExtArray; + pub(crate) mod compute; mod vtable; + pub use vtable::Extension; diff --git a/vortex-array/src/arrays/extension/view.rs b/vortex-array/src/arrays/extension/view.rs new file mode 100644 index 00000000000..efe5e6c4f4a --- /dev/null +++ b/vortex-array/src/arrays/extension/view.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::ArrayRef; +use crate::arrays::ExtensionArray; +use crate::dtype::extension::ExtDType; +use crate::dtype::extension::ExtVTable; + +/// A typed view of an extension array. +pub struct ExtArray<'a, V: ExtVTable> { + ext_dtype: &'a ExtDType, + array: &'a ExtensionArray, +} + +impl<'a, V: ExtVTable> ExtArray<'a, V> { + pub fn try_new(array: &'a ExtensionArray) -> Option { + let ext_dtype = array.ext_dtype().downcast_ref::()?; + Some(Self { ext_dtype, array }) + } + + pub fn ext_dtype(&self) -> &ExtDType { + self.ext_dtype + } + + pub fn storage_array(&self) -> &ArrayRef { + self.array.storage_array() + } +} diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index e624c5c7ed2..de77db9173a 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -9,10 +9,12 @@ use std::hash::Hash; use std::hash::Hasher; use std::sync::Arc; +use arrow_array::ArrayRef; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; +use crate::arrays::ExtensionArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::extension::ExtDType; @@ -114,6 +116,25 @@ impl ExtDTypeRef { pub fn least_supertype(&self, other: &DType) -> Option { self.0.coercion_least_supertype(other) } + + /// Attempt to cast the extension array to the target dtype. + pub fn cast_into_ext(&self, array: &ArrayRef) -> VortexResult> { + self.0.cast_into_ext(array, self) + } + + /// Attempt to cast the extension array to the target dtype. + pub fn cast_from_ext( + &self, + array: &ExtensionArray, + target: &DType, + ) -> VortexResult> { + if let Some(ext_dtype) = array.dtype().as_extension_opt() { + if ext_dtype != self { + return Ok(None); + } + }; + self.0.cast_from_ext(array, target) + } } /// Methods for downcasting type-erased extension dtypes. @@ -166,6 +187,11 @@ impl ExtDTypeRef { }) .vortex_expect("Failed to downcast ExtDTypeRef") } + + /// Downcast to the concrete [`ExtDType`]. + pub fn downcast_ref(&self) -> Option<&ExtDType> { + self.0.as_any().downcast_ref::>() + } } impl PartialEq for ExtDTypeRef { diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index baa3706b56b..27e0f275c04 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -15,7 +15,10 @@ use std::sync::Arc; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::arrays::ExtensionArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::extension::ExtDTypeRef; @@ -129,6 +132,18 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { fn coercion_can_coerce_to(&self, other: &DType) -> bool; /// Compute the least supertype of this extension type and another type. fn coercion_least_supertype(&self, other: &DType) -> Option; + /// Attempt to cast an array into this extension dtype. + fn cast_into_ext( + &self, + array: &ArrayRef, + target: &ExtDTypeRef, + ) -> VortexResult>; + /// Attempt to cast this extension array into a target dtype. + fn cast_from_ext( + &self, + array: &ExtensionArray, + target: &DType, + ) -> VortexResult>; } impl DynExtDType for ExtDType { @@ -212,4 +227,30 @@ impl DynExtDType for ExtDType { fn coercion_least_supertype(&self, other: &DType) -> Option { self.vtable.least_supertype(self, other) } + + fn cast_into_ext( + &self, + array: &ArrayRef, + target: &ExtDTypeRef, + ) -> VortexResult> { + self.vtable.cast_into_ext( + array, + target + .downcast_ref() + .ok_or_else(|| vortex_err!("Target is not an extension dtype"))?, + ) + } + + fn cast_from_ext( + &self, + array: &ExtensionArray, + target: &DType, + ) -> VortexResult> { + self.vtable.cast_from_ext( + &array + .downcast_ref() + .ok_or_else(|| vortex_err!("Array is not an extension array of this type"))?, + target, + ) + } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 120cfa02dc3..8c99190c94d 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -7,6 +7,8 @@ use std::hash::Hash; use vortex_error::VortexResult; +use crate::ArrayRef; +use crate::arrays::extension::ExtArray; use crate::dtype::DType; use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; @@ -96,4 +98,26 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { ext_dtype: &'a ExtDType, storage_value: &'a ScalarValue, ) -> VortexResult>; + + /// Cast an array into this extension DType. + /// + /// Returns `None` if the cast is not possible. + fn cast_into_ext( + &self, + array: &ArrayRef, + target: &ExtDType, + ) -> VortexResult> { + let _ = (array, target); + Ok(None) + } + + /// Cast an array of this extension DType into another DType. + fn cast_from_ext( + &self, + array: &ExtArray, + target: &DType, + ) -> VortexResult> { + let _ = (array, target); + Ok(None) + } } From e2223ca17ae6407ded42933ccca12570c212e727 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 21:32:34 -0400 Subject: [PATCH 2/3] Extension Casting Signed-off-by: Nicholas Gates --- .../src/arrays/extension/compute/cast.rs | 84 ++++++++++++++----- vortex-array/src/dtype/coercion.rs | 7 ++ vortex-array/src/dtype/extension/erased.rs | 12 +-- vortex-array/src/dtype/extension/vtable.rs | 8 ++ vortex-array/src/extension/datetime/date.rs | 77 +++++++++++++++++ vortex-array/src/extension/datetime/time.rs | 76 +++++++++++++++++ .../src/extension/datetime/timestamp.rs | 49 +++++++++++ vortex-array/src/extension/datetime/unit.rs | 11 +++ vortex-array/src/scalar_fn/fns/cast/mod.rs | 7 ++ 9 files changed, 306 insertions(+), 25 deletions(-) diff --git a/vortex-array/src/arrays/extension/compute/cast.rs b/vortex-array/src/arrays/extension/compute/cast.rs index f666269f05a..2296c1c17bc 100644 --- a/vortex-array/src/arrays/extension/compute/cast.rs +++ b/vortex-array/src/arrays/extension/compute/cast.rs @@ -11,20 +11,19 @@ use crate::scalar_fn::fns::cast::CastReduce; impl CastReduce for Extension { fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult> { - // Try casting the internal storage of the extension array. + // Fast path: same extension type (ignoring nullability), just cast the storage. if array.dtype().eq_ignore_nullability(dtype) { let DType::Extension(ext_dtype) = dtype else { unreachable!("Already verified we have an extension dtype"); }; - if let Some(new_storage) = array + let new_storage = array .storage_array() - .cast(ext_dtype.storage_dtype().clone())? - { - return Ok(Some( - ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), - )); - }; + .cast(ext_dtype.storage_dtype().clone())?; + + return Ok(Some( + ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(), + )); } // Otherwise we defer to the extension vtable. @@ -82,21 +81,68 @@ mod tests { } #[test] - fn cast_different_ext_dtype() { - let original_dtype = + fn cast_timestamp_ms_to_ns() { + let source_dtype = Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased(); - // Note NS here instead of MS let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased(); + let storage = buffer![1i64, 2, 3].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype.clone())).unwrap(); + assert_eq!(result.dtype(), &DType::Extension(target_dtype)); + + // Verify values were scaled: ms → ns is ×1_000_000 + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + assert_eq!(prim.as_slice::(), &[1_000_000, 2_000_000, 3_000_000]); + } + + #[test] + fn cast_timestamp_s_to_us() { + let source_dtype = Timestamp::new(TimeUnit::Seconds, Nullability::NonNullable).erased(); + let target_dtype = + Timestamp::new(TimeUnit::Microseconds, Nullability::NonNullable).erased(); + + let storage = buffer![10i64, 20].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype)).unwrap(); + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + assert_eq!(prim.as_slice::(), &[10_000_000, 20_000_000]); + } + + #[test] + fn cast_timestamp_tz_mismatch_fails() { + use std::sync::Arc; + + let utc_dtype = Timestamp::new_with_tz( + TimeUnit::Seconds, + Some(Arc::from("UTC")), + Nullability::NonNullable, + ) + .erased(); + let no_tz_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased(); + let storage = buffer![1i64].into_array(); - let arr = ExtensionArray::new(original_dtype, storage); - - assert!( - arr.into_array() - .cast(DType::Extension(target_dtype)) - .and_then(|a| a.to_canonical().map(|c| c.into_array())) - .is_err() - ); + let arr = ExtensionArray::new(utc_dtype, storage).into_array(); + + // Timezone mismatch: cast creates a lazy expression, error surfaces on evaluation. + let result = arr + .cast(DType::Extension(no_tz_dtype)) + .and_then(|a| a.to_canonical().map(|c| c.into_array())); + assert!(result.is_err()); } #[rstest] diff --git a/vortex-array/src/dtype/coercion.rs b/vortex-array/src/dtype/coercion.rs index 528d9b673b2..fee93900a0e 100644 --- a/vortex-array/src/dtype/coercion.rs +++ b/vortex-array/src/dtype/coercion.rs @@ -205,6 +205,13 @@ impl DType { /// Convenience — is there a path from `self` to `other`? pub fn can_coerce_to(&self, other: &DType) -> bool { + if let DType::Extension(ext) = self { + // Extension types can define coercions in either direction, so check both. + if ext.can_coerce_to(other) { + return true; + } + }; + other.can_coerce_from(self) } diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index de77db9173a..688344d965b 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -9,11 +9,11 @@ use std::hash::Hash; use std::hash::Hasher; use std::sync::Arc; -use arrow_array::ArrayRef; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; +use crate::ArrayRef; use crate::arrays::ExtensionArray; use crate::dtype::DType; use crate::dtype::Nullability; @@ -128,11 +128,11 @@ impl ExtDTypeRef { array: &ExtensionArray, target: &DType, ) -> VortexResult> { - if let Some(ext_dtype) = array.dtype().as_extension_opt() { - if ext_dtype != self { - return Ok(None); - } - }; + if let Some(ext_dtype) = array.dtype().as_extension_opt() + && ext_dtype != self + { + return Ok(None); + } self.0.cast_from_ext(array, target) } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 8c99190c94d..736eb1a5fc0 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -101,6 +101,9 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Cast an array into this extension DType. /// + /// Note that this function does not take an execution context. It is expected that the + /// implementation can be expressed via operations on the storage array. + /// /// Returns `None` if the cast is not possible. fn cast_into_ext( &self, @@ -112,6 +115,11 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { } /// Cast an array of this extension DType into another DType. + /// + /// Note that this function does not take an execution context. It is expected that the + /// implementation can be expressed via operations on the storage array. + /// + /// Returns `None` if the cast is not possible. fn cast_from_ext( &self, array: &ExtArray, diff --git a/vortex-array/src/extension/datetime/date.rs b/vortex-array/src/extension/datetime/date.rs index 57fe7ad87e0..2b0a139ea66 100644 --- a/vortex-array/src/extension/datetime/date.rs +++ b/vortex-array/src/extension/datetime/date.rs @@ -10,6 +10,12 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ExtensionArray; +use crate::arrays::extension::ExtArray; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -18,6 +24,7 @@ use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; use crate::scalar::ScalarValue; +use crate::scalar_fn::fns::operators::Operator; /// The Unix epoch date (1970-01-01). const EPOCH: jiff::civil::Date = jiff::civil::Date::constant(1970, 1, 1); @@ -129,6 +136,48 @@ impl ExtVTable for Date { Some(DType::Extension(Date::new(finest, union_null).erased())) } + fn cast_from_ext( + &self, + array: &ExtArray, + target: &DType, + ) -> VortexResult> { + let DType::Extension(target_ext) = target else { + return Ok(None); + }; + let Some(target_unit) = target_ext.metadata_opt::() else { + return Ok(None); + }; + let source_unit = array.ext_dtype().metadata(); + + let source_nanos = source_unit.nanos_per_unit(); + let target_nanos = target_unit.nanos_per_unit(); + + // Cast storage to target ptype first (e.g. i32 → i64 for Days → Ms). + let storage = array + .storage_array() + .cast(target_ext.storage_dtype().clone())?; + + let storage = if source_nanos == target_nanos { + storage + } else if source_nanos > target_nanos { + let factor = source_nanos / target_nanos; + storage.binary( + ConstantArray::new(factor, storage.len()).into_array(), + Operator::Mul, + )? + } else { + let factor = target_nanos / source_nanos; + storage.binary( + ConstantArray::new(factor, storage.len()).into_array(), + Operator::Div, + )? + }; + + Ok(Some( + ExtensionArray::new(target_ext.clone(), storage).into_array(), + )) + } + fn unpack_native( &self, ext_dtype: &ExtDType, @@ -210,4 +259,32 @@ mod tests { assert!(ms.can_coerce_from(&days)); assert!(!days.can_coerce_from(&ms)); } + + #[test] + fn cast_date_days_to_ms() { + use vortex_buffer::buffer; + + use crate::IntoArray; + use crate::arrays::ExtensionArray; + use crate::builtins::ArrayBuiltins; + use crate::dtype::Nullability::NonNullable; + + let source_dtype = Date::new(TimeUnit::Days, NonNullable).erased(); + let target_dtype = Date::new(TimeUnit::Milliseconds, NonNullable).erased(); + + // 1 day and 2 days since epoch + let storage = buffer![1i32, 2].into_array(); + let arr = ExtensionArray::new(source_dtype, storage).into_array(); + + let result = arr.cast(DType::Extension(target_dtype)).unwrap(); + let ext = result.to_canonical().unwrap().as_extension().clone(); + let prim = ext + .storage_array() + .to_canonical() + .unwrap() + .as_primitive() + .clone(); + // Days → Ms: ×86_400_000 + assert_eq!(prim.as_slice::(), &[86_400_000i64, 172_800_000]); + } } diff --git a/vortex-array/src/extension/datetime/time.rs b/vortex-array/src/extension/datetime/time.rs index eb53a0e0fe3..0e7cc44dd8b 100644 --- a/vortex-array/src/extension/datetime/time.rs +++ b/vortex-array/src/extension/datetime/time.rs @@ -10,6 +10,12 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use crate::ArrayRef; +use crate::IntoArray; +use crate::arrays::ConstantArray; +use crate::arrays::ExtensionArray; +use crate::arrays::extension::ExtArray; +use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -18,6 +24,7 @@ use crate::dtype::extension::ExtId; use crate::dtype::extension::ExtVTable; use crate::extension::datetime::TimeUnit; use crate::scalar::ScalarValue; +use crate::scalar_fn::fns::operators::Operator; /// Time DType. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] @@ -129,6 +136,48 @@ impl ExtVTable for Time { Some(DType::Extension(Time::new(finest, union_null).erased())) } + fn cast_from_ext( + &self, + array: &ExtArray, + target: &DType, + ) -> VortexResult> { + let DType::Extension(target_ext) = target else { + return Ok(None); + }; + let Some(target_unit) = target_ext.metadata_opt::