Skip to content
101 changes: 83 additions & 18 deletions BitFaster.Caching/Lfu/CmSketchCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public unsafe class CmSketchCore<T, I>
{
private const long ResetMask = 0x7777777777777777L;
private const long OneMask = 0x1111111111111111L;
private const nuint CacheLineAlignmentMask = 63;

private long[] table;
#if NET6_0_OR_GREATER
Expand Down Expand Up @@ -152,7 +153,7 @@ private void EnsureCapacity(long maximumSize)
table = GC.AllocateArray<long>(Math.Max(BitOps.CeilingPowerOfTwo(maximum), 8) + pad, pinned);

tableAddr = (long*)Unsafe.AsPointer(ref table[0]);
tableAddr = (long*)((long)tableAddr + (long)tableAddr % 64);
tableAddr = (long*)(((nuint)tableAddr + CacheLineAlignmentMask) & ~CacheLineAlignmentMask);

blockMask = (int)((uint)(table.Length - pad) >> 3) - 1;
}
Expand Down Expand Up @@ -299,19 +300,44 @@ private unsafe int EstimateFrequencyAvx(T value)
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
int h0 = counterHash;
int h1 = counterHash >>> 8;
int h2 = counterHash >>> 16;
int h3 = counterHash >>> 24;

Vector256<ulong> indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
int index0 = ((h0 >>> 1) & 15) << 2;
int index1 = ((h1 >>> 1) & 15) << 2;
int index2 = ((h2 >>> 1) & 15) << 2;
int index3 = ((h3 >>> 1) & 15) << 2;

int lane0 = h0 & 1;
int lane1 = (h1 & 1) + 2;
int lane2 = h2 & 1;
int lane3 = (h3 & 1) + 2;

Vector256<ulong> index = Vector256.Create((ulong)index0, (ulong)index1, (ulong)index2, (ulong)index3);

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
#if NET6_0_OR_GREATER
Vector256<long> lower = Avx.LoadAlignedVector256(tablePtr + block);
Vector256<long> upper = Avx.LoadAlignedVector256(tablePtr + block + 4);
#else
Vector256<long> lower = Avx.LoadVector256(tablePtr + block);
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can this use the load aligned vector method instead to make it faster?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in cfbd9ac to use aligned AVX loads/stores on the pre-aligned table pointer, and corrected the pointer alignment math so those aligned accesses are safe.

Vector256<long> upper = Avx.LoadVector256(tablePtr + block + 4);
#endif

Vector256<ulong> countVector = Vector256.Create(
(ulong)lower.GetElement(lane0),
(ulong)lower.GetElement(lane1),
(ulong)upper.GetElement(lane2),
(ulong)upper.GetElement(lane3));

Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(countVector, index), Vector256.Create(0xfUL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
.GetLower()
.AsUInt16();

Expand All @@ -333,31 +359,70 @@ private unsafe void IncrementAvx(T value)
int counterHash = Rehash(blockHash);
int block = (blockHash & blockMask) << 3;

Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
int h0 = counterHash;
int h1 = counterHash >>> 8;
int h2 = counterHash >>> 16;
int h3 = counterHash >>> 24;

int index0 = ((h0 >>> 1) & 15) << 2;
int index1 = ((h1 >>> 1) & 15) << 2;
int index2 = ((h2 >>> 1) & 15) << 2;
int index3 = ((h3 >>> 1) & 15) << 2;

int lane0 = h0 & 1;
int lane1 = (h1 & 1) + 2;
int lane2 = h2 & 1;
int lane3 = (h3 & 1) + 2;

Vector256<long> laneOffsets = Vector256.Create(0L, 1L, 2L, 3L);
Vector256<ulong> lowerIndex = Vector256.Create((ulong)index0, (ulong)index0, (ulong)index1, (ulong)index1);
Vector256<ulong> upperIndex = Vector256.Create((ulong)index2, (ulong)index2, (ulong)index3, (ulong)index3);

Vector256<ulong> offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong);
Vector256<long> lowerLaneMask = Avx2.CompareEqual(
laneOffsets,
Vector256.Create((long)lane0, (long)lane0, (long)lane1, (long)lane1));

Vector256<long> upperLaneMask = Avx2.CompareEqual(
laneOffsets,
Vector256.Create((long)lane2, (long)lane2, (long)lane3, (long)lane3));

#if NET6_0_OR_GREATER
long* tablePtr = tableAddr;
#else
fixed (long* tablePtr = table)
#endif
{
#if NET6_0_OR_GREATER
Vector256<long> lower = Avx.LoadAlignedVector256(tablePtr + block);
Vector256<long> upper = Avx.LoadAlignedVector256(tablePtr + block + 4);
#else
Vector256<long> lower = Avx.LoadVector256(tablePtr + block);
Vector256<long> upper = Avx.LoadVector256(tablePtr + block + 4);
#endif

Vector256<long> lowerMask = Avx2.And(Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), lowerIndex), lowerLaneMask);
Vector256<long> upperMask = Avx2.And(Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), upperIndex), upperLaneMask);

// Note masked is 'equal' - therefore use AndNot below
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask);
Vector256<long> lowerMasked = Avx2.CompareEqual(Avx2.And(lower, lowerMask), lowerMask);
Vector256<long> upperMasked = Avx2.CompareEqual(Avx2.And(upper, upperMask), upperMask);

// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
Vector256<long> inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong));
Vector256<long> lowerInc = Avx2.And(Avx2.AndNot(lowerMasked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), lowerIndex)), lowerLaneMask);
Vector256<long> upperInc = Avx2.And(Avx2.AndNot(upperMasked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), upperIndex)), upperLaneMask);

bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(Avx2.Or(lowerInc, upperInc).AsByte(), Vector256<byte>.Zero).AsByte()) != unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));

bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
lower = Avx2.Add(lower, lowerInc);
upper = Avx2.Add(upper, upperInc);

tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);
tablePtr[blockOffset.GetElement(2)] += inc.GetElement(2);
tablePtr[blockOffset.GetElement(3)] += inc.GetElement(3);
#if NET6_0_OR_GREATER
Avx.StoreAligned(tablePtr + block, lower);
Avx.StoreAligned(tablePtr + block + 4, upper);
#else
Avx.Store(tablePtr + block, lower);
Avx.Store(tablePtr + block + 4, upper);
#endif

if (wasInc && (++size == sampleSize))
{
Expand Down