Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@


LOGIN_THROTTLE_RATE = env("LOGIN_THROTTLE_RATE", "20/min")
DCR_THROTTLE_RATE = env("DCR_THROTTLE_RATE", "10/min")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the window of this shoulld be much bigger? something like 500/month?

SIGNUP_THROTTLE_RATE = env("SIGNUP_THROTTLE_RATE", "10000/min")
USER_THROTTLE_RATE = env("USER_THROTTLE_RATE", default=None)
MASTER_API_KEY_THROTTLE_RATE = env("MASTER_API_KEY_THROTTLE_RATE", default=None)
Expand All @@ -322,6 +323,7 @@
"DEFAULT_THROTTLE_CLASSES": DEFAULT_THROTTLE_CLASSES,
"DEFAULT_THROTTLE_RATES": {
"login": LOGIN_THROTTLE_RATE,
"dcr_register": DCR_THROTTLE_RATE,
"signup": SIGNUP_THROTTLE_RATE,
"master_api_key": MASTER_API_KEY_THROTTLE_RATE,
"mfa_code": "5/min",
Expand Down
1 change: 1 addition & 0 deletions api/app/settings/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = ["core.throttling.UserRateThrottle"]
REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = {
"login": "100/min",
"dcr_register": "100/min",
"mfa_code": "5/min",
"invite": "10/min",
"signup": "100/min",
Expand Down
10 changes: 9 additions & 1 deletion api/app/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from django.urls import include, path, re_path
from django.views.generic.base import TemplateView

from oauth2_metadata.views import authorization_server_metadata
from oauth2_metadata.views import (
DynamicClientRegistrationView,
authorization_server_metadata,
)
from users.views import password_reset_redirect

from . import views
Expand Down Expand Up @@ -53,6 +56,11 @@
"robots.txt",
TemplateView.as_view(template_name="robots.txt", content_type="text/plain"),
),
path(
"o/register/",
DynamicClientRegistrationView.as_view(),
name="oauth2-dcr-register",
),
# Authorize template view for testing: this will be moved to the frontend in following issues
path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")),
]
Expand Down
79 changes: 79 additions & 0 deletions api/oauth2_metadata/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import re

from django.core.exceptions import ValidationError as DjangoValidationError
from rest_framework import serializers

from oauth2_metadata.services import validate_redirect_uri

# Allow ASCII letters, digits, spaces, hyphens, underscores, dots, and parentheses.
# ASCII-only to prevent Unicode homoglyph spoofing on the consent screen.
_CLIENT_NAME_RE = re.compile(r"^[\w\s.\-()]+$", re.ASCII)


class DCRRequestSerializer(serializers.Serializer[None]):
client_name = serializers.CharField(max_length=255, required=True)
redirect_uris = serializers.ListField(
child=serializers.URLField(),
min_length=1,
max_length=5,
required=True,
)
grant_types = serializers.ListField(
child=serializers.CharField(),
required=False,
default=["authorization_code", "refresh_token"],
)
response_types = serializers.ListField(
child=serializers.CharField(),
required=False,
default=["code"],
)
token_endpoint_auth_method = serializers.CharField(
required=False,
default="none",
)

def validate_client_name(self, value: str) -> str:
if not _CLIENT_NAME_RE.match(value):
raise serializers.ValidationError(
"Client name may only contain letters, digits, spaces, "
"hyphens, underscores, dots, and parentheses."
)
return value

def validate_redirect_uris(self, value: list[str]) -> list[str]:
errors: list[str] = []
for uri in value:
try:
validate_redirect_uri(uri)
except DjangoValidationError as e:
errors.append(str(e.message))
if errors:
raise serializers.ValidationError(errors)
return value

def validate_token_endpoint_auth_method(self, value: str) -> str:
if value != "none":
raise serializers.ValidationError(
"Only public clients are supported; "
"token_endpoint_auth_method must be 'none'."
)
return value

def validate_grant_types(self, value: list[str]) -> list[str]:
allowed = {"authorization_code", "refresh_token"}
invalid = set(value) - allowed
if invalid:
raise serializers.ValidationError(
f"Unsupported grant types: {', '.join(sorted(invalid))}"
)
return value

def validate_response_types(self, value: list[str]) -> list[str]:
allowed = {"code"}
invalid = set(value) - allowed
if invalid:
raise serializers.ValidationError(
f"Unsupported response types: {', '.join(sorted(invalid))}"
)
return value
59 changes: 59 additions & 0 deletions api/oauth2_metadata/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
from urllib.parse import urlparse

from django.core.exceptions import ValidationError
from oauth2_provider.models import Application

logger = logging.getLogger(__name__)


def validate_redirect_uri(uri: str) -> str:
"""Validate a single redirect URI per DCR policy.

Rules:
- HTTPS required for all redirect URIs
- No wildcards, exact match only
- No fragment components
- localhost exception: http://localhost:*, http://127.0.0.1:*, and http://[::1]:* permitted
"""
parsed = urlparse(uri)

if not parsed.scheme or not parsed.netloc:
raise ValidationError(f"Invalid URI: {uri}")

if "*" in uri:
raise ValidationError(f"Wildcards are not permitted in redirect URIs: {uri}")

if parsed.fragment:
raise ValidationError(f"Fragment components are not permitted: {uri}")

is_localhost = parsed.hostname in ("localhost", "127.0.0.1", "::1")

if parsed.scheme != "https" and not (parsed.scheme == "http" and is_localhost):
raise ValidationError(
f"HTTPS is required for redirect URIs (localhost excepted): {uri}"
)

return uri


def create_oauth2_application(
*,
client_name: str,
redirect_uris: list[str],
) -> Application:
"""Create a public OAuth2 application for dynamic client registration."""
application: Application = Application.objects.create(
name=client_name,
client_type=Application.CLIENT_PUBLIC,
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
client_secret="",
redirect_uris=" ".join(redirect_uris),
skip_authorization=False,
)
logger.info(
"OAuth2 DCR: registered application %s (client_id=%s).",
client_name,
application.client_id,
)
return application
28 changes: 28 additions & 0 deletions api/oauth2_metadata/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import logging
from datetime import timedelta

from django.core.management import call_command
from django.utils import timezone
from task_processor.decorators import register_recurring_task

logger = logging.getLogger(__name__)


@register_recurring_task(run_every=timedelta(hours=24))
def clear_expired_oauth2_tokens() -> None:
call_command("cleartokens")


@register_recurring_task(run_every=timedelta(hours=24))
def cleanup_stale_oauth2_applications() -> None:
"""Remove DCR applications that were never used to obtain a token.

An application is considered stale if it was registered more than 14 days
ago and has no associated access tokens, refresh tokens, or grants.
"""
from django.db.models import Exists, OuterRef
from oauth2_provider.models import AccessToken, Application, Grant, RefreshToken

threshold = timezone.now() - timedelta(days=14)
stale = Application.objects.filter(
created__lt=threshold,
user__isnull=True, # Only DCR-created apps (no user)
).exclude(
Exists(AccessToken.objects.filter(application=OuterRef("pk")))
| Exists(RefreshToken.objects.filter(application=OuterRef("pk")))
| Exists(Grant.objects.filter(application=OuterRef("pk")))
)
count, _ = stale.delete()
if count:
logger.info("OAuth2 DCR cleanup: removed %d stale application(s).", count)
69 changes: 69 additions & 0 deletions api/oauth2_metadata/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from django.http import HttpRequest, JsonResponse
from django.views.decorators.csrf import csrf_exempt
from django.views.decorators.http import require_GET
from rest_framework import status as drf_status
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.throttling import ScopedRateThrottle
from rest_framework.views import APIView

from oauth2_metadata.serializers import DCRRequestSerializer
from oauth2_metadata.services import create_oauth2_application


@csrf_exempt
Expand Down Expand Up @@ -35,3 +44,63 @@ def authorization_server_metadata(request: HttpRequest) -> JsonResponse:
}

return JsonResponse(metadata)


class DynamicClientRegistrationView(APIView):
"""RFC 7591 Dynamic Client Registration endpoint."""

authentication_classes: list[type] = []
permission_classes = [AllowAny]
throttle_classes = [ScopedRateThrottle]
throttle_scope = "dcr_register"

# Map DRF serializer field names to RFC 7591 error codes.
_rfc7591_error_codes: dict[str, str] = {
"redirect_uris": "invalid_redirect_uri",
"client_name": "invalid_client_metadata",
"grant_types": "invalid_client_metadata",
"response_types": "invalid_client_metadata",
"token_endpoint_auth_method": "invalid_client_metadata",
}

def post(self, request: Request) -> Response:
serializer = DCRRequestSerializer(data=request.data)
if not serializer.is_valid():
return self._rfc7591_error_response(serializer.errors)

data = serializer.validated_data

application = create_oauth2_application(
client_name=data["client_name"],
redirect_uris=data["redirect_uris"],
)

return Response(
{
"client_id": application.client_id,
"client_name": application.name,
"redirect_uris": data["redirect_uris"],
"grant_types": data["grant_types"],
"response_types": data["response_types"],
"token_endpoint_auth_method": data["token_endpoint_auth_method"],
"client_id_issued_at": int(application.created.timestamp()),
},
status=drf_status.HTTP_201_CREATED,
)

def _rfc7591_error_response(self, errors: dict[str, list[str]]) -> Response:
"""Format validation errors per RFC 7591 section 3.2.2."""
first_field = next(iter(errors))
error_code = self._rfc7591_error_codes.get(
first_field, "invalid_client_metadata"
)
messages = errors[first_field]
description = messages[0] if isinstance(messages[0], str) else str(messages[0])

return Response(
{
"error": error_code,
"error_description": description,
},
status=drf_status.HTTP_400_BAD_REQUEST,
)
2 changes: 1 addition & 1 deletion api/oauth2_test_server.mjs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { createServer } from "node:http";
import { randomBytes, createHash } from "node:crypto";

const CLIENT_ID = "ZLsLu3hhJI4GlhNsGeFVC3K2U3QBGfXtmc0EcyiG";
const CLIENT_ID = "B4wAl37pg9y1PRsIvAXZ14cTp0FpqpNCtMSI7ETC";
const REDIRECT_URI = "http://localhost:3000/oauth/callback";
const API_URL = "http://localhost:8000";
const PORT = 3000;
Expand Down
Loading
Loading