Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
use crate::aggregates::group_values::GroupValues;
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
use arrow::array::{
ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray,
cast::AsArray,
Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder,
PrimitiveArray, cast::AsArray,
};
use arrow::datatypes::{DataType, i256};
use datafusion_common::Result;
use datafusion_common::hash_utils::RandomState;
use datafusion_common::hash_utils::{RandomState, with_hashes};
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_expr::EmitTo;
use half::f16;
Expand Down Expand Up @@ -81,13 +81,12 @@ hash_float!(f16, f32, f64);
pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
/// The data type of the output array
data_type: DataType,
/// Stores the `(group_index, hash)` based on the hash of its value
/// Stores the `(group_index, value)` based on the hash of its value
///
/// We also store `hash` is for reducing cost of rehashing. Such cost
/// is obvious in high cardinality group by situation.
/// More details can see:
/// <https://github.com/apache/datafusion/issues/15961>
map: HashTable<(usize, u64)>,
/// Storing the value inline avoids an indirection into `values` during
/// hash-table probing (one fewer cache miss per probe) and allows
/// direct equality comparison without a stored hash.
map: HashTable<(usize, T::Native)>,
/// The group index of the null value if any
null_group: Option<usize>,
/// The values for each group index
Expand Down Expand Up @@ -117,42 +116,79 @@ where
assert_eq!(cols.len(), 1);
groups.clear();

for v in cols[0].as_primitive::<T>() {
let group_id = match v {
None => *self.null_group.get_or_insert_with(|| {
let group_id = self.values.len();
self.values.push(Default::default());
group_id
}),
Some(key) => {
let state = &self.random_state;
let hash = key.hash(state);
let insert = self.map.entry(
// Destructure to avoid borrow conflicts between random_state
// (borrowed by with_hashes) and map/values/null_group (mutated in callback)
let Self {
map,
null_group,
values,
random_state,
..
} = self;

// Phase 1: Compute all hashes in a tight loop (via with_hashes)
// Phase 2: Probe hash table using pre-computed hashes
with_hashes([&cols[0]], random_state, |hashes| {
let arr = cols[0].as_primitive::<T>();

if arr.null_count() == 0 {
for (idx, &key) in arr.values().iter().enumerate() {
let hash = hashes[idx];
let insert = map.entry(
hash,
|&(g, h)| unsafe {
hash == h && self.values.get_unchecked(g).is_eq(key)
},
|&(_, h)| h,
|&(_, v)| v.is_eq(key),
|&(_, v)| v.hash(random_state),
);

match insert {
let group_id = match insert {
hashbrown::hash_table::Entry::Occupied(o) => o.get().0,
hashbrown::hash_table::Entry::Vacant(v) => {
let g = self.values.len();
v.insert((g, hash));
self.values.push(key);
let g = values.len();
v.insert((g, key));
values.push(key);
g
}
}
};
groups.push(group_id);
}
};
groups.push(group_id)
}
} else {
for (idx, v) in arr.iter().enumerate() {
let group_id = match v {
None => *null_group.get_or_insert_with(|| {
let group_id = values.len();
values.push(Default::default());
group_id
}),
Some(key) => {
let hash = hashes[idx];
let insert = map.entry(
hash,
|&(_, v)| v.is_eq(key),
|&(_, v)| v.hash(random_state),
);
match insert {
hashbrown::hash_table::Entry::Occupied(o) => {
o.get().0
}
hashbrown::hash_table::Entry::Vacant(v) => {
let g = values.len();
v.insert((g, key));
values.push(key);
g
}
}
}
};
groups.push(group_id);
}
}
Ok(())
})?;
Ok(())
}

fn size(&self) -> usize {
self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size()
self.map.capacity() * size_of::<(usize, T::Native)>()
+ self.values.allocated_size()
}

fn is_empty(&self) -> bool {
Expand Down Expand Up @@ -219,6 +255,6 @@ where
self.values.clear();
self.values.shrink_to(num_rows);
self.map.clear();
self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared
self.map.shrink_to(num_rows, |&(_, v)| v.hash(&self.random_state));
}
}
Loading