diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheAsyncAlternateLookupTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheAsyncAlternateLookupTests.cs index ba5524f2..d54ee99c 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheAsyncAlternateLookupTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheAsyncAlternateLookupTests.cs @@ -157,6 +157,25 @@ public async Task AsyncAlternateLookupGetOrAddAsyncWithArgUsesActualKeyOnMissAnd cache.TryGet("42", out var value).Should().BeTrue(); value.Should().Be("value-42"); } + + [Fact] + public async Task AsyncAlternateLookupGetOrAddAsyncWithRefStructArgUsesActualKeyOnMissAndHit() + { + var alternate = cache.GetAsyncAlternateLookup>(); + var factoryCalls = 0; + + var result = await alternate.GetOrAddAsync("42".AsSpan(), static (key, argument) => Task.FromResult($"{key}-{argument.Length}"), "xx".AsSpan()); + result.Should().Be("42-2"); + + result = await alternate.GetOrAddAsync("42".AsSpan(), (key, argument) => + { + factoryCalls++; + return Task.FromResult($"{key}-{argument.Length}"); + }, "ignored".AsSpan()); + + result.Should().Be("42-2"); + factoryCalls.Should().Be(0); + } } } #endif diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheTests.cs index 13afc71e..a9662486 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryAsyncCacheTests.cs @@ -69,6 +69,16 @@ public async Task WhenItemIsAddedWithArgValueIsCorrect() value.Should().Be(3); } +#if NET9_0_OR_GREATER + [Fact] + public async Task GetOrAddAsyncWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var value = await this.cache.GetOrAddAsync(1, static (key, argument) => Task.FromResult(key + argument.Length), "xx".AsSpan()); + + value.Should().Be(3); + } +#endif + #if NET9_0_OR_GREATER [Fact] public void ComparerReturnsConfiguredComparer() diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs index af00322d..935cab14 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs @@ -114,6 +114,25 @@ public void AlternateLookupGetOrAddWithArgUsesActualKeyOnMissAndHit() factoryCalls.Should().Be(1); } + [Fact] + public void AlternateLookupGetOrAddWithRefStructArgUsesActualKeyOnMissAndHit() + { + var alternate = cache.GetAlternateLookup>(); + var factoryCalls = 0; + + var result = alternate.GetOrAdd("42".AsSpan(), static (key, argument) => $"{key}-{argument.Length}", "xx".AsSpan()); + result.Should().Be("42-2"); + + result = alternate.GetOrAdd("42".AsSpan(), (key, argument) => + { + factoryCalls++; + return $"{key}-{argument.Length}"; + }, "ignored".AsSpan()); + + result.Should().Be("42-2"); + factoryCalls.Should().Be(0); + } + [Fact] public void AlternateLookupTryUpdateReturnsFalseForMissingKeyAndUpdatesExistingValue() { diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheTests.cs index 4433ab70..6a88ae01 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheTests.cs @@ -68,6 +68,14 @@ public void WhenItemIsAddedWithArgValueIsCorrect() value.Should().Be(3); } +#if NET9_0_OR_GREATER + [Fact] + public void GetOrAddWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + this.cache.GetOrAdd(1, static (key, argument) => key + argument.Length, "xx".AsSpan()).Should().Be(3); + } +#endif + #if NET9_0_OR_GREATER [Fact] public void ComparerReturnsConfiguredComparer() diff --git a/BitFaster.Caching.UnitTests/CacheTests.cs b/BitFaster.Caching.UnitTests/CacheTests.cs index 3f000f3a..b93f05f8 100644 --- a/BitFaster.Caching.UnitTests/CacheTests.cs +++ b/BitFaster.Caching.UnitTests/CacheTests.cs @@ -1,8 +1,9 @@  -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using BitFaster.Caching.Lru; +using FluentAssertions; using Moq; using Xunit; @@ -11,10 +12,10 @@ namespace BitFaster.Caching.UnitTests // Tests for interface default implementations. public class CacheTests { - // backcompat: remove conditional compile -#if NETCOREAPP3_0_OR_GREATER - [Fact] - public void WhenCacheInterfaceDefaultGetOrAddFallback() + // backcompat: remove conditional compile +#if NETCOREAPP3_0_OR_GREATER && !NET9_0_OR_GREATER + [Fact] + public void WhenCacheInterfaceDefaultGetOrAddFallback() { var cache = new Mock>(); cache.CallBase = true; @@ -51,8 +52,8 @@ public void WhenCacheInterfaceDefaultTryRemoveKeyValueThrows() } [Fact] - public async Task WhenAsyncCacheInterfaceDefaultGetOrAddFallback() - { + public async Task WhenAsyncCacheInterfaceDefaultGetOrAddFallback() + { var cache = new Mock>(); cache.CallBase = true; @@ -64,11 +65,14 @@ public async Task WhenAsyncCacheInterfaceDefaultGetOrAddFallback() (k, a) => Task.FromResult(k + a), 2); - r.Should().Be(3); - } - - [Fact] - public void WhenAsyncCacheInterfaceDefaultTryRemoveKeyThrows() + r.Should().Be(3); + } +#endif + +#if NETCOREAPP3_0_OR_GREATER + + [Fact] + public void WhenAsyncCacheInterfaceDefaultTryRemoveKeyThrows() { var cache = new Mock>(); cache.CallBase = true; @@ -138,6 +142,26 @@ public async Task WhenScopedAsyncCacheInterfaceDefaultGetOrAddFallback() } #if NET9_0_OR_GREATER + [Fact] + public void GetOrAddWithRefStructArgViaCacheInterfaceWhenValueMissingReturnsCreatedValue() + { + ICache cache = new ClassicLru(3); + + var value = cache.GetOrAdd(1, static (key, argument) => key + argument.Length, "xx".AsSpan()); + + value.Should().Be(3); + } + + [Fact] + public async Task GetOrAddAsyncWithRefStructArgViaAsyncCacheInterfaceWhenValueMissingReturnsCreatedValue() + { + IAsyncCache cache = new ClassicLru(3); + + var value = await cache.GetOrAddAsync(1, static (key, argument) => Task.FromResult(key + argument.Length), "xx".AsSpan()); + + value.Should().Be(3); + } + [Fact] public void WhenCacheInterfaceDefaultGetAlternateLookupThrows() { diff --git a/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs b/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs index 215ce4c9..6d5d5832 100644 --- a/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs +++ b/BitFaster.Caching.UnitTests/Lfu/ConcurrentLfuTests.cs @@ -79,14 +79,24 @@ public void WhenKeyIsRequestedItIsCreatedAndCached() } [Fact] - public void WhenKeyIsRequestedWithArgItIsCreatedAndCached() - { - var result1 = cache.GetOrAdd(1, valueFactory.Create, 9); - var result2 = cache.GetOrAdd(1, valueFactory.Create, 17); + public void WhenKeyIsRequestedWithArgItIsCreatedAndCached() + { + var result1 = cache.GetOrAdd(1, valueFactory.Create, 9); + var result2 = cache.GetOrAdd(1, valueFactory.Create, 17); - valueFactory.timesCalled.Should().Be(1); - result1.Should().Be(result2); - } + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + +#if NET9_0_OR_GREATER + [Fact] + public void GetOrAddWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = cache.GetOrAdd(1, static (key, argument) => key + argument.Length, "xx".AsSpan()); + + result.Should().Be(3); + } +#endif [Fact] public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() @@ -99,14 +109,24 @@ public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() } [Fact] - public async Task WhenKeyIsRequestedWithArgItIsCreatedAndCachedAsync() - { - var result1 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 9); - var result2 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 17); + public async Task WhenKeyIsRequestedWithArgItIsCreatedAndCachedAsync() + { + var result1 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 9); + var result2 = await cache.GetOrAddAsync(1, valueFactory.CreateAsync, 17); - valueFactory.timesCalled.Should().Be(1); - result1.Should().Be(result2); - } + valueFactory.timesCalled.Should().Be(1); + result1.Should().Be(result2); + } + +#if NET9_0_OR_GREATER + [Fact] + public async Task GetOrAddAsyncWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = await cache.GetOrAddAsync(1, static (key, argument) => Task.FromResult(key + argument.Length), "xx".AsSpan()); + + result.Should().Be(3); + } +#endif [Fact] public void WhenItemsAddedExceedsCapacityItemsAreDiscarded() diff --git a/BitFaster.Caching.UnitTests/Lru/ClassicLruTests.cs b/BitFaster.Caching.UnitTests/Lru/ClassicLruTests.cs index d3dc13f1..2e861c6a 100644 --- a/BitFaster.Caching.UnitTests/Lru/ClassicLruTests.cs +++ b/BitFaster.Caching.UnitTests/Lru/ClassicLruTests.cs @@ -203,6 +203,16 @@ public void WhenKeyIsRequestedWithArgItIsCreatedAndCached() result1.Should().Be(result2); } +#if NET9_0_OR_GREATER + [Fact] + public void GetOrAddWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = lru.GetOrAdd(1, static (key, argument) => $"{key}-{argument.Length}", "xx".AsSpan()); + + result.Should().Be("1-2"); + } +#endif + [Fact] public async Task WhenKeyIsRequesteItIsCreatedAndCachedAsync() { @@ -223,6 +233,16 @@ public async Task WhenKeyIsRequestedWithArgItIsCreatedAndCachedAsync() result1.Should().Be(result2); } +#if NET9_0_OR_GREATER + [Fact] + public async Task GetOrAddAsyncWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = await lru.GetOrAddAsync(1, static (key, argument) => Task.FromResult($"{key}-{argument.Length}"), "xx".AsSpan()); + + result.Should().Be("1-2"); + } +#endif + [Fact] public void WhenDifferentKeysAreRequestedValueIsCreatedForEach() { diff --git a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs index fee2a917..3643001d 100644 --- a/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs +++ b/BitFaster.Caching.UnitTests/Lru/ConcurrentLruTests.cs @@ -141,12 +141,30 @@ public void WhenItemIsAddedCountIsCorrect() } [Fact] - public async Task WhenItemIsAddedCountIsCorrectAsync() - { - lru.Count.Should().Be(0); - await lru.GetOrAddAsync(0, valueFactory.CreateAsync); - lru.Count.Should().Be(1); - } + public async Task WhenItemIsAddedCountIsCorrectAsync() + { + lru.Count.Should().Be(0); + await lru.GetOrAddAsync(0, valueFactory.CreateAsync); + lru.Count.Should().Be(1); + } + +#if NET9_0_OR_GREATER + [Fact] + public void GetOrAddWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = lru.GetOrAdd(1, static (key, argument) => $"{key}-{argument.Length}", "xx".AsSpan()); + + result.Should().Be("1-2"); + } + + [Fact] + public async Task GetOrAddAsyncWithRefStructArgWhenValueMissingReturnsCreatedValue() + { + var result = await lru.GetOrAddAsync(1, static (key, argument) => Task.FromResult($"{key}-{argument.Length}"), "xx".AsSpan()); + + result.Should().Be("1-2"); + } +#endif [Fact] public void WhenItemsAddedKeysContainsTheKeys() diff --git a/BitFaster.Caching/Atomic/AsyncAtomicFactory.cs b/BitFaster.Caching/Atomic/AsyncAtomicFactory.cs index ff48ce40..a27943c6 100644 --- a/BitFaster.Caching/Atomic/AsyncAtomicFactory.cs +++ b/BitFaster.Caching/Atomic/AsyncAtomicFactory.cs @@ -53,8 +53,8 @@ public async ValueTask GetValueAsync(K key, Func> valueFactory) } return await CreateValueAsync(key, new AsyncValueFactory(valueFactory)).ConfigureAwait(false); - } - + } + /// /// Gets the value. If is false, calling will force initialization via the parameter. /// @@ -63,15 +63,24 @@ public async ValueTask GetValueAsync(K key, Func> valueFactory) /// The value factory to use to create the value when it is not initialized. /// The value factory argument. /// The value. - public async ValueTask GetValueAsync(K key, Func> valueFactory, TArg factoryArgument) - { - if (initializer == null) - { - return value!; - } - - return await CreateValueAsync(key, new AsyncValueFactoryArg(valueFactory, factoryArgument)).ConfigureAwait(false); - } +#if NET9_0_OR_GREATER + public ValueTask GetValueAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public ValueTask GetValueAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { + if (initializer == null) + { + return new ValueTask(value!); + } + +#if NET9_0_OR_GREATER + return CreateValueAsync(key, new RefAsyncValueFactoryArg(valueFactory, factoryArgument)); +#else + return CreateValueAsync(key, new AsyncValueFactoryArg(valueFactory, factoryArgument)); +#endif + } /// /// Gets a value indicating whether the value has been initialized. @@ -122,52 +131,95 @@ public override int GetHashCode() return ValueIfCreated!.GetHashCode(); } - private async ValueTask CreateValueAsync(K key, TFactory valueFactory) where TFactory : struct, IAsyncValueFactory - { - var init = Volatile.Read(ref initializer); - - if (init != null) - { - value = await init.CreateValueAsync(key, valueFactory).ConfigureAwait(false); - Volatile.Write(ref initializer, null); - } - - return value!; - } + private ValueTask CreateValueAsync(K key, TFactory valueFactory) +#if NET9_0_OR_GREATER + where TFactory : struct, IAsyncValueFactory, allows ref struct +#else + where TFactory : struct, IAsyncValueFactory +#endif + { + var init = Volatile.Read(ref initializer); + + if (init != null) + { + var createdValue = init.CreateValueAsync(key, valueFactory); + + if (createdValue.IsCompletedSuccessfully) + { + value = createdValue.Result; + Volatile.Write(ref initializer, null); + } + else + { + return AwaitCreatedValueAsync(createdValue); + } + } + + return new ValueTask(value!); + } + + private async ValueTask AwaitCreatedValueAsync(ValueTask createdValue) + { + value = await createdValue.ConfigureAwait(false); + Volatile.Write(ref initializer, null); + return value!; + } private class Initializer { private bool isInitialized; private Task? valueTask; - public async ValueTask CreateValueAsync(K key, TFactory valueFactory) where TFactory : struct, IAsyncValueFactory - { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var synchronizedTask = DoubleCheck(tcs.Task); - - if (ReferenceEquals(synchronizedTask, tcs.Task)) - { - try - { - var value = await valueFactory.CreateAsync(key).ConfigureAwait(false); - tcs.SetResult(value); - - return value; - } - catch (Exception ex) - { - Volatile.Write(ref isInitialized, false); - tcs.SetException(ex); - - // always await the task to avoid unobserved task exceptions - normal case is that no other thread is waiting. - // this will re-throw the exception. - await tcs.Task.ConfigureAwait(false); - } - } - - return await synchronizedTask.ConfigureAwait(false); - } + public ValueTask CreateValueAsync(K key, TFactory valueFactory) +#if NET9_0_OR_GREATER + where TFactory : struct, IAsyncValueFactory, allows ref struct +#else + where TFactory : struct, IAsyncValueFactory +#endif + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var synchronizedTask = DoubleCheck(tcs.Task); + + if (ReferenceEquals(synchronizedTask, tcs.Task)) + { + Task factoryTask; + + try + { + factoryTask = valueFactory.CreateAsync(key); + } + catch (Exception ex) + { + Volatile.Write(ref isInitialized, false); + tcs.SetException(ex); + return new ValueTask(tcs.Task); + } + + return CompleteSynchronizedTaskAsync(factoryTask, tcs); + } + + return new ValueTask(synchronizedTask); + } + + private async ValueTask CompleteSynchronizedTaskAsync(Task factoryTask, TaskCompletionSource tcs) + { + try + { + var createdValue = await factoryTask.ConfigureAwait(false); + tcs.SetResult(createdValue); + return createdValue; + } + catch (Exception ex) + { + Volatile.Write(ref isInitialized, false); + tcs.SetException(ex); + + // always await the task to avoid unobserved task exceptions - normal case is that no other thread is waiting. + // this will re-throw the exception. + return await tcs.Task.ConfigureAwait(false); + } + } #pragma warning disable CA2002 // Do not lock on objects with weak identity private Task DoubleCheck(Task value) diff --git a/BitFaster.Caching/Atomic/AtomicFactory.cs b/BitFaster.Caching/Atomic/AtomicFactory.cs index 8aaf8e3f..73179352 100644 --- a/BitFaster.Caching/Atomic/AtomicFactory.cs +++ b/BitFaster.Caching/Atomic/AtomicFactory.cs @@ -53,8 +53,8 @@ public V GetValue(K key, Func valueFactory) } return CreateValue(key, new ValueFactory(valueFactory)); - } - + } + /// /// Gets the value. If is false, calling will force initialization via the parameter. /// @@ -63,15 +63,24 @@ public V GetValue(K key, Func valueFactory) /// The value factory to use to create the value when it is not initialized. /// The value factory argument. /// The value. - public V GetValue(K key, Func valueFactory, TArg factoryArgument) - { - if (initializer == null) - { - return value!; - } - - return CreateValue(key, new ValueFactoryArg(valueFactory, factoryArgument)); - } +#if NET9_0_OR_GREATER + public V GetValue(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetValue(K key, Func valueFactory, TArg factoryArgument) +#endif + { + if (initializer == null) + { + return value!; + } + +#if NET9_0_OR_GREATER + return CreateValue(key, new RefValueFactoryArg(valueFactory, factoryArgument)); +#else + return CreateValue(key, new ValueFactoryArg(valueFactory, factoryArgument)); +#endif + } /// /// Gets a value indicating whether the value has been initialized. @@ -106,8 +115,13 @@ public V? ValueIfCreated /// This mitigates lock convoys where many queued threads will fail slowly one by one, introducing delays /// and multiplying the number of calls to the failing resource. /// - private V CreateValue(K key, TFactory valueFactory) where TFactory : struct, IValueFactory - { + private V CreateValue(K key, TFactory valueFactory) +#if NET9_0_OR_GREATER + where TFactory : struct, IValueFactory, allows ref struct +#else + where TFactory : struct, IValueFactory +#endif + { var init = Volatile.Read(ref initializer); if (init != null) @@ -163,8 +177,13 @@ private class Initializer private V? value; private ExceptionDispatchInfo? exceptionDispatch; - public V CreateValue(K key, TFactory valueFactory) where TFactory : struct, IValueFactory - { + public V CreateValue(K key, TFactory valueFactory) +#if NET9_0_OR_GREATER + where TFactory : struct, IValueFactory, allows ref struct +#else + where TFactory : struct, IValueFactory +#endif + { lock (this) { if (isInitialized) diff --git a/BitFaster.Caching/Atomic/AtomicFactoryAsyncCache.cs b/BitFaster.Caching/Atomic/AtomicFactoryAsyncCache.cs index fcfe26c8..04569399 100644 --- a/BitFaster.Caching/Atomic/AtomicFactoryAsyncCache.cs +++ b/BitFaster.Caching/Atomic/AtomicFactoryAsyncCache.cs @@ -79,8 +79,8 @@ public ValueTask GetOrAddAsync(K key, Func> valueFactory) { var synchronized = cache.GetOrAdd(key, _ => new AsyncAtomicFactory()); return synchronized.GetValueAsync(key, valueFactory); - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -90,10 +90,15 @@ public ValueTask GetOrAddAsync(K key, Func> valueFactory) /// The factory function used to asynchronously generate a value for the key. /// An argument value to pass into valueFactory. /// A task that represents the asynchronous GetOrAdd operation. - public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - var synchronized = cache.GetOrAdd(key, _ => new AsyncAtomicFactory()); - return synchronized.GetValueAsync(key, valueFactory, factoryArgument); +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { + var synchronized = cache.GetOrAdd(key, _ => new AsyncAtomicFactory()); + return synchronized.GetValueAsync(key, valueFactory, factoryArgument); } /// @@ -237,9 +242,10 @@ public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFacto return synchronized.GetValueAsync(actualKey, valueFactory); } - public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) - { - var synchronized = inner.GetOrAdd(key, _ => new AsyncAtomicFactory()); + public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { + var synchronized = inner.GetOrAdd(key, _ => new AsyncAtomicFactory()); if (synchronized.IsValueCreated) { diff --git a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs index 5a0e6cec..9669c0bb 100644 --- a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs +++ b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs @@ -79,8 +79,8 @@ public V GetOrAdd(K key, Func valueFactory) { var atomicFactory = cache.GetOrAdd(key, _ => new AtomicFactory()); return atomicFactory.GetValue(key, valueFactory); - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -91,10 +91,15 @@ public V GetOrAdd(K key, Func valueFactory) /// An argument value to pass into valueFactory. /// The value for the key. This will be either the existing value for the key if the key is already /// in the cache, or the new value if the key was not in the cache. - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { - var atomicFactory = cache.GetOrAdd(key, _ => new AtomicFactory()); - return atomicFactory.GetValue(key, valueFactory, factoryArgument); +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { + var atomicFactory = cache.GetOrAdd(key, _ => new AtomicFactory()); + return atomicFactory.GetValue(key, valueFactory, factoryArgument); } /// @@ -255,9 +260,10 @@ public V GetOrAdd(TAlternateKey key, Func valueFactory) return atomicFactory.GetValue(actualKey, valueFactory); } - public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) - { - var atomicFactory = inner.GetOrAdd(key, _ => new AtomicFactory()); + public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { + var atomicFactory = inner.GetOrAdd(key, _ => new AtomicFactory()); if (atomicFactory.IsValueCreated) { diff --git a/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs b/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs index f332435f..ed2692bf 100644 --- a/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs +++ b/BitFaster.Caching/Atomic/ConcurrentDictionaryExtensions.cs @@ -33,9 +33,12 @@ public static V GetOrAdd(this ConcurrentDictionary> /// The function used to generate a value for the key. /// An argument value to pass into valueFactory. /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary. - public static V GetOrAdd(this ConcurrentDictionary> dictionary, K key, Func valueFactory, TArg factoryArgument) - where K : notnull - { + public static V GetOrAdd(this ConcurrentDictionary> dictionary, K key, Func valueFactory, TArg factoryArgument) + where K : notnull +#if NET9_0_OR_GREATER + where TArg : allows ref struct +#endif + { var atomicFactory = dictionary.GetOrAdd(key, _ => new AtomicFactory()); return atomicFactory.GetValue(key, valueFactory, factoryArgument); } @@ -62,9 +65,12 @@ public static ValueTask GetOrAddAsync(this ConcurrentDictionaryThe function used to generate a value for the key. /// An argument value to pass into valueFactory. /// The value for the key. This will be either the existing value for the key if the key is already in the dictionary, or the new value if the key was not in the dictionary. - public static ValueTask GetOrAddAsync(this ConcurrentDictionary> dictionary, K key, Func> valueFactory, TArg factoryArgument) - where K : notnull - { + public static ValueTask GetOrAddAsync(this ConcurrentDictionary> dictionary, K key, Func> valueFactory, TArg factoryArgument) + where K : notnull +#if NET9_0_OR_GREATER + where TArg : allows ref struct +#endif + { var asyncAtomicFactory = dictionary.GetOrAdd(key, _ => new AsyncAtomicFactory()); return asyncAtomicFactory.GetValueAsync(key, valueFactory, factoryArgument); } diff --git a/BitFaster.Caching/IAlternateLookup.cs b/BitFaster.Caching/IAlternateLookup.cs index 1c8a1778..beb67e11 100644 --- a/BitFaster.Caching/IAlternateLookup.cs +++ b/BitFaster.Caching/IAlternateLookup.cs @@ -62,7 +62,8 @@ public interface IAlternateLookup /// The value factory, invoked with the actual cache key when a value must be created. /// The factory argument. /// The cached value. - TValue GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument); + TValue GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct; } } #endif diff --git a/BitFaster.Caching/IAsyncAlternateLookup.cs b/BitFaster.Caching/IAsyncAlternateLookup.cs index a8b2efaf..3ddc131c 100644 --- a/BitFaster.Caching/IAsyncAlternateLookup.cs +++ b/BitFaster.Caching/IAsyncAlternateLookup.cs @@ -63,7 +63,8 @@ public interface IAsyncAlternateLookup /// The factory function used to asynchronously generate a value, invoked with the actual cache key. /// An argument value to pass into valueFactory. /// A task that represents the asynchronous GetOrAdd operation. - ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument); + ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct; } } #endif diff --git a/BitFaster.Caching/IAsyncCache.cs b/BitFaster.Caching/IAsyncCache.cs index ce65cc18..0d776d2f 100644 --- a/BitFaster.Caching/IAsyncCache.cs +++ b/BitFaster.Caching/IAsyncCache.cs @@ -73,7 +73,12 @@ public interface IAsyncCache : IEnumerable> /// An argument value to pass into valueFactory. /// A task that represents the asynchronous GetOrAdd operation. /// The default implementation given here is the fallback that provides backwards compatibility for classes that implement ICache on prior versions - ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) => this.GetOrAddAsync(key, k => valueFactory(k, factoryArgument)); +#if NET9_0_OR_GREATER + ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct; +#else + ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) => this.GetOrAddAsync(key, k => valueFactory(k, factoryArgument)); +#endif /// /// Attempts to remove and return the value that has the specified key. diff --git a/BitFaster.Caching/ICache.cs b/BitFaster.Caching/ICache.cs index 97068348..adbddc7e 100644 --- a/BitFaster.Caching/ICache.cs +++ b/BitFaster.Caching/ICache.cs @@ -74,7 +74,12 @@ public interface ICache : IEnumerable> /// The value for the key. This will be either the existing value for the key if the key is already /// in the cache, or the new value if the key was not in the cache. /// The default implementation given here is the fallback that provides backwards compatibility for classes that implement ICache on prior versions - V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) => this.GetOrAdd(key, k => valueFactory(k, factoryArgument)); +#if NET9_0_OR_GREATER + V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct; +#else + V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) => this.GetOrAdd(key, k => valueFactory(k, factoryArgument)); +#endif /// /// Attempts to remove and return the value that has the specified key. diff --git a/BitFaster.Caching/IValueFactory.cs b/BitFaster.Caching/IValueFactory.cs index e4550931..9d26ff53 100644 --- a/BitFaster.Caching/IValueFactory.cs +++ b/BitFaster.Caching/IValueFactory.cs @@ -71,6 +71,26 @@ public V Create(K key) } } +#if NET9_0_OR_GREATER + internal readonly ref struct RefValueFactoryArg : IValueFactory + where TArg : allows ref struct + { + private readonly Func factory; + private readonly TArg arg; + + public RefValueFactoryArg(Func factory, TArg arg) + { + this.factory = factory; + this.arg = arg; + } + + public V Create(K key) + { + return this.factory(key, this.arg); + } + } +#endif + /// /// Represents an async cache value factory. /// @@ -137,4 +157,24 @@ public Task CreateAsync(K key) return this.factory(key, arg); } } + +#if NET9_0_OR_GREATER + internal readonly ref struct RefAsyncValueFactoryArg : IAsyncValueFactory + where TArg : allows ref struct + { + private readonly Func> factory; + private readonly TArg arg; + + public RefAsyncValueFactoryArg(Func> factory, TArg arg) + { + this.factory = factory; + this.arg = arg; + } + + public Task CreateAsync(K key) + { + return this.factory(key, this.arg); + } + } +#endif } diff --git a/BitFaster.Caching/Lfu/ConcurrentLfu.cs b/BitFaster.Caching/Lfu/ConcurrentLfu.cs index 627cd913..2beae4c1 100644 --- a/BitFaster.Caching/Lfu/ConcurrentLfu.cs +++ b/BitFaster.Caching/Lfu/ConcurrentLfu.cs @@ -127,25 +127,35 @@ public void Clear() public V GetOrAdd(K key, Func valueFactory) { return core.GetOrAdd(key, valueFactory); - } - + } + /// - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { - return core.GetOrAdd(key, valueFactory, factoryArgument); - } +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { + return core.GetOrAdd(key, valueFactory, factoryArgument); + } /// public ValueTask GetOrAddAsync(K key, Func> valueFactory) { return core.GetOrAddAsync(key, valueFactory); - } - + } + /// - public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - return core.GetOrAddAsync(key, valueFactory, factoryArgument); - } +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { + return core.GetOrAddAsync(key, valueFactory, factoryArgument); + } /// public void Trim(int itemCount) diff --git a/BitFaster.Caching/Lfu/ConcurrentLfuCore.cs b/BitFaster.Caching/Lfu/ConcurrentLfuCore.cs index abf1fc84..e0efc95c 100644 --- a/BitFaster.Caching/Lfu/ConcurrentLfuCore.cs +++ b/BitFaster.Caching/Lfu/ConcurrentLfuCore.cs @@ -216,8 +216,13 @@ public V GetOrAdd(K key, Func valueFactory) } } - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { while (true) { if (this.TryGet(key, out V? value)) @@ -250,22 +255,56 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) } } - public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - while (true) - { - if (this.TryGet(key, out V? value)) - { - return value; - } - - value = await valueFactory(key, factoryArgument).ConfigureAwait(false); - if (this.TryAdd(key, value)) - { - return value; - } - } - } +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { +#if NET9_0_OR_GREATER + if (this.TryGet(key, out V? value)) + { + return new ValueTask(value); + } + + return GetOrAddAsyncCore(key, valueFactory(key, factoryArgument)); +#else + while (true) + { + if (this.TryGet(key, out V? value)) + { + return value; + } + + value = await valueFactory(key, factoryArgument).ConfigureAwait(false); + if (this.TryAdd(key, value)) + { + return value; + } + } +#endif + } + +#if NET9_0_OR_GREATER + private async ValueTask GetOrAddAsyncCore(K key, Task valueTask) + { + var value = await valueTask.ConfigureAwait(false); + + while (true) + { + if (this.TryAdd(key, value)) + { + return value; + } + + if (this.TryGet(key, out V? existing)) + { + return existing; + } + } + } +#endif public bool TryGet(K key, [MaybeNullWhen(false)] out V value) { @@ -1130,8 +1169,9 @@ public V GetOrAdd(TAlternateKey key, Func valueFactory) } } - public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) - { + public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { while (true) { if (this.TryGet(key, out var value)) @@ -1162,8 +1202,9 @@ public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFacto return GetOrAddAsyncSlow(actualKey, task); } - public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) - { + public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { if (this.TryGet(key, out var value)) { return new ValueTask(value); diff --git a/BitFaster.Caching/Lfu/ConcurrentTLfu.cs b/BitFaster.Caching/Lfu/ConcurrentTLfu.cs index 363885de..5d2a4844 100644 --- a/BitFaster.Caching/Lfu/ConcurrentTLfu.cs +++ b/BitFaster.Caching/Lfu/ConcurrentTLfu.cs @@ -78,25 +78,35 @@ public void Clear() public V GetOrAdd(K key, Func valueFactory) { return core.GetOrAdd(key, valueFactory); - } - + } + /// - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { - return core.GetOrAdd(key, valueFactory, factoryArgument); - } +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { + return core.GetOrAdd(key, valueFactory, factoryArgument); + } /// public ValueTask GetOrAddAsync(K key, Func> valueFactory) { return core.GetOrAddAsync(key, valueFactory); - } - + } + /// - public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - return core.GetOrAddAsync(key, valueFactory, factoryArgument); - } +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { + return core.GetOrAddAsync(key, valueFactory, factoryArgument); + } /// public void Trim(int itemCount) diff --git a/BitFaster.Caching/Lru/ClassicLru.cs b/BitFaster.Caching/Lru/ClassicLru.cs index b132ca40..44716c08 100644 --- a/BitFaster.Caching/Lru/ClassicLru.cs +++ b/BitFaster.Caching/Lru/ClassicLru.cs @@ -172,8 +172,8 @@ public V GetOrAdd(K key, Func valueFactory) } return this.GetOrAdd(key, valueFactory); - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -184,8 +184,13 @@ public V GetOrAdd(K key, Func valueFactory) /// An argument value to pass into valueFactory. /// The value for the key. This will be either the existing value for the key if the key is already /// in the cache, or the new value if the key was not in the cache. - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { if (this.TryGet(key, out var value)) { return value; @@ -217,8 +222,8 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) } return await this.GetOrAddAsync(key, valueFactory).ConfigureAwait(false); - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -228,22 +233,51 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) /// The factory function used to asynchronously generate a value for the key. /// An argument value to pass into valueFactory. /// A task that represents the asynchronous GetOrAdd operation. - public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - if (this.TryGet(key, out var value)) - { - return value; - } - - value = await valueFactory(key, factoryArgument).ConfigureAwait(false); - - if (TryAdd(key, value)) - { - return value; - } - - return await this.GetOrAddAsync(key, valueFactory, factoryArgument).ConfigureAwait(false); - } +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { + if (this.TryGet(key, out var value)) + { + return new ValueTask(value); + } + +#if NET9_0_OR_GREATER + return GetOrAddAsyncCore(key, valueFactory(key, factoryArgument)); +#else + value = await valueFactory(key, factoryArgument).ConfigureAwait(false); + + if (TryAdd(key, value)) + { + return value; + } + + return await this.GetOrAddAsync(key, valueFactory, factoryArgument).ConfigureAwait(false); +#endif + } + +#if NET9_0_OR_GREATER + private async ValueTask GetOrAddAsyncCore(K key, Task valueTask) + { + var value = await valueTask.ConfigureAwait(false); + + while (true) + { + if (TryAdd(key, value)) + { + return value; + } + + if (this.TryGet(key, out var existing)) + { + return existing; + } + } + } +#endif /// /// Attempts to remove the specified key value pair. @@ -622,8 +656,9 @@ public V GetOrAdd(TAlternateKey key, Func valueFactory) } } - public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) - { + public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { while (true) { if (this.TryGet(key, out var value)) @@ -654,8 +689,9 @@ public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFacto return GetOrAddAsyncSlow(actualKey, task); } - public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) - { + public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { if (this.TryGet(key, out var value)) { return new ValueTask(value); diff --git a/BitFaster.Caching/Lru/ConcurrentLruCore.cs b/BitFaster.Caching/Lru/ConcurrentLruCore.cs index cd8f6cd8..ebc321e2 100644 --- a/BitFaster.Caching/Lru/ConcurrentLruCore.cs +++ b/BitFaster.Caching/Lru/ConcurrentLruCore.cs @@ -224,8 +224,8 @@ public V GetOrAdd(K key, Func valueFactory) return value; } } - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -236,8 +236,13 @@ public V GetOrAdd(K key, Func valueFactory) /// An argument value to pass into valueFactory. /// The value for the key. This will be either the existing value for the key if the key is already /// in the cache, or the new value if the key was not in the cache. - public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) - { +#if NET9_0_OR_GREATER + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public V GetOrAdd(K key, Func valueFactory, TArg factoryArgument) +#endif + { while (true) { if (this.TryGet(key, out var value)) @@ -274,8 +279,8 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) return value; } } - } - + } + /// /// Adds a key/value pair to the cache if the key does not already exist. Returns the new value, or the /// existing value if the key already exists. @@ -285,24 +290,58 @@ public async ValueTask GetOrAddAsync(K key, Func> valueFactory) /// The factory function used to asynchronously generate a value for the key. /// An argument value to pass into valueFactory. /// A task that represents the asynchronous GetOrAdd operation. - public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) - { - while (true) - { - if (this.TryGet(key, out var value)) - { - return value; - } - - // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. - value = await valueFactory(key, factoryArgument).ConfigureAwait(false); - - if (TryAdd(key, value)) - { - return value; - } - } - } +#if NET9_0_OR_GREATER + public ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct +#else + public async ValueTask GetOrAddAsync(K key, Func> valueFactory, TArg factoryArgument) +#endif + { +#if NET9_0_OR_GREATER + if (this.TryGet(key, out var value)) + { + return new ValueTask(value); + } + + return GetOrAddAsyncCore(key, valueFactory(key, factoryArgument)); +#else + while (true) + { + if (this.TryGet(key, out var value)) + { + return value; + } + + // The value factory may be called concurrently for the same key, but the first write to the dictionary wins. + value = await valueFactory(key, factoryArgument).ConfigureAwait(false); + + if (TryAdd(key, value)) + { + return value; + } + } +#endif + } + +#if NET9_0_OR_GREATER + private async ValueTask GetOrAddAsyncCore(K key, Task valueTask) + { + var value = await valueTask.ConfigureAwait(false); + + while (true) + { + if (TryAdd(key, value)) + { + return value; + } + + if (this.TryGet(key, out var existing)) + { + return existing; + } + } + } +#endif /// /// Attempts to remove the specified key value pair. @@ -1056,8 +1095,9 @@ public V GetOrAdd(TAlternateKey key, Func valueFactory) } } - public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) - { + public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { while (true) { if (this.TryGet(key, out var value)) @@ -1088,8 +1128,9 @@ public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFacto return GetOrAddAsyncSlow(actualKey, task); } - public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) - { + public ValueTask GetOrAddAsync(TAlternateKey key, Func> valueFactory, TArg factoryArgument) + where TArg : allows ref struct + { if (this.TryGet(key, out var value)) { return new ValueTask(value);