Skip to content

Commit dd6753c

Browse files
refactor action fetching method
1 parent e05a900 commit dd6753c

File tree

1 file changed

+79
-43
lines changed

1 file changed

+79
-43
lines changed

cogs/committee_actions_tracking.py

Lines changed: 79 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import random
55
from enum import Enum
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, overload
77

88
import discord
99
from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist, ValidationError
@@ -21,7 +21,7 @@
2121
)
2222

2323
if TYPE_CHECKING:
24-
from collections.abc import Sequence
24+
from collections.abc import Iterable, Sequence
2525
from collections.abc import Set as AbstractSet
2626
from logging import Logger
2727
from typing import Final
@@ -110,6 +110,74 @@ async def _create_action(
110110
) from create_action_error
111111
return action
112112

113+
@overload
114+
async def get_user_actions(
115+
self, action_user: discord.Member | discord.User, status: str
116+
) -> list[AssignedCommitteeAction]: ...
117+
118+
@overload
119+
async def get_user_actions(
120+
self, action_user: Iterable[discord.Member] | Iterable[discord.User], status: list[str]
121+
) -> dict[discord.Member, list[AssignedCommitteeAction]]: ...
122+
123+
async def get_user_actions(
124+
self,
125+
action_user: discord.Member
126+
| discord.User
127+
| Iterable[discord.Member]
128+
| Iterable[discord.User],
129+
status: str | list[str],
130+
) -> list[AssignedCommitteeAction] | dict[discord.Member, list[AssignedCommitteeAction]]:
131+
"""
132+
Get the actions for a given user.
133+
134+
Takes in the user and returns a list of their actions.
135+
"""
136+
if isinstance(action_user, (discord.User, discord.Member)):
137+
user_actions: list[AssignedCommitteeAction]
138+
139+
if not status:
140+
user_actions = [
141+
action
142+
async for action in await AssignedCommitteeAction.objects.afilter(
143+
(
144+
Q(status=Status.IN_PROGRESS.value)
145+
| Q(status=Status.BLOCKED.value)
146+
| Q(status=Status.NOT_STARTED.value)
147+
),
148+
discord_id=int(action_user.id),
149+
)
150+
]
151+
else:
152+
user_actions = [
153+
action
154+
async for action in await AssignedCommitteeAction.objects.afilter(
155+
status=status,
156+
discord_id=int(action_user.id),
157+
)
158+
]
159+
160+
return user_actions
161+
162+
actions: list[AssignedCommitteeAction] = [
163+
action async for action in AssignedCommitteeAction.objects.select_related().all()
164+
]
165+
166+
committee_actions: dict[discord.Member, list[AssignedCommitteeAction]] = {
167+
committee: [
168+
action
169+
for action in actions
170+
if str(action.discord_member) == DiscordMember.hash_discord_id(committee.id) # type: ignore[has-type]
171+
and action.status in status
172+
]
173+
for committee in action_user
174+
if isinstance(committee, discord.Member)
175+
}
176+
177+
return {
178+
committee: actions for committee, actions in committee_actions.items() if actions
179+
}
180+
113181

114182
class CommitteeActionsTrackingSlashCommandsCog(CommitteeActionsTrackingBaseCog):
115183
"""Cog class that defines the committee-actions tracking slash commands functionality."""
@@ -585,28 +653,10 @@ async def list_user_actions(
585653
else:
586654
action_member = ctx.user
587655

588-
user_actions: list[AssignedCommitteeAction]
589-
590-
if not status:
591-
user_actions = [
592-
action
593-
async for action in await AssignedCommitteeAction.objects.afilter(
594-
(
595-
Q(status=Status.IN_PROGRESS.value)
596-
| Q(status=Status.BLOCKED.value)
597-
| Q(status=Status.NOT_STARTED.value)
598-
),
599-
discord_id=int(action_member.id),
600-
)
601-
]
602-
else:
603-
user_actions = [
604-
action
605-
async for action in await AssignedCommitteeAction.objects.afilter(
606-
status=status,
607-
discord_id=int(action_member.id),
608-
)
609-
]
656+
user_actions: list[AssignedCommitteeAction] = await self.get_user_actions(
657+
action_user=action_member,
658+
status=status,
659+
)
610660

611661
if not user_actions:
612662
await ctx.respond(
@@ -743,10 +793,6 @@ async def list_all_actions(
743793
"""List all actions.""" # NOTE: this doesn't actually list *all* actions as it is possible for non-committee to be actioned.
744794
committee_role: discord.Role = await self.bot.committee_role
745795

746-
actions: list[AssignedCommitteeAction] = [
747-
action async for action in AssignedCommitteeAction.objects.select_related().all()
748-
]
749-
750796
desired_status: list[str] = (
751797
[status]
752798
if status
@@ -759,21 +805,11 @@ async def list_all_actions(
759805

760806
committee_members: list[discord.Member] = committee_role.members
761807

762-
committee_actions: dict[discord.Member, list[AssignedCommitteeAction]] = {
763-
committee: [
764-
action
765-
for action in actions
766-
if str(action.discord_member) == DiscordMember.hash_discord_id(committee.id) # type: ignore[has-type]
767-
and action.status in desired_status
768-
]
769-
for committee in committee_members
770-
}
771-
772-
filtered_committee_actions = {
773-
committee: actions for committee, actions in committee_actions.items() if actions
774-
}
808+
committee_actions: dict[discord.Member, list[AssignedCommitteeAction]] = (
809+
await self.get_user_actions(action_user=committee_members,status=desired_status)
810+
)
775811

776-
if not filtered_committee_actions:
812+
if not committee_actions:
777813
await ctx.respond(content="No one has any actions that match the request!")
778814
logger.debug("No actions found with the status filter: %s", status)
779815
return
@@ -782,7 +818,7 @@ async def list_all_actions(
782818
[
783819
f"\n{committee.mention if ping else committee}, Actions:"
784820
f"\n{', \n'.join(str(action.description) + f' ({AssignedCommitteeAction.Status(action.status).label})' for action in actions)}" # noqa: E501
785-
for committee, actions in filtered_committee_actions.items()
821+
for committee, actions in committee_actions.items()
786822
],
787823
)
788824

0 commit comments

Comments
 (0)