diff --git a/api/app/settings/common.py b/api/app/settings/common.py index e82fd1b8241a..0b667b0f5467 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -304,6 +304,7 @@ LOGIN_THROTTLE_RATE = env("LOGIN_THROTTLE_RATE", "20/min") +DCR_THROTTLE_RATE = env("DCR_THROTTLE_RATE", "10/min") SIGNUP_THROTTLE_RATE = env("SIGNUP_THROTTLE_RATE", "10000/min") USER_THROTTLE_RATE = env("USER_THROTTLE_RATE", default=None) MASTER_API_KEY_THROTTLE_RATE = env("MASTER_API_KEY_THROTTLE_RATE", default=None) @@ -322,6 +323,7 @@ "DEFAULT_THROTTLE_CLASSES": DEFAULT_THROTTLE_CLASSES, "DEFAULT_THROTTLE_RATES": { "login": LOGIN_THROTTLE_RATE, + "dcr_register": DCR_THROTTLE_RATE, "signup": SIGNUP_THROTTLE_RATE, "master_api_key": MASTER_API_KEY_THROTTLE_RATE, "mfa_code": "5/min", diff --git a/api/app/settings/test.py b/api/app/settings/test.py index 1f0ab33f395c..edcead60e710 100644 --- a/api/app/settings/test.py +++ b/api/app/settings/test.py @@ -7,6 +7,7 @@ REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = ["core.throttling.UserRateThrottle"] REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = { "login": "100/min", + "dcr_register": "100/min", "mfa_code": "5/min", "invite": "10/min", "signup": "100/min", diff --git a/api/app/urls.py b/api/app/urls.py index b9a7e1181b32..5df2790853fc 100644 --- a/api/app/urls.py +++ b/api/app/urls.py @@ -6,7 +6,10 @@ from django.urls import include, path, re_path from django.views.generic.base import TemplateView -from oauth2_metadata.views import authorization_server_metadata +from oauth2_metadata.views import ( + DynamicClientRegistrationView, + authorization_server_metadata, +) from users.views import password_reset_redirect from . import views @@ -53,6 +56,11 @@ "robots.txt", TemplateView.as_view(template_name="robots.txt", content_type="text/plain"), ), + path( + "o/register/", + DynamicClientRegistrationView.as_view(), + name="oauth2-dcr-register", + ), # Authorize template view for testing: this will be moved to the frontend in following issues path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), ] diff --git a/api/oauth2_metadata/serializers.py b/api/oauth2_metadata/serializers.py new file mode 100644 index 000000000000..da0648743bb1 --- /dev/null +++ b/api/oauth2_metadata/serializers.py @@ -0,0 +1,79 @@ +import re + +from django.core.exceptions import ValidationError as DjangoValidationError +from rest_framework import serializers + +from oauth2_metadata.services import validate_redirect_uri + +# Allow ASCII letters, digits, spaces, hyphens, underscores, dots, and parentheses. +# ASCII-only to prevent Unicode homoglyph spoofing on the consent screen. +_CLIENT_NAME_RE = re.compile(r"^[\w\s.\-()]+$", re.ASCII) + + +class DCRRequestSerializer(serializers.Serializer[None]): + client_name = serializers.CharField(max_length=255, required=True) + redirect_uris = serializers.ListField( + child=serializers.URLField(), + min_length=1, + max_length=5, + required=True, + ) + grant_types = serializers.ListField( + child=serializers.CharField(), + required=False, + default=["authorization_code", "refresh_token"], + ) + response_types = serializers.ListField( + child=serializers.CharField(), + required=False, + default=["code"], + ) + token_endpoint_auth_method = serializers.CharField( + required=False, + default="none", + ) + + def validate_client_name(self, value: str) -> str: + if not _CLIENT_NAME_RE.match(value): + raise serializers.ValidationError( + "Client name may only contain letters, digits, spaces, " + "hyphens, underscores, dots, and parentheses." + ) + return value + + def validate_redirect_uris(self, value: list[str]) -> list[str]: + errors: list[str] = [] + for uri in value: + try: + validate_redirect_uri(uri) + except DjangoValidationError as e: + errors.append(str(e.message)) + if errors: + raise serializers.ValidationError(errors) + return value + + def validate_token_endpoint_auth_method(self, value: str) -> str: + if value != "none": + raise serializers.ValidationError( + "Only public clients are supported; " + "token_endpoint_auth_method must be 'none'." + ) + return value + + def validate_grant_types(self, value: list[str]) -> list[str]: + allowed = {"authorization_code", "refresh_token"} + invalid = set(value) - allowed + if invalid: + raise serializers.ValidationError( + f"Unsupported grant types: {', '.join(sorted(invalid))}" + ) + return value + + def validate_response_types(self, value: list[str]) -> list[str]: + allowed = {"code"} + invalid = set(value) - allowed + if invalid: + raise serializers.ValidationError( + f"Unsupported response types: {', '.join(sorted(invalid))}" + ) + return value diff --git a/api/oauth2_metadata/services.py b/api/oauth2_metadata/services.py new file mode 100644 index 000000000000..837c3a87fd71 --- /dev/null +++ b/api/oauth2_metadata/services.py @@ -0,0 +1,59 @@ +import logging +from urllib.parse import urlparse + +from django.core.exceptions import ValidationError +from oauth2_provider.models import Application + +logger = logging.getLogger(__name__) + + +def validate_redirect_uri(uri: str) -> str: + """Validate a single redirect URI per DCR policy. + + Rules: + - HTTPS required for all redirect URIs + - No wildcards, exact match only + - No fragment components + - localhost exception: http://localhost:*, http://127.0.0.1:*, and http://[::1]:* permitted + """ + parsed = urlparse(uri) + + if not parsed.scheme or not parsed.netloc: + raise ValidationError(f"Invalid URI: {uri}") + + if "*" in uri: + raise ValidationError(f"Wildcards are not permitted in redirect URIs: {uri}") + + if parsed.fragment: + raise ValidationError(f"Fragment components are not permitted: {uri}") + + is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1") + + if parsed.scheme != "https" and not (parsed.scheme == "http" and is_localhost): + raise ValidationError( + f"HTTPS is required for redirect URIs (localhost excepted): {uri}" + ) + + return uri + + +def create_oauth2_application( + *, + client_name: str, + redirect_uris: list[str], +) -> Application: + """Create a public OAuth2 application for dynamic client registration.""" + application: Application = Application.objects.create( + name=client_name, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + client_secret="", + redirect_uris=" ".join(redirect_uris), + skip_authorization=False, + ) + logger.info( + "OAuth2 DCR: registered application %s (client_id=%s).", + client_name, + application.client_id, + ) + return application diff --git a/api/oauth2_metadata/tasks.py b/api/oauth2_metadata/tasks.py index 372078267f9e..b8620b95b9b7 100644 --- a/api/oauth2_metadata/tasks.py +++ b/api/oauth2_metadata/tasks.py @@ -1,9 +1,37 @@ +import logging from datetime import timedelta from django.core.management import call_command +from django.utils import timezone from task_processor.decorators import register_recurring_task +logger = logging.getLogger(__name__) + @register_recurring_task(run_every=timedelta(hours=24)) def clear_expired_oauth2_tokens() -> None: call_command("cleartokens") + + +@register_recurring_task(run_every=timedelta(hours=24)) +def cleanup_stale_oauth2_applications() -> None: + """Remove DCR applications that were never used to obtain a token. + + An application is considered stale if it was registered more than 14 days + ago and has no associated access tokens, refresh tokens, or grants. + """ + from django.db.models import Exists, OuterRef + from oauth2_provider.models import AccessToken, Application, Grant, RefreshToken + + threshold = timezone.now() - timedelta(days=14) + stale = Application.objects.filter( + created__lt=threshold, + user__isnull=True, # Only DCR-created apps (no user) + ).exclude( + Exists(AccessToken.objects.filter(application=OuterRef("pk"))) + | Exists(RefreshToken.objects.filter(application=OuterRef("pk"))) + | Exists(Grant.objects.filter(application=OuterRef("pk"))) + ) + count, _ = stale.delete() + if count: + logger.info("OAuth2 DCR cleanup: removed %d stale application(s).", count) diff --git a/api/oauth2_metadata/views.py b/api/oauth2_metadata/views.py index 25cbc77071d5..7a592016f66e 100644 --- a/api/oauth2_metadata/views.py +++ b/api/oauth2_metadata/views.py @@ -4,6 +4,15 @@ from django.http import HttpRequest, JsonResponse from django.views.decorators.csrf import csrf_exempt from django.views.decorators.http import require_GET +from rest_framework import status as drf_status +from rest_framework.permissions import AllowAny +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.throttling import ScopedRateThrottle +from rest_framework.views import APIView + +from oauth2_metadata.serializers import DCRRequestSerializer +from oauth2_metadata.services import create_oauth2_application @csrf_exempt @@ -35,3 +44,63 @@ def authorization_server_metadata(request: HttpRequest) -> JsonResponse: } return JsonResponse(metadata) + + +class DynamicClientRegistrationView(APIView): + """RFC 7591 Dynamic Client Registration endpoint.""" + + authentication_classes: list[type] = [] + permission_classes = [AllowAny] + throttle_classes = [ScopedRateThrottle] + throttle_scope = "dcr_register" + + # Map DRF serializer field names to RFC 7591 error codes. + _rfc7591_error_codes: dict[str, str] = { + "redirect_uris": "invalid_redirect_uri", + "client_name": "invalid_client_metadata", + "grant_types": "invalid_client_metadata", + "response_types": "invalid_client_metadata", + "token_endpoint_auth_method": "invalid_client_metadata", + } + + def post(self, request: Request) -> Response: + serializer = DCRRequestSerializer(data=request.data) + if not serializer.is_valid(): + return self._rfc7591_error_response(serializer.errors) + + data = serializer.validated_data + + application = create_oauth2_application( + client_name=data["client_name"], + redirect_uris=data["redirect_uris"], + ) + + return Response( + { + "client_id": application.client_id, + "client_name": application.name, + "redirect_uris": data["redirect_uris"], + "grant_types": data["grant_types"], + "response_types": data["response_types"], + "token_endpoint_auth_method": data["token_endpoint_auth_method"], + "client_id_issued_at": int(application.created.timestamp()), + }, + status=drf_status.HTTP_201_CREATED, + ) + + def _rfc7591_error_response(self, errors: dict[str, list[str]]) -> Response: + """Format validation errors per RFC 7591 section 3.2.2.""" + first_field = next(iter(errors)) + error_code = self._rfc7591_error_codes.get( + first_field, "invalid_client_metadata" + ) + messages = errors[first_field] + description = messages[0] if isinstance(messages[0], str) else str(messages[0]) + + return Response( + { + "error": error_code, + "error_description": description, + }, + status=drf_status.HTTP_400_BAD_REQUEST, + ) diff --git a/api/oauth2_test_server.mjs b/api/oauth2_test_server.mjs index 1d0029a59faf..d93af0f29bee 100644 --- a/api/oauth2_test_server.mjs +++ b/api/oauth2_test_server.mjs @@ -1,7 +1,7 @@ import { createServer } from "node:http"; import { randomBytes, createHash } from "node:crypto"; -const CLIENT_ID = "ZLsLu3hhJI4GlhNsGeFVC3K2U3QBGfXtmc0EcyiG"; +const CLIENT_ID = "B4wAl37pg9y1PRsIvAXZ14cTp0FpqpNCtMSI7ETC"; const REDIRECT_URI = "http://localhost:3000/oauth/callback"; const API_URL = "http://localhost:8000"; const PORT = 3000; diff --git a/api/tests/unit/oauth2_metadata/test_dcr.py b/api/tests/unit/oauth2_metadata/test_dcr.py new file mode 100644 index 000000000000..cb7fbc949a00 --- /dev/null +++ b/api/tests/unit/oauth2_metadata/test_dcr.py @@ -0,0 +1,281 @@ +from unittest.mock import patch + +import pytest +from django.core.exceptions import ValidationError +from django.urls import reverse +from oauth2_provider.models import Application +from rest_framework import status +from rest_framework.test import APIClient + +from oauth2_metadata.services import validate_redirect_uri + +DCR_URL = reverse("oauth2-dcr-register") + + +@pytest.fixture() +def api_client() -> APIClient: + return APIClient() + + +def _valid_payload(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "client_name": "Test MCP Client", + "redirect_uris": ["https://example.com/callback"], + } + payload.update(overrides) + return payload + + +@pytest.mark.django_db() +def test_dcr_register__valid_request__returns_201_with_client_id( + api_client: APIClient, +) -> None: + # Given + payload = _valid_payload() + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["client_id"] + assert data["client_name"] == "Test MCP Client" + assert data["redirect_uris"] == ["https://example.com/callback"] + assert data["grant_types"] == ["authorization_code", "refresh_token"] + assert data["response_types"] == ["code"] + assert data["token_endpoint_auth_method"] == "none" + assert isinstance(data["client_id_issued_at"], int) + + +@pytest.mark.django_db() +@pytest.mark.parametrize( + "redirect_uri", + [ + "http://localhost:8080/callback", + "http://127.0.0.1:3000/callback", + "http://[::1]:3000/callback", + "https://example.com/callback", + ], + ids=["localhost", "127.0.0.1", "::1", "https"], +) +def test_dcr_register__valid_redirect_uri__returns_201( + api_client: APIClient, + redirect_uri: str, +) -> None: + # Given + payload = _valid_payload(redirect_uris=[redirect_uri]) + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + + +@pytest.mark.django_db() +@pytest.mark.parametrize( + "client_name", + [ + "Claude Desktop (v2.1-beta)", + "My_App.test", + "Simple", + ], + ids=["special-chars", "underscores-dots", "simple"], +) +def test_dcr_register__valid_client_name__returns_201( + api_client: APIClient, + client_name: str, +) -> None: + # Given + payload = _valid_payload(client_name=client_name) + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_201_CREATED + assert response.json()["client_name"] == client_name + + +@pytest.mark.django_db() +def test_dcr_register__defaults_applied__returns_expected_defaults( + api_client: APIClient, +) -> None: + # Given - only required fields + payload = _valid_payload() + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + data = response.json() + assert data["grant_types"] == ["authorization_code", "refresh_token"] + assert data["response_types"] == ["code"] + assert data["token_endpoint_auth_method"] == "none" + + +@pytest.mark.django_db() +def test_dcr_register__valid_request__creates_public_application_in_database( + api_client: APIClient, +) -> None: + # Given + payload = _valid_payload() + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + client_id = response.json()["client_id"] + application = Application.objects.get(client_id=client_id) + assert application.client_type == Application.CLIENT_PUBLIC + assert application.authorization_grant_type == Application.GRANT_AUTHORIZATION_CODE + assert application.name == "Test MCP Client" + assert "https://example.com/callback" in application.redirect_uris + assert application.user is None + assert application.skip_authorization is False + + +@pytest.mark.parametrize( + ("redirect_uris", "expected_fragment"), + [ + (["http://example.com/callback"], "HTTPS"), + (["https://example.com/callback#frag"], "Fragment"), + (["https://*.example.com/callback"], "valid URL"), + ([], "at least 1"), + ([f"https://example.com/cb{i}" for i in range(6)], "no more than 5"), + ], + ids=["http-non-localhost", "fragment", "wildcard", "empty-list", "too-many"], +) +def test_dcr_register__invalid_redirect_uris__returns_rfc7591_error( + api_client: APIClient, + redirect_uris: list[str], + expected_fragment: str, +) -> None: + # Given + payload = _valid_payload(redirect_uris=redirect_uris) + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["error"] == "invalid_redirect_uri" + assert expected_fragment.lower() in data["error_description"].lower() + + +@pytest.mark.parametrize( + ("overrides", "expected_fragment"), + [ + ({"client_name": ""}, "letters"), + ({"client_name": " "}, "blank"), + ({"grant_types": ["implicit"]}, "grant type"), + ({"response_types": ["token"]}, "response type"), + ({"token_endpoint_auth_method": "client_secret_basic"}, "public clients"), + ], + ids=[ + "xss-client-name", + "blank-client-name", + "bad-grant-type", + "bad-response-type", + "bad-auth-method", + ], +) +def test_dcr_register__invalid_client_metadata__returns_rfc7591_error( + api_client: APIClient, + overrides: dict[str, object], + expected_fragment: str, +) -> None: + # Given + payload = _valid_payload(**overrides) + + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["error"] == "invalid_client_metadata" + assert expected_fragment.lower() in data["error_description"].lower() + + +@pytest.mark.parametrize( + ("payload", "expected_error"), + [ + ( + {"redirect_uris": ["https://example.com/callback"]}, + "invalid_client_metadata", + ), + ( + {"client_name": "Test"}, + "invalid_redirect_uri", + ), + ], + ids=["missing-client-name", "missing-redirect-uris"], +) +def test_dcr_register__missing_required_field__returns_rfc7591_error( + api_client: APIClient, + payload: dict[str, object], + expected_error: str, +) -> None: + # Given / When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["error"] == expected_error + assert "error_description" in data + + +def test_dcr_register__get_request__returns_405( + api_client: APIClient, +) -> None: + # Given / When + response = api_client.get(DCR_URL) + + # Then + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + + +@pytest.mark.django_db() +def test_dcr_register__rate_limited__returns_429( + api_client: APIClient, +) -> None: + # Given + payload = _valid_payload() + + with ( + patch( + "rest_framework.throttling.ScopedRateThrottle.allow_request", + return_value=False, + ), + patch( + "rest_framework.throttling.ScopedRateThrottle.wait", + return_value=60.0, + ), + ): + # When + response = api_client.post(DCR_URL, data=payload, format="json") + + # Then + assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS + + +@pytest.mark.parametrize( + ("uri", "expected_message"), + [ + ("not-a-uri", "Invalid URI"), + ("https://*.example.com/callback", "Wildcards"), + ], + ids=["invalid-uri", "wildcard"], +) +def test_validate_redirect_uri__invalid_input__raises_validation_error( + uri: str, + expected_message: str, +) -> None: + # Given / When + # Then + with pytest.raises(ValidationError, match=expected_message): + validate_redirect_uri(uri) diff --git a/api/tests/unit/oauth2_metadata/test_tasks.py b/api/tests/unit/oauth2_metadata/test_tasks.py index d5ea32e9bc0c..2ad72b75ea09 100644 --- a/api/tests/unit/oauth2_metadata/test_tasks.py +++ b/api/tests/unit/oauth2_metadata/test_tasks.py @@ -1,6 +1,15 @@ +from datetime import timedelta from unittest.mock import MagicMock -from oauth2_metadata.tasks import clear_expired_oauth2_tokens +import pytest +from django.contrib.auth.models import AbstractUser +from django.utils import timezone +from oauth2_provider.models import AccessToken, Application + +from oauth2_metadata.tasks import ( + cleanup_stale_oauth2_applications, + clear_expired_oauth2_tokens, +) def test_clear_expired_oauth2_tokens__called__invokes_cleartokens_command( @@ -14,3 +23,71 @@ def test_clear_expired_oauth2_tokens__called__invokes_cleartokens_command( # Then mock_call_command.assert_called_once_with("cleartokens") + + +@pytest.mark.django_db() +def test_cleanup_stale_oauth2_applications__old_app_with_no_token__deletes_it() -> None: + # Given + app = Application.objects.create( + name="Stale App", + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + redirect_uris="https://example.com/callback", + ) + Application.objects.filter(pk=app.pk).update( + created=timezone.now() - timedelta(days=15), + ) + + # When + cleanup_stale_oauth2_applications() + + # Then + assert not Application.objects.filter(pk=app.pk).exists() + + +@pytest.mark.django_db() +def test_cleanup_stale_oauth2_applications__old_app_with_token__keeps_it( + admin_user: AbstractUser, +) -> None: + # Given + app = Application.objects.create( + name="Active App", + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + redirect_uris="https://example.com/callback", + ) + Application.objects.filter(pk=app.pk).update( + created=timezone.now() - timedelta(days=15), + ) + AccessToken.objects.create( + user=admin_user, + application=app, + token="test-token", + expires=timezone.now() + timedelta(hours=1), + ) + + # When + cleanup_stale_oauth2_applications() + + # Then + assert Application.objects.filter(pk=app.pk).exists() + + +@pytest.mark.django_db() +def test_cleanup_stale_oauth2_applications__recent_app__keeps_it() -> None: + # Given - an app created 5 days ago with no tokens + app = Application.objects.create( + name="Recent App", + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + redirect_uris="https://example.com/callback", + ) + Application.objects.filter(pk=app.pk).update( + created=timezone.now() - timedelta(days=5), + ) + + # When + cleanup_stale_oauth2_applications() + + # Then + assert Application.objects.filter(pk=app.pk).exists()