From df86041797ec21b39379fff6eeaacd97b292b287 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 12 Mar 2026 14:23:14 +0000 Subject: [PATCH 1/3] Faster true count using AVX intrinsics Signed-off-by: Robert Kruszewski --- vortex-buffer/benches/vortex_bitbuffer.rs | 11 +- vortex-buffer/src/bit/buf.rs | 3 +- vortex-buffer/src/bit/count_ones.rs | 228 ++++++++++++++++++++++ vortex-buffer/src/bit/mod.rs | 1 + 4 files changed, 240 insertions(+), 3 deletions(-) create mode 100644 vortex-buffer/src/bit/count_ones.rs diff --git a/vortex-buffer/benches/vortex_bitbuffer.rs b/vortex-buffer/benches/vortex_bitbuffer.rs index cc84bf2c2e6..1c31c32759c 100644 --- a/vortex-buffer/benches/vortex_bitbuffer.rs +++ b/vortex-buffer/benches/vortex_bitbuffer.rs @@ -26,6 +26,11 @@ impl FromIterator for Arrow { const INPUT_SIZE: &[usize] = &[128, 1024, 2048, 16_384, 65_536]; +#[inline] +fn true_count_pattern(i: usize) -> bool { + (i.is_multiple_of(3)) ^ (i.is_multiple_of(11)) +} + #[cfg(not(codspeed))] #[divan::bench(args = INPUT_SIZE)] fn from_iter_arrow(n: usize) { @@ -160,7 +165,7 @@ fn slice_arrow_buffer(bencher: Bencher, length: usize) { #[divan::bench(args = INPUT_SIZE)] fn true_count_vortex_buffer(bencher: Bencher, length: usize) { - let buffer = BitBuffer::from_iter((0..length).map(|i| i % 2 == 0)); + let buffer = BitBuffer::from_iter((0..length).map(true_count_pattern)); bencher .with_inputs(|| &buffer) .bench_refs(|buffer| buffer.true_count()) @@ -168,7 +173,9 @@ fn true_count_vortex_buffer(bencher: Bencher, length: usize) { #[divan::bench(args = INPUT_SIZE)] fn true_count_arrow_buffer(bencher: Bencher, length: usize) { - let buffer = Arrow(BooleanBuffer::from_iter((0..length).map(|i| i % 2 == 0))); + let buffer = Arrow(BooleanBuffer::from_iter( + (0..length).map(true_count_pattern), + )); bencher .with_inputs(|| &buffer) .bench_refs(|buffer| buffer.0.count_set_bits()); diff --git a/vortex-buffer/src/bit/buf.rs b/vortex-buffer/src/bit/buf.rs index 87b4054774d..7a714111f27 100644 --- a/vortex-buffer/src/bit/buf.rs +++ b/vortex-buffer/src/bit/buf.rs @@ -21,6 +21,7 @@ use crate::bit::BitIndexIterator; use crate::bit::BitIterator; use crate::bit::BitSliceIterator; use crate::bit::UnalignedBitChunk; +use crate::bit::count_ones::count_ones; use crate::bit::get_bit_unchecked; use crate::bit::ops::bitwise_binary_op; use crate::bit::ops::bitwise_unary_op; @@ -316,7 +317,7 @@ impl BitBuffer { /// Get the number of set bits in the buffer. pub fn true_count(&self) -> usize { - self.unaligned_chunks().count_ones() + count_ones(self.buffer.as_slice(), self.offset, self.len) } /// Get the number of unset bits in the buffer. diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs new file mode 100644 index 00000000000..fec9d72054a --- /dev/null +++ b/vortex-buffer/src/bit/count_ones.rs @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#[cfg(target_arch = "x86_64")] +use vortex_error::VortexExpect; + +#[inline] +pub fn count_ones(bytes: &[u8], offset: usize, len: usize) -> usize { + if bytes.is_empty() { + return 0; + } + + let (head, middle, tail) = align_offset_len(bytes, offset, len); + + let mut count = head.map_or(0, |v| v.count_ones() as usize); + + if !middle.is_empty() { + count += count_ones_aligned(middle); + } + + count + tail.map_or(0, |v| v.count_ones() as usize) +} + +#[inline] +fn align_offset_len(bytes: &[u8], offset: usize, len: usize) -> (Option, &[u8], Option) { + let start_byte = offset / 8; + let start_bit = offset % 8; + let end_bit = offset + len; + let end_byte = end_bit / 8; + let head = (start_bit != 0).then(|| { + let start_len = (8 - start_bit).min(len); + mask_byte(bytes[start_byte], start_bit, start_len) + }); + + let middle_start = start_byte + usize::from(start_bit != 0); + let middle_end = end_byte; + let middle = if middle_start < middle_end { + &bytes[middle_start..middle_end] + } else { + &[] + }; + + let consumed = if start_bit != 0 { + (8 - start_bit).min(len) + } else { + 0 + } + middle.len() * 8; + let tail_len = len - consumed; + let tail = (tail_len != 0).then(|| mask_byte(bytes[middle_end], 0, tail_len)); + + (head, middle, tail) +} + +#[inline] +fn mask_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { + debug_assert!(bit_offset < 8); + debug_assert!(bit_len <= 8 - bit_offset); + + let shifted = byte >> bit_offset; + let mask = if bit_len == 8 { + u8::MAX + } else { + (1u8 << bit_len) - 1 + }; + + shifted & mask +} + +#[inline] +fn count_ones_aligned(bytes: &[u8]) -> usize { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512f") + && is_x86_feature_detected!("avx512vpopcntdq") + && bytes.len() >= 64 + { + // SAFETY: Runtime detection guarantees the required target features. + return unsafe { count_ones_aligned_avx512(bytes) }; + } + + if is_x86_feature_detected!("avx2") && bytes.len() >= 32 { + // SAFETY: Runtime detection guarantees the required target features. + return unsafe { count_ones_aligned_avx2(bytes) }; + } + } + + count_ones_aligned_scalar(bytes) +} + +#[inline] +fn count_ones_aligned_scalar(bytes: &[u8]) -> usize { + let (words, tail) = bytes.as_chunks::<8>(); + let count = words + .iter() + .map(|word| u64::from_le_bytes(*word).count_ones() as usize) + .sum::(); + + count + + tail + .iter() + .map(|byte| byte.count_ones() as usize) + .sum::() +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn count_ones_aligned_avx2(bytes: &[u8]) -> usize { + use std::arch::x86_64::__m256i; + use std::arch::x86_64::_mm256_add_epi8; + use std::arch::x86_64::_mm256_add_epi64; + use std::arch::x86_64::_mm256_and_si256; + use std::arch::x86_64::_mm256_loadu_si256; + use std::arch::x86_64::_mm256_sad_epu8; + use std::arch::x86_64::_mm256_set1_epi8; + use std::arch::x86_64::_mm256_setr_epi8; + use std::arch::x86_64::_mm256_setzero_si256; + use std::arch::x86_64::_mm256_shuffle_epi8; + use std::arch::x86_64::_mm256_srli_epi16; + use std::arch::x86_64::_mm256_storeu_si256; + + #[inline] + unsafe fn byte_popcount(chunk: __m256i, mask: __m256i, lookup: __m256i) -> __m256i { + let lo = unsafe { _mm256_and_si256(chunk, mask) }; + let hi = unsafe { _mm256_and_si256(_mm256_srli_epi16(chunk, 4), mask) }; + unsafe { + _mm256_add_epi8( + _mm256_shuffle_epi8(lookup, lo), + _mm256_shuffle_epi8(lookup, hi), + ) + } + } + + let lookup = _mm256_setr_epi8( + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, + 3, 4, + ); + let mask = _mm256_set1_epi8(0x0f); + let zero = _mm256_setzero_si256(); + let mut accum = _mm256_setzero_si256(); + let mut index = 0; + + while index + 128 <= bytes.len() { + for lane in 0..4 { + let ptr = unsafe { bytes.as_ptr().add(index + lane * 32) }.cast::<__m256i>(); + let chunk = unsafe { _mm256_loadu_si256(ptr) }; + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); + } + index += 128; + } + + while index + 32 <= bytes.len() { + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m256i>(); + let chunk = unsafe { _mm256_loadu_si256(ptr) }; + let counts = unsafe { byte_popcount(chunk, mask, lookup) }; + accum = _mm256_add_epi64(accum, _mm256_sad_epu8(counts, zero)); + index += 32; + } + + let mut lanes = [0u64; 4]; + unsafe { _mm256_storeu_si256(lanes.as_mut_ptr().cast::<__m256i>(), accum) }; + + usize::try_from(lanes.iter().sum::()).vortex_expect("true_count doesn't fit in usize") + + count_ones_aligned_scalar(&bytes[index..]) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f,avx512vpopcntdq")] +unsafe fn count_ones_aligned_avx512(bytes: &[u8]) -> usize { + use std::arch::x86_64::__m512i; + use std::arch::x86_64::_mm512_add_epi64; + use std::arch::x86_64::_mm512_loadu_si512; + use std::arch::x86_64::_mm512_popcnt_epi64; + use std::arch::x86_64::_mm512_setzero_si512; + use std::arch::x86_64::_mm512_storeu_si512; + + let mut accum = _mm512_setzero_si512(); + let mut index = 0; + + while index + 64 <= bytes.len() { + let ptr = unsafe { bytes.as_ptr().add(index) }.cast::<__m512i>(); + let chunk = unsafe { _mm512_loadu_si512(ptr) }; + accum = _mm512_add_epi64(accum, _mm512_popcnt_epi64(chunk)); + index += 64; + } + + let mut lanes = [0u64; 8]; + unsafe { _mm512_storeu_si512(lanes.as_mut_ptr().cast::<__m512i>(), accum) }; + + usize::try_from(lanes.iter().sum::()).vortex_expect("true_count doesn't fit in usize") + + count_ones_aligned_scalar(&bytes[index..]) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::BitBuffer; + + #[rstest] + fn test_count_ones_matches_iteration_for_slices( + #[values( + 0usize, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30 + )] + offset: usize, + #[values( + 0usize, 1, 2, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 255, 256, 257, 513 + )] + slice_len: usize, + ) { + let len = 513; + let buf = BitBuffer::collect_bool(len + 31, |i| (i % 3 == 0) ^ (i % 11 == 0)); + + if offset + slice_len > buf.len() { + return; + } + + let sliced = buf.slice(offset..offset + slice_len); + let expected = sliced.iter().filter(|bit| *bit).count(); + + assert_eq!( + sliced.true_count(), + expected, + "offset={offset} len={slice_len}" + ); + } +} diff --git a/vortex-buffer/src/bit/mod.rs b/vortex-buffer/src/bit/mod.rs index 5ca932c0187..034be84a18c 100644 --- a/vortex-buffer/src/bit/mod.rs +++ b/vortex-buffer/src/bit/mod.rs @@ -10,6 +10,7 @@ mod arrow; mod buf; mod buf_mut; +mod count_ones; mod macros; mod ops; From 9a5d9d830262925a0fc0bc79da7688ed998dd4e8 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 13 Mar 2026 14:11:14 +0000 Subject: [PATCH 2/3] original Signed-off-by: Robert Kruszewski --- vortex-buffer/src/bit/count_ones.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs index fec9d72054a..2fb3250b311 100644 --- a/vortex-buffer/src/bit/count_ones.rs +++ b/vortex-buffer/src/bit/count_ones.rs @@ -70,15 +70,15 @@ fn mask_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { fn count_ones_aligned(bytes: &[u8]) -> usize { #[cfg(target_arch = "x86_64")] { - if is_x86_feature_detected!("avx512f") + if bytes.len() >= 64 + && is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vpopcntdq") - && bytes.len() >= 64 { // SAFETY: Runtime detection guarantees the required target features. return unsafe { count_ones_aligned_avx512(bytes) }; } - if is_x86_feature_detected!("avx2") && bytes.len() >= 32 { + if bytes.len() >= 32 && is_x86_feature_detected!("avx2") { // SAFETY: Runtime detection guarantees the required target features. return unsafe { count_ones_aligned_avx2(bytes) }; } From d3c062daee8a6f155d5814654426f76175ce6970 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Sat, 14 Mar 2026 02:04:38 +0000 Subject: [PATCH 3/3] trythis --- vortex-buffer/src/bit/count_ones.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vortex-buffer/src/bit/count_ones.rs b/vortex-buffer/src/bit/count_ones.rs index 2fb3250b311..37cbdbd49c1 100644 --- a/vortex-buffer/src/bit/count_ones.rs +++ b/vortex-buffer/src/bit/count_ones.rs @@ -70,6 +70,24 @@ fn mask_byte(byte: u8, bit_offset: usize, bit_len: usize) -> u8 { fn count_ones_aligned(bytes: &[u8]) -> usize { #[cfg(target_arch = "x86_64")] { + // When the target feature is guaranteed at compile time, skip runtime detection. + #[cfg(all(target_feature = "avx512f", target_feature = "avx512vpopcntdq"))] + if bytes.len() >= 64 { + // SAFETY: Compile-time target feature guarantees availability. + return unsafe { count_ones_aligned_avx512(bytes) }; + } + + #[cfg(all( + not(all(target_feature = "avx512f", target_feature = "avx512vpopcntdq")), + target_feature = "avx2" + ))] + if bytes.len() >= 32 { + // SAFETY: Compile-time target feature guarantees availability. + return unsafe { count_ones_aligned_avx2(bytes) }; + } + + // Fall back to runtime detection when features aren't compile-time guaranteed. + #[cfg(not(all(target_feature = "avx512f", target_feature = "avx512vpopcntdq")))] if bytes.len() >= 64 && is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vpopcntdq") @@ -78,6 +96,7 @@ fn count_ones_aligned(bytes: &[u8]) -> usize { return unsafe { count_ones_aligned_avx512(bytes) }; } + #[cfg(not(target_feature = "avx2"))] if bytes.len() >= 32 && is_x86_feature_detected!("avx2") { // SAFETY: Runtime detection guarantees the required target features. return unsafe { count_ones_aligned_avx2(bytes) };