Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ where
let mut buffer = cell.try_borrow_mut()
.map_err(|_| _internal_datafusion_err!("with_hashes cannot be called reentrantly on the same thread"))?;

// Ensure buffer has sufficient length, clearing old values
// Ensure buffer has sufficient capacity without zero-filling.
// create_hashes writes all positions (including null sentinels),
// so pre-zeroing is unnecessary.
buffer.clear();
buffer.resize(required_size, 0);
buffer.reserve(required_size);
// SAFETY: create_hashes will write every position in the buffer
// (null positions get a consistent sentinel hash).
unsafe { buffer.set_len(required_size) };

// Create hashes in the buffer - this consumes the iterator
create_hashes(iter, random_state, &mut buffer[..required_size])?;
Expand Down Expand Up @@ -244,6 +249,10 @@ fn hash_array_primitive<T>(
hashes_buffer[i] = hasher.finish();
}
} else {
// Fill with null sentinel, then overwrite valid positions.
// This allows callers to skip pre-zeroing the buffer.
let null_hash = random_state.hash_one(1u8);
hashes_buffer.fill(null_hash);
for i in array.nulls().unwrap().valid_indices() {
let value = unsafe { array.value_unchecked(i) };
hashes_buffer[i] = value.hash_one(random_state);
Expand Down Expand Up @@ -289,6 +298,10 @@ fn hash_array<T>(
combine_hashes(value.hash_one(random_state), hashes_buffer[i]);
}
} else {
// Fill with null sentinel, then overwrite valid positions.
// This allows callers to skip pre-zeroing the buffer.
let null_hash = random_state.hash_one(1u8);
hashes_buffer.fill(null_hash);
for i in array.nulls().unwrap().valid_indices() {
let value = unsafe { array.value_unchecked(i) };
hashes_buffer[i] = value.hash_one(random_state);
Expand Down Expand Up @@ -331,9 +344,13 @@ fn hash_string_view_array_inner<
}
};

let null_hash = random_state.hash_one(1u8);
let hashes_and_views = hashes_buffer.iter_mut().zip(array.views().iter());
for (i, (hash, &v)) in hashes_and_views.enumerate() {
if HAS_NULLS && array.is_null(i) {
if !REHASH {
*hash = null_hash;
}
continue;
}
let view_len = v as u32;
Expand Down Expand Up @@ -447,6 +464,7 @@ fn hash_dictionary_inner<
let mut dict_hashes = vec![0; dict_values.len()];
create_hashes([dict_values], random_state, &mut dict_hashes)?;

let null_hash = random_state.hash_one(1u8);
if HAS_NULL_KEYS {
for (hash, key) in hashes_buffer.iter_mut().zip(array.keys().iter()) {
if let Some(key) = key {
Expand All @@ -457,7 +475,11 @@ fn hash_dictionary_inner<
} else {
*hash = dict_hashes[idx];
}
} else if !MULTI_COL {
*hash = null_hash;
}
} else if !MULTI_COL {
*hash = null_hash;
}
}
} else {
Expand All @@ -469,6 +491,8 @@ fn hash_dictionary_inner<
} else {
*hash = dict_hashes[idx];
}
} else if !MULTI_COL {
*hash = null_hash;
}
}
}
Expand Down Expand Up @@ -916,6 +940,10 @@ fn hash_run_array_inner<
let end_in_slice = (absolute_run_end - array_offset).min(array_len);

if HAS_NULL_VALUES && sliced_values.is_null(adjusted_physical_index) {
if !REHASH {
let null_hash = random_state.hash_one(1u8);
hashes_buffer[start_in_slice..end_in_slice].fill(null_hash);
}
start_in_slice = end_in_slice;
continue;
}
Expand Down Expand Up @@ -1103,11 +1131,34 @@ where
for (i, array) in arrays.into_iter().enumerate() {
// combine hashes with `combine_hashes` for all columns besides the first
let rehash = i >= 1;
hash_single_array(array.as_dyn_array(), random_state, hashes_buffer, rehash)?;
let arr = array.as_dyn_array();
// Complex types (struct, list, map, union) always combine with
// existing hash values rather than initializing them, so the buffer
// must be zeroed when they appear as the first column.
if !rehash && needs_zero_init(arr.data_type()) {
hashes_buffer.fill(0);
}
hash_single_array(arr, random_state, hashes_buffer, rehash)?;
}
Ok(hashes_buffer)
}

/// Returns true for types whose hash functions always combine with existing
/// buffer values (no `rehash=false` path), requiring zero-initialized buffers.
fn needs_zero_init(dt: &DataType) -> bool {
matches!(
dt,
DataType::Struct(_)
| DataType::List(_)
| DataType::LargeList(_)
| DataType::ListView(_)
| DataType::LargeListView(_)
| DataType::Map(_, _)
| DataType::FixedSizeList(_, _)
| DataType::Union(_, _)
)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -1190,11 +1241,12 @@ mod tests {
create_hashes(&[binary_array], &random_state, &mut binary_hashes)
.unwrap();

// Null values result in a zero hash,
// Null values result in a consistent null sentinel hash
let null_hash = random_state.hash_one(1u8);
for (val, hash) in binary.iter().zip(binary_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
None => assert_eq!(*hash, null_hash),
}
}

Expand Down Expand Up @@ -1260,11 +1312,12 @@ mod tests {
let mut dict_hashes = vec![0; strings.len()];
create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap();

// Null values result in a zero hash,
// Null values result in a consistent null sentinel hash
let null_hash = random_state.hash_one(1u8);
for (val, hash) in strings.iter().zip(string_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
None => assert_eq!(*hash, null_hash),
}
}

Expand Down Expand Up @@ -1377,11 +1430,12 @@ mod tests {
let mut dict_hashes = vec![0; strings.len()];
create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap();

// Null values result in a zero hash,
// Null values result in a consistent null sentinel hash
let null_hash = random_state.hash_one(1u8);
for (val, hash) in strings.iter().zip(string_hashes.iter()) {
match val {
Some(_) => assert_ne!(*hash, 0),
None => assert_eq!(*hash, 0),
None => assert_eq!(*hash, null_hash),
}
}

Expand Down Expand Up @@ -2047,10 +2101,11 @@ mod tests {
&mut hashes,
)?;

let null_hash = random_state.hash_one(1u8);
assert_eq!(hashes[0], hashes[1]);
assert_ne!(hashes[0], 0);
assert_eq!(hashes[2], hashes[3]);
assert_eq!(hashes[2], 0);
assert_eq!(hashes[2], null_hash);
assert_eq!(hashes[4], hashes[5]);
assert_ne!(hashes[4], 0);
assert_ne!(hashes[0], hashes[4]);
Expand Down
64 changes: 47 additions & 17 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl ArrowBytesViewSet {

/// Inserts each value from `values` into the set
pub fn insert(&mut self, values: &ArrayRef) {
fn make_payload_fn(_value: Option<&[u8]>) {}
fn make_payload_fn() {}
fn observe_payload_fn(_payload: ()) {}
self.0
.insert_if_new(values, make_payload_fn, observe_payload_fn);
Expand Down Expand Up @@ -209,7 +209,7 @@ where
make_payload_fn: MP,
observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
MP: FnMut() -> V,
OP: FnMut(V),
{
// Sanity check array type
Expand Down Expand Up @@ -248,7 +248,7 @@ where
mut make_payload_fn: MP,
mut observe_payload_fn: OP,
) where
MP: FnMut(Option<&[u8]>) -> V,
MP: FnMut() -> V,
OP: FnMut(V),
B: ByteViewType,
{
Expand All @@ -266,6 +266,35 @@ where

// Get raw views buffer for direct comparison
let input_views = values.views();
let input_buffers = values.data_buffers();

// Decode input value bytes directly from view + buffers,
// avoiding the overhead of values.value(i) accessor.
let input_value_bytes = |idx: usize| -> &[u8] {
let view = input_views[idx];
let len = view as u32;
if len <= 12 {
// Inline: bytes are stored at offset 4 in the view.
// Reference the view in input_views (not a stack copy)
// so the returned slice has a valid lifetime.
// SAFETY: input_views[idx] is valid for the function's lifetime,
// and the inline data occupies bytes 4..4+len of the u128 view.
unsafe {
let ptr = (input_views.as_ptr().add(idx)) as *const u8;
std::slice::from_raw_parts(ptr.add(4), len as usize)
}
} else {
let byte_view = ByteView::from(view);
let buf_idx = byte_view.buffer_index as usize;
let offset = byte_view.offset as usize;
// SAFETY: view comes from a valid array
unsafe {
input_buffers
.get_unchecked(buf_idx)
.get_unchecked(offset..offset + len as usize)
}
}
};

// Ensure lengths are equivalent
assert_eq!(values.len(), self.hashes_buffer.len());
Expand All @@ -279,7 +308,7 @@ where
let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
payload
} else {
let payload = make_payload_fn(None);
let payload = make_payload_fn();
let null_index = self.views.len();
self.views.push(0);
self.nulls.append_null();
Expand Down Expand Up @@ -329,8 +358,7 @@ where
} else {
&in_progress[offset..offset + stored_len]
};
let input_value: &[u8] = values.value(i).as_ref();
stored_value == input_value
stored_value == input_value_bytes(i)
})
.map(|entry| entry.payload)
};
Expand All @@ -339,11 +367,18 @@ where
payload
} else {
// no existing value, make a new one
let value: &[u8] = values.value(i).as_ref();
let payload = make_payload_fn(Some(value));

// Create view pointing to our buffers
let new_view = self.append_value(value);
let (new_view, payload) = if len <= 12 {
// Inline string: the view is self-contained, no need
// to decode bytes or copy to buffers — just reuse the
// input view directly.
self.views.push(view_u128);
self.nulls.append_non_null();
(view_u128, make_payload_fn())
} else {
let value = input_value_bytes(i);
let new_view = self.append_value(value);
(new_view, make_payload_fn())
};
let new_header = Entry {
view: new_view,
hash,
Expand Down Expand Up @@ -726,16 +761,12 @@ mod tests {
}

// insert the values into the map, recording what we did
let mut seen_new_strings = vec![];
let mut seen_indexes = vec![];
self.map.insert_if_new(
&arr,
|s| {
let value = s
.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
|| {
let index = next_index;
next_index += 1;
seen_new_strings.push(value);
TestPayload { index }
},
|payload| {
Expand All @@ -744,7 +775,6 @@ mod tests {
);

assert_eq!(actual_seen_indexes, seen_indexes);
assert_eq!(actual_new_strings, seen_new_strings);
}

/// Call `self.map.into_array()` validating that the strings are in the same
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl GroupValues for GroupValuesBytesView {
self.map.insert_if_new(
arr,
// called for each new group
|_value| {
|| {
// assign new group index on each insert
let group_idx = self.num_groups;
self.num_groups += 1;
Expand Down
Loading