Skip to content

Commit 2de45ae

Browse files
bgavrilMSfengga
andauthored
Instance discovery remains cloud local on known clouds (#875)
* Instance discovery remains cloud local on known clouds * More changes and address PR comments * Try to fix tests --------- Co-authored-by: Feng Gao <fengga@microsoft.com>
1 parent 09c8821 commit 2de45ae

File tree

7 files changed

+348
-46
lines changed

7 files changed

+348
-46
lines changed

msal/application.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111

1212
from .oauth2cli import Client, JwtAssertionCreator
1313
from .oauth2cli.oidc import decode_part
14-
from .authority import Authority, WORLD_WIDE
14+
from .authority import (
15+
Authority,
16+
WORLD_WIDE,
17+
_get_instance_discovery_endpoint,
18+
_get_instance_discovery_host,
19+
)
1520
from .mex import send_request as mex_send_request
1621
from .wstrust_request import send_request as wst_send_request
1722
from .wstrust_response import *
@@ -671,7 +676,7 @@ def __init__(
671676
self._region_detected = None
672677
self.client, self._regional_client = self._build_client(
673678
client_credential, self.authority)
674-
self.authority_groups = None
679+
self.authority_groups = {}
675680
self._telemetry_buffer = {}
676681
self._telemetry_lock = Lock()
677682
_msal_extension_check()
@@ -1304,9 +1309,16 @@ def _find_msal_accounts(self, environment):
13041309
}
13051310
return list(grouped_accounts.values())
13061311

1307-
def _get_instance_metadata(self): # This exists so it can be mocked in unit test
1312+
def _get_instance_metadata(self, instance): # This exists so it can be mocked in unit test
1313+
instance_discovery_host = _get_instance_discovery_host(instance)
13081314
resp = self.http_client.get(
1309-
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", # TBD: We may extend this to use self._instance_discovery endpoint
1315+
_get_instance_discovery_endpoint(instance),
1316+
params={
1317+
'api-version': '1.1',
1318+
'authorization_endpoint': (
1319+
"https://{}/common/oauth2/authorize".format(instance_discovery_host)
1320+
),
1321+
},
13101322
headers={'Accept': 'application/json'})
13111323
resp.raise_for_status()
13121324
return json.loads(resp.text)['metadata']
@@ -1318,10 +1330,10 @@ def _get_authority_aliases(self, instance):
13181330
# Then it is an ADFS/B2C/known_authority_hosts situation
13191331
# which may not reach the central endpoint, so we skip it.
13201332
return []
1321-
if not self.authority_groups:
1322-
self.authority_groups = [
1323-
set(group['aliases']) for group in self._get_instance_metadata()]
1324-
for group in self.authority_groups:
1333+
if instance not in self.authority_groups:
1334+
self.authority_groups[instance] = [
1335+
set(group['aliases']) for group in self._get_instance_metadata(instance)]
1336+
for group in self.authority_groups[instance]:
13251337
if instance in group:
13261338
return [alias for alias in group if alias != instance]
13271339
return []

msal/authority.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,28 @@
99
# Endpoints were copied from here
1010
# https://docs.microsoft.com/en-us/azure/active-directory/develop/authentication-national-cloud#azure-ad-authentication-endpoints
1111
AZURE_US_GOVERNMENT = "login.microsoftonline.us"
12-
AZURE_CHINA = "login.chinacloudapi.cn"
12+
DEPRECATED_AZURE_CHINA = "login.chinacloudapi.cn"
1313
AZURE_PUBLIC = "login.microsoftonline.com"
14+
AZURE_GOV_FR = "login.sovcloud-identity.fr"
15+
AZURE_GOV_DE = "login.sovcloud-identity.de"
16+
AZURE_GOV_SG = "login.sovcloud-identity.sg"
1417

1518
WORLD_WIDE = 'login.microsoftonline.com' # There was an alias login.windows.net
16-
WELL_KNOWN_AUTHORITY_HOSTS = set([
19+
WELL_KNOWN_AUTHORITY_HOSTS = frozenset([
1720
WORLD_WIDE,
18-
AZURE_CHINA,
19-
'login-us.microsoftonline.com',
20-
AZURE_US_GOVERNMENT,
21-
])
22-
23-
# Trusted issuer hosts for OIDC issuer validation
24-
# Includes all well-known Microsoft identity provider hosts and national clouds
25-
TRUSTED_ISSUER_HOSTS = frozenset([
26-
# Global/Public cloud
27-
"login.microsoftonline.com",
2821
"login.microsoft.com",
2922
"login.windows.net",
3023
"sts.windows.net",
31-
# China cloud
32-
"login.chinacloudapi.cn",
24+
DEPRECATED_AZURE_CHINA,
3325
"login.partner.microsoftonline.cn",
34-
# Germany cloud (legacy)
35-
"login.microsoftonline.de",
36-
# US Government clouds
37-
"login.microsoftonline.us",
26+
"login.microsoftonline.de", # deprecated
27+
'login-us.microsoftonline.com',
28+
AZURE_US_GOVERNMENT,
3829
"login.usgovcloudapi.net",
39-
"login-us.microsoftonline.com",
40-
"https://login.sovcloud-identity.fr", # AzureBleu
41-
"https://login.sovcloud-identity.de", # AzureDelos
42-
"https://login.sovcloud-identity.sg", # AzureGovSG
43-
])
30+
AZURE_GOV_FR,
31+
AZURE_GOV_DE,
32+
AZURE_GOV_SG,
33+
])
4434

4535
WELL_KNOWN_B2C_HOSTS = [
4636
"b2clogin.com",
@@ -52,6 +42,15 @@
5242
_CIAM_DOMAIN_SUFFIX = ".ciamlogin.com"
5343

5444

45+
def _get_instance_discovery_host(instance):
46+
return instance if instance in WELL_KNOWN_AUTHORITY_HOSTS else WORLD_WIDE
47+
48+
49+
def _get_instance_discovery_endpoint(instance):
50+
return 'https://{}/common/discovery/instance'.format(
51+
_get_instance_discovery_host(instance))
52+
53+
5554
class AuthorityBuilder(object):
5655
def __init__(self, instance, tenant):
5756
"""A helper to save caller from doing string concatenation.
@@ -157,10 +156,8 @@ def _initialize_entra_authority(
157156
) or (len(parts) == 3 and parts[2].lower().startswith("b2c_"))
158157
self._is_known_to_developer = self.is_adfs or self._is_b2c or not validate_authority
159158
is_known_to_microsoft = self.instance in WELL_KNOWN_AUTHORITY_HOSTS
160-
instance_discovery_endpoint = 'https://{}/common/discovery/instance'.format( # Note: This URL seemingly returns V1 endpoint only
161-
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
162-
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
163-
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
159+
instance_discovery_endpoint = _get_instance_discovery_endpoint( # Note: This URL seemingly returns V1 endpoint only
160+
self.instance
164161
) if instance_discovery in (None, True) else instance_discovery
165162
if instance_discovery_endpoint and not (
166163
is_known_to_microsoft or self._is_known_to_developer):
@@ -172,8 +169,8 @@ def _initialize_entra_authority(
172169
if payload.get("error") == "invalid_instance":
173170
raise ValueError(
174171
"invalid_instance: "
175-
"The authority you provided, %s, is not whitelisted. "
176-
"If it is indeed your legit customized domain name, "
172+
"The authority you provided, %s, is not known. "
173+
"If it is a valid domain name known to you, "
177174
"you can turn off this check by passing in "
178175
"instance_discovery=False"
179176
% authority_url)
@@ -230,7 +227,7 @@ def has_valid_issuer(self):
230227
return False
231228

232229
# Case 2: Issuer is from a trusted Microsoft host - O(1) lookup
233-
if issuer_host in TRUSTED_ISSUER_HOSTS:
230+
if issuer_host in WELL_KNOWN_AUTHORITY_HOSTS:
234231
return True
235232

236233
# Case 3: Regional variant check - O(1) lookup
@@ -240,7 +237,7 @@ def has_valid_issuer(self):
240237
potential_base = issuer_host[dot_index + 1:]
241238
if "." not in issuer_host[:dot_index]:
242239
# 3a: Base host is a trusted Microsoft host
243-
if potential_base in TRUSTED_ISSUER_HOSTS:
240+
if potential_base in WELL_KNOWN_AUTHORITY_HOSTS:
244241
return True
245242
# 3b: Issuer has a region prefix on the authority host
246243
# e.g. issuer=us.someweb.com, authority=someweb.com

tests/http_client.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,44 @@ def raise_for_status(self):
4040
if self._raw_resp is not None: # Turns out `if requests.response` won't work
4141
# cause it would be True when 200<=status<400
4242
self._raw_resp.raise_for_status()
43+
44+
45+
class RecordingHttpClient(object):
46+
def __init__(self):
47+
self.get_calls = []
48+
self.post_calls = []
49+
self._get_routes = []
50+
self._post_routes = []
51+
52+
def add_get_route(self, matcher, responder):
53+
self._get_routes.append((matcher, responder))
54+
55+
def add_post_route(self, matcher, responder):
56+
self._post_routes.append((matcher, responder))
57+
58+
def get(self, url, params=None, headers=None, **kwargs):
59+
call = {
60+
"url": url,
61+
"params": params,
62+
"headers": headers,
63+
"kwargs": kwargs,
64+
}
65+
self.get_calls.append(call)
66+
for matcher, responder in self._get_routes:
67+
if matcher(call):
68+
return responder(call)
69+
return MinimalResponse(status_code=404, text="")
70+
71+
def post(self, url, params=None, data=None, headers=None, **kwargs):
72+
call = {
73+
"url": url,
74+
"params": params,
75+
"data": data,
76+
"headers": headers,
77+
"kwargs": kwargs,
78+
}
79+
self.post_calls.append(call)
80+
for matcher, responder in self._post_routes:
81+
if matcher(call):
82+
return responder(call)
83+
return MinimalResponse(status_code=404, text="")

tests/test_authority.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import msal
88
from msal.authority import *
9-
from msal.authority import _CIAM_DOMAIN_SUFFIX, TRUSTED_ISSUER_HOSTS # Explicitly import private/new constants
9+
from msal.authority import _CIAM_DOMAIN_SUFFIX
1010
from tests import unittest
1111
from tests.http_client import MinimalHttpClient
1212

@@ -37,10 +37,90 @@ def _test_authority_builder(self, host, tenant):
3737
c.close()
3838

3939
def test_wellknown_host_and_tenant(self):
40-
# Assert all well known authority hosts are using their own "common" tenant
40+
# This test makes real HTTP calls to authority endpoints.
41+
# It is intentionally network-based to validate reachable hosts end-to-end.
42+
excluded_hosts = {
43+
DEPRECATED_AZURE_CHINA,
44+
"login.microsoftonline.de", # deprecated
45+
"login.microsoft.com", # issuer-only in this test context
46+
"login.windows.net", # issuer-only in this test context
47+
"sts.windows.net", # issuer-only in this test context
48+
"login.partner.microsoftonline.cn", # issuer-only in this test context
49+
"login.usgovcloudapi.net", # issuer-only in this test context
50+
AZURE_GOV_FR, # currently unreachable in this environment
51+
AZURE_GOV_DE, # currently unreachable in this environment
52+
AZURE_GOV_SG, # currently unreachable in this environment
53+
}
54+
for host in WELL_KNOWN_AUTHORITY_HOSTS:
55+
if host in excluded_hosts:
56+
continue
57+
self._test_given_host_and_tenant(host, "common")
58+
59+
@patch("msal.authority._instance_discovery")
60+
@patch("msal.authority.tenant_discovery")
61+
def test_new_sovereign_hosts_should_build_authority_endpoints(
62+
self, tenant_discovery_mock, instance_discovery_mock):
63+
for host in WELL_KNOWN_AUTHORITY_HOSTS:
64+
tenant_discovery_mock.return_value = {
65+
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
66+
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
67+
"issuer": "https://{}/common/v2.0".format(host),
68+
}
69+
instance_discovery_mock.return_value = {
70+
"tenant_discovery_endpoint": (
71+
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)
72+
),
73+
}
74+
c = MinimalHttpClient()
75+
a = Authority(AuthorityBuilder(host, "common"), c)
76+
self.assertEqual(
77+
a.authorization_endpoint,
78+
"https://{}/common/oauth2/v2.0/authorize".format(host))
79+
self.assertEqual(
80+
a.token_endpoint,
81+
"https://{}/common/oauth2/v2.0/token".format(host))
82+
c.close()
83+
84+
@patch("msal.authority._instance_discovery")
85+
@patch("msal.authority.tenant_discovery")
86+
def test_known_authority_should_use_same_host_and_skip_instance_discovery(
87+
self, tenant_discovery_mock, instance_discovery_mock):
4188
for host in WELL_KNOWN_AUTHORITY_HOSTS:
42-
if host != AZURE_CHINA: # It is prone to ConnectionError
43-
self._test_given_host_and_tenant(host, "common")
89+
tenant_discovery_mock.return_value = {
90+
"authorization_endpoint": "https://{}/common/oauth2/v2.0/authorize".format(host),
91+
"token_endpoint": "https://{}/common/oauth2/v2.0/token".format(host),
92+
"issuer": "https://{}/common/v2.0".format(host),
93+
}
94+
c = MinimalHttpClient()
95+
Authority("https://{}/common".format(host), c)
96+
c.close()
97+
98+
instance_discovery_mock.assert_not_called()
99+
tenant_discovery_endpoint = tenant_discovery_mock.call_args[0][0]
100+
self.assertTrue(
101+
tenant_discovery_endpoint.startswith(
102+
"https://{}/common/v2.0/.well-known/openid-configuration".format(host)))
103+
104+
@patch("msal.authority._instance_discovery")
105+
@patch("msal.authority.tenant_discovery")
106+
def test_unknown_authority_should_use_world_wide_instance_discovery_endpoint(
107+
self, tenant_discovery_mock, instance_discovery_mock):
108+
tenant_discovery_mock.return_value = {
109+
"authorization_endpoint": "https://example.com/tenant/oauth2/v2.0/authorize",
110+
"token_endpoint": "https://example.com/tenant/oauth2/v2.0/token",
111+
"issuer": "https://example.com/tenant/v2.0",
112+
}
113+
instance_discovery_mock.return_value = {
114+
"tenant_discovery_endpoint": "https://example.com/tenant/v2.0/.well-known/openid-configuration",
115+
}
116+
117+
c = MinimalHttpClient()
118+
Authority("https://example.com/tenant", c)
119+
c.close()
120+
121+
self.assertEqual(
122+
"https://{}/common/discovery/instance".format(WORLD_WIDE),
123+
instance_discovery_mock.call_args[0][2])
44124

45125
def test_wellknown_host_and_tenant_using_new_authority_builder(self):
46126
self._test_authority_builder(AZURE_PUBLIC, "consumers")
@@ -276,7 +356,24 @@ def test_by_default_a_known_to_microsoft_authority_should_skip_validation_but_st
276356
app = msal.ClientApplication("id", authority="https://login.microsoftonline.com/common")
277357
known_to_microsoft_validation.assert_not_called()
278358
app.get_accounts() # This could make an instance metadata call for authority aliases
279-
instance_metadata.assert_called_once_with()
359+
instance_metadata.assert_called_once_with("login.microsoftonline.com")
360+
361+
def test_by_default_a_sovereign_known_authority_should_use_cloud_local_instance_metadata(
362+
self, instance_metadata, known_to_microsoft_validation, _):
363+
app = msal.ClientApplication("id", authority="https://login.microsoftonline.us/common")
364+
known_to_microsoft_validation.assert_not_called()
365+
app.get_accounts() # This could make an instance metadata call for authority aliases
366+
instance_metadata.assert_called_once_with("login.microsoftonline.us")
367+
368+
def test_fr_known_authority_should_still_work_when_instance_metadata_has_no_alias_entry(
369+
self, instance_metadata, known_to_microsoft_validation, _):
370+
app = msal.ClientApplication("id", authority="https://{}/common".format(AZURE_GOV_FR))
371+
known_to_microsoft_validation.assert_not_called()
372+
373+
accounts = app.get_accounts()
374+
375+
self.assertEqual([], accounts)
376+
instance_metadata.assert_called_once_with(AZURE_GOV_FR)
280377

281378
def test_validate_authority_boolean_should_skip_validation_and_instance_metadata(
282379
self, instance_metadata, known_to_microsoft_validation, _):

tests/test_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,7 @@ class WorldWideRegionalEndpointTestCase(LabBasedTestCase):
11521152
These tests verify that MSAL correctly routes requests to regional vs global endpoints.
11531153
"""
11541154
region = "westus"
1155-
timeout = 2 # Short timeout makes this test case responsive on non-VM
1155+
timeout = 5 # Short timeout makes this test case responsive on non-VM
11561156

11571157
def _test_acquire_token_for_client(self, configured_region, expected_region):
11581158
"""This is the only grant supported by regional endpoint, for now.

tests/test_mi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def test_arc_by_file_existence_on_linux(self, mocked_exists):
462462

463463
@patch("msal.managed_identity.os.path.exists", return_value=True)
464464
@patch("msal.managed_identity.sys.platform", new="win32")
465-
@patch.dict(os.environ, {"ProgramFiles": "C:\Program Files"})
465+
@patch.dict(os.environ, {"ProgramFiles": r"C:\Program Files"})
466466
def test_arc_by_file_existence_on_windows(self, mocked_exists):
467467
self.assertEqual(get_managed_identity_source(), AZURE_ARC)
468468
mocked_exists.assert_called_with(

0 commit comments

Comments
 (0)