diff --git a/src/StackExchange.Redis/RedisValue.cs b/src/StackExchange.Redis/RedisValue.cs index fb2a6c0b0..3d0c5714e 100644 --- a/src/StackExchange.Redis/RedisValue.cs +++ b/src/StackExchange.Redis/RedisValue.cs @@ -469,6 +469,18 @@ internal string RawString() // memory / short-blob / sequence) compare by raw bytes in any combination if (IsBlob(xType) && IsBlob(yType)) return BlobSequenceEqual(x, y); + switch (xType) + { + case StorageType.Sequence when yType == StorageType.String: + return y.RawStringEquals(x.RawSequence()); + case StorageType.String when yType == StorageType.Sequence: + return x.RawStringEquals(y.RawSequence()); + case StorageType.ByteArray or StorageType.MemoryManager or StorageType.ShortBlob when yType == StorageType.String: + return y.RawStringEquals(x.UnsafeRawSpan(out _)); + case StorageType.String when yType is StorageType.ByteArray or StorageType.MemoryManager or StorageType.ShortBlob: + return x.RawStringEquals(y.UnsafeRawSpan(out _)); + } + // otherwise (anything involving a string), compare as strings return (string?)x == (string?)y; } @@ -559,12 +571,15 @@ internal static unsafe bool Equals(byte[]? x, byte[]? y) return true; } - private static int AddHashCode(ReadOnlySpan span, int acc) + // used by RedisKey, whose equality is byte-based (unlike RedisValue, which treats non-numeric + // buffers as strings - see GetHashCode(RedisValue)) + internal static int GetHashCode(ReadOnlySpan span) { unchecked { int len = span.Length; - Debug.Assert(len > 0); + if (len == 0) return 0; + var acc = 728271210; var span64 = MemoryMarshal.Cast(span); for (int i = 0; i < span64.Length; i++) @@ -582,16 +597,43 @@ private static int AddHashCode(ReadOnlySpan span, int acc) } } - // used by RedisKey, whose equality is byte-based (unlike RedisValue, which treats non-numeric - // buffers as strings - see GetHashCode(RedisValue)) - internal static int GetHashCode(ReadOnlySpan span) - { - if (span.Length == 0) return 0; + private bool RawStringEquals(ReadOnlySpan span) + { + string s = RawString(); + var length = s.Length; + if (length == 0) return span.IsEmpty; + var maxChars = Encoding.UTF8.GetMaxCharCount(span.Length); + if (length > maxChars) return false; + + byte[]? leased = null; + var maxBytes = Encoding.UTF8.GetMaxByteCount(length); + Span bytes = maxBytes <= StackByteLimit ? stackalloc byte[maxBytes] : (leased = ArrayPool.Shared.Rent(maxBytes)); + var written = Encoding.UTF8.GetBytes(s, bytes); + var result = span.SequenceEqual(bytes.Slice(0, written)); + if (leased is not null) ArrayPool.Shared.Return(leased); + return result; + } - return AddHashCode(span, HashCodeStart); + private bool RawStringEquals(ReadOnlySequence seq) + { + string s = RawString(); + var length = s.Length; + var seqLength = seq.Length; + if (length == 0) return seqLength == 0; + if (seq.Length > int.MaxValue) return false; + var maxChars = Encoding.UTF8.GetMaxCharCount(checked((int)seqLength)); + if (length > maxChars) return false; + + byte[]? leased = null; + var maxBytes = Encoding.UTF8.GetMaxByteCount(length); + Span bytes = maxBytes <= StackByteLimit ? stackalloc byte[maxBytes] : (leased = ArrayPool.Shared.Rent(maxBytes)); + var written = Encoding.UTF8.GetBytes(s, bytes); + var result = seq.SequenceEqual(bytes.Slice(0, written)); + if (leased is not null) ArrayPool.Shared.Return(leased); + return result; } - private const int HashCodeStart = 728271210; + private const int StackByteLimit = 512; internal void AssertNotNull() { @@ -1791,7 +1833,8 @@ public bool StartsWith(ReadOnlySpan value) return buffer.Slice(0, len).StartsWith(value); case StorageType.String: var s = RawString().AsSpan(); - if (s.Length < value.Length) return false; // not enough characters to match + // BUG if Not ASCII + // if (s.Length < value.Length) return false; // not enough characters to match if (s.Length > value.Length) s = s.Slice(0, value.Length); // only need to match the prefix var maxBytes = Encoding.UTF8.GetMaxByteCount(s.Length); byte[]? lease = null;