diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 3db6e4ee91b..757515fea64 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -28,6 +28,138 @@ pub fn vortex_array::arrays::PrimitiveArray::with_iterator(&self, f: F) -> pub mod vortex_array::aggregate_fn +pub mod vortex_array::aggregate_fn::combined + +pub struct vortex_array::aggregate_fn::combined::Combined(pub T) + +impl vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::new(inner: T) -> Self + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::clone(&self) -> vortex_array::aggregate_fn::combined::Combined + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::combined::PairOptions(pub L, pub R) + +impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::clone(&self) -> vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::Eq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::PartialEq for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::eq(&self, other: &vortex_array::aggregate_fn::combined::PairOptions) -> bool + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +pub trait vortex_array::aggregate_fn::combined::BinaryCombined: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Left: vortex_array::aggregate_fn::AggregateFnVTable + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Right: vortex_array::aggregate_fn::AggregateFnVTable + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::coerce_args(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::finalize(&self, left: vortex_array::ArrayRef, right: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left(&self) -> Self::Left + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right(&self) -> Self::Right + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub type vortex_array::aggregate_fn::combined::CombinedOptions = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + pub mod vortex_array::aggregate_fn::fns pub mod vortex_array::aggregate_fn::fns::count @@ -342,6 +474,50 @@ pub struct vortex_array::aggregate_fn::fns::last::LastPartial pub fn vortex_array::aggregate_fn::fns::last::last(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub mod vortex_array::aggregate_fn::fns::mean + +pub struct vortex_array::aggregate_fn::fns::mean::Mean + +impl vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::combined() -> vortex_array::aggregate_fn::combined::Combined + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::clone(&self) -> vortex_array::aggregate_fn::fns::mean::Mean + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::mean(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::min_max pub struct vortex_array::aggregate_fn::fns::min_max::MinMax @@ -1068,6 +1244,42 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Sel pub fn vortex_array::aggregate_fn::fns::sum::Sum::try_accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 24f8f0157ec..1d887b6eb21 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -311,7 +311,7 @@ impl GroupedAccumulator { if validity.value(i) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } diff --git a/vortex-array/src/aggregate_fn/combined.rs b/vortex-array/src/aggregate_fn/combined.rs new file mode 100644 index 00000000000..9ecc0dcc8d1 --- /dev/null +++ b/vortex-array/src/aggregate_fn/combined.rs @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Generic adapter for aggregates whose result is computed from two child +//! aggregate functions, e.g. `Mean = Sum / Count`. + +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::{self}; +use std::hash::Hash; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_session::VortexSession; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::FieldNames; +use crate::dtype::Nullability; +use crate::dtype::StructFields; +use crate::scalar::Scalar; + +/// Pair of options for the two children of a [`BinaryCombined`] aggregate. +/// +/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound +/// requires `Display`, which tuples don't implement. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PairOptions(pub L, pub R); + +impl Display for PairOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({}, {})", self.0, self.1) + } +} + +// Convenience aliases so signatures stay readable. +type LeftOptions = <::Left as AggregateFnVTable>::Options; +type RightOptions = <::Right as AggregateFnVTable>::Options; +type LeftPartial = <::Left as AggregateFnVTable>::Partial; +type RightPartial = <::Right as AggregateFnVTable>::Partial; +/// Combined options for a [`BinaryCombined`] aggregate. +pub type CombinedOptions = PairOptions, RightOptions>; + +/// Declare an aggregate function in terms of two child aggregates. +pub trait BinaryCombined: 'static + Send + Sync + Clone { + /// The left child aggregate vtable. + type Left: AggregateFnVTable; + /// The right child aggregate vtable. + type Right: AggregateFnVTable; + + /// Stable identifier for the combined aggregate. + fn id(&self) -> AggregateFnId; + + /// Construct the left child vtable. + fn left(&self) -> Self::Left; + + /// Construct the right child vtable. + fn right(&self) -> Self::Right; + + /// Field name for the left child in the partial struct dtype. + fn left_name(&self) -> &'static str { + "left" + } + + /// Field name for the right child in the partial struct dtype. + fn right_name(&self) -> &'static str { + "right" + } + + /// Return type of the combined aggregate. + fn return_dtype(&self, input_dtype: &DType) -> Option; + + /// Combine the finalized left and right results into the final aggregate. + fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult; + + /// Serialize the options for this combined aggregate. Default: not serializable. + fn serialize(&self, options: &CombinedOptions) -> VortexResult>> { + let _ = options; + Ok(None) + } + + /// Deserialize the options for this combined aggregate. Default: bails. + fn deserialize( + &self, + metadata: &[u8], + session: &VortexSession, + ) -> VortexResult> { + let _ = (metadata, session); + vortex_bail!( + "Combined aggregate function {} is not deserializable", + BinaryCombined::id(self) + ); + } + + /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`. + fn coerce_args( + &self, + options: &CombinedOptions, + input_dtype: &DType, + ) -> VortexResult { + let left_coerced = self.left().coerce_args(&options.0, input_dtype)?; + self.right().coerce_args(&options.1, &left_coerced) + } +} + +/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`]. +#[derive(Clone, Debug)] +pub struct Combined(pub T); + +impl Combined { + /// Construct a new combined aggregate vtable. + pub fn new(inner: T) -> Self { + Self(inner) + } +} + +impl AggregateFnVTable for Combined { + type Options = CombinedOptions; + type Partial = (LeftPartial, RightPartial); + + fn id(&self) -> AggregateFnId { + self.0.id() + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + BinaryCombined::serialize(&self.0, options) + } + + fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult { + BinaryCombined::deserialize(&self.0, metadata, session) + } + + fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + BinaryCombined::coerce_args(&self.0, options, input_dtype) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + BinaryCombined::return_dtype(&self.0, input_dtype) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + let l = self.0.left().partial_dtype(&options.0, input_dtype)?; + let r = self.0.right().partial_dtype(&options.1, input_dtype)?; + Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r)) + } + + fn empty_partial( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + Ok(( + self.0.left().empty_partial(&options.0, input_dtype)?, + self.0.right().empty_partial(&options.1, input_dtype)?, + )) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + return Ok(()); + } + let s = other.as_struct(); + let lname = self.0.left_name(); + let rname = self.0.right_name(); + let l_field = s + .field(lname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?; + let r_field = s + .field(rname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?; + self.0.left().combine_partials(&mut partial.0, l_field)?; + self.0.right().combine_partials(&mut partial.1, r_field)?; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + let l_scalar = self.0.left().to_scalar(&partial.0)?; + let r_scalar = self.0.right().to_scalar(&partial.1)?; + let dtype = struct_dtype( + self.0.left_name(), + self.0.right_name(), + l_scalar.dtype().clone(), + r_scalar.dtype().clone(), + ); + Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar])) + } + + fn reset(&self, partial: &mut Self::Partial) { + self.0.left().reset(&mut partial.0); + self.0.right().reset(&mut partial.1); + } + + fn is_saturated(&self, partial: &Self::Partial) -> bool { + self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1) + } + + /// Fans out to each child's `try_accumulate`, falling back to `accumulate` + /// against a lazily-canonicalized batch. We always claim to handle the + /// batch ourselves so [`Self::accumulate`] is unreachable — this is the + /// same trick `Count` uses to opt out of the canonicalization path. + fn try_accumulate( + &self, + state: &mut Self::Partial, + batch: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mut canonical: Option = None; + if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? { + let c = canonical.insert(batch.clone().execute::(ctx)?); + self.0.left().accumulate(&mut state.0, c, ctx)?; + } + if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? { + let c = match canonical.as_ref() { + Some(c) => c, + None => canonical.insert(batch.clone().execute::(ctx)?), + }; + self.0.right().accumulate(&mut state.1, c, ctx)?; + } + Ok(true) + } + + fn accumulate( + &self, + _state: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("Combined::try_accumulate handles all batches") + } + + fn finalize(&self, states: ArrayRef) -> VortexResult { + let l_field = states.get_item(FieldName::from(self.0.left_name()))?; + let r_field = states.get_item(FieldName::from(self.0.right_name()))?; + let l_finalized = self.0.left().finalize(l_field)?; + let r_finalized = self.0.right().finalize(r_field)?; + BinaryCombined::finalize(&self.0, l_finalized, r_finalized) + } +} + +fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType { + DType::Struct( + StructFields::new( + FieldNames::from_iter([FieldName::from(left_name), FieldName::from(right_name)]), + vec![left, right], + ), + Nullability::NonNullable, + ) +} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs new file mode 100644 index 00000000000..c57707a441b --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::combined::BinaryCombined; +use crate::aggregate_fn::combined::Combined; +use crate::aggregate_fn::combined::CombinedOptions; +use crate::aggregate_fn::combined::PairOptions; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::sum::Sum; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::DecimalDType; +use crate::dtype::MAX_PRECISION; +use crate::dtype::MAX_SCALE; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::operators::Operator; + +/// Compute the arithmetic mean of an array. +/// +/// See [`Mean`] for details. +pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + array.dtype().clone(), + )?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Compute the arithmetic mean of an array. +/// +/// Implemented as `Sum / Count` via [`BinaryCombined`]. +/// +/// Coercion / return type: +/// - Booleans and primitive numeric types are coerced to `f64` and the result +/// is a nullable `f64`. +/// - Decimals are kept as decimals with widened precision and scale +/// (`+4` each, capped at [`MAX_PRECISION`] / [`MAX_SCALE`]), matching +/// DataFusion's `coerce_avg_type`. +#[derive(Clone, Debug)] +pub struct Mean; + +impl Mean { + pub fn combined() -> Combined { + Combined(Mean) + } +} + +impl BinaryCombined for Mean { + type Left = Sum; + type Right = Count; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.mean") + } + + fn left(&self) -> Sum { + Sum + } + + fn right(&self) -> Count { + Count + } + + fn left_name(&self) -> &'static str { + "sum" + } + + fn right_name(&self) -> &'static str { + "count" + } + + fn return_dtype(&self, input_dtype: &DType) -> Option { + Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable)) + } + + fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult { + let target = match sum.dtype() { + DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable), + _ => DType::Primitive(PType::F64, Nullability::Nullable), + }; + let sum_cast = sum.cast(target.clone())?; + let count_cast = count.cast(target)?; + sum_cast.binary(count_cast, Operator::Div) + } + + fn serialize(&self, _options: &CombinedOptions) -> VortexResult>> { + unimplemented!("Mean is not yet serializable"); + } + + fn coerce_args( + &self, + _options: &PairOptions< + ::Options, + ::Options, + >, + input_dtype: &DType, + ) -> VortexResult { + // Advisory hint for query planners: where possible, cast input to the + // type we're going to compute the mean in. + Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone())) + } +} + +/// Hint for callers: what to cast the input to before accumulation. +/// +/// - Bool stays as bool — `Sum` has a native bool path and bool → f64 isn't +/// currently a direct cast in vortex. +/// - Primitive numerics → `f64` so the sum and finalize work without overflow. +/// - Decimals → decimal with widened precision and scale (`+4` each, capped), +/// matching DataFusion's `coerce_avg_type`. +fn coerced_input_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) => Some(input_dtype.clone()), + DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)), + DType::Decimal(d, n) => { + let new_precision = u8::min(MAX_PRECISION, d.precision().saturating_add(4)); + let new_scale = i8::min(MAX_SCALE, d.scale().saturating_add(4)); + Some(DType::Decimal( + DecimalDType::new(new_precision, new_scale), + *n, + )) + } + _ => None, + } +} + +fn mean_output_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => { + Some(DType::Primitive(PType::F64, Nullability::Nullable)) + } + DType::Decimal(d, _) => { + let new_precision = u8::min(MAX_PRECISION, d.precision().saturating_add(10)); + Some(DType::Decimal( + DecimalDType::new(new_precision, d.scale()), + Nullability::Nullable, + )) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use super::*; + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::DecimalArray; + use crate::arrays::PrimitiveArray; + use crate::scalar::DecimalValue; + use crate::validity::Validity; + + #[test] + fn mean_all_valid() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_integers() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_bool() -> VortexResult<()> { + let array: BoolArray = [true, false, true, true].into_iter().collect(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(0.75)); + Ok(()) + } + + #[test] + fn mean_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(5.0f64, 4); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(5.0)); + Ok(()) + } + + #[test] + fn mean_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]); + let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&chunked.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + // TODO: vortex's cast kernel doesn't currently support `u64 → decimal`, + #[test] + #[ignore = "u64 → decimal cast not yet supported"] + fn mean_decimal() -> VortexResult<()> { + // 1.00, 2.00, 3.00 in decimal(5, 2). Mean = 2.00. + let values = buffer![100i128, 200i128, 300i128]; + let dt = DecimalDType::new(5, 2); + let array = DecimalArray::new(values, dt, Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + // `Sum` widens precision by +10, so the result lives in decimal(15, 2). + // 2.00 in scale=2 is the integer 200. + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I128(200)) + ); + Ok(()) + } + + #[test] + fn mean_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + dtype, + )?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 38d5340cd1f..f1281c18544 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -6,6 +6,7 @@ pub mod first; pub mod is_constant; pub mod is_sorted; pub mod last; +pub mod mean; pub mod min_max; pub mod nan_count; pub mod sum; diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs index a22006cb981..ba106ff63bf 100644 --- a/vortex-array/src/aggregate_fn/mod.rs +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -32,6 +32,7 @@ pub use erased::*; mod options; pub use options::*; +pub mod combined; pub mod fns; pub mod kernels; pub mod proto;