Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions api/subscriptions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rest_framework.response import Response

from framework.auth.oauth_scopes import CoreScopes
from framework import sentry

from api.base.views import JSONAPIBaseView
from api.base.filters import ListFilterMixin
Expand All @@ -31,7 +32,7 @@
Guid,
OSFUser,
)
from osf.models.notification_type import NotificationTypeEnum
from osf.models.notification_type import NotificationTypeEnum, NotificationType
from osf.models.notification_subscription import NotificationSubscription


Expand Down Expand Up @@ -174,12 +175,47 @@ def get_queryset(self):

class AbstractProviderSubscriptionList(SubscriptionList):
def get_queryset(self):
provider = AbstractProvider.objects.get(_id=self.kwargs['provider_id'])
# Get specific provider by _id and type to avoid potential conflicts with multiple provider with the same _id.
try:
provider = AbstractProvider.objects.get(
_id=self.kwargs['provider_id'],
type=self.provider_class._typedmodels_type,
)
except AbstractProvider.MultipleObjectsReturned:
provider = AbstractProvider.objects.filter(
_id=self.kwargs['provider_id'],
type=self.provider_class._typedmodels_type,
).first()
sentry.log_message(f'''
Multiple providers found with the same _id: [_id={self.kwargs["provider_id"]}, type={self.provider_class._typedmodels_type}].
Using the first one with id={provider.id}.''')

# Create missing default subscriptions for the provider if they don't exist.
notification_names = [
NotificationTypeEnum.PROVIDER_NEW_PENDING_SUBMISSIONS.value,
NotificationTypeEnum.PROVIDER_NEW_PENDING_WITHDRAW_REQUESTS.value,
]
content_type = ContentType.objects.get_for_model(provider.__class__)

notification_types = NotificationType.objects.filter(name__in=notification_names)
for nt in notification_types:
NotificationSubscription.objects.get_or_create(
object_id=provider.id,
content_type=content_type,
user=self.request.user,
notification_type=nt,
defaults={
'_is_digest': nt.is_digest_type,
'message_frequency': 'instantly',
},
)

return NotificationSubscription.objects.filter(
object_id=provider,
content_type=ContentType.objects.get_for_model(provider.__class__),
object_id=provider.id,
content_type=content_type,
user=self.request.user,
)
notification_type__name__in=notification_names,
).annotate(legacy_id=F('notification_type__name'), event_name=F('notification_type__name'))

class SubscriptionDetail(JSONAPIBaseView, generics.RetrieveUpdateAPIView):
view_name = 'notification-subscription-detail'
Expand Down Expand Up @@ -222,7 +258,7 @@ def get_object(self):
content_type=user_ct,
then=Value(f'{user_guid}_global_reviews'),
),
default=Value(f'{user_guid}_global'),
default=F('notification_type__name'),
output_field=CharField(),
),
)
Expand All @@ -248,7 +284,7 @@ def update(self, request, *args, **kwargs):
"""
Update a notification subscription
"""
self.get_object()
instance = self.get_object()

if '_global_file_updated' in self.kwargs['subscription_id']:
# Copy _global_file_updated subscription changes to all file subscriptions
Expand Down Expand Up @@ -324,7 +360,14 @@ def update(self, request, *args, **kwargs):
return Response(serializer.data)

else:
return super().update(request, *args, **kwargs)
instance.event_name = instance.notification_type.name # Set event_name for serializer to use

partial = kwargs.pop('partial', False)
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)

return Response(serializer.data)


class AbstractProviderSubscriptionDetail(SubscriptionDetail):
Expand Down
Loading