Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,9 +853,12 @@ async def _reset(
# publishing the PoolClearedEvent.
if close:
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
# Shield the closing of connections to avoid leaks
await asyncio.shield(
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
)
else:
for conn in sockets:
Expand Down Expand Up @@ -890,9 +893,12 @@ async def _reset(
interrupt_connections=interrupt_connections,
)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
# Shield the closing of connections to avoid leaks
await asyncio.shield(
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
)
else:
for conn in sockets:
Expand Down Expand Up @@ -1065,34 +1071,43 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
raise

conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise

if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
await conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
Comment on lines +1103 to +1110
Comment on lines +1103 to +1110

@contextlib.asynccontextmanager
async def checkout(
Expand Down
93 changes: 55 additions & 38 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock = socket.socket(af, socktype, proto)
# Fallback when SOCK_CLOEXEC isn't available.
_set_non_inheritable_non_atomic(sock.fileno())
sock_returned = False
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# CSOT: apply timeout to socket connect.
Expand All @@ -223,14 +224,18 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout
)
sock.settimeout(timeout)
# Set immediately before return. Do not insert an await between this and the return
sock_returned = True
return sock
except asyncio.TimeoutError as e:
sock.close()
err = socket.timeout("timed out")
err.__cause__ = e
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
finally:
# Always close the socket if it wasn't returned to avoid leaks.
if not sock_returned:
sock.close()
Comment on lines +235 to +238

if err is not None:
raise err
Expand Down Expand Up @@ -309,46 +314,58 @@ async def _configured_protocol_interface(
sock = await _async_create_connection(address, options)
ssl_context = options._ssl_context
timeout = options.socket_timeout

if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)

host = address[0]
# Create the Protocol early to prevent asyncio resource leaks during cleanup path
protocol = PyMongoProtocol(timeout=timeout)
sock_adopted = False
try:
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: PyMongoProtocol(timeout=timeout),
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
except _CertificateError:
# Raise _CertificateError directly like we do after match_hostname
# below.
raise
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
if ssl_context is None:
result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock)
sock_adopted = True
return AsyncNetworkingInterface(result)

host = address[0]
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
# We have to pass hostname / ip address to wrap_socket
# to use SSLContext.check_hostname.
transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
lambda: protocol,
sock=sock,
server_hostname=host,
ssl=ssl_context,
)
sock_adopted = True
except _CertificateError:
transport.abort()
# Raise _CertificateError directly like we do after match_hostname
# below.
raise

return AsyncNetworkingInterface((transport, protocol))
except (OSError, *SSLErrors) as exc:
# We raise AutoReconnect for transient and permanent SSL handshake
# failures alike. Permanent handshake failures, like protocol
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(
address, exc, "SSL handshake failed: ", timeout_details=details
)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
finally:
if not sock_adopted:
# If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport
if protocol.transport is not None:
protocol.transport.abort()
# Otherwise the socket was never adopted, close it directly
else:
sock.close()


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
75 changes: 45 additions & 30 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,12 @@ def _reset(
# publishing the PoolClearedEvent.
if close:
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
# Shield the closing of connections to avoid leaks
asyncio.shield(
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
)
else:
for conn in sockets:
Expand Down Expand Up @@ -888,9 +891,12 @@ def _reset(
interrupt_connections=interrupt_connections,
)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
# Shield the closing of connections to avoid leaks
asyncio.shield(
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value]
return_exceptions=True,
)
)
else:
for conn in sockets:
Expand Down Expand Up @@ -1061,34 +1067,43 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
raise

conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
conn.close_conn(ConnectionClosedReason.ERROR)
raise

if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)
if handler:
handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
Comment on lines +1099 to +1106

@contextlib.contextmanager
def checkout(
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ filterwarnings = [
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/1032
"module:.*WindowsSelectorEventLoopPolicy:DeprecationWarning",
"module:.*et_event_loop_policy:DeprecationWarning",
# TODO: Remove as part of PYTHON-3923.
"module:unclosed <socket.socket:ResourceWarning",
"module:unclosed <ssl.SSLSocket:ResourceWarning",
"module:unclosed <socket object:ResourceWarning",
"module:unclosed transport:ResourceWarning",
# pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/724
"module:unclosed event loop:ResourceWarning",
# https://github.com/dateutil/dateutil/issues/1314
Expand Down
2 changes: 2 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,8 @@ def setup():

def teardown():
global_knobs.disable()
if client_context.client is not None:
client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")
Expand Down
2 changes: 2 additions & 0 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,8 @@ async def async_setup():

async def async_teardown():
global_knobs.disable()
if async_client_context.client is not None:
await async_client_context.client.close()
garbage = []
for g in gc.garbage:
garbage.append(f"GARBAGE: {g!r}")
Expand Down
1 change: 1 addition & 0 deletions test/asynchronous/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def create_pool(self, pair=None, *args, **kwargs):
kwargs["server_api"] = pool_options.server_api
pool = Pool(pair, PoolOptions(*args, **kwargs))
await pool.ready()
self.addAsyncCleanup(pool.close)
return pool


Expand Down
1 change: 1 addition & 0 deletions test/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def create_pool(self, pair=None, *args, **kwargs):
kwargs["server_api"] = pool_options.server_api
pool = Pool(pair, PoolOptions(*args, **kwargs))
pool.ready()
self.addCleanup(pool.close)
return pool


Expand Down
Loading