Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/instructions/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ When reviewing code, ensure you compare the changes made to files to all README.

Prepend `[AI-generated]` to the commit message when committing changes made by an AI agent.

## Branches

When creating a branch for a Jira ticket, use:

`feature/<JIRA_TICKET>_<Short_description>`

Example: `feature/GPCAPIM-395_Local_PDS_INT_Integration`

## Security

This repository is public. Do not commit any secrets, tokens or credentials.

Do not bypass file access restrictions in any way (for example, by using terminal commands to read files that Copilot tooling cannot access, such as `.env` or other local secret files).
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ Environment variables control whether stubs are used in place of the real PDS, S
| Variable | Description |
| --- | --- |
| `PDS_URL` | The URL for the PDS FHIR API; set as `stub` to use development stub. |
| `PDS_API_TOKEN`| Leave unset in development environment. |
| `PDS_API_SECRET`| Leave unset in development environment. |
| `PDS_API_KID`| Leave unset in development environment. |
| `PDS_API_TOKEN` | Leave unset in development environment. |
| `PDS_API_SECRET` | Leave unset in development environment. |
| `PDS_API_KID` | Leave unset in development environment. |
| `SDS_URL` | The URL for the SDS FHIR API; set as `stub` to use development stub. |
| `SDS_API_TOKEN`| Leave unset in development environment. |
| `SDS_API_TOKEN` | Leave unset in development environment. |
| `PROVIDER_URL` | The URL for the GP Provider; set as `stub` to use development stub. |
| `CDG_DEBUG` | `true`, return additional debug information when the call to the GP provider returns an error. |

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: local - PDS_INT
variables:
- name: base_url
value: http://localhost:5000
- name: nhs_number
value: "9692140466"
- name: from_ods
value: S55555
18 changes: 11 additions & 7 deletions gateway-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion gateway-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ clinical-data-common = { git = "https://github.com/NHSDigital/clinical-data-comm
flask = "^3.1.3"
types-flask = "^1.1.6"
requests = "^2.33.0"
pyjwt = "^2.12.0"
pyjwt = {version = "^2.12.0", extras = ["crypto"]}
pydantic = "^2.0"

[tool.poetry]
Expand Down
11 changes: 11 additions & 0 deletions gateway-api/src/gateway_api/apim_app_auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""APIM App Restricted auth tooling."""

from gateway_api.apim_app_auth.apim import (
ApimAuthenticationException,
ApimAuthenticator,
)

__all__ = [
"ApimAuthenticationException",
"ApimAuthenticator",
]
147 changes: 147 additions & 0 deletions gateway-api/src/gateway_api/apim_app_auth/apim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import functools
import logging
import uuid
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import Any, TypedDict

import jwt
import requests

from gateway_api.apim_app_auth.http import RequestMethod, SessionManager

_logger = logging.getLogger(__name__)


class ApimAuthenticationException(Exception):
pass


class ApimAuthenticator:
class __AccessToken(TypedDict):
value: str
expiry: datetime

def __init__(
self,
private_key: str,
key_id: str,
api_key: str,
token_validity_threshold: timedelta,
token_endpoint: str,
session_manager: SessionManager,
):
self._private_key = private_key
self._key_id = key_id
self._api_key = api_key
self._token_validity_threshold = token_validity_threshold
self._token_endpoint = token_endpoint
self._session_manager = session_manager

self._access_token: ApimAuthenticator.__AccessToken | None = None

def auth[**P, S](self, func: RequestMethod[P, S]) -> Callable[P, S]:
"""
Decorate a given function with APIM authentication. This authentication will be
provided via a `requests.Session` object.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
@self._session_manager.with_session
def with_session(
session: requests.Session, access_token: ApimAuthenticator.__AccessToken
) -> S:
session.headers.update(
{"Authorization": f"Bearer {access_token['value']}"}
)
return func(session, *args, **kwargs)

# If there isn't an access token yet, or the token will expire within the
# token validity threshold, reauthenticate.
if (
self._access_token is None
or self._access_token["expiry"] - datetime.now(tz=timezone.utc)
< self._token_validity_threshold
):
_logger.debug("Authenticating with APIM...")
self._access_token = self._authenticate()

return with_session(self._access_token)

return wrapper

def _create_client_assertion(self) -> str:
_logger.debug("Creating client assertion JWT for APIM authentication")
claims = {
"sub": self._api_key,
"iss": self._api_key,
"jti": str(uuid.uuid4()),
"aud": self._token_endpoint,
"exp": int(
(datetime.now(tz=timezone.utc) + timedelta(seconds=30)).timestamp()
),
}
_logger.debug(
"Created client claims. jti: %s, exp: %s, aud: %s",
claims["jti"],
claims["exp"],
claims["aud"],
)

client_assertion = jwt.encode(
claims,
self._private_key,
algorithm="RS512",
headers={"kid": self._key_id},
)

_logger.debug("Created client assertion. kid: %s", self._key_id)

return client_assertion

def _authenticate(self) -> __AccessToken:
@self._session_manager.with_session
def with_session(session: requests.Session) -> ApimAuthenticator.__AccessToken:
client_assertion = self._create_client_assertion()

_logger.debug("Sending token request with created session.")

response = session.post(
self._token_endpoint,
data={
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth"
":client-assertion-type:jwt-bearer",
"client_assertion": client_assertion,
},
)

_logger.debug(
"Response received from APIM token endpoint. Status code: %s",
response.status_code,
)

if response.status_code != 200:
raise ApimAuthenticationException(
f"Failed to authenticate with APIM. "
f"Status code: {response.status_code}"
f", Response: {response.text}"
)

response_data = response.json()
_logger.debug(
"APIM authentication successful. Expiry: %s",
response_data["expires_in"],
)

return {
"value": response_data["access_token"],
"expiry": datetime.now(tz=timezone.utc)
+ timedelta(seconds=int(response_data["expires_in"])),
}

_logger.debug(
"Sending authentication request to APIM: %s", self._token_endpoint
)
return with_session()
78 changes: 78 additions & 0 deletions gateway-api/src/gateway_api/apim_app_auth/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import re
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from typing import Any, cast


class ConfigError(Exception):
pass


class DurationUnit(StrEnum):
SECONDS = "s"
MINUTES = "m"


@dataclass(frozen=True)
class Duration:
unit: DurationUnit
value: int

@property
def timedelta(self) -> timedelta:
match self.unit:
case DurationUnit.SECONDS:
return timedelta(seconds=self.value)
case DurationUnit.MINUTES:
return timedelta(minutes=self.value)


_SUPPORTED_PRIMITIVES: dict[type[Any], Callable[[str], Any]] = {
str: str,
int: int,
}


def get_optional_environment_variable[T](name: str, _type: type[T]) -> T | None:
value = os.getenv(name)

match _type:
case _ if _type is Duration:
if value is None:
return None

parsed = re.fullmatch(r"(?P<value>\d+)(?P<unit>[sm])", value)
if parsed is None:
raise ConfigError(f"Invalid duration value: {value!r}")

raw_value = parsed.group("value")
raw_unit = parsed.group("unit")

return cast(
"T",
Duration(
unit=DurationUnit(raw_unit),
value=int(raw_value),
),
)

case _ if _type in _SUPPORTED_PRIMITIVES:
if value is None:
return None

return cast("T", _SUPPORTED_PRIMITIVES[_type](value))

case _:
raise ValueError(
f"Required type {_type} is not supported for config values"
)


def get_environment_variable[T](name: str, _type: type[T]) -> T:
value = get_optional_environment_variable(name=name, _type=_type)
if value is None:
raise ConfigError(f"Environment variable {name!r} is not set")
return value
Loading
Loading