diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 4686648fb1e3d..409ef2d70be78 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -18,12 +18,12 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ - ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray, - cast::AsArray, + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, + PrimitiveArray, cast::AsArray, }; use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; -use datafusion_common::hash_utils::RandomState; +use datafusion_common::hash_utils::{RandomState, with_hashes}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; @@ -81,13 +81,12 @@ hash_float!(f16, f32, f64); pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value + /// Stores the `(group_index, value)` based on the hash of its value /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. - /// More details can see: - /// - map: HashTable<(usize, u64)>, + /// Storing the value inline avoids an indirection into `values` during + /// hash-table probing (one fewer cache miss per probe) and allows + /// direct equality comparison without a stored hash. + map: HashTable<(usize, T::Native)>, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -117,42 +116,79 @@ where assert_eq!(cols.len(), 1); groups.clear(); - for v in cols[0].as_primitive::() { - let group_id = match v { - None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); - group_id - }), - Some(key) => { - let state = &self.random_state; - let hash = key.hash(state); - let insert = self.map.entry( + // Destructure to avoid borrow conflicts between random_state + // (borrowed by with_hashes) and map/values/null_group (mutated in callback) + let Self { + map, + null_group, + values, + random_state, + .. + } = self; + + // Phase 1: Compute all hashes in a tight loop (via with_hashes) + // Phase 2: Probe hash table using pre-computed hashes + with_hashes([&cols[0]], random_state, |hashes| { + let arr = cols[0].as_primitive::(); + + if arr.null_count() == 0 { + for (idx, &key) in arr.values().iter().enumerate() { + let hash = hashes[idx]; + let insert = map.entry( hash, - |&(g, h)| unsafe { - hash == h && self.values.get_unchecked(g).is_eq(key) - }, - |&(_, h)| h, + |&(_, v)| v.is_eq(key), + |&(_, v)| v.hash(random_state), ); - - match insert { + let group_id = match insert { hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); - self.values.push(key); + let g = values.len(); + v.insert((g, key)); + values.push(key); g } - } + }; + groups.push(group_id); } - }; - groups.push(group_id) - } + } else { + for (idx, v) in arr.iter().enumerate() { + let group_id = match v { + None => *null_group.get_or_insert_with(|| { + let group_id = values.len(); + values.push(Default::default()); + group_id + }), + Some(key) => { + let hash = hashes[idx]; + let insert = map.entry( + hash, + |&(_, v)| v.is_eq(key), + |&(_, v)| v.hash(random_state), + ); + match insert { + hashbrown::hash_table::Entry::Occupied(o) => { + o.get().0 + } + hashbrown::hash_table::Entry::Vacant(v) => { + let g = values.len(); + v.insert((g, key)); + values.push(key); + g + } + } + } + }; + groups.push(group_id); + } + } + Ok(()) + })?; Ok(()) } fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, T::Native)>() + + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -219,6 +255,6 @@ where self.values.clear(); self.values.shrink_to(num_rows); self.map.clear(); - self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |&(_, v)| v.hash(&self.random_state)); } }