diff --git a/rating_api/models/db.py b/rating_api/models/db.py index 6929947..b0b7641 100644 --- a/rating_api/models/db.py +++ b/rating_api/models/db.py @@ -248,6 +248,40 @@ def order_by_like_diff(cls, asc_order: bool = False): else: return cls.like_dislike_diff.desc() + @hybrid_method + def has_reaction(self, user_id: int, react: Reaction) -> bool: + return any(reaction.user_id == user_id and reaction.reaction == react for reaction in self.reactions) + + @has_reaction.expression + def has_reaction(cls, user_id: int, react: Reaction): + return ( + select([true()]) + .where( + and_( + CommentReaction.comment_uuid == cls.uuid, + CommentReaction.user_id == user_id, + CommentReaction.reaction == react, + ) + ) + .exists() + ) + + @classmethod + def reactions_for_comments(cls, user_id: int, session, comments): + if not user_id or not comments: + return {} + comments_uuid = [c.uuid for c in comments] + result = ( + session.query(Comment.uuid, CommentReaction.reaction) + .join( + CommentReaction, and_(Comment.uuid == CommentReaction.comment_uuid, CommentReaction.user_id == user_id) + ) + .filter(Comment.uuid.in_(comments_uuid)) + .group_by(Comment.uuid, CommentReaction.reaction) + .all() + ) + return dict(result) + class LecturerUserComment(BaseDbModel): id: Mapped[int] = mapped_column(Integer, primary_key=True) diff --git a/rating_api/routes/comment.py b/rating_api/routes/comment.py index 2d5579e..cfac5a5 100644 --- a/rating_api/routes/comment.py +++ b/rating_api/routes/comment.py @@ -163,14 +163,18 @@ async def import_comments( @comment.get("/{uuid}", response_model=CommentGet) -async def get_comment(uuid: UUID) -> CommentGet: +async def get_comment(uuid: UUID, user=Depends(UnionAuth(auto_error=False, allow_none=False))) -> CommentGet: """ Возвращает комментарий по его UUID в базе данных RatingAPI """ comment: Comment = Comment.query(session=db.session).filter(Comment.uuid == uuid).one_or_none() if comment is None: raise ObjectNotFound(Comment, uuid) - return CommentGet.model_validate(comment) + base_data = CommentGet.model_validate(comment) + if user: + base_data.is_liked=comment.has_reaction(user.get("id"), Reaction.LIKE) + base_data.is_disliked=comment.has_reaction(user.get("id"), Reaction.DISLIKE) + return base_data @comment.get("", response_model=Union[CommentGetAll, CommentGetAllWithAllInfo, CommentGetAllWithStatus]) @@ -187,7 +191,7 @@ async def get_comments( unreviewed: bool = False, asc_order: bool = False, user=Depends(UnionAuth(scopes=["rating.comment.review"], auto_error=False, allow_none=False)), -) -> CommentGetAll: +) -> Union[CommentGetAll, CommentGetAllWithAllInfo, CommentGetAllWithStatus]: """ Scopes: `["rating.comment.review"]` @@ -251,8 +255,23 @@ async def get_comments( result.comments = [comment for comment in result.comments if comment.review_status is ReviewStatus.APPROVED] result.total = len(result.comments) - result.comments = [comment_validator.model_validate(comment) for comment in result.comments] + comments_with_like = [] + current_user_id = user.get("id") if user else None + + if current_user_id and result.comments: + user_reactions = Comment.reactions_for_comments(current_user_id, db.session, result.comments) + else: + user_reactions = {} + + for comment in result.comments: + base_data = comment_validator.model_validate(comment) + if current_user_id: + reaction = user_reactions.get(comment.uuid) + base_data.is_liked = (reaction == Reaction.LIKE) + base_data.is_disliked = (reaction == Reaction.DISLIKE) + comments_with_like.append(base_data) + result.comments = comments_with_like return result diff --git a/rating_api/schemas/models.py b/rating_api/schemas/models.py index b99fede..b6e1e91 100644 --- a/rating_api/schemas/models.py +++ b/rating_api/schemas/models.py @@ -26,6 +26,8 @@ class CommentGet(Base): like_count: int dislike_count: int user_fullname: str | None = None + is_liked: bool = False + is_disliked: bool = False class CommentGetWithStatus(CommentGet): diff --git a/tests/conftest.py b/tests/conftest.py index 3742f8f..afb64a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -139,6 +139,23 @@ def comment(dbsession, lecturer): dbsession.commit() +@pytest.fixture +def comment_reaction(dbsession, comment): + created_reactions = [] + + def _create_reaction(user_id: int, react: Reaction): + reaction = CommentReaction(user_id=user_id, comment_uuid=comment.uuid, reaction=react) + dbsession.add(reaction) + dbsession.commit() + created_reactions.append(reaction) + + yield _create_reaction + + for reaction in created_reactions: + dbsession.delete(reaction) + dbsession.commit() + + @pytest.fixture def unreviewed_comment(dbsession, lecturer): _comment = Comment( diff --git a/tests/test_routes/test_comment.py b/tests/test_routes/test_comment.py index 5f62bba..3a82873 100644 --- a/tests/test_routes/test_comment.py +++ b/tests/test_routes/test_comment.py @@ -1,6 +1,5 @@ import datetime import logging -import uuid import pytest from starlette import status @@ -196,13 +195,39 @@ def test_create_comment(client, dbsession, lecturers, body, lecturer_n, response assert user_comment is not None -def test_get_comment(client, comment): +@pytest.mark.parametrize( + "reaction_data, expected_reaction, comment_user_id", + [ + (None, None, 0), + ((0, Reaction.LIKE), "is_liked", 0), # my like on my comment + ((0, Reaction.DISLIKE), "is_disliked", 0), + ((999, Reaction.LIKE), None, 0), # someone else's like on my comment + ((999, Reaction.DISLIKE), None, 0), + ((0, Reaction.LIKE), "is_liked", 999), # my like on someone else's comment + ((0, Reaction.DISLIKE), "is_disliked", 999), + ((333, Reaction.LIKE), None, 999), # someone else's like on another person's comment + ((333, Reaction.DISLIKE), None, 999), + (None, None, None), # anonymous + ], +) +def test_get_comment_with_reaction(client, comment, reaction_data, expected_reaction, comment_user_id, comment_reaction): + comment.user_id = comment_user_id + + if reaction_data: + user_id, reaction_type = reaction_data + comment_reaction(user_id, reaction_type) + response_comment = client.get(f'{url}/{comment.uuid}') - print("1") - assert response_comment.status_code == status.HTTP_200_OK - random_uuid = uuid.uuid4() - response = client.get(f'{url}/{random_uuid}') - assert response.status_code == status.HTTP_404_NOT_FOUND + + if response_comment: + data = response_comment.json() + if expected_reaction: + assert data[expected_reaction] + else: + assert data["is_liked"] == False + assert data["is_disliked"] == False + else: + assert response_comment.status_code == status.HTTP_404_NOT_FOUND @pytest.fixture