diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 6aeea53f4c..f20631f208 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -802,6 +802,7 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") for entity in self._host: @@ -852,6 +853,8 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix") opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. @@ -889,7 +892,9 @@ def __init__( self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) - self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) self._opened = False self._closed = False @@ -907,6 +912,7 @@ async def _resolve_srv(self) -> None: opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, @@ -927,6 +933,7 @@ async def _resolve_srv(self) -> None: connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, ) seeds.update(res["nodelist"]) opts = res["options"] @@ -959,6 +966,8 @@ async def _resolve_srv(self) -> None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix") opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. @@ -968,10 +977,16 @@ async def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) def _init_based_on_options( - self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + self, + seeds: Collection[tuple[str, int]], + srv_max_hosts: Any, + srv_service_name: Any, + srv_allowed_hosts_suffix: Any, ) -> None: self._event_listeners = self._options.pool_options._event_listeners self._topology_settings = TopologySettings( @@ -990,6 +1005,7 @@ def _init_based_on_options( load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, server_monitoring_mode=self._options.server_monitoring_mode, topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 45c12b219f..a0ee5e50ac 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -418,6 +418,7 @@ async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._fqdn, self._settings.pool_options.connect_timeout, self._settings.srv_service_name, + srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix, ) seedlist, ttl = await resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index e3c2ee7fb3..0f3cfa2366 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -52,6 +52,7 @@ def __init__( load_balanced: Optional[bool] = None, srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, topology_id: Optional[ObjectId] = None, ): @@ -79,6 +80,7 @@ def __init__( self._load_balanced = load_balanced self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 + self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix self._server_monitoring_mode = server_monitoring_mode if topology_id is not None: self._topology_id = topology_id @@ -156,6 +158,11 @@ def srv_max_hosts(self) -> int: """The srvMaxHosts.""" return self._srv_max_hosts + @property + def srv_allowed_hosts_suffix(self) -> Optional[str]: + """The srvAllowedHostsSuffix.""" + return self._srv_allowed_hosts_suffix + @property def server_monitoring_mode(self) -> str: """The serverMonitoringMode.""" diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index 2d99ef16c8..51b8586707 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -71,11 +71,15 @@ def __init__( connect_timeout: Optional[float], srv_service_name: str, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, ): self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT self.__srv_max_hosts = srv_max_hosts or 0 + self.__srv_allowed_hosts_suffix = ( + "." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None + ) # ensure there's a . at the beginning of the domain # Validate the fully qualified domain name. try: ipaddress.ip_address(fqdn) @@ -135,12 +139,16 @@ async def _get_srv_response_and_hosts( raise ConfigurationError( "Invalid SRV host: return address is identical to SRV hostname" ) - try: - nlist = srv_host.split(".")[1:][-self.__slen :] - except Exception as exc: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_allowed_hosts_suffix is not None: + if not srv_host.endswith(self.__srv_allowed_hosts_suffix): + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + else: + try: + nlist = srv_host.split(".")[1:][-self.__slen :] + except Exception as exc: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index c1d4b1d2f4..235e7e1dd3 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -48,6 +48,7 @@ async def parse_uri( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -116,6 +117,7 @@ async def parse_uri( connect_timeout, srv_service_name, srv_max_hosts, + srv_allowed_hosts_suffix, ) ) result["options"] = _make_options_case_sensitive(result["options"]) @@ -131,6 +133,7 @@ async def _parse_srv( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: if uri.startswith(SCHEME): is_srv = False @@ -158,6 +161,7 @@ async def _parse_srv( hosts = unquote_plus(hosts) srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix") if is_srv: nodes = split_hosts(hosts, default_port=None) fqdn, _port = nodes[0] @@ -165,7 +169,9 @@ async def _parse_srv( # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + dns_resolver = _SrvResolver( + fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix + ) nodes = await dns_resolver.get_hosts() dns_options = await dns_resolver.get_options() if dns_options: diff --git a/pymongo/common.py b/pymongo/common.py index 08da34c7bf..70eda82322 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -718,6 +718,7 @@ def validate_server_monitoring_mode(option: str, value: str) -> str: "zlibcompressionlevel": validate_zlib_compression_level, "srvservicename": validate_string, "srvmaxhosts": validate_non_negative_integer, + "srvallowedhostssuffix": validate_string, "timeoutms": validate_timeoutms, "servermonitoringmode": validate_server_monitoring_mode, "maxadaptiveretries": validate_non_negative_integer, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6b7c5d9c98..339b8df4be 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -803,6 +803,7 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") for entity in self._host: @@ -853,6 +854,8 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix") opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. @@ -890,7 +893,9 @@ def __init__( self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) - self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) self._opened = False self._closed = False @@ -908,6 +913,7 @@ def _resolve_srv(self) -> None: opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, @@ -928,6 +934,7 @@ def _resolve_srv(self) -> None: connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, ) seeds.update(res["nodelist"]) opts = res["options"] @@ -960,6 +967,8 @@ def _resolve_srv(self) -> None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix") opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. @@ -969,10 +978,16 @@ def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) def _init_based_on_options( - self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + self, + seeds: Collection[tuple[str, int]], + srv_max_hosts: Any, + srv_service_name: Any, + srv_allowed_hosts_suffix: Any, ) -> None: self._event_listeners = self._options.pool_options._event_listeners self._topology_settings = TopologySettings( @@ -991,6 +1006,7 @@ def _init_based_on_options( load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, server_monitoring_mode=self._options.server_monitoring_mode, topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index f395588814..9ecc42505c 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -416,6 +416,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._fqdn, self._settings.pool_options.connect_timeout, self._settings.srv_service_name, + srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix, ) seedlist, ttl = resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index bc664e8421..3fe2464354 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -52,6 +52,7 @@ def __init__( load_balanced: Optional[bool] = None, srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, topology_id: Optional[ObjectId] = None, ): @@ -79,6 +80,7 @@ def __init__( self._load_balanced = load_balanced self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 + self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix self._server_monitoring_mode = server_monitoring_mode if topology_id is not None: self._topology_id = topology_id @@ -156,6 +158,11 @@ def srv_max_hosts(self) -> int: """The srvMaxHosts.""" return self._srv_max_hosts + @property + def srv_allowed_hosts_suffix(self) -> Optional[str]: + """The srvAllowedHostsSuffix.""" + return self._srv_allowed_hosts_suffix + @property def server_monitoring_mode(self) -> str: """The serverMonitoringMode.""" diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index de1a0fc9e6..be1ca16cf2 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -71,11 +71,15 @@ def __init__( connect_timeout: Optional[float], srv_service_name: str, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, ): self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT self.__srv_max_hosts = srv_max_hosts or 0 + self.__srv_allowed_hosts_suffix = ( + "." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None + ) # ensure there's a . at the beginning of the domain # Validate the fully qualified domain name. try: ipaddress.ip_address(fqdn) @@ -135,12 +139,16 @@ def _get_srv_response_and_hosts( raise ConfigurationError( "Invalid SRV host: return address is identical to SRV hostname" ) - try: - nlist = srv_host.split(".")[1:][-self.__slen :] - except Exception as exc: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_allowed_hosts_suffix is not None: + if not srv_host.endswith(self.__srv_allowed_hosts_suffix): + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + else: + try: + nlist = srv_host.split(".")[1:][-self.__slen :] + except Exception as exc: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index 8d86c91e90..9a6ab8c326 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -48,6 +48,7 @@ def parse_uri( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -116,6 +117,7 @@ def parse_uri( connect_timeout, srv_service_name, srv_max_hosts, + srv_allowed_hosts_suffix, ) ) result["options"] = _make_options_case_sensitive(result["options"]) @@ -131,6 +133,7 @@ def _parse_srv( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: if uri.startswith(SCHEME): is_srv = False @@ -158,6 +161,7 @@ def _parse_srv( hosts = unquote_plus(hosts) srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix") if is_srv: nodes = split_hosts(hosts, default_port=None) fqdn, _port = nodes[0] @@ -165,7 +169,9 @@ def _parse_srv( # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + dns_resolver = _SrvResolver( + fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix + ) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py index 62d862bcfa..efb2cef6c2 100644 --- a/pymongo/uri_parser_shared.py +++ b/pymongo/uri_parser_shared.py @@ -87,6 +87,7 @@ "socketTimeoutMS", "srvMaxHosts", "srvServiceName", + "srvAllowedHostsSuffix", "ssl", "tls", "tlsAllowInvalidCertificates", diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json new file mode 100644 index 0000000000..56e26524c4 --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json @@ -0,0 +1,6 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=test.build.10gen.cc", + "seeds": [], + "hosts": [], + "error": true +} diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json new file mode 100644 index 0000000000..8ff14a8958 --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json @@ -0,0 +1,11 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=.build.10gen.cc", + "seeds": [ + "localhost.build.10gen.cc:27017" + ], + "options": { + "srvAllowedHostsSuffix": ".build.10gen.cc", + "ssl": true + }, + "ping": false +} diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot_pass.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot_pass.json new file mode 100644 index 0000000000..3f4c1f1f71 --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot_pass.json @@ -0,0 +1,11 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=build.10gen.cc", + "seeds": [ + "localhost.build.10gen.cc:27017" + ], + "options": { + "srvAllowedHostsSuffix": "build.10gen.cc", + "ssl": true + }, + "ping": false +}