Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/DnsClient/DnsMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal enum DnsMessageHandleType
TCP
}

internal abstract class DnsMessageHandler
internal abstract class DnsMessageHandler : IDisposable
{
public abstract DnsMessageHandleType Type { get; }

Expand Down Expand Up @@ -174,5 +174,16 @@ public virtual DnsResponseMessage GetResponseMessage(ArraySegment<byte> response

return response;
}

protected virtual void Dispose(bool disposing)
{
// Nothing to do in base class.
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
}
}
160 changes: 137 additions & 23 deletions src/DnsClient/DnsTcpMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ namespace DnsClient
{
internal class DnsTcpMessageHandler : DnsMessageHandler
{
private bool _disposedValue = false;
private readonly ConcurrentDictionary<IPEndPoint, ClientPool> _pools = new ConcurrentDictionary<IPEndPoint, ClientPool>();

public override DnsMessageHandleType Type { get; } = DnsMessageHandleType.TCP;

public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage request, TimeSpan timeout)
{
CancellationToken cancellationToken = default;
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

using var cts = timeout.TotalMilliseconds != Timeout.Infinite && timeout.TotalMilliseconds < int.MaxValue ?
new CancellationTokenSource(timeout) : null;

cancellationToken = cts?.Token ?? default;
var cancellationToken = cts?.Token ?? default;

ClientPool pool;
while (!_pools.TryGetValue(server, out pool))
Expand All @@ -36,7 +40,7 @@ public override DnsResponseMessage Query(IPEndPoint server, DnsRequestMessage re

cancellationToken.ThrowIfCancellationRequested();

var entry = pool.GetNextClient();
var entry = pool.GetNextClient(cancellationToken);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -79,6 +83,11 @@ public override async Task<DnsResponseMessage> QueryAsync(
DnsRequestMessage request,
CancellationToken cancellationToken)
{
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(DnsTcpMessageHandler));
}

cancellationToken.ThrowIfCancellationRequested();

ClientPool pool;
Expand All @@ -87,7 +96,7 @@ public override async Task<DnsResponseMessage> QueryAsync(
_pools.TryAdd(server, new ClientPool(true, server));
}

var entry = pool.GetNextClient();
var entry = await pool.GetNextClientAsync(cancellationToken).ConfigureAwait(false);

using var cancelCallback = cancellationToken.Register(() =>
{
Expand Down Expand Up @@ -301,9 +310,24 @@ private async Task<DnsResponseMessage> QueryAsyncInternal(TcpClient client, DnsR
return DnsResponseMessage.Combine(responses);
}

private class ClientPool : IDisposable
protected override void Dispose(bool disposing)
{
if (disposing && !_disposedValue)
{
_disposedValue = true;

foreach (var entry in _pools)
{
entry.Value.Dispose();
}
}

base.Dispose(disposing);
}

private sealed class ClientPool : IDisposable
{
private bool _disposedValue;
private bool _disposedValue = false;
private readonly bool _enablePool;
private ConcurrentQueue<ClientEntry> _clients = new ConcurrentQueue<ClientEntry>();
private readonly IPEndPoint _endpoint;
Expand All @@ -314,7 +338,66 @@ public ClientPool(bool enablePool, IPEndPoint endpoint)
_endpoint = endpoint;
}

public ClientEntry GetNextClient()
public ClientEntry GetNextClient(CancellationToken cancellationToken)
{
if (_disposedValue)
{
throw new ObjectDisposedException(nameof(ClientPool));
}

ClientEntry entry = null;
if (_enablePool)
{
while (entry == null && !TryDequeue(out entry))
{
entry = ConnectNew(cancellationToken);
}
}
else
{
entry = ConnectNew(cancellationToken);
}

return entry;
}

private ClientEntry ConnectNew(CancellationToken cancellationToken)
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
newClient.Connect(_endpoint.Address, _endpoint.Port);
}
catch (Exception) when (gotCanceled)
{
throw new OperationCanceledException("Connection timed out.", cancellationToken);
}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}

return new ClientEntry(newClient, _endpoint);
}

public async Task<ClientEntry> GetNextClientAsync(CancellationToken cancellationToken)
{
if (_disposedValue)
{
Expand All @@ -326,17 +409,57 @@ public ClientEntry GetNextClient()
{
while (entry == null && !TryDequeue(out entry))
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily) { LingerState = new LingerOption(true, 0) }, _endpoint);
entry = await ConnectNewAsync(cancellationToken).ConfigureAwait(false);
}
}
else
{
entry = new ClientEntry(new TcpClient(_endpoint.AddressFamily), _endpoint);
entry = await ConnectNewAsync(cancellationToken).ConfigureAwait(false);
}

return entry;
}

private async Task<ClientEntry> ConnectNewAsync(CancellationToken cancellationToken)
{
var newClient = new TcpClient(_endpoint.AddressFamily)
{
LingerState = new LingerOption(true, 0)
};

#if NET6_0_OR_GREATER
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port, cancellationToken).ConfigureAwait(false);
#else

bool gotCanceled = false;
cancellationToken.Register(() =>
{
gotCanceled = true;
newClient.Dispose();
});

try
{
await newClient.ConnectAsync(_endpoint.Address, _endpoint.Port).ConfigureAwait(false);
}
catch (Exception) when (gotCanceled)
{
throw new OperationCanceledException("Connection timed out.", cancellationToken);
}
catch (Exception)
{
try
{
newClient.Dispose();
}
catch { }

throw;
}
#endif
return new ClientEntry(newClient, _endpoint);
}

public void Enqueue(ClientEntry entry)
{
if (_disposedValue)
Expand Down Expand Up @@ -390,29 +513,20 @@ public bool TryDequeue(out ClientEntry entry)
return result;
}

protected virtual void Dispose(bool disposing)
public void Dispose()
{
if (!_disposedValue)
{
if (disposing)
_disposedValue = true;
foreach (var entry in _clients)
{
foreach (var entry in _clients)
{
entry.DisposeClient();
}

_clients = new ConcurrentQueue<ClientEntry>();
entry.DisposeClient();
}

_disposedValue = true;
_clients = new ConcurrentQueue<ClientEntry>();
}
}

public void Dispose()
{
Dispose(true);
}

public class ClientEntry
{
public ClientEntry(TcpClient client, IPEndPoint endpoint)
Expand Down
20 changes: 17 additions & 3 deletions src/DnsClient/LookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace DnsClient
/// ]]>
/// </code>
/// </example>
public class LookupClient : ILookupClient, IDnsQuery
public sealed class LookupClient : ILookupClient, IDnsQuery, IDisposable
{
private const int LogEventStartQuery = 1;
private const int LogEventQuery = 2;
Expand All @@ -66,6 +66,7 @@ public class LookupClient : ILookupClient, IDnsQuery
private readonly SkipWorker _skipper;

private IReadOnlyCollection<NameServer> _resolvedNameServers;
private bool _disposedValue;

/// <inheritdoc/>
public IReadOnlyCollection<NameServer> NameServers => Settings.NameServers;
Expand Down Expand Up @@ -212,7 +213,7 @@ internal LookupClient(LookupClientOptions options, DnsMessageHandler udpHandler

// Setting up name servers.
// Using manually configured ones and/or auto resolved ones.
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? new NameServer[0];
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers?.ToArray() ?? Array.Empty<NameServer>();

if (options.AutoResolveNameServers)
{
Expand Down Expand Up @@ -269,7 +270,9 @@ private void CheckResolvedNameservers()
}

_resolvedNameServers = newServers;
var servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
IReadOnlyCollection<NameServer> servers = _originalOptions.NameServers.Concat(_resolvedNameServers).ToArray();
servers = NameServer.ValidateNameServers(servers, _logger);

Settings = new LookupClientSettings(_originalOptions, servers);
}
catch (Exception ex)
Expand Down Expand Up @@ -1634,6 +1637,17 @@ public void MaybeDoWork()
}
}
}

/// <inheritdoc/>
public void Dispose()
{
if (!_disposedValue)
{
_disposedValue = true;
_tcpFallbackHandler?.Dispose();
_messageHandler?.Dispose();
}
}
}

internal class LookupClientAudit
Expand Down
4 changes: 2 additions & 2 deletions test-other/OldReference/TestLookupClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void TestOtherTypes()

public void TestQuery_1_1()
{
var client = new LookupClient();
var client = new LookupClient(NameServer.GooglePublicDns.Address);
client.Query("domain", QueryType.A);
client.Query("domain", QueryType.A, QueryClass.IN);
client.QueryReverse(IPAddress.Loopback);
Expand All @@ -76,7 +76,7 @@ public void TestQuery_1_1()

public async Task TestQueryAsync_1_1()
{
var client = new LookupClient();
var client = new LookupClient(NameServer.GooglePublicDns.Address);
await client.QueryAsync("domain", QueryType.A).ConfigureAwait(false);
await client.QueryAsync("domain", QueryType.A, QueryClass.IN).ConfigureAwait(false);
await client.QueryAsync("domain", QueryType.A, cancellationToken: default).ConfigureAwait(false);
Expand Down
Loading