diff --git a/BitFaster.Caching/Lfu/CmSketchCore.cs b/BitFaster.Caching/Lfu/CmSketchCore.cs index 9764c51b..3a659c64 100644 --- a/BitFaster.Caching/Lfu/CmSketchCore.cs +++ b/BitFaster.Caching/Lfu/CmSketchCore.cs @@ -37,6 +37,7 @@ public unsafe class CmSketchCore { private const long ResetMask = 0x7777777777777777L; private const long OneMask = 0x1111111111111111L; + private const nuint CacheLineAlignmentMask = 63; private long[] table; #if NET6_0_OR_GREATER @@ -152,7 +153,7 @@ private void EnsureCapacity(long maximumSize) table = GC.AllocateArray(Math.Max(BitOps.CeilingPowerOfTwo(maximum), 8) + pad, pinned); tableAddr = (long*)Unsafe.AsPointer(ref table[0]); - tableAddr = (long*)((long)tableAddr + (long)tableAddr % 64); + tableAddr = (long*)(((nuint)tableAddr + CacheLineAlignmentMask) & ~CacheLineAlignmentMask); blockMask = (int)((uint)(table.Length - pad) >> 3) - 1; } @@ -299,11 +300,22 @@ private unsafe int EstimateFrequencyAvx(T value) int counterHash = Rehash(blockHash); int block = (blockHash & blockMask) << 3; - Vector128 h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); - Vector128 index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2); - Vector128 blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + int h0 = counterHash; + int h1 = counterHash >>> 8; + int h2 = counterHash >>> 16; + int h3 = counterHash >>> 24; - Vector256 indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64(); + int index0 = ((h0 >>> 1) & 15) << 2; + int index1 = ((h1 >>> 1) & 15) << 2; + int index2 = ((h2 >>> 1) & 15) << 2; + int index3 = ((h3 >>> 1) & 15) << 2; + + int lane0 = h0 & 1; + int lane1 = (h1 & 1) + 2; + int lane2 = h2 & 1; + int lane3 = (h3 & 1) + 2; + + Vector256 index = Vector256.Create((ulong)index0, (ulong)index1, (ulong)index2, (ulong)index3); #if NET6_0_OR_GREATER long* tablePtr = tableAddr; @@ -311,7 +323,21 @@ private unsafe int EstimateFrequencyAvx(T value) fixed (long* tablePtr = table) #endif { - Vector128 count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7)) +#if NET6_0_OR_GREATER + Vector256 lower = Avx.LoadAlignedVector256(tablePtr + block); + Vector256 upper = Avx.LoadAlignedVector256(tablePtr + block + 4); +#else + Vector256 lower = Avx.LoadVector256(tablePtr + block); + Vector256 upper = Avx.LoadVector256(tablePtr + block + 4); +#endif + + Vector256 countVector = Vector256.Create( + (ulong)lower.GetElement(lane0), + (ulong)lower.GetElement(lane1), + (ulong)upper.GetElement(lane2), + (ulong)upper.GetElement(lane3)); + + Vector128 count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(countVector, index), Vector256.Create(0xfUL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7)) .GetLower() .AsUInt16(); @@ -333,12 +359,32 @@ private unsafe void IncrementAvx(T value) int counterHash = Rehash(blockHash); int block = (blockHash & blockMask) << 3; - Vector128 h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32(); - Vector128 index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2); - Vector128 blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + int h0 = counterHash; + int h1 = counterHash >>> 8; + int h2 = counterHash >>> 16; + int h3 = counterHash >>> 24; + + int index0 = ((h0 >>> 1) & 15) << 2; + int index1 = ((h1 >>> 1) & 15) << 2; + int index2 = ((h2 >>> 1) & 15) << 2; + int index3 = ((h3 >>> 1) & 15) << 2; + + int lane0 = h0 & 1; + int lane1 = (h1 & 1) + 2; + int lane2 = h2 & 1; + int lane3 = (h3 & 1) + 2; + + Vector256 laneOffsets = Vector256.Create(0L, 1L, 2L, 3L); + Vector256 lowerIndex = Vector256.Create((ulong)index0, (ulong)index0, (ulong)index1, (ulong)index1); + Vector256 upperIndex = Vector256.Create((ulong)index2, (ulong)index2, (ulong)index3, (ulong)index3); - Vector256 offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64(); - Vector256 mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong); + Vector256 lowerLaneMask = Avx2.CompareEqual( + laneOffsets, + Vector256.Create((long)lane0, (long)lane0, (long)lane1, (long)lane1)); + + Vector256 upperLaneMask = Avx2.CompareEqual( + laneOffsets, + Vector256.Create((long)lane2, (long)lane2, (long)lane3, (long)lane3)); #if NET6_0_OR_GREATER long* tablePtr = tableAddr; @@ -346,18 +392,37 @@ private unsafe void IncrementAvx(T value) fixed (long* tablePtr = table) #endif { +#if NET6_0_OR_GREATER + Vector256 lower = Avx.LoadAlignedVector256(tablePtr + block); + Vector256 upper = Avx.LoadAlignedVector256(tablePtr + block + 4); +#else + Vector256 lower = Avx.LoadVector256(tablePtr + block); + Vector256 upper = Avx.LoadVector256(tablePtr + block + 4); +#endif + + Vector256 lowerMask = Avx2.And(Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), lowerIndex), lowerLaneMask); + Vector256 upperMask = Avx2.And(Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), upperIndex), upperLaneMask); + // Note masked is 'equal' - therefore use AndNot below - Vector256 masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask); + Vector256 lowerMasked = Avx2.CompareEqual(Avx2.And(lower, lowerMask), lowerMask); + Vector256 upperMasked = Avx2.CompareEqual(Avx2.And(upper, upperMask), upperMask); // Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters) - Vector256 inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong)); + Vector256 lowerInc = Avx2.And(Avx2.AndNot(lowerMasked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), lowerIndex)), lowerLaneMask); + Vector256 upperInc = Avx2.And(Avx2.AndNot(upperMasked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), upperIndex)), upperLaneMask); + + bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(Avx2.Or(lowerInc, upperInc).AsByte(), Vector256.Zero).AsByte()) != unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111)); - bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111)); + lower = Avx2.Add(lower, lowerInc); + upper = Avx2.Add(upper, upperInc); - tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0); - tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1); - tablePtr[blockOffset.GetElement(2)] += inc.GetElement(2); - tablePtr[blockOffset.GetElement(3)] += inc.GetElement(3); +#if NET6_0_OR_GREATER + Avx.StoreAligned(tablePtr + block, lower); + Avx.StoreAligned(tablePtr + block + 4, upper); +#else + Avx.Store(tablePtr + block, lower); + Avx.Store(tablePtr + block + 4, upper); +#endif if (wasInc && (++size == sampleSize)) {