From 9115a8f1856d019faf82f9ee567fb5f1b033683a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 14:35:50 -0400 Subject: [PATCH 1/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/src/dtype/coercion.rs | 452 +++++++++++++++++++ vortex-array/src/dtype/extension/erased.rs | 15 + vortex-array/src/dtype/extension/typed.rs | 18 + vortex-array/src/dtype/extension/vtable.rs | 21 + vortex-array/src/dtype/mod.rs | 1 + vortex-array/src/dtype/ptype.rs | 95 ++++ vortex-array/src/expr/transform/mod.rs | 2 + vortex-array/src/scalar_fn/erased.rs | 5 + vortex-array/src/scalar_fn/fns/binary/mod.rs | 12 + vortex-array/src/scalar_fn/typed.rs | 5 + vortex-array/src/scalar_fn/vtable.rs | 17 +- 11 files changed, 641 insertions(+), 2 deletions(-) create mode 100644 vortex-array/src/dtype/coercion.rs diff --git a/vortex-array/src/dtype/coercion.rs b/vortex-array/src/dtype/coercion.rs new file mode 100644 index 00000000000..b83d42f2177 --- /dev/null +++ b/vortex-array/src/dtype/coercion.rs @@ -0,0 +1,452 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Utilities for performing type coercion. + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::dtype::DType; +use crate::dtype::PType; +use crate::dtype::decimal::DecimalDType; + +impl DType { + /// The core primitive — what type can hold both `self` and `other`? + pub fn least_supertype(&self, other: &DType) -> VortexResult { + // 1. Identity (ignoring nullability): return self with union nullability + if self.eq_ignore_nullability(other) { + return Ok(self.with_nullability(self.nullability() | other.nullability())); + } + + let union_null = self.nullability() | other.nullability(); + + // 2. Null + X: return X as nullable + if matches!(self, DType::Null) { + return Ok(other.as_nullable()); + } + if matches!(other, DType::Null) { + return Ok(self.as_nullable()); + } + + // 3. Bool + numeric: return the numeric type (with union nullability) + if self.is_boolean() && other.is_numeric() { + return Ok(other.with_nullability(union_null)); + } + if other.is_boolean() && self.is_numeric() { + return Ok(self.with_nullability(union_null)); + } + + // 4. Primitive + Primitive (different ptypes): delegate to PType::least_supertype + if let (DType::Primitive(lhs_p, _), DType::Primitive(rhs_p, _)) = (self, other) { + return match lhs_p.least_supertype(*rhs_p) { + Some(p) => Ok(DType::Primitive(p, union_null)), + None => vortex_bail!( + "No common supertype for primitive types {} and {}", + lhs_p, + rhs_p + ), + }; + } + + // 5. Decimal + Decimal: compute wider decimal + if let (DType::Decimal(lhs_d, _), DType::Decimal(rhs_d, _)) = (self, other) { + let d = decimal_least_supertype(*lhs_d, *rhs_d)?; + return Ok(DType::Decimal(d, union_null)); + } + + // 6. Decimal + integer Primitive: convert integer to Decimal, then widen + if let (DType::Decimal(dec, _), DType::Primitive(p, _)) = (self, other) + && p.is_int() + { + let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0); + let d = decimal_least_supertype(*dec, int_dec)?; + return Ok(DType::Decimal(d, union_null)); + } + if let (DType::Primitive(p, _), DType::Decimal(dec, _)) = (self, other) + && p.is_int() + { + let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0); + let d = decimal_least_supertype(int_dec, *dec)?; + return Ok(DType::Decimal(d, union_null)); + } + + // 7. Extension + anything: delegate to vtable + if let DType::Extension(ext) = self { + return match ext.least_supertype(other) { + Some(dt) => Ok(dt.with_nullability(union_null)), + None => vortex_bail!( + "No common supertype for extension type {} and {}", + ext.id(), + other + ), + }; + } + if let DType::Extension(ext) = other { + return match ext.least_supertype(self) { + Some(dt) => Ok(dt.with_nullability(union_null)), + None => vortex_bail!( + "No common supertype for {} and extension type {}", + self, + ext.id() + ), + }; + } + + // 8. Everything else: error + vortex_bail!("No common supertype for {} and {}", self, other) + } + + /// Fold over a slice — what type can hold all of these? + pub fn least_supertype_of(types: &[DType]) -> VortexResult { + types + .iter() + .try_fold(types[0].clone(), |acc, t| acc.least_supertype(t)) + } + + /// Is there any implicit coercion path from `other` to `self`? + pub fn can_coerce_from(&self, other: &DType) -> bool { + // 1. Same type (ignoring nullability): check nullability compatibility + if self.eq_ignore_nullability(other) { + return self.is_nullable() || !other.is_nullable(); + } + + // 2. Null → nullable target + if matches!(other, DType::Null) { + return self.is_nullable(); + } + + // 3. Bool → numeric + if other.is_boolean() && self.is_numeric() { + return self.is_nullable() || !other.is_nullable(); + } + + // 4. Primitive widening: true if least_supertype(source, target) == target + if let (DType::Primitive(..), DType::Primitive(..)) = (self, other) { + return other + .least_supertype(self) + .is_ok_and(|st| st.eq_ignore_nullability(self)) + && (self.is_nullable() || !other.is_nullable()); + } + + // 5. Decimal widening + if let (DType::Decimal(target, _), DType::Decimal(source, _)) = (self, other) { + let target_integral = target.precision() as i16 - target.scale() as i16; + let source_integral = source.precision() as i16 - source.scale() as i16; + return target_integral >= source_integral + && target.scale() >= source.scale() + && (self.is_nullable() || !other.is_nullable()); + } + + // 6. Integer → Decimal + if let (DType::Decimal(dec, _), DType::Primitive(p, _)) = (self, other) + && p.is_int() + { + let needed = integer_decimal_precision(*p); + let integral_digits = dec.precision() as i16 - dec.scale() as i16; + return integral_digits >= needed as i16 + && (self.is_nullable() || !other.is_nullable()); + } + + // 7. Extension: delegate to vtable + if let DType::Extension(ext) = self { + return ext.can_coerce_from(other); + } + + // 8. Everything else: false + false + } + + /// Convenience — is there a path from `self` to `other`? + pub fn can_coerce_to(&self, other: &DType) -> bool { + other.can_coerce_from(self) + } + + /// Are all types in the slice mutually coercible to a common type? + pub fn are_coercible(types: &[DType]) -> bool { + DType::least_supertype_of(types).is_ok() + } + + /// Can all types in the slice be coerced to a specific target? + pub fn all_coercible_to(types: &[DType], target: &DType) -> bool { + types.iter().all(|t| target.can_coerce_from(t)) + } + + /// Coerce a slice to a specific target — returns the vec of targets + /// if all are coercible, error if any are not. + pub fn coerce_all_to(types: &[DType], target: &DType) -> VortexResult> { + types + .iter() + .enumerate() + .map(|(i, t)| { + if target.can_coerce_from(t) { + Ok(target.clone()) + } else { + vortex_bail!("Cannot coerce {} to {} in position {}", t, target, i) + } + }) + .collect() + } + + /// Coerce a slice to their mutual least supertype. + pub fn coerce_to_supertype(types: &[DType]) -> VortexResult> { + let supertype = DType::least_supertype_of(types)?; + Ok(vec![supertype; types.len()]) + } + + /// Is this a numeric type (primitive int/float or decimal)? + pub fn is_numeric(&self) -> bool { + matches!(self, DType::Primitive(..) | DType::Decimal(..)) + } + + /// Is this a temporal type (date, time, timestamp, duration)? + pub fn is_temporal(&self) -> bool { + match self { + DType::Extension(ext) => { + use crate::dtype::extension::Matcher; + use crate::extension::datetime::AnyTemporal; + AnyTemporal::matches(ext) + } + _ => false, + } + } +} + +/// Maps integer PType widths to the minimum decimal precision needed. +fn integer_decimal_precision(ptype: PType) -> u8 { + match ptype { + PType::U8 | PType::I8 => 3, + PType::U16 | PType::I16 => 5, + PType::U32 | PType::I32 => 10, + PType::U64 | PType::I64 => 19, + _ => 19, + } +} + +/// Compute the least supertype of two decimal types using SQL-standard rules. +/// +/// Result precision = max(integral digits) + max(scale), capped at MAX_PRECISION. +fn decimal_least_supertype(a: DecimalDType, b: DecimalDType) -> VortexResult { + let a_integral = a.precision() as i16 - a.scale() as i16; + let b_integral = b.precision() as i16 - b.scale() as i16; + let max_integral = a_integral.max(b_integral); + let max_scale = a.scale().max(b.scale()); + let precision = u8::try_from(max_integral + max_scale as i16).map_err(|_| { + vortex_error::vortex_err!( + "Decimal supertype precision overflow for ({}, {}) and ({}, {})", + a.precision(), + a.scale(), + b.precision(), + b.scale() + ) + })?; + DecimalDType::try_new(precision, max_scale) +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::dtype::DType; + use crate::dtype::PType; + use crate::dtype::decimal::DecimalDType; + use crate::dtype::nullability::Nullability::NonNullable; + use crate::dtype::nullability::Nullability::Nullable; + + #[test] + fn is_numeric() { + assert!(DType::Primitive(PType::I32, NonNullable).is_numeric()); + assert!(DType::Primitive(PType::F64, NonNullable).is_numeric()); + assert!(DType::Decimal(DecimalDType::new(10, 2), NonNullable).is_numeric()); + assert!(!DType::Bool(NonNullable).is_numeric()); + assert!(!DType::Utf8(NonNullable).is_numeric()); + assert!(!DType::Null.is_numeric()); + } + + #[test] + fn least_supertype_identity() -> VortexResult<()> { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + assert_eq!(i32_nn.least_supertype(&i32_nn)?, i32_nn); + Ok(()) + } + + #[test] + fn least_supertype_nullability_union() -> VortexResult<()> { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let i32_n = DType::Primitive(PType::I32, Nullable); + assert_eq!(i32_nn.least_supertype(&i32_n)?, i32_n); + assert_eq!(i32_n.least_supertype(&i32_nn)?, i32_n); + Ok(()) + } + + #[test] + fn least_supertype_null_absorption() -> VortexResult<()> { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + assert_eq!( + DType::Null.least_supertype(&i32_nn)?, + DType::Primitive(PType::I32, Nullable) + ); + assert_eq!( + i32_nn.least_supertype(&DType::Null)?, + DType::Primitive(PType::I32, Nullable) + ); + Ok(()) + } + + #[test] + fn least_supertype_unsigned_widening() -> VortexResult<()> { + let u8_nn = DType::Primitive(PType::U8, NonNullable); + let u32_nn = DType::Primitive(PType::U32, NonNullable); + assert_eq!(u8_nn.least_supertype(&u32_nn)?, u32_nn); + Ok(()) + } + + #[test] + fn least_supertype_signed_widening() -> VortexResult<()> { + let i16_nn = DType::Primitive(PType::I16, NonNullable); + let i64_nn = DType::Primitive(PType::I64, NonNullable); + assert_eq!(i16_nn.least_supertype(&i64_nn)?, i64_nn); + Ok(()) + } + + #[test] + fn least_supertype_cross_family() -> VortexResult<()> { + let u8_nn = DType::Primitive(PType::U8, NonNullable); + let i8_nn = DType::Primitive(PType::I8, NonNullable); + assert_eq!( + u8_nn.least_supertype(&i8_nn)?, + DType::Primitive(PType::I16, NonNullable) + ); + Ok(()) + } + + #[test] + fn least_supertype_u64_i64_error() { + let u64_nn = DType::Primitive(PType::U64, NonNullable); + let i64_nn = DType::Primitive(PType::I64, NonNullable); + assert!(u64_nn.least_supertype(&i64_nn).is_err()); + } + + #[test] + fn least_supertype_int_float_promotion() -> VortexResult<()> { + let u8_nn = DType::Primitive(PType::U8, NonNullable); + let f32_nn = DType::Primitive(PType::F32, NonNullable); + assert_eq!(u8_nn.least_supertype(&f32_nn)?, f32_nn); + Ok(()) + } + + #[test] + fn least_supertype_i32_f32_to_f64() -> VortexResult<()> { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let f32_nn = DType::Primitive(PType::F32, NonNullable); + assert_eq!( + i32_nn.least_supertype(&f32_nn)?, + DType::Primitive(PType::F64, NonNullable) + ); + Ok(()) + } + + #[test] + fn least_supertype_bool_numeric() -> VortexResult<()> { + let bool_nn = DType::Bool(NonNullable); + let i32_nn = DType::Primitive(PType::I32, NonNullable); + assert_eq!(bool_nn.least_supertype(&i32_nn)?, i32_nn); + assert_eq!(i32_nn.least_supertype(&bool_nn)?, i32_nn); + Ok(()) + } + + #[test] + fn least_supertype_decimal_widening() -> VortexResult<()> { + let d1 = DType::Decimal(DecimalDType::new(10, 2), NonNullable); + let d2 = DType::Decimal(DecimalDType::new(15, 5), NonNullable); + let result = d1.least_supertype(&d2)?; + // integral digits: max(8, 10) = 10, max scale = 5, precision = 15 + assert_eq!( + result, + DType::Decimal(DecimalDType::new(15, 5), NonNullable) + ); + Ok(()) + } + + #[test] + fn least_supertype_incompatible_error() { + let utf8 = DType::Utf8(NonNullable); + let i32_nn = DType::Primitive(PType::I32, NonNullable); + assert!(utf8.least_supertype(&i32_nn).is_err()); + } + + #[test] + fn can_coerce_from_widening() { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let i64_nn = DType::Primitive(PType::I64, NonNullable); + assert!(i64_nn.can_coerce_from(&i32_nn)); + } + + #[test] + fn can_coerce_from_narrowing_rejected() { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let i64_nn = DType::Primitive(PType::I64, NonNullable); + assert!(!i32_nn.can_coerce_from(&i64_nn)); + } + + #[test] + fn can_coerce_from_nullability_constraints() { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let i32_n = DType::Primitive(PType::I32, Nullable); + assert!(i32_n.can_coerce_from(&i32_nn)); + assert!(!i32_nn.can_coerce_from(&i32_n)); + } + + #[test] + fn can_coerce_from_null() { + let i32_n = DType::Primitive(PType::I32, Nullable); + let i32_nn = DType::Primitive(PType::I32, NonNullable); + assert!(i32_n.can_coerce_from(&DType::Null)); + assert!(!i32_nn.can_coerce_from(&DType::Null)); + } + + #[test] + fn are_coercible_mixed() { + let types = [ + DType::Primitive(PType::I32, NonNullable), + DType::Primitive(PType::I64, NonNullable), + ]; + assert!(DType::are_coercible(&types)); + } + + #[test] + fn all_coercible_to_target() { + let types = [ + DType::Primitive(PType::I32, NonNullable), + DType::Primitive(PType::I16, NonNullable), + ]; + let target = DType::Primitive(PType::I64, NonNullable); + assert!(DType::all_coercible_to(&types, &target)); + } + + #[test] + fn coerce_to_supertype_works() -> VortexResult<()> { + let types = [ + DType::Primitive(PType::U8, NonNullable), + DType::Primitive(PType::I16, NonNullable), + ]; + let result = DType::coerce_to_supertype(&types)?; + // U8 + I16: unsigned_signed_supertype max_width=max(1,2)=2 => I32 + assert_eq!(result, vec![DType::Primitive(PType::I32, NonNullable); 2]); + Ok(()) + } + + #[test] + fn least_supertype_integer_decimal() -> VortexResult<()> { + let i32_nn = DType::Primitive(PType::I32, NonNullable); + let dec = DType::Decimal(DecimalDType::new(15, 5), NonNullable); + let result = i32_nn.least_supertype(&dec)?; + // int_dec for I32 = Decimal(10, 0). integral digits = 10. + // dec integral = 15 - 5 = 10. + // max_integral = 10, max_scale = 5, precision = 15 + assert_eq!( + result, + DType::Decimal(DecimalDType::new(15, 5), NonNullable) + ); + Ok(()) + } +} diff --git a/vortex-array/src/dtype/extension/erased.rs b/vortex-array/src/dtype/extension/erased.rs index 3f58cfadc49..e624c5c7ed2 100644 --- a/vortex-array/src/dtype/extension/erased.rs +++ b/vortex-array/src/dtype/extension/erased.rs @@ -99,6 +99,21 @@ impl ExtDTypeRef { pub(crate) fn validate_storage_value(&self, storage_value: &ScalarValue) -> VortexResult<()> { self.0.value_validate(storage_value) } + + /// Can a value of `other` be implicitly coerced into this extension type? + pub fn can_coerce_from(&self, other: &DType) -> bool { + self.0.coercion_can_coerce_from(other) + } + + /// Can this extension type be implicitly coerced into `other`? + pub fn can_coerce_to(&self, other: &DType) -> bool { + self.0.coercion_can_coerce_to(other) + } + + /// Compute the least supertype of this extension type and another type. + pub fn least_supertype(&self, other: &DType) -> Option { + self.0.coercion_least_supertype(other) + } } /// Methods for downcasting type-erased extension dtypes. diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index de843249727..003b8aa3393 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -123,6 +123,12 @@ pub(super) trait DynExtDType: 'static + Send + Sync + super::sealed::Sealed { /// Formats an extension scalar value using the current dtype for metadata context. fn value_display(&self, f: &mut fmt::Formatter<'_>, storage_value: &ScalarValue) -> fmt::Result; + /// Can a value of `other` be implicitly coerced into this extension type? + fn coercion_can_coerce_from(&self, other: &DType) -> bool; + /// Can this extension type be implicitly coerced into `other`? + fn coercion_can_coerce_to(&self, other: &DType) -> bool; + /// Compute the least supertype of this extension type and another type. + fn coercion_least_supertype(&self, other: &DType) -> Option; } impl DynExtDType for ExtDType { @@ -194,4 +200,16 @@ impl DynExtDType for ExtDType { ), } } + + fn coercion_can_coerce_from(&self, other: &DType) -> bool { + self.vtable.can_coerce_from(other) + } + + fn coercion_can_coerce_to(&self, other: &DType) -> bool { + self.vtable.can_coerce_to(other) + } + + fn coercion_least_supertype(&self, other: &DType) -> Option { + self.vtable.least_supertype(other) + } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 8d324521b41..92271ea0ede 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -7,6 +7,7 @@ use std::hash::Hash; use vortex_error::VortexResult; +use crate::dtype::DType; use crate::dtype::extension::ExtDType; use crate::dtype::extension::ExtId; use crate::scalar::ScalarValue; @@ -38,6 +39,26 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Validate that the given storage type is compatible with this extension type. fn validate_dtype(&self, ext_dtype: &ExtDType) -> VortexResult<()>; + /// Can a value of `other` be implicitly widened into this type? + /// e.g. GeographyType might accept Point, LineString, etc. + fn can_coerce_from(&self, other: &DType) -> bool { + let _ = other; + false + } + + /// Can this type be implicitly widened into `other`? + fn can_coerce_to(&self, other: &DType) -> bool { + let _ = other; + false + } + + /// Given two types in a Uniform context, what is their least supertype? + /// Return None if no supertype exists. + fn least_supertype(&self, other: &DType) -> Option { + let _ = other; + None + } + // Methods related to the extension scalar values. /// Validate the given storage value is compatible with the extension type. diff --git a/vortex-array/src/dtype/mod.rs b/vortex-array/src/dtype/mod.rs index 6c5f58b966e..6395f316eb5 100644 --- a/vortex-array/src/dtype/mod.rs +++ b/vortex-array/src/dtype/mod.rs @@ -10,6 +10,7 @@ mod arbitrary; pub mod arrow; mod bigint; +mod coercion; mod decimal; mod dtype_impl; pub mod extension; diff --git a/vortex-array/src/dtype/ptype.rs b/vortex-array/src/dtype/ptype.rs index 842f2d585a1..be6be455801 100644 --- a/vortex-array/src/dtype/ptype.rs +++ b/vortex-array/src/dtype/ptype.rs @@ -817,6 +817,101 @@ impl PType { other } } + + /// Returns the least supertype (widest common type) of two primitive types, + /// or `None` if no lossless promotion exists. + /// + /// Rules: + /// - Same family (both unsigned, both signed, both float): pick the wider one. + /// - Unsigned + Signed crossover: promote to signed type one width-step wider. + /// - Int + Float: find the minimum float that losslessly represents the integer. + pub fn least_supertype(self, other: Self) -> Option { + if self == other { + return Some(self); + } + + match ( + self.is_unsigned_int(), + self.is_signed_int(), + self.is_float(), + ) { + // self is unsigned + (true, ..) => match ( + other.is_unsigned_int(), + other.is_signed_int(), + other.is_float(), + ) { + // unsigned + unsigned => wider unsigned + (true, ..) => Some(self.max_unsigned_ptype(other)), + // unsigned + signed => signed one step wider than max of both + (_, true, _) => Self::unsigned_signed_supertype(self, other), + // unsigned + float => int-to-float promotion + (_, _, true) => Self::int_float_supertype(self, other), + _ => None, + }, + // self is signed + (_, true, _) => match ( + other.is_unsigned_int(), + other.is_signed_int(), + other.is_float(), + ) { + // signed + unsigned => signed one step wider than max of both + (true, ..) => Self::unsigned_signed_supertype(other, self), + // signed + signed => wider signed + (_, true, _) => Some(self.max_signed_ptype(other)), + // signed + float => int-to-float promotion + (_, _, true) => Self::int_float_supertype(self, other), + _ => None, + }, + // self is float + (_, _, true) => match ( + other.is_unsigned_int(), + other.is_signed_int(), + other.is_float(), + ) { + // float + unsigned/signed => int-to-float promotion + (true, ..) | (_, true, _) => Self::int_float_supertype(other, self), + // float + float => wider float + (_, _, true) => { + if self.byte_width() >= other.byte_width() { + Some(self) + } else { + Some(other) + } + } + _ => None, + }, + _ => None, + } + } + + /// Promote unsigned + signed to a signed type one width-step wider than both. + fn unsigned_signed_supertype(unsigned: Self, signed: Self) -> Option { + let max_width = unsigned.byte_width().max(signed.byte_width()); + match max_width { + 1 => Some(Self::I16), + 2 => Some(Self::I32), + 4 => Some(Self::I64), + // U64 + I64 — no lossless 128-bit integer type + _ => None, + } + } + + /// Promote integer + float to the minimum float that losslessly represents the integer. + fn int_float_supertype(int: Self, float: Self) -> Option { + let min_float = match int.byte_width() { + 1 => Self::F16, // f16 has 11 bits mantissa, enough for 8-bit ints + 2 => Self::F32, // f32 has 24 bits mantissa, enough for 16-bit ints + 4 => Self::F64, // f64 has 53 bits mantissa, enough for 32-bit ints + _ => return None, // no standard float for 64-bit ints + }; + // Pick the wider of the required float and the given float + if float.byte_width() >= min_float.byte_width() { + Some(float) + } else { + Some(min_float) + } + } } impl Display for PType { diff --git a/vortex-array/src/expr/transform/mod.rs b/vortex-array/src/expr/transform/mod.rs index b1081ad5ca3..6013241d971 100644 --- a/vortex-array/src/expr/transform/mod.rs +++ b/vortex-array/src/expr/transform/mod.rs @@ -2,9 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors //! A collection of transformations that can be applied to a [`crate::expr::Expression`]. +mod coerce; pub(crate) mod match_between; mod partition; mod replace; +pub use coerce::*; pub use partition::*; pub use replace::*; diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index a6c4a5a463f..154ad13bace 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -126,6 +126,11 @@ impl ScalarFnRef { self.0.return_dtype(arg_types) } + /// Coerce the argument types for this scalar function. + pub fn coerce_args(&self, arg_types: &[DType]) -> VortexResult> { + self.0.coerce_args(arg_types) + } + /// Transforms the expression into one representing the validity of this expression. pub fn validity(&self, expr: &Expression) -> VortexResult { Ok(self.0.validity(expr)?.unwrap_or_else(|| { diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 661a535d969..06f1f08e4ab 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -98,6 +98,18 @@ impl ScalarFnVTable for Binary { write!(f, ")") } + fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult> { + let lhs = &args[0]; + let rhs = &args[1]; + if operator.is_arithmetic() || operator.is_comparison() { + let supertype = lhs.least_supertype(rhs)?; + Ok(vec![supertype.clone(), supertype]) + } else { + // Boolean And/Or: no coercion + Ok(args.to_vec()) + } + } + fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index ce3a3e377a4..88072ff4002 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -81,6 +81,7 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { // Bound methods — options accessed from self fn execute(&self, args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx) -> VortexResult; fn return_dtype(&self, arg_types: &[DType]) -> VortexResult; + fn coerce_args(&self, arg_types: &[DType]) -> VortexResult>; fn reduce( &self, node: &dyn ReduceNode, @@ -174,6 +175,10 @@ impl DynScalarFn for ScalarFn { V::return_dtype(&self.vtable, &self.options, arg_dtypes) } + fn coerce_args(&self, arg_types: &[DType]) -> VortexResult> { + V::coerce_args(&self.vtable, &self.options, arg_types) + } + fn reduce( &self, node: &dyn ReduceNode, diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index c5cdac95a88..ce2cecf5df2 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -68,7 +68,7 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { /// Returns the name of the nth child of the expr. fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName; - /// Format this expression in nice human-readable SQL-style format + /// Format this expression in a nice human-readable SQL-style format /// /// The implementation should recursively format child expressions by calling /// `expr.child(i).fmt_sql(f)`. @@ -79,8 +79,21 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { f: &mut Formatter<'_>, ) -> fmt::Result; + /// Coerce the arguments of this function. + /// + /// This is optionally used by Vortex users when performing type coercion over a Vortex + /// expression. Note that direct Vortex query engine integrations (e.g. DuckDB, DataFusion, + /// etc.) do not perform type coercion and rely on the engine's own logical planner. + /// + /// Note that the default implementation simply returns the arguments without coercion, and it + /// is expected that the [`ScalarFnVTable::return_dtype`] call may still fail. + fn coerce_args(&self, options: &Self::Options, args: &[DType]) -> VortexResult> { + let _ = options; + Ok(args.to_vec()) + } + /// Compute the return [`DType`] of the expression if evaluated over the given input types. - fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult; + fn return_dtype(&self, options: &Self::Options, args: &[DType]) -> VortexResult; /// Execute the expression over the input arguments. /// From f210a04259a15ff78cd58dd2a476d278f05c17d9 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 14:35:54 -0400 Subject: [PATCH 2/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/src/expr/transform/coerce.rs | 194 ++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 vortex-array/src/expr/transform/coerce.rs diff --git a/vortex-array/src/expr/transform/coerce.rs b/vortex-array/src/expr/transform/coerce.rs new file mode 100644 index 00000000000..1b6e9acd661 --- /dev/null +++ b/vortex-array/src/expr/transform/coerce.rs @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Expression-level type coercion pass. + +use vortex_error::VortexResult; + +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::cast; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar_fn::fns::literal::Literal; +use crate::scalar_fn::fns::root::Root; + +/// Rewrite an expression tree to insert casts where a scalar function's `coerce_args` demands +/// a different type than what the child currently produces. +/// +/// The rewrite is bottom-up: children are coerced first, then each parent node checks whether +/// its children match the coerced argument types. +pub fn coerce_expression(expr: Expression, scope: &DType) -> VortexResult { + // We capture scope by reference for the closure. + let scope = scope.clone(); + expr.transform_up(|node| { + // Leaf nodes (Root, Literal) have no children to coerce. + if node.is::() || node.is::() || node.children().is_empty() { + return Ok(Transformed::no(node)); + } + + // Compute the current child return types. + let child_dtypes: Vec = node + .children() + .iter() + .map(|c| c.return_dtype(&scope)) + .collect::>()?; + + // Ask the scalar function what types it wants. + let coerced_dtypes = node.scalar_fn().coerce_args(&child_dtypes)?; + + // If nothing changed, skip. + if child_dtypes == coerced_dtypes { + return Ok(Transformed::no(node)); + } + + // Build new children, inserting casts where needed. + let new_children: Vec = node + .children() + .iter() + .zip(coerced_dtypes.iter()) + .map(|(child, target)| { + let child_dtype = child.return_dtype(&scope)?; + if child_dtype.eq_ignore_nullability(target) + && child_dtype.nullability() == target.nullability() + { + Ok(child.clone()) + } else { + Ok(cast(child.clone(), target.clone())) + } + }) + .collect::>()?; + + let new_expr = node.with_children(new_children)?; + Ok(Transformed::yes(new_expr)) + }) + .map(|t| t.into_inner()) +} + +#[cfg(test)] +mod tests { + use vortex_error::VortexResult; + + use crate::dtype::DType; + use crate::dtype::Nullability::NonNullable; + use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::expr::col; + use crate::expr::lit; + use crate::expr::transform::coerce::coerce_expression; + use crate::scalar::Scalar; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::scalar_fn::fns::binary::Binary; + use crate::scalar_fn::fns::cast::Cast; + use crate::scalar_fn::fns::operators::Operator; + + fn test_scope() -> DType { + DType::Struct( + StructFields::new( + ["x", "y"].into(), + vec![ + DType::Primitive(PType::I32, NonNullable), + DType::Primitive(PType::I64, NonNullable), + ], + ), + NonNullable, + ) + } + + #[test] + fn mixed_type_comparison_inserts_cast() -> VortexResult<()> { + let scope = test_scope(); + // x (I32) < y (I64) => should cast x to I64 + let expr = Binary.new_expr(Operator::Lt, [col("x"), col("y")]); + let coerced = coerce_expression(expr, &scope)?; + + // The LHS child should now be a cast expression + assert!(coerced.child(0).is::()); + // The coerced LHS should return I64 + assert_eq!( + coerced.child(0).return_dtype(&scope)?, + DType::Primitive(PType::I64, NonNullable) + ); + // The RHS should be unchanged + assert!(!coerced.child(1).is::()); + Ok(()) + } + + #[test] + fn same_type_comparison_no_cast() -> VortexResult<()> { + let scope = test_scope(); + // x (I32) < x (I32) => no cast needed + let expr = Binary.new_expr(Operator::Lt, [col("x"), col("x")]); + let coerced = coerce_expression(expr, &scope)?; + + // Neither child should be a cast + assert!(!coerced.child(0).is::()); + assert!(!coerced.child(1).is::()); + Ok(()) + } + + #[test] + fn mixed_type_arithmetic_coerces_both() -> VortexResult<()> { + let scope = DType::Struct( + StructFields::new( + ["a", "b"].into(), + vec![ + DType::Primitive(PType::U8, NonNullable), + DType::Primitive(PType::I32, NonNullable), + ], + ), + NonNullable, + ); + // a (U8) + b (I32) => both should be coerced to I32 + // U8 + I32: unsigned_signed_supertype(U8, I32) => max(1,4)=4 => I64 + let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]); + let coerced = coerce_expression(expr, &scope)?; + + // LHS (U8) should be cast + assert!(coerced.child(0).is::()); + // Both should return the same supertype + let lhs_dt = coerced.child(0).return_dtype(&scope)?; + let rhs_dt = coerced.child(1).return_dtype(&scope)?; + assert_eq!(lhs_dt, rhs_dt); + Ok(()) + } + + #[test] + fn boolean_operators_no_coercion() -> VortexResult<()> { + let scope = DType::Struct( + StructFields::new( + ["p", "q"].into(), + vec![DType::Bool(NonNullable), DType::Bool(NonNullable)], + ), + NonNullable, + ); + let expr = Binary.new_expr(Operator::And, [col("p"), col("q")]); + let coerced = coerce_expression(expr, &scope)?; + + assert!(!coerced.child(0).is::()); + assert!(!coerced.child(1).is::()); + Ok(()) + } + + #[test] + fn literal_coercion() -> VortexResult<()> { + let scope = DType::Struct( + StructFields::new( + ["x"].into(), + vec![DType::Primitive(PType::I64, NonNullable)], + ), + NonNullable, + ); + // x (I64) + 1i32 => literal should be cast to I64 + let expr = Binary.new_expr(Operator::Add, [col("x"), lit(Scalar::from(1i32))]); + let coerced = coerce_expression(expr, &scope)?; + + // The RHS (literal) should be cast to I64 + assert!(coerced.child(1).is::()); + assert_eq!( + coerced.child(1).return_dtype(&scope)?, + DType::Primitive(PType::I64, NonNullable) + ); + Ok(()) + } +} From fee68a4074d70e59fa123074657e8891ed88b96b Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 14:44:13 -0400 Subject: [PATCH 3/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/src/dtype/coercion.rs | 67 +++++++++++++++++++++ vortex-array/src/dtype/ptype.rs | 95 ------------------------------ 2 files changed, 67 insertions(+), 95 deletions(-) diff --git a/vortex-array/src/dtype/coercion.rs b/vortex-array/src/dtype/coercion.rs index b83d42f2177..7588d4f8d98 100644 --- a/vortex-array/src/dtype/coercion.rs +++ b/vortex-array/src/dtype/coercion.rs @@ -211,6 +211,73 @@ impl DType { } } +/// Returns the least supertype (widest common type) of two primitive types, +/// or `None` if no lossless promotion exists. +/// +/// Rules: +/// - Same type: return it. +/// - Same family (both unsigned, both signed, both float): pick the wider one. +/// - Unsigned + Signed: promote to a signed type one width-step wider than both. +/// - Int + Float: pick the smallest float that losslessly represents the integer, +/// then take the wider of that and the given float. +fn ptype_least_supertype(a: PType, b: PType) -> Option { + use PType::*; + + if a == b { + return Some(a); + } + + // Same family — pick the wider. + if a.is_unsigned_int() && b.is_unsigned_int() { + return Some(a.max_unsigned_ptype(b)); + } + if a.is_signed_int() && b.is_signed_int() { + return Some(a.max_signed_ptype(b)); + } + if a.is_float() && b.is_float() { + return if a.byte_width() >= b.byte_width() { + Some(a) + } else { + Some(b) + }; + } + + // Unsigned + Signed crossover — promote to signed one width-step wider. + let (unsigned, signed) = if a.is_unsigned_int() && b.is_signed_int() { + (a, b) + } else if a.is_signed_int() && b.is_unsigned_int() { + (b, a) + } else { + // Must be int + float (in either order). + let (int, float) = if a.is_float() { (b, a) } else { (a, b) }; + return int_float_supertype(int, float); + }; + + match unsigned.byte_width().max(signed.byte_width()) { + 1 => Some(I16), + 2 => Some(I32), + 4 => Some(I64), + _ => None, // U64 + I64 — no lossless 128-bit integer type + } +} + +/// Promote integer + float to the minimum float that losslessly represents the integer. +fn int_float_supertype(int: PType, float: PType) -> Option { + use PType::*; + + let min_float = match int.byte_width() { + 1 => F16, // f16 has 11-bit mantissa, enough for 8-bit ints + 2 => F32, // f32 has 24-bit mantissa, enough for 16-bit ints + 4 => F64, // f64 has 53-bit mantissa, enough for 32-bit ints + _ => return None, // no standard float for 64-bit ints + }; + if float.byte_width() >= min_float.byte_width() { + Some(float) + } else { + Some(min_float) + } +} + /// Maps integer PType widths to the minimum decimal precision needed. fn integer_decimal_precision(ptype: PType) -> u8 { match ptype { diff --git a/vortex-array/src/dtype/ptype.rs b/vortex-array/src/dtype/ptype.rs index be6be455801..842f2d585a1 100644 --- a/vortex-array/src/dtype/ptype.rs +++ b/vortex-array/src/dtype/ptype.rs @@ -817,101 +817,6 @@ impl PType { other } } - - /// Returns the least supertype (widest common type) of two primitive types, - /// or `None` if no lossless promotion exists. - /// - /// Rules: - /// - Same family (both unsigned, both signed, both float): pick the wider one. - /// - Unsigned + Signed crossover: promote to signed type one width-step wider. - /// - Int + Float: find the minimum float that losslessly represents the integer. - pub fn least_supertype(self, other: Self) -> Option { - if self == other { - return Some(self); - } - - match ( - self.is_unsigned_int(), - self.is_signed_int(), - self.is_float(), - ) { - // self is unsigned - (true, ..) => match ( - other.is_unsigned_int(), - other.is_signed_int(), - other.is_float(), - ) { - // unsigned + unsigned => wider unsigned - (true, ..) => Some(self.max_unsigned_ptype(other)), - // unsigned + signed => signed one step wider than max of both - (_, true, _) => Self::unsigned_signed_supertype(self, other), - // unsigned + float => int-to-float promotion - (_, _, true) => Self::int_float_supertype(self, other), - _ => None, - }, - // self is signed - (_, true, _) => match ( - other.is_unsigned_int(), - other.is_signed_int(), - other.is_float(), - ) { - // signed + unsigned => signed one step wider than max of both - (true, ..) => Self::unsigned_signed_supertype(other, self), - // signed + signed => wider signed - (_, true, _) => Some(self.max_signed_ptype(other)), - // signed + float => int-to-float promotion - (_, _, true) => Self::int_float_supertype(self, other), - _ => None, - }, - // self is float - (_, _, true) => match ( - other.is_unsigned_int(), - other.is_signed_int(), - other.is_float(), - ) { - // float + unsigned/signed => int-to-float promotion - (true, ..) | (_, true, _) => Self::int_float_supertype(other, self), - // float + float => wider float - (_, _, true) => { - if self.byte_width() >= other.byte_width() { - Some(self) - } else { - Some(other) - } - } - _ => None, - }, - _ => None, - } - } - - /// Promote unsigned + signed to a signed type one width-step wider than both. - fn unsigned_signed_supertype(unsigned: Self, signed: Self) -> Option { - let max_width = unsigned.byte_width().max(signed.byte_width()); - match max_width { - 1 => Some(Self::I16), - 2 => Some(Self::I32), - 4 => Some(Self::I64), - // U64 + I64 — no lossless 128-bit integer type - _ => None, - } - } - - /// Promote integer + float to the minimum float that losslessly represents the integer. - fn int_float_supertype(int: Self, float: Self) -> Option { - let min_float = match int.byte_width() { - 1 => Self::F16, // f16 has 11 bits mantissa, enough for 8-bit ints - 2 => Self::F32, // f32 has 24 bits mantissa, enough for 16-bit ints - 4 => Self::F64, // f64 has 53 bits mantissa, enough for 32-bit ints - _ => return None, // no standard float for 64-bit ints - }; - // Pick the wider of the required float and the given float - if float.byte_width() >= min_float.byte_width() { - Some(float) - } else { - Some(min_float) - } - } } impl Display for PType { From 09430b82b737942631bee77823ca902e30632a3c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 15:00:39 -0400 Subject: [PATCH 4/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/src/dtype/coercion.rs | 313 ++++++++----------- vortex-array/src/scalar_fn/fns/binary/mod.rs | 4 +- 2 files changed, 134 insertions(+), 183 deletions(-) diff --git a/vortex-array/src/dtype/coercion.rs b/vortex-array/src/dtype/coercion.rs index 7588d4f8d98..528d9b673b2 100644 --- a/vortex-array/src/dtype/coercion.rs +++ b/vortex-array/src/dtype/coercion.rs @@ -3,55 +3,113 @@ //! Utilities for performing type coercion. -use vortex_error::VortexResult; -use vortex_error::vortex_bail; - use crate::dtype::DType; use crate::dtype::PType; use crate::dtype::decimal::DecimalDType; +impl PType { + /// Returns the least supertype (widest common type) of two primitive types, + /// or `None` if no lossless promotion exists. + pub fn least_supertype(self, other: PType) -> Option { + if self == other { + return Some(self); + } + + // Same family — pick the wider. + if self.is_unsigned_int() && other.is_unsigned_int() { + return Some(self.max_unsigned_ptype(other)); + } + if self.is_signed_int() && other.is_signed_int() { + return Some(self.max_signed_ptype(other)); + } + if self.is_float() && other.is_float() { + return if self.byte_width() >= other.byte_width() { + Some(self) + } else { + Some(other) + }; + } + + // Unsigned + Signed crossover — promote to signed one width-step wider. + if self.is_unsigned_int() && other.is_signed_int() { + return Self::unsigned_signed_supertype(self, other); + } + if self.is_signed_int() && other.is_unsigned_int() { + return Self::unsigned_signed_supertype(other, self); + } + + // Int + Float — pick the smallest float that losslessly represents the integer. + let (int, float) = if self.is_float() { + (other, self) + } else { + (self, other) + }; + Self::int_float_supertype(int, float) + } + + fn unsigned_signed_supertype(unsigned: PType, signed: PType) -> Option { + use PType::*; + match unsigned.byte_width().max(signed.byte_width()) { + 1 => Some(I16), + 2 => Some(I32), + 4 => Some(I64), + _ => None, // U64 + I64 — no lossless 128-bit integer type + } + } + + fn int_float_supertype(int: PType, float: PType) -> Option { + use PType::*; + let min_float = match int.byte_width() { + 1 => F16, // f16 has 11-bit mantissa, enough for 8-bit ints + 2 => F32, // f32 has 24-bit mantissa, enough for 16-bit ints + 4 => F64, // f64 has 53-bit mantissa, enough for 32-bit ints + _ => return None, // no standard float for 64-bit ints + }; + if float.byte_width() >= min_float.byte_width() { + Some(float) + } else { + Some(min_float) + } + } +} + impl DType { /// The core primitive — what type can hold both `self` and `other`? - pub fn least_supertype(&self, other: &DType) -> VortexResult { + /// Returns `None` if no common supertype exists. + pub fn least_supertype(&self, other: &DType) -> Option { // 1. Identity (ignoring nullability): return self with union nullability if self.eq_ignore_nullability(other) { - return Ok(self.with_nullability(self.nullability() | other.nullability())); + return Some(self.with_nullability(self.nullability() | other.nullability())); } let union_null = self.nullability() | other.nullability(); // 2. Null + X: return X as nullable if matches!(self, DType::Null) { - return Ok(other.as_nullable()); + return Some(other.as_nullable()); } if matches!(other, DType::Null) { - return Ok(self.as_nullable()); + return Some(self.as_nullable()); } // 3. Bool + numeric: return the numeric type (with union nullability) if self.is_boolean() && other.is_numeric() { - return Ok(other.with_nullability(union_null)); + return Some(other.with_nullability(union_null)); } if other.is_boolean() && self.is_numeric() { - return Ok(self.with_nullability(union_null)); + return Some(self.with_nullability(union_null)); } // 4. Primitive + Primitive (different ptypes): delegate to PType::least_supertype if let (DType::Primitive(lhs_p, _), DType::Primitive(rhs_p, _)) = (self, other) { - return match lhs_p.least_supertype(*rhs_p) { - Some(p) => Ok(DType::Primitive(p, union_null)), - None => vortex_bail!( - "No common supertype for primitive types {} and {}", - lhs_p, - rhs_p - ), - }; + return lhs_p + .least_supertype(*rhs_p) + .map(|p| DType::Primitive(p, union_null)); } // 5. Decimal + Decimal: compute wider decimal if let (DType::Decimal(lhs_d, _), DType::Decimal(rhs_d, _)) = (self, other) { - let d = decimal_least_supertype(*lhs_d, *rhs_d)?; - return Ok(DType::Decimal(d, union_null)); + return decimal_least_supertype(*lhs_d, *rhs_d).map(|d| DType::Decimal(d, union_null)); } // 6. Decimal + integer Primitive: convert integer to Decimal, then widen @@ -59,47 +117,36 @@ impl DType { && p.is_int() { let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0); - let d = decimal_least_supertype(*dec, int_dec)?; - return Ok(DType::Decimal(d, union_null)); + return decimal_least_supertype(*dec, int_dec).map(|d| DType::Decimal(d, union_null)); } if let (DType::Primitive(p, _), DType::Decimal(dec, _)) = (self, other) && p.is_int() { let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0); - let d = decimal_least_supertype(int_dec, *dec)?; - return Ok(DType::Decimal(d, union_null)); + return decimal_least_supertype(int_dec, *dec).map(|d| DType::Decimal(d, union_null)); } // 7. Extension + anything: delegate to vtable if let DType::Extension(ext) = self { - return match ext.least_supertype(other) { - Some(dt) => Ok(dt.with_nullability(union_null)), - None => vortex_bail!( - "No common supertype for extension type {} and {}", - ext.id(), - other - ), - }; + return ext + .least_supertype(other) + .map(|dt| dt.with_nullability(union_null)); } if let DType::Extension(ext) = other { - return match ext.least_supertype(self) { - Some(dt) => Ok(dt.with_nullability(union_null)), - None => vortex_bail!( - "No common supertype for {} and extension type {}", - self, - ext.id() - ), - }; + return ext + .least_supertype(self) + .map(|dt| dt.with_nullability(union_null)); } - // 8. Everything else: error - vortex_bail!("No common supertype for {} and {}", self, other) + // 8. Everything else: no common supertype + None } /// Fold over a slice — what type can hold all of these? - pub fn least_supertype_of(types: &[DType]) -> VortexResult { + pub fn least_supertype_of(types: &[DType]) -> Option { types .iter() + .skip(1) .try_fold(types[0].clone(), |acc, t| acc.least_supertype(t)) } @@ -124,7 +171,7 @@ impl DType { if let (DType::Primitive(..), DType::Primitive(..)) = (self, other) { return other .least_supertype(self) - .is_ok_and(|st| st.eq_ignore_nullability(self)) + .is_some_and(|st| st.eq_ignore_nullability(self)) && (self.is_nullable() || !other.is_nullable()); } @@ -163,7 +210,7 @@ impl DType { /// Are all types in the slice mutually coercible to a common type? pub fn are_coercible(types: &[DType]) -> bool { - DType::least_supertype_of(types).is_ok() + DType::least_supertype_of(types).is_some() } /// Can all types in the slice be coerced to a specific target? @@ -172,25 +219,18 @@ impl DType { } /// Coerce a slice to a specific target — returns the vec of targets - /// if all are coercible, error if any are not. - pub fn coerce_all_to(types: &[DType], target: &DType) -> VortexResult> { + /// if all are coercible, `None` if any are not. + pub fn coerce_all_to(types: &[DType], target: &DType) -> Option> { types .iter() - .enumerate() - .map(|(i, t)| { - if target.can_coerce_from(t) { - Ok(target.clone()) - } else { - vortex_bail!("Cannot coerce {} to {} in position {}", t, target, i) - } - }) - .collect() + .all(|t| target.can_coerce_from(t)) + .then(|| vec![target.clone(); types.len()]) } /// Coerce a slice to their mutual least supertype. - pub fn coerce_to_supertype(types: &[DType]) -> VortexResult> { + pub fn coerce_to_supertype(types: &[DType]) -> Option> { let supertype = DType::least_supertype_of(types)?; - Ok(vec![supertype; types.len()]) + Some(vec![supertype; types.len()]) } /// Is this a numeric type (primitive int/float or decimal)? @@ -211,73 +251,6 @@ impl DType { } } -/// Returns the least supertype (widest common type) of two primitive types, -/// or `None` if no lossless promotion exists. -/// -/// Rules: -/// - Same type: return it. -/// - Same family (both unsigned, both signed, both float): pick the wider one. -/// - Unsigned + Signed: promote to a signed type one width-step wider than both. -/// - Int + Float: pick the smallest float that losslessly represents the integer, -/// then take the wider of that and the given float. -fn ptype_least_supertype(a: PType, b: PType) -> Option { - use PType::*; - - if a == b { - return Some(a); - } - - // Same family — pick the wider. - if a.is_unsigned_int() && b.is_unsigned_int() { - return Some(a.max_unsigned_ptype(b)); - } - if a.is_signed_int() && b.is_signed_int() { - return Some(a.max_signed_ptype(b)); - } - if a.is_float() && b.is_float() { - return if a.byte_width() >= b.byte_width() { - Some(a) - } else { - Some(b) - }; - } - - // Unsigned + Signed crossover — promote to signed one width-step wider. - let (unsigned, signed) = if a.is_unsigned_int() && b.is_signed_int() { - (a, b) - } else if a.is_signed_int() && b.is_unsigned_int() { - (b, a) - } else { - // Must be int + float (in either order). - let (int, float) = if a.is_float() { (b, a) } else { (a, b) }; - return int_float_supertype(int, float); - }; - - match unsigned.byte_width().max(signed.byte_width()) { - 1 => Some(I16), - 2 => Some(I32), - 4 => Some(I64), - _ => None, // U64 + I64 — no lossless 128-bit integer type - } -} - -/// Promote integer + float to the minimum float that losslessly represents the integer. -fn int_float_supertype(int: PType, float: PType) -> Option { - use PType::*; - - let min_float = match int.byte_width() { - 1 => F16, // f16 has 11-bit mantissa, enough for 8-bit ints - 2 => F32, // f32 has 24-bit mantissa, enough for 16-bit ints - 4 => F64, // f64 has 53-bit mantissa, enough for 32-bit ints - _ => return None, // no standard float for 64-bit ints - }; - if float.byte_width() >= min_float.byte_width() { - Some(float) - } else { - Some(min_float) - } -} - /// Maps integer PType widths to the minimum decimal precision needed. fn integer_decimal_precision(ptype: PType) -> u8 { match ptype { @@ -290,29 +263,17 @@ fn integer_decimal_precision(ptype: PType) -> u8 { } /// Compute the least supertype of two decimal types using SQL-standard rules. -/// -/// Result precision = max(integral digits) + max(scale), capped at MAX_PRECISION. -fn decimal_least_supertype(a: DecimalDType, b: DecimalDType) -> VortexResult { +fn decimal_least_supertype(a: DecimalDType, b: DecimalDType) -> Option { let a_integral = a.precision() as i16 - a.scale() as i16; let b_integral = b.precision() as i16 - b.scale() as i16; let max_integral = a_integral.max(b_integral); let max_scale = a.scale().max(b.scale()); - let precision = u8::try_from(max_integral + max_scale as i16).map_err(|_| { - vortex_error::vortex_err!( - "Decimal supertype precision overflow for ({}, {}) and ({}, {})", - a.precision(), - a.scale(), - b.precision(), - b.scale() - ) - })?; - DecimalDType::try_new(precision, max_scale) + let precision = u8::try_from(max_integral + max_scale as i16).ok()?; + DecimalDType::try_new(precision, max_scale).ok() } #[cfg(test)] mod tests { - use vortex_error::VortexResult; - use crate::dtype::DType; use crate::dtype::PType; use crate::dtype::decimal::DecimalDType; @@ -330,115 +291,105 @@ mod tests { } #[test] - fn least_supertype_identity() -> VortexResult<()> { + fn least_supertype_identity() { let i32_nn = DType::Primitive(PType::I32, NonNullable); - assert_eq!(i32_nn.least_supertype(&i32_nn)?, i32_nn); - Ok(()) + assert_eq!(i32_nn.least_supertype(&i32_nn).unwrap(), i32_nn); } #[test] - fn least_supertype_nullability_union() -> VortexResult<()> { + fn least_supertype_nullability_union() { let i32_nn = DType::Primitive(PType::I32, NonNullable); let i32_n = DType::Primitive(PType::I32, Nullable); - assert_eq!(i32_nn.least_supertype(&i32_n)?, i32_n); - assert_eq!(i32_n.least_supertype(&i32_nn)?, i32_n); - Ok(()) + assert_eq!(i32_nn.least_supertype(&i32_n).unwrap(), i32_n); + assert_eq!(i32_n.least_supertype(&i32_nn).unwrap(), i32_n); } #[test] - fn least_supertype_null_absorption() -> VortexResult<()> { + fn least_supertype_null_absorption() { let i32_nn = DType::Primitive(PType::I32, NonNullable); assert_eq!( - DType::Null.least_supertype(&i32_nn)?, + DType::Null.least_supertype(&i32_nn).unwrap(), DType::Primitive(PType::I32, Nullable) ); assert_eq!( - i32_nn.least_supertype(&DType::Null)?, + i32_nn.least_supertype(&DType::Null).unwrap(), DType::Primitive(PType::I32, Nullable) ); - Ok(()) } #[test] - fn least_supertype_unsigned_widening() -> VortexResult<()> { + fn least_supertype_unsigned_widening() { let u8_nn = DType::Primitive(PType::U8, NonNullable); let u32_nn = DType::Primitive(PType::U32, NonNullable); - assert_eq!(u8_nn.least_supertype(&u32_nn)?, u32_nn); - Ok(()) + assert_eq!(u8_nn.least_supertype(&u32_nn).unwrap(), u32_nn); } #[test] - fn least_supertype_signed_widening() -> VortexResult<()> { + fn least_supertype_signed_widening() { let i16_nn = DType::Primitive(PType::I16, NonNullable); let i64_nn = DType::Primitive(PType::I64, NonNullable); - assert_eq!(i16_nn.least_supertype(&i64_nn)?, i64_nn); - Ok(()) + assert_eq!(i16_nn.least_supertype(&i64_nn).unwrap(), i64_nn); } #[test] - fn least_supertype_cross_family() -> VortexResult<()> { + fn least_supertype_cross_family() { let u8_nn = DType::Primitive(PType::U8, NonNullable); let i8_nn = DType::Primitive(PType::I8, NonNullable); assert_eq!( - u8_nn.least_supertype(&i8_nn)?, + u8_nn.least_supertype(&i8_nn).unwrap(), DType::Primitive(PType::I16, NonNullable) ); - Ok(()) } #[test] - fn least_supertype_u64_i64_error() { + fn least_supertype_u64_i64_none() { let u64_nn = DType::Primitive(PType::U64, NonNullable); let i64_nn = DType::Primitive(PType::I64, NonNullable); - assert!(u64_nn.least_supertype(&i64_nn).is_err()); + assert!(u64_nn.least_supertype(&i64_nn).is_none()); } #[test] - fn least_supertype_int_float_promotion() -> VortexResult<()> { + fn least_supertype_int_float_promotion() { let u8_nn = DType::Primitive(PType::U8, NonNullable); let f32_nn = DType::Primitive(PType::F32, NonNullable); - assert_eq!(u8_nn.least_supertype(&f32_nn)?, f32_nn); - Ok(()) + assert_eq!(u8_nn.least_supertype(&f32_nn).unwrap(), f32_nn); } #[test] - fn least_supertype_i32_f32_to_f64() -> VortexResult<()> { + fn least_supertype_i32_f32_to_f64() { let i32_nn = DType::Primitive(PType::I32, NonNullable); let f32_nn = DType::Primitive(PType::F32, NonNullable); assert_eq!( - i32_nn.least_supertype(&f32_nn)?, + i32_nn.least_supertype(&f32_nn).unwrap(), DType::Primitive(PType::F64, NonNullable) ); - Ok(()) } #[test] - fn least_supertype_bool_numeric() -> VortexResult<()> { + fn least_supertype_bool_numeric() { let bool_nn = DType::Bool(NonNullable); let i32_nn = DType::Primitive(PType::I32, NonNullable); - assert_eq!(bool_nn.least_supertype(&i32_nn)?, i32_nn); - assert_eq!(i32_nn.least_supertype(&bool_nn)?, i32_nn); - Ok(()) + assert_eq!(bool_nn.least_supertype(&i32_nn).unwrap(), i32_nn); + assert_eq!(i32_nn.least_supertype(&bool_nn).unwrap(), i32_nn); } #[test] - fn least_supertype_decimal_widening() -> VortexResult<()> { + fn least_supertype_decimal_widening() { let d1 = DType::Decimal(DecimalDType::new(10, 2), NonNullable); let d2 = DType::Decimal(DecimalDType::new(15, 5), NonNullable); - let result = d1.least_supertype(&d2)?; + let result = d1.least_supertype(&d2).unwrap(); // integral digits: max(8, 10) = 10, max scale = 5, precision = 15 assert_eq!( result, DType::Decimal(DecimalDType::new(15, 5), NonNullable) ); - Ok(()) } #[test] - fn least_supertype_incompatible_error() { + fn least_supertype_incompatible_none() { let utf8 = DType::Utf8(NonNullable); let i32_nn = DType::Primitive(PType::I32, NonNullable); - assert!(utf8.least_supertype(&i32_nn).is_err()); + assert!(utf8.least_supertype(&i32_nn).is_none()); } #[test] @@ -491,22 +442,21 @@ mod tests { } #[test] - fn coerce_to_supertype_works() -> VortexResult<()> { + fn coerce_to_supertype_works() { let types = [ DType::Primitive(PType::U8, NonNullable), DType::Primitive(PType::I16, NonNullable), ]; - let result = DType::coerce_to_supertype(&types)?; + let result = DType::coerce_to_supertype(&types).unwrap(); // U8 + I16: unsigned_signed_supertype max_width=max(1,2)=2 => I32 assert_eq!(result, vec![DType::Primitive(PType::I32, NonNullable); 2]); - Ok(()) } #[test] - fn least_supertype_integer_decimal() -> VortexResult<()> { + fn least_supertype_integer_decimal() { let i32_nn = DType::Primitive(PType::I32, NonNullable); let dec = DType::Decimal(DecimalDType::new(15, 5), NonNullable); - let result = i32_nn.least_supertype(&dec)?; + let result = i32_nn.least_supertype(&dec).unwrap(); // int_dec for I32 = Decimal(10, 0). integral digits = 10. // dec integral = 15 - 5 = 10. // max_integral = 10, max_scale = 5, precision = 15 @@ -514,6 +464,5 @@ mod tests { result, DType::Decimal(DecimalDType::new(15, 5), NonNullable) ); - Ok(()) } } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 06f1f08e4ab..7c6a80704ee 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -102,7 +102,9 @@ impl ScalarFnVTable for Binary { let lhs = &args[0]; let rhs = &args[1]; if operator.is_arithmetic() || operator.is_comparison() { - let supertype = lhs.least_supertype(rhs)?; + let supertype = lhs.least_supertype(rhs).ok_or_else(|| { + vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs) + })?; Ok(vec![supertype.clone(), supertype]) } else { // Boolean And/Or: no coercion From f24caa32ed87b959325286355c7eee76042a9cc3 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 15:01:46 -0400 Subject: [PATCH 5/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 166 ++++++++++++++++++++++++++++++++++- 1 file changed, 165 insertions(+), 1 deletion(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index e7e66efbe27..99d9c9a0d73 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -10162,6 +10162,10 @@ pub struct vortex_array::dtype::extension::ExtDTypeRef(_) impl vortex_array::dtype::extension::ExtDTypeRef +pub fn vortex_array::dtype::extension::ExtDTypeRef::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::dtype::extension::ExtDTypeRef::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::dtype::extension::ExtDTypeRef::display_metadata(&self) -> impl core::fmt::Display + '_ pub fn vortex_array::dtype::extension::ExtDTypeRef::eq_ignore_nullability(&self, other: &Self) -> bool @@ -10170,6 +10174,8 @@ pub fn vortex_array::dtype::extension::ExtDTypeRef::id(&self) -> vortex_array::d pub fn vortex_array::dtype::extension::ExtDTypeRef::is_nullable(&self) -> bool +pub fn vortex_array::dtype::extension::ExtDTypeRef::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::dtype::extension::ExtDTypeRef::nullability(&self) -> vortex_array::dtype::Nullability pub fn vortex_array::dtype::extension::ExtDTypeRef::serialize_metadata(&self) -> vortex_error::VortexResult> @@ -10230,10 +10236,16 @@ pub type vortex_array::dtype::extension::ExtVTable::Metadata: 'static + core::ma pub type vortex_array::dtype::extension::ExtVTable::NativeValue<'a>: core::fmt::Display +pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::dtype::extension::ExtVTable::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::dtype::extension::ExtVTable::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::dtype::extension::ExtVTable::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10248,10 +10260,16 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue +pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Date::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10266,10 +10284,16 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue +pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Time::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10284,10 +10308,16 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10302,10 +10332,16 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid +pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -10490,6 +10526,28 @@ pub fn vortex_array::dtype::DType::with_nullability(&self, nullability: vortex_a impl vortex_array::dtype::DType +pub fn vortex_array::dtype::DType::all_coercible_to(types: &[vortex_array::dtype::DType], target: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::dtype::DType::are_coercible(types: &[vortex_array::dtype::DType]) -> bool + +pub fn vortex_array::dtype::DType::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::dtype::DType::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::dtype::DType::coerce_all_to(types: &[vortex_array::dtype::DType], target: &vortex_array::dtype::DType) -> core::option::Option> + +pub fn vortex_array::dtype::DType::coerce_to_supertype(types: &[vortex_array::dtype::DType]) -> core::option::Option> + +pub fn vortex_array::dtype::DType::is_numeric(&self) -> bool + +pub fn vortex_array::dtype::DType::is_temporal(&self) -> bool + +pub fn vortex_array::dtype::DType::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::dtype::DType::least_supertype_of(types: &[vortex_array::dtype::DType]) -> core::option::Option + +impl vortex_array::dtype::DType + pub fn vortex_array::dtype::DType::from_flatbuffer(buffer: vortex_flatbuffers::FlatBuffer, session: &vortex_session::VortexSession) -> vortex_error::VortexResult impl vortex_array::dtype::DType @@ -10886,6 +10944,10 @@ pub fn vortex_array::dtype::PType::from_i32(value: i32) -> core::option::Option< pub fn vortex_array::dtype::PType::is_valid(value: i32) -> bool +impl vortex_array::dtype::PType + +pub fn vortex_array::dtype::PType::least_supertype(self, other: vortex_array::dtype::PType) -> core::option::Option + impl core::clone::Clone for vortex_array::dtype::PType pub fn vortex_array::dtype::PType::clone(&self) -> vortex_array::dtype::PType @@ -13398,6 +13460,8 @@ impl core::fmt::Display for vortex_array::expr::transform pub fn vortex_array::expr::transform::PartitionedExpr::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_array::expr::transform::coerce_expression(expr: vortex_array::expr::Expression, scope: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub fn vortex_array::expr::transform::partition(expr: vortex_array::expr::Expression, scope: &vortex_array::dtype::DType, annotate_fn: A) -> vortex_error::VortexResult::Annotation>> where ::Annotation: core::fmt::Display, vortex_array::dtype::FieldName: core::convert::From<::Annotation> pub fn vortex_array::expr::transform::replace(expr: vortex_array::expr::Expression, needle: &vortex_array::expr::Expression, replacement: vortex_array::expr::Expression) -> vortex_array::expr::Expression @@ -14134,10 +14198,16 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue +pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Date::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Date::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14184,10 +14254,16 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue +pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Time::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Time::unpack_native(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14236,10 +14312,16 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::datetime::Timestamp::unpack_native<'a>(&self, ext_dtype: &'a vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -14312,10 +14394,16 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid +pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool + +pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool + pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId +pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option + pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> pub fn vortex_array::extension::uuid::Uuid::unpack_native<'a>(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, storage_value: &'a vortex_array::scalar::ScalarValue) -> vortex_error::VortexResult @@ -16828,6 +16916,8 @@ pub fn vortex_array::scalar_fn::fns::between::Between::arity(&self, _options: &S pub fn vortex_array::scalar_fn::fns::between::Between::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::between::Between::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::between::Between::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::between::Between::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -16958,6 +17048,8 @@ pub fn vortex_array::scalar_fn::fns::binary::Binary::arity(&self, _options: &Sel pub fn vortex_array::scalar_fn::fns::binary::Binary::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::binary::Binary::coerce_args(&self, operator: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::binary::Binary::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::binary::Binary::execute(&self, op: &vortex_array::scalar_fn::fns::operators::Operator, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17042,6 +17134,8 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17122,6 +17216,8 @@ pub fn vortex_array::scalar_fn::fns::cast::Cast::arity(&self, _options: &vortex_ pub fn vortex_array::scalar_fn::fns::cast::Cast::child_name(&self, _instance: &vortex_array::dtype::DType, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::cast::Cast::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::cast::Cast::deserialize(&self, _metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::cast::Cast::execute(&self, target_dtype: &vortex_array::dtype::DType, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17262,6 +17358,8 @@ pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::arity(&self, _o pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::execute(&self, data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17342,6 +17440,8 @@ pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::arity(&self, _options: pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17450,6 +17550,8 @@ pub fn vortex_array::scalar_fn::fns::get_item::GetItem::arity(&self, _field_name pub fn vortex_array::scalar_fn::fns::get_item::GetItem::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::get_item::GetItem::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::get_item::GetItem::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::get_item::GetItem::execute(&self, field_name: &vortex_array::dtype::FieldName, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17494,6 +17596,8 @@ pub fn vortex_array::scalar_fn::fns::is_null::IsNull::arity(&self, _options: &Se pub fn vortex_array::scalar_fn::fns::is_null::IsNull::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::is_null::IsNull::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::is_null::IsNull::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::is_null::IsNull::execute(&self, _data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17538,6 +17642,8 @@ pub fn vortex_array::scalar_fn::fns::like::Like::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::like::Like::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::like::Like::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::like::Like::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::like::Like::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17662,6 +17768,8 @@ pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::arity(&self, _ pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17746,6 +17854,8 @@ pub fn vortex_array::scalar_fn::fns::literal::Literal::arity(&self, _options: &S pub fn vortex_array::scalar_fn::fns::literal::Literal::child_name(&self, _instance: &Self::Options, _child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::literal::Literal::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::literal::Literal::deserialize(&self, _metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::literal::Literal::execute(&self, scalar: &vortex_array::scalar::Scalar, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17790,6 +17900,8 @@ pub fn vortex_array::scalar_fn::fns::mask::Mask::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::mask::Mask::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::mask::Mask::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::mask::Mask::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::mask::Mask::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -17966,6 +18078,8 @@ pub fn vortex_array::scalar_fn::fns::merge::Merge::arity(&self, _options: &Self: pub fn vortex_array::scalar_fn::fns::merge::Merge::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::merge::Merge::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::merge::Merge::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::merge::Merge::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18010,6 +18124,8 @@ pub fn vortex_array::scalar_fn::fns::not::Not::arity(&self, _options: &Self::Opt pub fn vortex_array::scalar_fn::fns::not::Not::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::not::Not::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::not::Not::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::not::Not::execute(&self, _data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18260,6 +18376,8 @@ pub fn vortex_array::scalar_fn::fns::pack::Pack::arity(&self, options: &Self::Op pub fn vortex_array::scalar_fn::fns::pack::Pack::child_name(&self, instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::pack::Pack::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::pack::Pack::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::pack::Pack::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18334,6 +18452,8 @@ pub fn vortex_array::scalar_fn::fns::root::Root::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::root::Root::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::root::Root::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::root::Root::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::root::Root::execute(&self, _data: &Self::Options, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18422,6 +18542,8 @@ pub fn vortex_array::scalar_fn::fns::select::Select::arity(&self, _options: &vor pub fn vortex_array::scalar_fn::fns::select::Select::child_name(&self, _instance: &vortex_array::scalar_fn::fns::select::FieldSelection, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::select::Select::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::select::Select::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::select::Select::execute(&self, selection: &vortex_array::scalar_fn::fns::select::FieldSelection, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18466,6 +18588,8 @@ pub fn vortex_array::scalar_fn::fns::zip::Zip::arity(&self, _options: &Self::Opt pub fn vortex_array::scalar_fn::fns::zip::Zip::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::zip::Zip::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::zip::Zip::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::zip::Zip::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18684,6 +18808,8 @@ pub fn vortex_array::scalar_fn::ScalarFnRef::as_(&self) -> core::option::Option<&::Options> +pub fn vortex_array::scalar_fn::ScalarFnRef::coerce_args(&self, arg_types: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::ScalarFnRef::downcast(self) -> alloc::sync::Arc> pub fn vortex_array::scalar_fn::ScalarFnRef::downcast_ref(&self) -> core::option::Option<&vortex_array::scalar_fn::ScalarFn> @@ -18842,6 +18968,8 @@ pub fn vortex_array::scalar_fn::ScalarFnVTable::arity(&self, options: &Self::Opt pub fn vortex_array::scalar_fn::ScalarFnVTable::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::ScalarFnVTable::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::ScalarFnVTable::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::ScalarFnVTable::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18856,7 +18984,7 @@ pub fn vortex_array::scalar_fn::ScalarFnVTable::is_null_sensitive(&self, options pub fn vortex_array::scalar_fn::ScalarFnVTable::reduce(&self, options: &Self::Options, node: &dyn vortex_array::scalar_fn::ReduceNode, ctx: &dyn vortex_array::scalar_fn::ReduceCtx) -> vortex_error::VortexResult> -pub fn vortex_array::scalar_fn::ScalarFnVTable::return_dtype(&self, options: &Self::Options, arg_dtypes: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult +pub fn vortex_array::scalar_fn::ScalarFnVTable::return_dtype(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::ScalarFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> @@ -18878,6 +19006,8 @@ pub fn vortex_array::scalar_fn::fns::between::Between::arity(&self, _options: &S pub fn vortex_array::scalar_fn::fns::between::Between::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::between::Between::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::between::Between::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::between::Between::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18914,6 +19044,8 @@ pub fn vortex_array::scalar_fn::fns::binary::Binary::arity(&self, _options: &Sel pub fn vortex_array::scalar_fn::fns::binary::Binary::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::binary::Binary::coerce_args(&self, operator: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::binary::Binary::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::binary::Binary::execute(&self, op: &vortex_array::scalar_fn::fns::operators::Operator, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18950,6 +19082,8 @@ pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::arity(&self, options: pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::child_name(&self, options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::case_when::CaseWhen::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -18986,6 +19120,8 @@ pub fn vortex_array::scalar_fn::fns::cast::Cast::arity(&self, _options: &vortex_ pub fn vortex_array::scalar_fn::fns::cast::Cast::child_name(&self, _instance: &vortex_array::dtype::DType, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::cast::Cast::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::cast::Cast::deserialize(&self, _metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::cast::Cast::execute(&self, target_dtype: &vortex_array::dtype::DType, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19022,6 +19158,8 @@ pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::arity(&self, _o pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::dynamic::DynamicComparison::execute(&self, data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19058,6 +19196,8 @@ pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::arity(&self, _options: pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::fill_null::FillNull::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19094,6 +19234,8 @@ pub fn vortex_array::scalar_fn::fns::get_item::GetItem::arity(&self, _field_name pub fn vortex_array::scalar_fn::fns::get_item::GetItem::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::get_item::GetItem::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::get_item::GetItem::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::get_item::GetItem::execute(&self, field_name: &vortex_array::dtype::FieldName, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19130,6 +19272,8 @@ pub fn vortex_array::scalar_fn::fns::is_null::IsNull::arity(&self, _options: &Se pub fn vortex_array::scalar_fn::fns::is_null::IsNull::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::is_null::IsNull::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::is_null::IsNull::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::is_null::IsNull::execute(&self, _data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19166,6 +19310,8 @@ pub fn vortex_array::scalar_fn::fns::like::Like::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::like::Like::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::like::Like::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::like::Like::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::like::Like::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19202,6 +19348,8 @@ pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::arity(&self, _ pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::list_contains::ListContains::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19238,6 +19386,8 @@ pub fn vortex_array::scalar_fn::fns::literal::Literal::arity(&self, _options: &S pub fn vortex_array::scalar_fn::fns::literal::Literal::child_name(&self, _instance: &Self::Options, _child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::literal::Literal::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::literal::Literal::deserialize(&self, _metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::literal::Literal::execute(&self, scalar: &vortex_array::scalar::Scalar, args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19274,6 +19424,8 @@ pub fn vortex_array::scalar_fn::fns::mask::Mask::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::mask::Mask::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::mask::Mask::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::mask::Mask::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::mask::Mask::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19310,6 +19462,8 @@ pub fn vortex_array::scalar_fn::fns::merge::Merge::arity(&self, _options: &Self: pub fn vortex_array::scalar_fn::fns::merge::Merge::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::merge::Merge::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::merge::Merge::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::merge::Merge::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19346,6 +19500,8 @@ pub fn vortex_array::scalar_fn::fns::not::Not::arity(&self, _options: &Self::Opt pub fn vortex_array::scalar_fn::fns::not::Not::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::not::Not::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::not::Not::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::not::Not::execute(&self, _data: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19382,6 +19538,8 @@ pub fn vortex_array::scalar_fn::fns::pack::Pack::arity(&self, options: &Self::Op pub fn vortex_array::scalar_fn::fns::pack::Pack::child_name(&self, instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::pack::Pack::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::pack::Pack::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::pack::Pack::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19418,6 +19576,8 @@ pub fn vortex_array::scalar_fn::fns::root::Root::arity(&self, _options: &Self::O pub fn vortex_array::scalar_fn::fns::root::Root::child_name(&self, _instance: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::root::Root::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::root::Root::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::root::Root::execute(&self, _data: &Self::Options, _args: &dyn vortex_array::scalar_fn::ExecutionArgs, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19454,6 +19614,8 @@ pub fn vortex_array::scalar_fn::fns::select::Select::arity(&self, _options: &vor pub fn vortex_array::scalar_fn::fns::select::Select::child_name(&self, _instance: &vortex_array::scalar_fn::fns::select::FieldSelection, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::select::Select::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::select::Select::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::select::Select::execute(&self, selection: &vortex_array::scalar_fn::fns::select::FieldSelection, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult @@ -19490,6 +19652,8 @@ pub fn vortex_array::scalar_fn::fns::zip::Zip::arity(&self, _options: &Self::Opt pub fn vortex_array::scalar_fn::fns::zip::Zip::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::ChildName +pub fn vortex_array::scalar_fn::fns::zip::Zip::coerce_args(&self, options: &Self::Options, args: &[vortex_array::dtype::DType]) -> vortex_error::VortexResult> + pub fn vortex_array::scalar_fn::fns::zip::Zip::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_array::scalar_fn::fns::zip::Zip::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::ExecutionArgs, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult From 3250b1a81cc2d9b183b0e7b55cee73137c179537 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 15:12:49 -0400 Subject: [PATCH 6/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/src/aggregate_fn/erased.rs | 5 +++++ vortex-array/src/aggregate_fn/typed.rs | 5 +++++ vortex-array/src/aggregate_fn/vtable.rs | 9 +++++++++ 3 files changed, 19 insertions(+) diff --git a/vortex-array/src/aggregate_fn/erased.rs b/vortex-array/src/aggregate_fn/erased.rs index 0d90a395499..78f182b6343 100644 --- a/vortex-array/src/aggregate_fn/erased.rs +++ b/vortex-array/src/aggregate_fn/erased.rs @@ -75,6 +75,11 @@ impl AggregateFnRef { AggregateFnOptions { inner: &*self.0 } } + /// Coerce the input type for this aggregate function. + pub fn coerce_args(&self, input_dtype: &DType) -> VortexResult { + self.0.coerce_args(input_dtype) + } + /// Compute the return [`DType`] per group given the input element type. pub fn return_dtype(&self, input_dtype: &DType) -> VortexResult { self.0.return_dtype(input_dtype) diff --git a/vortex-array/src/aggregate_fn/typed.rs b/vortex-array/src/aggregate_fn/typed.rs index fce90ecf33c..f8521738c69 100644 --- a/vortex-array/src/aggregate_fn/typed.rs +++ b/vortex-array/src/aggregate_fn/typed.rs @@ -39,6 +39,7 @@ pub(super) trait DynAggregateFn: 'static + Send + Sync + super::sealed::Sealed { fn id(&self) -> AggregateFnId; fn options_any(&self) -> &dyn Any; + fn coerce_args(&self, input_dtype: &DType) -> VortexResult; fn return_dtype(&self, input_dtype: &DType) -> VortexResult; fn state_dtype(&self, input_dtype: &DType) -> VortexResult; fn accumulator( @@ -84,6 +85,10 @@ impl DynAggregateFn for AggregateFnInner { &self.options } + fn coerce_args(&self, input_dtype: &DType) -> VortexResult { + V::coerce_args(&self.vtable, &self.options, input_dtype) + } + fn return_dtype(&self, input_dtype: &DType) -> VortexResult { V::return_dtype(&self.vtable, &self.options, input_dtype) } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 0e45e8a54fd..64db588d336 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -60,6 +60,15 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { vortex_bail!("Aggregate function {} is not deserializable", self.id()); } + /// Coerce the input type for this aggregate function. + /// + /// This is optionally used by Vortex users when performing type coercion over a Vortex + /// expression. The default implementation returns the input type unchanged. + fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + let _ = options; + Ok(input_dtype.clone()) + } + /// The return [`DType`] of the aggregate. fn return_dtype(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult; From c6b6da79e0e1594cb441fd285e728460ae23a14e Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 15:15:49 -0400 Subject: [PATCH 7/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 99d9c9a0d73..a54f285430a 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -60,6 +60,8 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggr pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult @@ -194,6 +196,8 @@ pub fn vortex_array::aggregate_fn::AggregateFnRef::as_(&self) -> core::option::Option<&::Options> +pub fn vortex_array::aggregate_fn::AggregateFnRef::coerce_args(&self, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub fn vortex_array::aggregate_fn::AggregateFnRef::id(&self) -> vortex_array::aggregate_fn::AggregateFnId pub fn vortex_array::aggregate_fn::AggregateFnRef::is(&self) -> bool @@ -288,6 +292,8 @@ pub type vortex_array::aggregate_fn::AggregateFnVTable::Partial: 'static + core: pub fn vortex_array::aggregate_fn::AggregateFnVTable::accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::Columnar, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::AggregateFnVTable::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub fn vortex_array::aggregate_fn::AggregateFnVTable::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::AggregateFnVTable::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult @@ -318,6 +324,8 @@ pub type vortex_array::aggregate_fn::fns::sum::Sum::Partial = vortex_array::aggr pub fn vortex_array::aggregate_fn::fns::sum::Sum::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> +pub fn vortex_array::aggregate_fn::fns::sum::Sum::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + pub fn vortex_array::aggregate_fn::fns::sum::Sum::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> pub fn vortex_array::aggregate_fn::fns::sum::Sum::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult From f5d2adf2fe71516f35713594f75ed5ca076845f4 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 13 Mar 2026 15:38:08 -0400 Subject: [PATCH 8/8] Add type coercion to scalar functions and extension types Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 54 ++++++------ vortex-array/src/dtype/extension/typed.rs | 6 +- vortex-array/src/dtype/extension/vtable.rs | 18 ++-- vortex-array/src/extension/datetime/date.rs | 44 ++++++++++ vortex-array/src/extension/datetime/time.rs | 33 +++++++ .../src/extension/datetime/timestamp.rs | 86 +++++++++++++++++++ 6 files changed, 205 insertions(+), 36 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index a54f285430a..ed872aa095a 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -10244,15 +10244,15 @@ pub type vortex_array::dtype::extension::ExtVTable::Metadata: 'static + core::ma pub type vortex_array::dtype::extension::ExtVTable::NativeValue<'a>: core::fmt::Display -pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::dtype::extension::ExtVTable::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::dtype::extension::ExtVTable::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::dtype::extension::ExtVTable::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::dtype::extension::ExtVTable::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::dtype::extension::ExtVTable::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::dtype::extension::ExtVTable::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -10268,15 +10268,15 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue -pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Date::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Date::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -10292,15 +10292,15 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue -pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Time::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Time::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -10316,15 +10316,15 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> -pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -10340,15 +10340,15 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid -pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -14206,15 +14206,15 @@ pub type vortex_array::extension::datetime::Date::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Date::NativeValue<'a> = vortex_array::extension::datetime::DateValue -pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Date::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Date::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Date::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Date::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Date::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Date::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Date::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -14262,15 +14262,15 @@ pub type vortex_array::extension::datetime::Time::Metadata = vortex_array::exten pub type vortex_array::extension::datetime::Time::NativeValue<'a> = vortex_array::extension::datetime::TimeValue -pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Time::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Time::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Time::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Time::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Time::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Time::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Time::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -14320,15 +14320,15 @@ pub type vortex_array::extension::datetime::Timestamp::Metadata = vortex_array:: pub type vortex_array::extension::datetime::Timestamp::NativeValue<'a> = vortex_array::extension::datetime::TimestampValue<'a> -pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::datetime::Timestamp::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::datetime::Timestamp::deserialize_metadata(&self, data: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::datetime::Timestamp::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::datetime::Timestamp::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::datetime::Timestamp::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> @@ -14402,15 +14402,15 @@ pub type vortex_array::extension::uuid::Uuid::Metadata = vortex_array::extension pub type vortex_array::extension::uuid::Uuid::NativeValue<'a> = uuid::Uuid -pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::uuid::Uuid::can_coerce_from(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool -pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, other: &vortex_array::dtype::DType) -> bool +pub fn vortex_array::extension::uuid::Uuid::can_coerce_to(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> bool pub fn vortex_array::extension::uuid::Uuid::deserialize_metadata(&self, metadata: &[u8]) -> vortex_error::VortexResult pub fn vortex_array::extension::uuid::Uuid::id(&self) -> vortex_array::dtype::extension::ExtId -pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, other: &vortex_array::dtype::DType) -> core::option::Option +pub fn vortex_array::extension::uuid::Uuid::least_supertype(&self, ext_dtype: &vortex_array::dtype::extension::ExtDType, other: &vortex_array::dtype::DType) -> core::option::Option pub fn vortex_array::extension::uuid::Uuid::serialize_metadata(&self, metadata: &Self::Metadata) -> vortex_error::VortexResult> diff --git a/vortex-array/src/dtype/extension/typed.rs b/vortex-array/src/dtype/extension/typed.rs index 003b8aa3393..baa3706b56b 100644 --- a/vortex-array/src/dtype/extension/typed.rs +++ b/vortex-array/src/dtype/extension/typed.rs @@ -202,14 +202,14 @@ impl DynExtDType for ExtDType { } fn coercion_can_coerce_from(&self, other: &DType) -> bool { - self.vtable.can_coerce_from(other) + self.vtable.can_coerce_from(self, other) } fn coercion_can_coerce_to(&self, other: &DType) -> bool { - self.vtable.can_coerce_to(other) + self.vtable.can_coerce_to(self, other) } fn coercion_least_supertype(&self, other: &DType) -> Option { - self.vtable.least_supertype(other) + self.vtable.least_supertype(self, other) } } diff --git a/vortex-array/src/dtype/extension/vtable.rs b/vortex-array/src/dtype/extension/vtable.rs index 92271ea0ede..120cfa02dc3 100644 --- a/vortex-array/src/dtype/extension/vtable.rs +++ b/vortex-array/src/dtype/extension/vtable.rs @@ -41,21 +41,27 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash { /// Can a value of `other` be implicitly widened into this type? /// e.g. GeographyType might accept Point, LineString, etc. - fn can_coerce_from(&self, other: &DType) -> bool { - let _ = other; + /// + /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` — both + /// exist so that either side of the coercion can provide the logic. + fn can_coerce_from(&self, ext_dtype: &ExtDType, other: &DType) -> bool { + let _ = (ext_dtype, other); false } /// Can this type be implicitly widened into `other`? - fn can_coerce_to(&self, other: &DType) -> bool { - let _ = other; + /// + /// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` — both + /// exist so that either side of the coercion can provide the logic. + fn can_coerce_to(&self, ext_dtype: &ExtDType, other: &DType) -> bool { + let _ = (ext_dtype, other); false } /// Given two types in a Uniform context, what is their least supertype? /// Return None if no supertype exists. - fn least_supertype(&self, other: &DType) -> Option { - let _ = other; + fn least_supertype(&self, ext_dtype: &ExtDType, other: &DType) -> Option { + let _ = (ext_dtype, other); None } diff --git a/vortex-array/src/extension/datetime/date.rs b/vortex-array/src/extension/datetime/date.rs index 6d67051721b..57fe7ad87e0 100644 --- a/vortex-array/src/extension/datetime/date.rs +++ b/vortex-array/src/extension/datetime/date.rs @@ -106,6 +106,29 @@ impl ExtVTable for Date { Ok(()) } + fn can_coerce_from(&self, ext_dtype: &ExtDType, other: &DType) -> bool { + let DType::Extension(other_ext) = other else { + return false; + }; + let Some(other_unit) = other_ext.metadata_opt::() else { + return false; + }; + let our_unit = ext_dtype.metadata(); + // We can coerce from other if our unit is finer (<=) and nullability is compatible. + our_unit <= other_unit && (ext_dtype.storage_dtype().is_nullable() || !other.is_nullable()) + } + + fn least_supertype(&self, ext_dtype: &ExtDType, other: &DType) -> Option { + let DType::Extension(other_ext) = other else { + return None; + }; + let other_unit = other_ext.metadata_opt::()?; + let our_unit = ext_dtype.metadata(); + let finest = (*our_unit).min(*other_unit); + let union_null = ext_dtype.storage_dtype().nullability() | other.nullability(); + Some(DType::Extension(Date::new(finest, union_null).erased())) + } + fn unpack_native( &self, ext_dtype: &ExtDType, @@ -166,4 +189,25 @@ mod tests { let scalar = Scalar::new(dtype, Some(ScalarValue::Primitive(PValue::I32(365)))); assert_eq!(format!("{}", scalar.as_extension()), "1971-01-01"); } + + #[test] + fn least_supertype_date_units() { + use crate::dtype::Nullability::NonNullable; + + let days = DType::Extension(Date::new(TimeUnit::Days, NonNullable).erased()); + let ms = DType::Extension(Date::new(TimeUnit::Milliseconds, NonNullable).erased()); + let expected = DType::Extension(Date::new(TimeUnit::Milliseconds, NonNullable).erased()); + assert_eq!(days.least_supertype(&ms).unwrap(), expected); + assert_eq!(ms.least_supertype(&days).unwrap(), expected); + } + + #[test] + fn can_coerce_from_date() { + use crate::dtype::Nullability::NonNullable; + + let days = DType::Extension(Date::new(TimeUnit::Days, NonNullable).erased()); + let ms = DType::Extension(Date::new(TimeUnit::Milliseconds, NonNullable).erased()); + assert!(ms.can_coerce_from(&days)); + assert!(!days.can_coerce_from(&ms)); + } } diff --git a/vortex-array/src/extension/datetime/time.rs b/vortex-array/src/extension/datetime/time.rs index 59e7c98bc7e..eb53a0e0fe3 100644 --- a/vortex-array/src/extension/datetime/time.rs +++ b/vortex-array/src/extension/datetime/time.rs @@ -107,6 +107,28 @@ impl ExtVTable for Time { Ok(()) } + fn can_coerce_from(&self, ext_dtype: &ExtDType, other: &DType) -> bool { + let DType::Extension(other_ext) = other else { + return false; + }; + let Some(other_unit) = other_ext.metadata_opt::