Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions vortex-array/public-api.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions vortex-array/src/arrays/extension/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;

use crate::ArrayRef;
use crate::arrays::extension::view::ExtArray;
use crate::dtype::DType;
use crate::dtype::extension::ExtDTypeRef;
use crate::dtype::extension::ExtVTable;
use crate::stats::ArrayStats;

/// An extension array that wraps another array with additional type information.
Expand Down Expand Up @@ -125,4 +127,8 @@ impl ExtensionArray {
pub fn storage_array(&self) -> &ArrayRef {
&self.storage_array
}

pub fn downcast_ref<V: ExtVTable>(&self) -> Option<ExtArray<'_, V>> {
ExtArray::try_new(self)
}
}
105 changes: 74 additions & 31 deletions vortex-array/src/arrays/extension/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,24 @@ use crate::scalar_fn::fns::cast::CastReduce;

impl CastReduce for Extension {
fn cast(array: &ExtensionArray, dtype: &DType) -> vortex_error::VortexResult<Option<ArrayRef>> {
if !array.dtype().eq_ignore_nullability(dtype) {
return Ok(None);
// Same extension type (ignoring nullability): just cast the storage nullability.
if array.dtype().eq_ignore_nullability(dtype) {
let DType::Extension(ext_dtype) = dtype else {
unreachable!("Already verified we have an extension dtype");
};

let new_storage = array
.storage_array()
.cast(ext_dtype.storage_dtype().clone())?;

return Ok(Some(
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
));
}

let DType::Extension(ext_dtype) = dtype else {
unreachable!("Already verified we have an extension dtype");
};

let new_storage = match array
.storage_array()
.cast(ext_dtype.storage_dtype().clone())
{
Ok(arr) => arr,
Err(e) => {
tracing::warn!("Failed to cast storage array: {e}");
return Ok(None);
}
};

Ok(Some(
ExtensionArray::new(ext_dtype.clone(), new_storage).into_array(),
))
// Type-specific casting (e.g. Timestamp(s) → Timestamp(ns)) is handled by
// ExtVTable::reduce_parent_array, which runs before this CastReduce fallback.
Ok(None)
}
}

Expand Down Expand Up @@ -86,21 +82,68 @@ mod tests {
}

#[test]
fn cast_different_ext_dtype() {
let original_dtype =
fn cast_timestamp_ms_to_ns() {
let source_dtype =
Timestamp::new(TimeUnit::Milliseconds, Nullability::NonNullable).erased();
// Note NS here instead of MS
let target_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();

let storage = buffer![1i64, 2, 3].into_array();
let arr = ExtensionArray::new(source_dtype, storage).into_array();

let result = arr.cast(DType::Extension(target_dtype.clone())).unwrap();
assert_eq!(result.dtype(), &DType::Extension(target_dtype));

// Verify values were scaled: ms → ns is ×1_000_000
let ext = result.to_canonical().unwrap().as_extension().clone();
let prim = ext
.storage_array()
.to_canonical()
.unwrap()
.as_primitive()
.clone();
assert_eq!(prim.as_slice::<i64>(), &[1_000_000, 2_000_000, 3_000_000]);
}

#[test]
fn cast_timestamp_s_to_us() {
let source_dtype = Timestamp::new(TimeUnit::Seconds, Nullability::NonNullable).erased();
let target_dtype =
Timestamp::new(TimeUnit::Microseconds, Nullability::NonNullable).erased();

let storage = buffer![10i64, 20].into_array();
let arr = ExtensionArray::new(source_dtype, storage).into_array();

let result = arr.cast(DType::Extension(target_dtype)).unwrap();
let ext = result.to_canonical().unwrap().as_extension().clone();
let prim = ext
.storage_array()
.to_canonical()
.unwrap()
.as_primitive()
.clone();
assert_eq!(prim.as_slice::<i64>(), &[10_000_000, 20_000_000]);
}

#[test]
fn cast_timestamp_tz_mismatch_fails() {
use std::sync::Arc;

let utc_dtype = Timestamp::new_with_tz(
TimeUnit::Seconds,
Some(Arc::from("UTC")),
Nullability::NonNullable,
)
.erased();
let no_tz_dtype = Timestamp::new(TimeUnit::Nanoseconds, Nullability::NonNullable).erased();

let storage = buffer![1i64].into_array();
let arr = ExtensionArray::new(original_dtype, storage);

assert!(
arr.into_array()
.cast(DType::Extension(target_dtype))
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
.is_err()
);
let arr = ExtensionArray::new(utc_dtype, storage).into_array();

// Timezone mismatch: cast creates a lazy expression, error surfaces on evaluation.
let result = arr
.cast(DType::Extension(no_tz_dtype))
.and_then(|a| a.to_canonical().map(|c| c.into_array()));
assert!(result.is_err());
}

#[rstest]
Expand Down
5 changes: 5 additions & 0 deletions vortex-array/src/arrays/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
mod array;
pub use array::ExtensionArray;

mod view;
pub use view::ExactExtArray;
pub use view::ExtArray;

pub(crate) mod compute;

mod vtable;

pub use vtable::Extension;
57 changes: 57 additions & 0 deletions vortex-array/src/arrays/extension/view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::marker::PhantomData;

use crate::ArrayRef;
use crate::DynArray;
use crate::arrays::Extension;
use crate::arrays::ExtensionArray;
use crate::dtype::extension::ExtDType;
use crate::dtype::extension::ExtVTable;
use crate::matcher::Matcher;

/// A typed view of an extension array.
pub struct ExtArray<'a, V: ExtVTable> {
ext_dtype: &'a ExtDType<V>,
array: &'a ExtensionArray,
}

impl<'a, V: ExtVTable> ExtArray<'a, V> {
pub fn try_new(array: &'a ExtensionArray) -> Option<Self> {
let ext_dtype = array.ext_dtype().downcast_ref::<V>()?;
Some(Self { ext_dtype, array })
}

pub fn ext_dtype(&self) -> &ExtDType<V> {
self.ext_dtype
}

pub fn storage_array(&self) -> &ArrayRef {
self.array.storage_array()
}
}

/// A matcher that matches an [`ExtensionArray`] with a specific [`ExtVTable`] type.
///
/// Similar to [`ExactScalarFn`](crate::arrays::scalar_fn::ExactScalarFn) for scalar functions,
/// this provides typed access to the extension array view ([`ExtArray<V>`]).
#[derive(Debug, Default)]
pub struct ExactExtArray<V: ExtVTable>(PhantomData<V>);

impl<V: ExtVTable> Matcher for ExactExtArray<V> {
type Match<'a> = ExtArray<'a, V>;

fn matches(array: &dyn DynArray) -> bool {
if let Some(ext_array) = array.as_opt::<Extension>() {
ext_array.downcast_ref::<V>().is_some()
} else {
false
}
}

fn try_match(array: &dyn DynArray) -> Option<Self::Match<'_>> {
let ext_array = array.as_opt::<Extension>()?;
ext_array.downcast_ref::<V>()
}
}
19 changes: 14 additions & 5 deletions vortex-array/src/arrays/extension/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,30 @@ impl VTable for Extension {
Ok(ExecutionStep::Done(array.clone().into_array()))
}

fn reduce_parent(
fn execute_parent(
array: &Self::Array,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
PARENT_RULES.evaluate(array, parent, child_idx)
if let Some(result) = array
.ext_dtype()
.execute_parent(array, parent, child_idx, ctx)?
{
return Ok(Some(result));
}
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
}

fn execute_parent(
fn reduce_parent(
array: &Self::Array,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
if let Some(result) = array.ext_dtype().reduce_parent(array, parent, child_idx)? {
return Ok(Some(result));
}
PARENT_RULES.evaluate(array, parent, child_idx)
}
}

Expand Down
7 changes: 7 additions & 0 deletions vortex-array/src/dtype/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ impl DType {

/// Convenience — is there a path from `self` to `other`?
pub fn can_coerce_to(&self, other: &DType) -> bool {
if let DType::Extension(ext) = self {
// Extension types can define coercions in either direction, so check both.
if ext.can_coerce_to(other) {
return true;
}
};

other.can_coerce_from(self)
}

Expand Down
27 changes: 27 additions & 0 deletions vortex-array/src/dtype/extension/erased.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::arrays::ExtensionArray;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::extension::ExtDType;
Expand Down Expand Up @@ -114,6 +117,25 @@ impl ExtDTypeRef {
pub fn least_supertype(&self, other: &DType) -> Option<DType> {
self.0.coercion_least_supertype(other)
}

pub(crate) fn execute_parent(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
self.0.execute_parent_array(array, parent, child_idx, ctx)
}

pub(crate) fn reduce_parent(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
self.0.reduce_parent_array(array, parent, child_idx)
}
}

/// Methods for downcasting type-erased extension dtypes.
Expand Down Expand Up @@ -166,6 +188,11 @@ impl ExtDTypeRef {
})
.vortex_expect("Failed to downcast ExtDTypeRef")
}

/// Downcast to the concrete [`ExtDType`].
pub fn downcast_ref<V: ExtVTable>(&self) -> Option<&ExtDType<V>> {
self.0.as_any().downcast_ref::<ExtDType<V>>()
}
}

impl PartialEq for ExtDTypeRef {
Expand Down
49 changes: 49 additions & 0 deletions vortex-array/src/dtype/extension/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ use std::sync::Arc;

use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::arrays::ExtensionArray;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::extension::ExtDTypeRef;
Expand Down Expand Up @@ -129,6 +133,19 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed {
fn coercion_can_coerce_to(&self, other: &DType) -> bool;
/// Compute the least supertype of this extension type and another type.
fn coercion_least_supertype(&self, other: &DType) -> Option<DType>;
fn execute_parent_array(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
fn reduce_parent_array(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>>;
}

impl<V: ExtVTable> DynExtDType for ExtDType<V> {
Expand Down Expand Up @@ -212,4 +229,36 @@ impl<V: ExtVTable> DynExtDType for ExtDType<V> {
fn coercion_least_supertype(&self, other: &DType) -> Option<DType> {
self.vtable.least_supertype(self, other)
}

fn execute_parent_array(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
self.vtable.execute_parent_array(
&array
.downcast_ref()
.ok_or_else(|| vortex_err!("Array is not an extension array of this type"))?,
parent,
child_idx,
ctx,
)
}

fn reduce_parent_array(
&self,
array: &ExtensionArray,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
self.vtable.reduce_parent_array(
&array
.downcast_ref()
.ok_or_else(|| vortex_err!("Array is not an extension array of this type"))?,
parent,
child_idx,
)
}
}
Loading
Loading