Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions msal/region.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import logging
import re
Expand All @@ -6,6 +7,10 @@

_VALID_REGION_RE = re.compile(r"^[a-z][a-z0-9-]*$")

# IMDS compute metadata API version used for region auto-discovery.
# Bump this single constant when moving to a newer IMDS API version.
_IMDS_API_VERSION = "2021-02-01"


def _validate_region(region, source="unknown"):
"""Return *region* unchanged if it looks like a valid Azure region name,
Expand All @@ -30,15 +35,11 @@ def _detect_region(http_client=None):

def _detect_region_of_azure_vm(http_client):
url = (
"http://169.254.169.254/metadata/instance"

# Utilize the "route parameters" feature to obtain region as a string
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#route-parameters
"/compute/location?format=text"
"http://169.254.169.254/metadata/instance/compute"

# Location info is available since API version 2017-04-02
# https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#response-1
"&api-version=2021-01-01"
# The region is read from the "location" field of the compute metadata.
# https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=linux#response-1
"?api-version=" + _IMDS_API_VERSION
)
logger.info(
"Connecting to IMDS {}. "
Expand All @@ -56,5 +57,16 @@ def _detect_region_of_azure_vm(http_client):
"IMDS {} unavailable. Perhaps not running in Azure VM?".format(url))
return None
else:
return _validate_region(resp.text.strip(), source="IMDS endpoint")
try:
location = json.loads(resp.text).get("location")
except (ValueError, AttributeError, TypeError):
# ValueError: body is not valid JSON;
# AttributeError: body is valid JSON but not a JSON object;
# TypeError: resp.text is not a string (e.g. a custom http_client).
logger.info("IMDS {} returned a malformed response.".format(url))
return None
if location is not None and not isinstance(location, str):
logger.info("IMDS {} returned a non-string location.".format(url))
return None
return _validate_region(location, source="IMDS endpoint")
Comment on lines +68 to +71

80 changes: 79 additions & 1 deletion tests/test_region.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,32 @@
import os
import unittest
from types import SimpleNamespace
from unittest.mock import patch

from msal.region import _detect_region, _validate_region
from msal.region import (
_detect_region, _detect_region_of_azure_vm, _validate_region)

from tests.http_client import MinimalResponse


class _StubHttpClient(object):
"""Records the requested URL/headers and returns a preconfigured response.

If *response* is an exception instance, it is raised from ``get`` to
simulate a network failure (e.g. not running in an Azure VM)."""

def __init__(self, response):
self._response = response
self.url = None
self.headers = None

def get(self, url, params=None, headers=None, **kwargs):
self.url = url
self.headers = headers
if isinstance(self._response, Exception):
raise self._response
return self._response



class TestValidateRegion(unittest.TestCase):
Expand Down Expand Up @@ -55,5 +79,59 @@ def test_empty_env_returns_none(self):
self.assertIsNone(_detect_region())


class TestDetectRegionOfAzureVm(unittest.TestCase):

def test_valid_location_is_returned(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text='{"location": "westus2"}'))
self.assertEqual(_detect_region_of_azure_vm(client), "westus2")

def test_request_uses_compute_json_endpoint(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text='{"location": "westus2"}'))
_detect_region_of_azure_vm(client)
self.assertEqual(
client.url,
"http://169.254.169.254/metadata/instance/compute"
"?api-version=2021-02-01")
self.assertNotIn("/location", client.url)
self.assertNotIn("format=text", client.url)
self.assertEqual(client.headers, {"Metadata": "true"})

def test_missing_location_returns_none(self):
client = _StubHttpClient(MinimalResponse(status_code=200, text="{}"))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_null_location_returns_none(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text='{"location": null}'))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_malformed_json_returns_none(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text="not json"))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_invalid_location_value_returns_none(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text='{"location": "evil.com/hijack"}'))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_non_string_location_returns_none(self):
client = _StubHttpClient(
MinimalResponse(status_code=200, text='{"location": 123}'))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_non_string_response_text_returns_none(self):
# A custom http_client could yield a non-string resp.text; json.loads
# would raise TypeError, which must be treated as a malformed response.
client = _StubHttpClient(SimpleNamespace(status_code=200, text=None))
self.assertIsNone(_detect_region_of_azure_vm(client))

def test_network_failure_returns_none(self):
client = _StubHttpClient(IOError("IMDS unreachable"))
self.assertIsNone(_detect_region_of_azure_vm(client))


if __name__ == "__main__":
unittest.main()
Loading