diff --git a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs index 3811d8c0..af00322d 100644 --- a/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs +++ b/BitFaster.Caching.UnitTests/Atomic/AtomicFactoryCacheAlternateLookupTests.cs @@ -77,21 +77,19 @@ public void AlternateLookupGetOrAddUsesActualKeyOnMissAndHit() var factoryCalls = 0; ReadOnlySpan key = "42"; - alternate.GetOrAdd(key, key => + alternate.GetOrAdd(key, k => { factoryCalls++; - return $"value-{key}"; + return $"value-{k}"; }).Should().Be("value-42"); - alternate.GetOrAdd(key, (_, prefix) => + alternate.GetOrAdd(key, k => { factoryCalls++; - return prefix; - }, "unused").Should().Be("value-42"); + return "1"; + }).Should().Be("value-42"); factoryCalls.Should().Be(1); - cache.TryGet("42", out var value).Should().BeTrue(); - value.Should().Be("value-42"); } [Fact] @@ -114,8 +112,6 @@ public void AlternateLookupGetOrAddWithArgUsesActualKeyOnMissAndHit() }, "unused").Should().Be("value-42"); factoryCalls.Should().Be(1); - cache.TryGet("42", out var value).Should().BeTrue(); - value.Should().Be("value-42"); } [Fact] diff --git a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs index ee47fb90..5a0e6cec 100644 --- a/BitFaster.Caching/Atomic/AtomicFactoryCache.cs +++ b/BitFaster.Caching/Atomic/AtomicFactoryCache.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; namespace BitFaster.Caching.Atomic { @@ -176,7 +177,8 @@ public IAlternateLookup GetAlternateLookup() where TAlternateKey : notnull, allows ref struct { var inner = cache.GetAlternateLookup(); - return new AlternateLookup(inner); + var comparer = (IAlternateEqualityComparer)cache.Comparer; + return new AlternateLookup(inner, comparer); } /// @@ -185,7 +187,8 @@ public bool TryGetAlternateLookup([MaybeNullWhen(false)] out IAlt { if (cache.TryGetAlternateLookup(out var inner)) { - lookup = new AlternateLookup(inner); + var comparer = (IAlternateEqualityComparer)cache.Comparer; + lookup = new AlternateLookup(inner, comparer); return true; } @@ -197,10 +200,12 @@ public bool TryGetAlternateLookup([MaybeNullWhen(false)] out IAlt where TAlternateKey : notnull, allows ref struct { private readonly IAlternateLookup> inner; + private readonly IAlternateEqualityComparer comparer; - internal AlternateLookup(IAlternateLookup> inner) + internal AlternateLookup(IAlternateLookup> inner, IAlternateEqualityComparer comparer) { this.inner = inner; + this.comparer = comparer; } public bool TryGet(TAlternateKey key, [MaybeNullWhen(false)] out V value) @@ -239,18 +244,28 @@ public void AddOrUpdate(TAlternateKey key, V value) public V GetOrAdd(TAlternateKey key, Func valueFactory) { - var atomicFactory = inner.GetOrAdd(key, - static (k, factory) => new AtomicFactory(factory(k)), - valueFactory); - return atomicFactory.ValueIfCreated!; + var atomicFactory = inner.GetOrAdd(key, _ => new AtomicFactory()); + + if (atomicFactory.IsValueCreated) + { + return atomicFactory.ValueIfCreated!; + } + + K actualKey = comparer.Create(key); + return atomicFactory.GetValue(actualKey, valueFactory); } public V GetOrAdd(TAlternateKey key, Func valueFactory, TArg factoryArgument) { - var atomicFactory = inner.GetOrAdd(key, - static (k, args) => new AtomicFactory(args.valueFactory(k, args.factoryArgument)), - (valueFactory, factoryArgument)); - return atomicFactory.ValueIfCreated!; + var atomicFactory = inner.GetOrAdd(key, _ => new AtomicFactory()); + + if (atomicFactory.IsValueCreated) + { + return atomicFactory.ValueIfCreated!; + } + + K actualKey = comparer.Create(key); + return atomicFactory.GetValue(actualKey, valueFactory, factoryArgument); } } #endif