diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a5d5b28990..89157e9543 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -853,9 +853,12 @@ async def _reset( # publishing the PoolClearedEvent. if close: if not _IS_SYNC: - await asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] - return_exceptions=True, + # Shield the closing of connections to avoid leaks + await asyncio.shield( + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] + return_exceptions=True, + ) ) else: for conn in sockets: @@ -890,9 +893,12 @@ async def _reset( interrupt_connections=interrupt_connections, ) if not _IS_SYNC: - await asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] - return_exceptions=True, + # Shield the closing of connections to avoid leaks + await asyncio.shield( + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] + return_exceptions=True, + ) ) else: for conn in sockets: @@ -1065,34 +1071,43 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A raise conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] - async with self.lock: - self.active_contexts.add(conn.cancel_context) - self.active_contexts.discard(tmp_context) - if tmp_context.cancelled: - conn.cancel_context.cancel() - completed_hello = False try: - if not self.is_sdam: - await conn.hello() - completed_hello = True - self.is_writable = conn.is_writable - if handler: - handler.contribute_socket(conn, completed_handshake=False) - - await conn.authenticate() - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException as e: async with self.lock: - self.active_contexts.discard(conn.cancel_context) - if not completed_hello: - self._handle_connection_error(e) - await conn.close_conn(ConnectionClosedReason.ERROR) - raise + self.active_contexts.add(conn.cancel_context) + self.active_contexts.discard(tmp_context) + if tmp_context.cancelled: + conn.cancel_context.cancel() + completed_hello = False + try: + if not self.is_sdam: + await conn.hello() + completed_hello = True + self.is_writable = conn.is_writable + if handler: + handler.contribute_socket(conn, completed_handshake=False) + + await conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as e: + async with self.lock: + self.active_contexts.discard(conn.cancel_context) + if not completed_hello: + self._handle_connection_error(e) + await conn.close_conn(ConnectionClosedReason.ERROR) + raise - if handler: - await handler.client._topology.receive_cluster_time(conn._cluster_time) + if handler: + await handler.client._topology.receive_cluster_time(conn._cluster_time) - return conn + return conn + # Catch cancellations that interrupt outside the inner try block above + except BaseException: + if not conn.closed: + try: + await conn.close_conn(ConnectionClosedReason.ERROR) + except BaseException: # noqa: S110 + pass + raise @contextlib.asynccontextmanager async def checkout( diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index a6f434885b..2e4522ea3f 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -207,6 +207,7 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s sock = socket.socket(af, socktype, proto) # Fallback when SOCK_CLOEXEC isn't available. _set_non_inheritable_non_atomic(sock.fileno()) + sock_returned = False try: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # CSOT: apply timeout to socket connect. @@ -223,14 +224,18 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout ) sock.settimeout(timeout) + # Set immediately before return. Do not insert an await between this and the return + sock_returned = True return sock except asyncio.TimeoutError as e: - sock.close() err = socket.timeout("timed out") err.__cause__ = e except OSError as e: - sock.close() err = e # type: ignore[assignment] + finally: + # Always close the socket if it wasn't returned to avoid leaks. + if not sock_returned: + sock.close() if err is not None: raise err @@ -309,46 +314,58 @@ async def _configured_protocol_interface( sock = await _async_create_connection(address, options) ssl_context = options._ssl_context timeout = options.socket_timeout - - if ssl_context is None: - return AsyncNetworkingInterface( - await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock - ) - ) - - host = address[0] + # Create the Protocol early to prevent asyncio resource leaks during cleanup path + protocol = PyMongoProtocol(timeout=timeout) + sock_adopted = False try: - # We have to pass hostname / ip address to wrap_socket - # to use SSLContext.check_hostname. - transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] - lambda: PyMongoProtocol(timeout=timeout), - sock=sock, - server_hostname=host, - ssl=ssl_context, - ) - except _CertificateError: - # Raise _CertificateError directly like we do after match_hostname - # below. - raise - except (OSError, *SSLErrors) as exc: - # We raise AutoReconnect for transient and permanent SSL handshake - # failures alike. Permanent handshake failures, like protocol - # mismatch, will be turned into ServerSelectionTimeoutErrors later. - details = _get_timeout_details(options) - _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): + if ssl_context is None: + result = await asyncio.get_running_loop().create_connection(lambda: protocol, sock=sock) + sock_adopted = True + return AsyncNetworkingInterface(result) + + host = address[0] try: - ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + transport, _ = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] + lambda: protocol, + sock=sock, + server_hostname=host, + ssl=ssl_context, + ) + sock_adopted = True except _CertificateError: - transport.abort() + # Raise _CertificateError directly like we do after match_hostname + # below. raise - - return AsyncNetworkingInterface((transport, protocol)) + except (OSError, *SSLErrors) as exc: + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure( + address, exc, "SSL handshake failed: ", timeout_details=details + ) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] + except _CertificateError: + transport.abort() + raise + + return AsyncNetworkingInterface((transport, protocol)) + finally: + if not sock_adopted: + # If the protocol owns the transport, it also adopted the socket and needs to be cleaned up from the transport + if protocol.transport is not None: + protocol.transport.abort() + # Otherwise the socket was never adopted, close it directly + else: + sock.close() def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 25f2d08fe7..a3790cd9c5 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -851,9 +851,12 @@ def _reset( # publishing the PoolClearedEvent. if close: if not _IS_SYNC: - asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] - return_exceptions=True, + # Shield the closing of connections to avoid leaks + asyncio.shield( + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], # type: ignore[func-returns-value] + return_exceptions=True, + ) ) else: for conn in sockets: @@ -888,9 +891,12 @@ def _reset( interrupt_connections=interrupt_connections, ) if not _IS_SYNC: - asyncio.gather( - *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] - return_exceptions=True, + # Shield the closing of connections to avoid leaks + asyncio.shield( + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], # type: ignore[func-returns-value] + return_exceptions=True, + ) ) else: for conn in sockets: @@ -1061,34 +1067,43 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect raise conn = Connection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type] - with self.lock: - self.active_contexts.add(conn.cancel_context) - self.active_contexts.discard(tmp_context) - if tmp_context.cancelled: - conn.cancel_context.cancel() - completed_hello = False try: - if not self.is_sdam: - conn.hello() - completed_hello = True - self.is_writable = conn.is_writable - if handler: - handler.contribute_socket(conn, completed_handshake=False) - - conn.authenticate() - # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. - except BaseException as e: with self.lock: - self.active_contexts.discard(conn.cancel_context) - if not completed_hello: - self._handle_connection_error(e) - conn.close_conn(ConnectionClosedReason.ERROR) - raise + self.active_contexts.add(conn.cancel_context) + self.active_contexts.discard(tmp_context) + if tmp_context.cancelled: + conn.cancel_context.cancel() + completed_hello = False + try: + if not self.is_sdam: + conn.hello() + completed_hello = True + self.is_writable = conn.is_writable + if handler: + handler.contribute_socket(conn, completed_handshake=False) + + conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as e: + with self.lock: + self.active_contexts.discard(conn.cancel_context) + if not completed_hello: + self._handle_connection_error(e) + conn.close_conn(ConnectionClosedReason.ERROR) + raise - if handler: - handler.client._topology.receive_cluster_time(conn._cluster_time) + if handler: + handler.client._topology.receive_cluster_time(conn._cluster_time) - return conn + return conn + # Catch cancellations that interrupt outside the inner try block above + except BaseException: + if not conn.closed: + try: + conn.close_conn(ConnectionClosedReason.ERROR) + except BaseException: # noqa: S110 + pass + raise @contextlib.contextmanager def checkout( diff --git a/pyproject.toml b/pyproject.toml index 9b3287834a..dd8a4955d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,11 +108,6 @@ filterwarnings = [ # pytest-asyncio known issue: https://github.com/pytest-dev/pytest-asyncio/issues/1032 "module:.*WindowsSelectorEventLoopPolicy:DeprecationWarning", "module:.*et_event_loop_policy:DeprecationWarning", - # TODO: Remove as part of PYTHON-3923. - "module:unclosed