diff --git a/msal/region.py b/msal/region.py index 37e01f0d..3b5b4b28 100644 --- a/msal/region.py +++ b/msal/region.py @@ -1,3 +1,4 @@ +import json import os import logging import re @@ -6,6 +7,10 @@ _VALID_REGION_RE = re.compile(r"^[a-z][a-z0-9-]*$") +# IMDS compute metadata API version used for region auto-discovery. +# Bump this single constant when moving to a newer IMDS API version. +_IMDS_API_VERSION = "2021-02-01" + def _validate_region(region, source="unknown"): """Return *region* unchanged if it looks like a valid Azure region name, @@ -30,15 +35,11 @@ def _detect_region(http_client=None): def _detect_region_of_azure_vm(http_client): url = ( - "http://169.254.169.254/metadata/instance" - - # Utilize the "route parameters" feature to obtain region as a string - # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#route-parameters - "/compute/location?format=text" + "http://169.254.169.254/metadata/instance/compute" - # Location info is available since API version 2017-04-02 - # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#response-1 - "&api-version=2021-01-01" + # The region is read from the "location" field of the compute metadata. + # https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#response-1 + "?api-version=" + _IMDS_API_VERSION ) logger.info( "Connecting to IMDS {}. " @@ -56,5 +57,16 @@ def _detect_region_of_azure_vm(http_client): "IMDS {} unavailable. Perhaps not running in Azure VM?".format(url)) return None else: - return _validate_region(resp.text.strip(), source="IMDS endpoint") + try: + location = json.loads(resp.text).get("location") + except (ValueError, AttributeError, TypeError): + # ValueError: body is not valid JSON; + # AttributeError: body is valid JSON but not a JSON object; + # TypeError: resp.text is not a string (e.g. a custom http_client). + logger.info("IMDS {} returned a malformed response.".format(url)) + return None + if location is not None and not isinstance(location, str): + logger.info("IMDS {} returned a non-string location.".format(url)) + return None + return _validate_region(location, source="IMDS endpoint") diff --git a/tests/test_region.py b/tests/test_region.py index c839f7c6..00b968c5 100644 --- a/tests/test_region.py +++ b/tests/test_region.py @@ -1,8 +1,32 @@ import os import unittest +from types import SimpleNamespace from unittest.mock import patch -from msal.region import _detect_region, _validate_region +from msal.region import ( + _detect_region, _detect_region_of_azure_vm, _validate_region) + +from tests.http_client import MinimalResponse + + +class _StubHttpClient(object): + """Records the requested URL/headers and returns a preconfigured response. + + If *response* is an exception instance, it is raised from ``get`` to + simulate a network failure (e.g. not running in an Azure VM).""" + + def __init__(self, response): + self._response = response + self.url = None + self.headers = None + + def get(self, url, params=None, headers=None, **kwargs): + self.url = url + self.headers = headers + if isinstance(self._response, Exception): + raise self._response + return self._response + class TestValidateRegion(unittest.TestCase): @@ -55,5 +79,59 @@ def test_empty_env_returns_none(self): self.assertIsNone(_detect_region()) +class TestDetectRegionOfAzureVm(unittest.TestCase): + + def test_valid_location_is_returned(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text='{"location": "westus2"}')) + self.assertEqual(_detect_region_of_azure_vm(client), "westus2") + + def test_request_uses_compute_json_endpoint(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text='{"location": "westus2"}')) + _detect_region_of_azure_vm(client) + self.assertEqual( + client.url, + "http://169.254.169.254/metadata/instance/compute" + "?api-version=2021-02-01") + self.assertNotIn("/location", client.url) + self.assertNotIn("format=text", client.url) + self.assertEqual(client.headers, {"Metadata": "true"}) + + def test_missing_location_returns_none(self): + client = _StubHttpClient(MinimalResponse(status_code=200, text="{}")) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_null_location_returns_none(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text='{"location": null}')) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_malformed_json_returns_none(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text="not json")) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_invalid_location_value_returns_none(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text='{"location": "evil.com/hijack"}')) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_non_string_location_returns_none(self): + client = _StubHttpClient( + MinimalResponse(status_code=200, text='{"location": 123}')) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_non_string_response_text_returns_none(self): + # A custom http_client could yield a non-string resp.text; json.loads + # would raise TypeError, which must be treated as a malformed response. + client = _StubHttpClient(SimpleNamespace(status_code=200, text=None)) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + def test_network_failure_returns_none(self): + client = _StubHttpClient(IOError("IMDS unreachable")) + self.assertIsNone(_detect_region_of_azure_vm(client)) + + if __name__ == "__main__": unittest.main()