Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions optimade/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

import requests
from requests.exceptions import SSLError

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,7 +136,9 @@ def mongo_id_for_database(database_id: str, database_type: str) -> str:
return str(ObjectId(hash_bytes))


def get_providers(add_mongo_id: bool = False) -> list:
def get_providers(
add_mongo_id: bool = False, session: requests.Session | None = None
) -> list:
"""Retrieve Materials-Consortia providers (from https://providers.optimade.org/v1/links).

Fallback order if providers.optimade.org is not available:
Expand All @@ -148,18 +151,20 @@ def get_providers(add_mongo_id: bool = False) -> list:
Arguments:
add_mongo_id: Whether to populate the `_id` field of the provider with MongoDB
ObjectID.
session: An optional `requests.Session` to use for the request, allowing custom
HTTP configuration (e.g. proxies). Defaults to the module-level `requests`.

Returns:
List of raw JSON-decoded providers including MongoDB object IDs.

"""
import json

import requests
_get = session.get if session is not None else requests.get

for provider_list_url in PROVIDER_LIST_URLS:
try:
providers = requests.get(provider_list_url, timeout=10).json()
providers = _get(provider_list_url, timeout=10).json()
except (
requests.exceptions.ConnectionError,
requests.exceptions.ConnectTimeout,
Expand Down Expand Up @@ -210,6 +215,7 @@ def get_child_database_links(
obey_aggregate: bool = True,
headers: dict | None = None,
skip_ssl: bool = False,
session: requests.Session | None = None,
) -> list[LinksResource]:
"""For a provider, return a list of available child databases.

Expand All @@ -218,6 +224,8 @@ def get_child_database_links(
obey_aggregate: Whether to only return links that allow
aggregation.
headers: Additional HTTP headers to pass to the provider.
session: An optional `requests.Session` to use for the request, allowing custom
HTTP configuration (e.g. proxies). Defaults to the module-level `requests`.

Returns:
A list of the valid links entries from this provider that
Expand All @@ -228,20 +236,20 @@ def get_child_database_links(
invalid, or the request otherwise fails.

"""
import requests

from optimade.models.links import Aggregate, LinkType

_get = session.get if session is not None else requests.get

base_url = provider.pop("base_url")
if base_url is None:
raise RuntimeError(f"Provider {provider['id']} provides no base URL.")

links_endp = base_url + "/v1/links"
try:
links = requests.get(links_endp, timeout=10, headers=headers)
links = _get(links_endp, timeout=10, headers=headers)
except SSLError as exc:
if skip_ssl:
links = requests.get(links_endp, timeout=10, headers=headers, verify=False)
links = _get(links_endp, timeout=10, headers=headers, verify=False)
else:
raise RuntimeError(
f"SSL error when connecting to provider {provider['id']}. Use `skip_ssl` to ignore."
Expand Down Expand Up @@ -284,13 +292,17 @@ def get_all_databases(
exclude_databases: Container[str] | None = None,
progress: "rich.progress.Progress | None" = None,
skip_ssl: bool = False,
session: requests.Session | None = None,
) -> Iterable[str]:
"""Iterate through all databases reported by registered OPTIMADE providers.

Parameters:
include_providers: A set/container of provider IDs to include child databases for.
exclude_providers: A set/container of provider IDs to exclude child databases for.
exclude_databases: A set/container of specific database URLs to exclude.
session: An optional `requests.Session` to use for the underlying requests,
allowing custom HTTP configuration (e.g. proxies). Defaults to the
module-level `requests`.

Returns:
A generator of child database links that obey the given parameters.
Expand All @@ -309,14 +321,16 @@ def get_all_databases(
_task = None

with _progress:
for provider in get_providers():
for provider in get_providers(session=session):
if exclude_providers and provider["id"] in exclude_providers:
continue
if include_providers and provider["id"] not in include_providers:
continue

try:
links = get_child_database_links(provider, skip_ssl=skip_ssl)
links = get_child_database_links(
provider, skip_ssl=skip_ssl, session=session
)
for link in links:
if link.attributes.base_url:
if (
Expand Down
103 changes: 102 additions & 1 deletion tests/server/routers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@
from unittest import mock

import pytest
from requests.exceptions import ConnectionError
import requests
from requests.exceptions import ConnectionError, SSLError


class _MockResponse:
def __init__(self, data: list | dict, status_code: int):
self.data = data
self.status_code = status_code

def json(self) -> list | dict:
return self.data

@property
def content(self) -> str:
return str(self.data)


def mocked_providers_list_response(
Expand Down Expand Up @@ -82,6 +96,93 @@ def test_get_all_databases():
assert list(get_all_databases())


def test_get_providers_uses_provided_session():
"""A caller-supplied `requests.Session` must be used for the provider-list
request, so custom HTTP config (e.g. proxies) is honoured.

https://github.com/Materials-Consortia/optimade-python-tools/issues/2275
"""
from optimade.utils import get_providers

session = mock.MagicMock(spec=requests.Session)
session.get.return_value = mocked_providers_list_response()

# If the global `requests.get` is used instead of the session, this raises.
with mock.patch(
"requests.get", side_effect=AssertionError("used requests.get, not the session")
):
providers_list = get_providers(session=session)

assert session.get.called
assert providers_list


def test_get_child_database_links_uses_provided_session():
"""`get_child_database_links` must route its request through a supplied session.

https://github.com/Materials-Consortia/optimade-python-tools/issues/2275
"""
from optimade.utils import get_child_database_links

session = mock.MagicMock(spec=requests.Session)
session.get.return_value = _MockResponse({"data": []}, 200)
provider = {"id": "dummy", "base_url": "https://example.org"}

with mock.patch(
"requests.get", side_effect=AssertionError("used requests.get, not the session")
):
links = get_child_database_links(provider, session=session)

assert links == []
assert session.get.called


def test_get_child_database_links_skip_ssl_uses_session():
"""On an SSL error with `skip_ssl=True`, the `verify=False` retry must also be
routed through the supplied session.

https://github.com/Materials-Consortia/optimade-python-tools/issues/2275
"""
from optimade.utils import get_child_database_links

session = mock.MagicMock(spec=requests.Session)
session.get.side_effect = [SSLError("ssl boom"), _MockResponse({"data": []}, 200)]
provider = {"id": "dummy", "base_url": "https://example.org"}

with mock.patch(
"requests.get", side_effect=AssertionError("used requests.get, not the session")
):
links = get_child_database_links(provider, session=session, skip_ssl=True)

assert links == []
assert session.get.call_count == 2
assert session.get.call_args.kwargs.get("verify") is False


def test_get_all_databases_threads_session():
"""`get_all_databases` must forward its session to the helper scrapers.

https://github.com/Materials-Consortia/optimade-python-tools/issues/2275
"""
from optimade import utils

session = mock.MagicMock(spec=requests.Session)
provider = {"id": "dummy", "base_url": "https://example.org"}

with (
mock.patch.object(
utils, "get_providers", return_value=[provider]
) as mock_get_providers,
mock.patch.object(
utils, "get_child_database_links", return_value=[]
) as mock_get_links,
):
list(utils.get_all_databases(session=session))

assert mock_get_providers.call_args.kwargs.get("session") is session
assert mock_get_links.call_args.kwargs.get("session") is session


def test_get_providers_warning(caplog, top_dir):
"""Make sure a warning is logged as a last resort."""
import copy
Expand Down