diff --git a/api/subscriptions/views.py b/api/subscriptions/views.py index 9e9ba419e10..52be8669231 100644 --- a/api/subscriptions/views.py +++ b/api/subscriptions/views.py @@ -9,6 +9,7 @@ from rest_framework.response import Response from framework.auth.oauth_scopes import CoreScopes +from framework import sentry from api.base.views import JSONAPIBaseView from api.base.filters import ListFilterMixin @@ -31,7 +32,7 @@ Guid, OSFUser, ) -from osf.models.notification_type import NotificationTypeEnum +from osf.models.notification_type import NotificationTypeEnum, NotificationType from osf.models.notification_subscription import NotificationSubscription @@ -174,12 +175,47 @@ def get_queryset(self): class AbstractProviderSubscriptionList(SubscriptionList): def get_queryset(self): - provider = AbstractProvider.objects.get(_id=self.kwargs['provider_id']) + # Get specific provider by _id and type to avoid potential conflicts with multiple provider with the same _id. + try: + provider = AbstractProvider.objects.get( + _id=self.kwargs['provider_id'], + type=self.provider_class._typedmodels_type, + ) + except AbstractProvider.MultipleObjectsReturned: + provider = AbstractProvider.objects.filter( + _id=self.kwargs['provider_id'], + type=self.provider_class._typedmodels_type, + ).first() + sentry.log_message(f''' + Multiple providers found with the same _id: [_id={self.kwargs["provider_id"]}, type={self.provider_class._typedmodels_type}]. + Using the first one with id={provider.id}.''') + + # Create missing default subscriptions for the provider if they don't exist. + notification_names = [ + NotificationTypeEnum.PROVIDER_NEW_PENDING_SUBMISSIONS.value, + NotificationTypeEnum.PROVIDER_NEW_PENDING_WITHDRAW_REQUESTS.value, + ] + content_type = ContentType.objects.get_for_model(provider.__class__) + + notification_types = NotificationType.objects.filter(name__in=notification_names) + for nt in notification_types: + NotificationSubscription.objects.get_or_create( + object_id=provider.id, + content_type=content_type, + user=self.request.user, + notification_type=nt, + defaults={ + '_is_digest': nt.is_digest_type, + 'message_frequency': 'instantly', + }, + ) + return NotificationSubscription.objects.filter( - object_id=provider, - content_type=ContentType.objects.get_for_model(provider.__class__), + object_id=provider.id, + content_type=content_type, user=self.request.user, - ) + notification_type__name__in=notification_names, + ).annotate(legacy_id=F('notification_type__name'), event_name=F('notification_type__name')) class SubscriptionDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView): view_name = 'notification-subscription-detail' @@ -222,7 +258,7 @@ def get_object(self): content_type=user_ct, then=Value(f'{user_guid}_global_reviews'), ), - default=Value(f'{user_guid}_global'), + default=F('notification_type__name'), output_field=CharField(), ), ) @@ -248,7 +284,7 @@ def update(self, request, *args, **kwargs): """ Update a notification subscription """ - self.get_object() + instance = self.get_object() if '_global_file_updated' in self.kwargs['subscription_id']: # Copy _global_file_updated subscription changes to all file subscriptions @@ -324,7 +360,14 @@ def update(self, request, *args, **kwargs): return Response(serializer.data) else: - return super().update(request, *args, **kwargs) + instance.event_name = instance.notification_type.name # Set event_name for serializer to use + + partial = kwargs.pop('partial', False) + serializer = self.get_serializer(instance, data=request.data, partial=partial) + serializer.is_valid(raise_exception=True) + self.perform_update(serializer) + + return Response(serializer.data) class AbstractProviderSubscriptionDetail(SubscriptionDetail):