From e61946baefe6d774686cc68ce1c6595a7e71da84 Mon Sep 17 00:00:00 2001 From: wadii Date: Fri, 3 Apr 2026 17:14:14 +0200 Subject: [PATCH 1/2] feat: backend-oauth-consent-screen --- api/app/urls.py | 6 + api/oauth2_metadata/serializers.py | 12 + api/oauth2_metadata/views.py | 100 ++++++- api/oauth2_test_server.mjs | 5 +- .../oauth2_metadata/test_authorize_view.py | 247 ++++++++++++++++++ 5 files changed, 365 insertions(+), 5 deletions(-) create mode 100644 api/tests/unit/oauth2_metadata/test_authorize_view.py diff --git a/api/app/urls.py b/api/app/urls.py index 5df2790853fc..0e3bf4e4ffb2 100644 --- a/api/app/urls.py +++ b/api/app/urls.py @@ -8,6 +8,7 @@ from oauth2_metadata.views import ( DynamicClientRegistrationView, + OAuthAuthorizeView, authorization_server_metadata, ) from users.views import password_reset_redirect @@ -56,6 +57,11 @@ "robots.txt", TemplateView.as_view(template_name="robots.txt", content_type="text/plain"), ), + path( + "api/v1/oauth/authorize/", + OAuthAuthorizeView.as_view(), + name="oauth-authorize", + ), path( "o/register/", DynamicClientRegistrationView.as_view(), diff --git a/api/oauth2_metadata/serializers.py b/api/oauth2_metadata/serializers.py index da0648743bb1..e373afaaf0d0 100644 --- a/api/oauth2_metadata/serializers.py +++ b/api/oauth2_metadata/serializers.py @@ -5,6 +5,18 @@ from oauth2_metadata.services import validate_redirect_uri + +class OAuthConsentSerializer(serializers.Serializer): # type: ignore[type-arg] + allow = serializers.BooleanField() + client_id = serializers.CharField() + redirect_uri = serializers.CharField() + response_type = serializers.CharField() + scope = serializers.CharField(required=False, default="mcp") + code_challenge = serializers.CharField() + code_challenge_method = serializers.CharField() + state = serializers.CharField(required=False, allow_blank=True, default="") + + # Allow ASCII letters, digits, spaces, hyphens, underscores, dots, and parentheses. # ASCII-only to prevent Unicode homoglyph spoofing on the consent screen. _CLIENT_NAME_RE = re.compile(r"^[\w\s.\-()]+$", re.ASCII) diff --git a/api/oauth2_metadata/views.py b/api/oauth2_metadata/views.py index 7a592016f66e..bd43a55b8e42 100644 --- a/api/oauth2_metadata/views.py +++ b/api/oauth2_metadata/views.py @@ -1,17 +1,23 @@ from typing import Any +from urllib.parse import urlencode, urlparse, urlunparse from django.conf import settings -from django.http import HttpRequest, JsonResponse +from django.http import HttpRequest, JsonResponse, QueryDict from django.views.decorators.csrf import csrf_exempt from django.views.decorators.http import require_GET +from oauth2_provider.exceptions import OAuthToolkitError +from oauth2_provider.models import get_application_model +from oauth2_provider.scopes import get_scopes_backend +from oauth2_provider.views.mixins import OAuthLibMixin +from rest_framework import status from rest_framework import status as drf_status -from rest_framework.permissions import AllowAny +from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response from rest_framework.throttling import ScopedRateThrottle from rest_framework.views import APIView -from oauth2_metadata.serializers import DCRRequestSerializer +from oauth2_metadata.serializers import DCRRequestSerializer, OAuthConsentSerializer from oauth2_metadata.services import create_oauth2_application @@ -46,6 +52,94 @@ def authorization_server_metadata(request: HttpRequest) -> JsonResponse: return JsonResponse(metadata) +class OAuthAuthorizeView(OAuthLibMixin, APIView): # type: ignore[misc] + """Validate an OAuth authorisation request and process consent decisions.""" + + permission_classes = [IsAuthenticated] + + def get(self, request: Request, *args: Any, **kwargs: Any) -> Response: + """Validate an authorisation request and return application info.""" + # Bridge DRF auth to Django request so DOT sees the authenticated user. + request._request.user = request.user # type: ignore[assignment] + + try: + scopes, credentials = self.validate_authorization_request(request._request) + except OAuthToolkitError as e: + oauthlib_error = e.oauthlib_error + return Response( + { + "error": getattr(oauthlib_error, "error", "invalid_request"), + "error_description": getattr(oauthlib_error, "description", str(e)), + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + Application = get_application_model() + application = Application.objects.get( + client_id=credentials["client_id"], + ) + all_scopes = get_scopes_backend().get_all_scopes() + scopes_dict: dict[str, str] = {s: all_scopes.get(s, s) for s in scopes} + return Response( + { + "application": { + "name": application.name, + "client_id": application.client_id, + }, + "scopes": scopes_dict, + "redirect_uri": credentials.get("redirect_uri", ""), + "is_verified": bool(application.skip_authorization), + } + ) + + def post(self, request: Request, *args: Any, **kwargs: Any) -> Response: + """Process a consent decision and return the redirect URI.""" + serializer = OAuthConsentSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + data: dict[str, Any] = serializer.validated_data + allow: bool = data.pop("allow") + + # Bridge DRF auth to Django request so DOT sees the authenticated user. + request._request.user = request.user # type: ignore[assignment] + + # DOT's validate_authorization_request reads OAuth params from GET + # and also from request.get_full_path() which uses META['QUERY_STRING']. + query = QueryDict(mutable=True) + for key, value in data.items(): + query[key] = str(value) + request._request.GET = query + request._request.META["QUERY_STRING"] = query.urlencode() + + try: + scopes, credentials = self.validate_authorization_request(request._request) + except OAuthToolkitError as e: + oauthlib_error = e.oauthlib_error + return Response( + { + "error": getattr(oauthlib_error, "error", "invalid_request"), + "error_description": getattr(oauthlib_error, "description", str(e)), + }, + status=status.HTTP_400_BAD_REQUEST, + ) + + try: + scopes_str = " ".join(scopes) if isinstance(scopes, list) else scopes + uri, _headers, _body, _status = self.create_authorization_response( + request._request, scopes_str, credentials, allow + ) + except OAuthToolkitError: + # User denied access -- build the error redirect manually. + redirect_uri = credentials.get("redirect_uri", data.get("redirect_uri", "")) + state = credentials.get("state", data.get("state", "")) + error_params: dict[str, str] = {"error": "access_denied"} + if state: + error_params["state"] = state + parsed = urlparse(str(redirect_uri)) + uri = urlunparse(parsed._replace(query=urlencode(error_params))) + + return Response({"redirect_uri": uri}) + + class DynamicClientRegistrationView(APIView): """RFC 7591 Dynamic Client Registration endpoint.""" diff --git a/api/oauth2_test_server.mjs b/api/oauth2_test_server.mjs index d93af0f29bee..a8d3a0bdd8a9 100644 --- a/api/oauth2_test_server.mjs +++ b/api/oauth2_test_server.mjs @@ -1,8 +1,9 @@ import { createServer } from "node:http"; import { randomBytes, createHash } from "node:crypto"; -const CLIENT_ID = "B4wAl37pg9y1PRsIvAXZ14cTp0FpqpNCtMSI7ETC"; +const CLIENT_ID = "PVuLryS7ISh5gveydoLafTt02q1jMsCiwwOVoMy6"; const REDIRECT_URI = "http://localhost:3000/oauth/callback"; +const FRONTEND_URL = "http://localhost:8080"; const API_URL = "http://localhost:8000"; const PORT = 3000; @@ -13,7 +14,7 @@ const codeChallenge = createHash("sha256") .digest("base64url"); const authorizeUrl = - `${API_URL}/o/authorize/?` + + `${FRONTEND_URL}/oauth/authorize?` + new URLSearchParams({ response_type: "code", client_id: CLIENT_ID, diff --git a/api/tests/unit/oauth2_metadata/test_authorize_view.py b/api/tests/unit/oauth2_metadata/test_authorize_view.py new file mode 100644 index 000000000000..7dbe353f2389 --- /dev/null +++ b/api/tests/unit/oauth2_metadata/test_authorize_view.py @@ -0,0 +1,247 @@ +import base64 +import hashlib +import secrets +from urllib.parse import parse_qs, urlparse + +import pytest +from django.contrib.auth import get_user_model +from django.urls import reverse +from oauth2_provider.models import Application +from rest_framework import status +from rest_framework.test import APIClient + +AUTHORIZE_URL = "oauth-authorize" + +User = get_user_model() + + +def _pkce_pair() -> tuple[str, str]: + """Return (code_verifier, code_challenge) for S256 PKCE.""" + code_verifier = secrets.token_urlsafe(32) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return code_verifier, code_challenge + + +@pytest.fixture() +def oauth_application(admin_user: User) -> Application: # type: ignore[valid-type] + return Application.objects.create( # type: ignore[no-any-return] + name="Test App", + user=admin_user, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + redirect_uris="https://example.com/callback", + ) + + +@pytest.fixture() +def auth_client(admin_user: User) -> APIClient: # type: ignore[valid-type] + client = APIClient() + client.force_authenticate(user=admin_user) + return client + + +@pytest.fixture() +def pkce_pair() -> tuple[str, str]: + return _pkce_pair() + + +def test_get__valid_params__returns_application_info( + auth_client: APIClient, + oauth_application: Application, + pkce_pair: tuple[str, str], +) -> None: + # Given + _verifier, challenge = pkce_pair + url = reverse(AUTHORIZE_URL) + + # When + response = auth_client.get( + url, + { + "client_id": oauth_application.client_id, + "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + ) + + # Then + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["application"]["name"] == "Test App" + assert data["application"]["client_id"] == oauth_application.client_id + assert "mcp" in data["scopes"] + assert data["redirect_uri"] == "https://example.com/callback" + assert data["is_verified"] is False + + +def test_get__invalid_client_id__returns_400( + auth_client: APIClient, + pkce_pair: tuple[str, str], + db: None, +) -> None: + # Given + _verifier, challenge = pkce_pair + url = reverse(AUTHORIZE_URL) + + # When + response = auth_client.get( + url, + { + "client_id": "nonexistent-client-id", + "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + ) + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert "error" in data + + +def test_get__unauthenticated__returns_401( + db: None, +) -> None: + # Given + client = APIClient() + url = reverse(AUTHORIZE_URL) + + # When + response = client.get( + url, + { + "client_id": "some-id", + "response_type": "code", + }, + ) + + # Then + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_post__allow_true__returns_redirect_with_code( + auth_client: APIClient, + oauth_application: Application, + pkce_pair: tuple[str, str], +) -> None: + # Given + _verifier, challenge = pkce_pair + url = reverse(AUTHORIZE_URL) + + # When + response = auth_client.post( + url, + { + "allow": True, + "client_id": oauth_application.client_id, + "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "test-state", + }, + format="json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + data = response.json() + redirect_uri = data["redirect_uri"] + parsed = urlparse(redirect_uri) + query_params = parse_qs(parsed.query) + assert "code" in query_params + assert query_params["state"] == ["test-state"] + + +def test_post__allow_false__returns_redirect_with_error( + auth_client: APIClient, + oauth_application: Application, + pkce_pair: tuple[str, str], +) -> None: + # Given + _verifier, challenge = pkce_pair + url = reverse(AUTHORIZE_URL) + + # When + response = auth_client.post( + url, + { + "allow": False, + "client_id": oauth_application.client_id, + "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "test-state", + }, + format="json", + ) + + # Then + assert response.status_code == status.HTTP_200_OK + data = response.json() + redirect_uri = data["redirect_uri"] + parsed = urlparse(redirect_uri) + query_params = parse_qs(parsed.query) + assert query_params["error"] == ["access_denied"] + assert query_params["state"] == ["test-state"] + + +def test_post__pkce_params_preserved__code_exchangeable( + auth_client: APIClient, + oauth_application: Application, +) -> None: + # Given + code_verifier, code_challenge = _pkce_pair() + authorize_url = reverse(AUTHORIZE_URL) + + # When -- obtain an authorisation code + response = auth_client.post( + authorize_url, + { + "allow": True, + "client_id": oauth_application.client_id, + "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + }, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + redirect_uri = response.json()["redirect_uri"] + parsed = urlparse(redirect_uri) + query_params = parse_qs(parsed.query) + code = query_params["code"][0] + + # When -- exchange the code for a token + token_url = reverse("oauth2_provider:token") + token_client = APIClient() + token_response = token_client.post( + token_url, + { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": "https://example.com/callback", + "client_id": oauth_application.client_id, + "code_verifier": code_verifier, + }, + ) + + # Then + assert token_response.status_code == status.HTTP_200_OK + token_data = token_response.json() + assert "access_token" in token_data + assert "refresh_token" in token_data + assert token_data["token_type"] == "Bearer" From 74887310feb00ac8624ed072de93d5ec4cb601d7 Mon Sep 17 00:00:00 2001 From: wadii Date: Fri, 3 Apr 2026 17:16:55 +0200 Subject: [PATCH 2/2] feat: parametrize-tests --- .../oauth2_metadata/test_authorize_view.py | 97 +++++++++++-------- 1 file changed, 54 insertions(+), 43 deletions(-) diff --git a/api/tests/unit/oauth2_metadata/test_authorize_view.py b/api/tests/unit/oauth2_metadata/test_authorize_view.py index 7dbe353f2389..d10e55e3e6b9 100644 --- a/api/tests/unit/oauth2_metadata/test_authorize_view.py +++ b/api/tests/unit/oauth2_metadata/test_authorize_view.py @@ -34,6 +34,18 @@ def oauth_application(admin_user: User) -> Application: # type: ignore[valid-ty ) +@pytest.fixture() +def verified_oauth_application(admin_user: User) -> Application: # type: ignore[valid-type] + return Application.objects.create( # type: ignore[no-any-return] + name="Verified App", + user=admin_user, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + redirect_uris="https://example.com/callback", + skip_authorization=True, + ) + + @pytest.fixture() def auth_client(admin_user: User) -> APIClient: # type: ignore[valid-type] client = APIClient() @@ -78,10 +90,10 @@ def test_get__valid_params__returns_application_info( assert data["is_verified"] is False -def test_get__invalid_client_id__returns_400( +def test_get__verified_application__returns_is_verified_true( auth_client: APIClient, + verified_oauth_application: Application, pkce_pair: tuple[str, str], - db: None, ) -> None: # Given _verifier, challenge = pkce_pair @@ -91,7 +103,7 @@ def test_get__invalid_client_id__returns_400( response = auth_client.get( url, { - "client_id": "nonexistent-client-id", + "client_id": verified_oauth_application.client_id, "response_type": "code", "redirect_uri": "https://example.com/callback", "scope": "mcp", @@ -101,70 +113,71 @@ def test_get__invalid_client_id__returns_400( ) # Then - assert response.status_code == status.HTTP_400_BAD_REQUEST - data = response.json() - assert "error" in data + assert response.status_code == status.HTTP_200_OK + assert response.json()["is_verified"] is True -def test_get__unauthenticated__returns_401( +def test_get__invalid_client_id__returns_400( + auth_client: APIClient, + pkce_pair: tuple[str, str], db: None, ) -> None: # Given - client = APIClient() + _verifier, challenge = pkce_pair url = reverse(AUTHORIZE_URL) # When - response = client.get( + response = auth_client.get( url, { - "client_id": "some-id", + "client_id": "nonexistent-client-id", "response_type": "code", + "redirect_uri": "https://example.com/callback", + "scope": "mcp", + "code_challenge": challenge, + "code_challenge_method": "S256", }, ) # Then - assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert "error" in data -def test_post__allow_true__returns_redirect_with_code( - auth_client: APIClient, - oauth_application: Application, - pkce_pair: tuple[str, str], +@pytest.mark.parametrize("method", ["get", "post"]) +def test__unauthenticated__returns_401( + method: str, + db: None, ) -> None: # Given - _verifier, challenge = pkce_pair + client = APIClient() url = reverse(AUTHORIZE_URL) # When - response = auth_client.post( + response = getattr(client, method)( url, - { - "allow": True, - "client_id": oauth_application.client_id, - "response_type": "code", - "redirect_uri": "https://example.com/callback", - "scope": "mcp", - "code_challenge": challenge, - "code_challenge_method": "S256", - "state": "test-state", - }, - format="json", + {"client_id": "some-id", "response_type": "code"}, ) # Then - assert response.status_code == status.HTTP_200_OK - data = response.json() - redirect_uri = data["redirect_uri"] - parsed = urlparse(redirect_uri) - query_params = parse_qs(parsed.query) - assert "code" in query_params - assert query_params["state"] == ["test-state"] + assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_post__allow_false__returns_redirect_with_error( +@pytest.mark.parametrize( + "allow, expected_params", + [ + (True, {"state": ["test-state"]}), + (False, {"error": ["access_denied"], "state": ["test-state"]}), + ], + ids=["allow", "deny"], +) +def test_post__consent_decision__returns_redirect( auth_client: APIClient, oauth_application: Application, pkce_pair: tuple[str, str], + allow: bool, + expected_params: dict[str, list[str]], ) -> None: # Given _verifier, challenge = pkce_pair @@ -174,7 +187,7 @@ def test_post__allow_false__returns_redirect_with_error( response = auth_client.post( url, { - "allow": False, + "allow": allow, "client_id": oauth_application.client_id, "response_type": "code", "redirect_uri": "https://example.com/callback", @@ -188,12 +201,11 @@ def test_post__allow_false__returns_redirect_with_error( # Then assert response.status_code == status.HTTP_200_OK - data = response.json() - redirect_uri = data["redirect_uri"] + redirect_uri = response.json()["redirect_uri"] parsed = urlparse(redirect_uri) query_params = parse_qs(parsed.query) - assert query_params["error"] == ["access_denied"] - assert query_params["state"] == ["test-state"] + for key, value in expected_params.items(): + assert query_params[key] == value def test_post__pkce_params_preserved__code_exchangeable( @@ -204,7 +216,7 @@ def test_post__pkce_params_preserved__code_exchangeable( code_verifier, code_challenge = _pkce_pair() authorize_url = reverse(AUTHORIZE_URL) - # When -- obtain an authorisation code + # When response = auth_client.post( authorize_url, { @@ -225,7 +237,6 @@ def test_post__pkce_params_preserved__code_exchangeable( query_params = parse_qs(parsed.query) code = query_params["code"][0] - # When -- exchange the code for a token token_url = reverse("oauth2_provider:token") token_client = APIClient() token_response = token_client.post(