Skip to content

Commit b1a350f

Browse files
committed
✨ Make it run
1 parent 8555732 commit b1a350f

File tree

11 files changed

+62
-32
lines changed

11 files changed

+62
-32
lines changed

discord/app/cache.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@
2323
"""
2424

2525
from collections import OrderedDict, defaultdict, deque
26-
from typing import Deque, Protocol, TypeVar
26+
from typing import TYPE_CHECKING, Deque, Protocol, TypeVar
2727

2828
from discord import utils
29-
from discord.app.state import ConnectionState
3029
from discord.member import Member
3130
from discord.message import Message
3231

33-
from ..abc import MessageableChannel, PrivateChannel
3432
from ..channel import DMChannel
3533
from ..emoji import AppEmoji, GuildEmoji
3634
from ..guild import Guild
@@ -45,10 +43,28 @@
4543
from ..ui.view import View
4644
from ..user import User
4745

46+
if TYPE_CHECKING:
47+
from discord.app.state import ConnectionState
48+
49+
from ..abc import MessageableChannel, PrivateChannel
50+
4851
T = TypeVar("T")
4952

5053

5154
class Cache(Protocol):
55+
def __init__(self):
56+
self.__state: ConnectionState | None = None
57+
58+
@property
59+
def _state(self) -> "ConnectionState":
60+
if self.__state is None:
61+
raise RuntimeError("Cache state has not been initialized.")
62+
return self.__state
63+
64+
@_state.setter
65+
def _state(self, state: "ConnectionState") -> None:
66+
self.__state = state
67+
5268
# users
5369
async def get_all_users(self) -> list[User]: ...
5470

@@ -114,17 +130,17 @@ async def store_poll(self, poll: Poll, message_id: int) -> None: ...
114130

115131
# private channels
116132

117-
async def get_private_channels(self) -> list[PrivateChannel]: ...
133+
async def get_private_channels(self) -> "list[PrivateChannel]": ...
118134

119-
async def get_private_channel(self, channel_id: int) -> PrivateChannel: ...
135+
async def get_private_channel(self, channel_id: int) -> "PrivateChannel": ...
120136

121-
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel | None: ...
137+
async def get_private_channel_by_user(self, user_id: int) -> "PrivateChannel | None": ...
122138

123-
async def store_private_channel(self, channel: PrivateChannel) -> None: ...
139+
async def store_private_channel(self, channel: "PrivateChannel") -> None: ...
124140

125141
# messages
126142

127-
async def store_message(self, message: MessagePayload, channel: MessageableChannel) -> Message: ...
143+
async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message: ...
128144

129145
async def upsert_message(self, message: Message) -> None: ...
130146

@@ -152,8 +168,8 @@ async def clear(self, views: bool = True) -> None: ...
152168

153169

154170
class MemoryCache(Cache):
155-
def __init__(self, max_messages: int | None = None, *, state: ConnectionState):
156-
self._state = state
171+
def __init__(self, max_messages: int | None = None) -> None:
172+
self.__state: ConnectionState | None = None
157173
self.max_messages = max_messages
158174
self._users: dict[int, User] = {}
159175
self._guilds: dict[int, Guild] = {}
@@ -312,10 +328,10 @@ async def store_poll(self, poll: Poll, message_id: int) -> None:
312328

313329
# private channels
314330

315-
async def get_private_channels(self) -> list[PrivateChannel]:
331+
async def get_private_channels(self) -> "list[PrivateChannel]":
316332
return list(self._private_channels.values())
317333

318-
async def get_private_channel(self, channel_id: int) -> PrivateChannel | None:
334+
async def get_private_channel(self, channel_id: int) -> "PrivateChannel | None":
319335
try:
320336
channel = self._private_channels[channel_id]
321337
except KeyError:
@@ -324,7 +340,7 @@ async def get_private_channel(self, channel_id: int) -> PrivateChannel | None:
324340
self._private_channels.move_to_end(channel_id)
325341
return channel
326342

327-
async def store_private_channel(self, channel: PrivateChannel) -> None:
343+
async def store_private_channel(self, channel: "PrivateChannel") -> None:
328344
channel_id = channel.id
329345
self._private_channels[channel_id] = channel
330346

@@ -336,15 +352,15 @@ async def store_private_channel(self, channel: PrivateChannel) -> None:
336352
if isinstance(channel, DMChannel) and channel.recipient:
337353
self._private_channels_by_user[channel.recipient.id] = channel
338354

339-
async def get_private_channel_by_user(self, user_id: int) -> PrivateChannel | None:
355+
async def get_private_channel_by_user(self, user_id: int) -> "PrivateChannel | None":
340356
return self._private_channels_by_user.get(user_id)
341357

342358
# messages
343359

344360
async def upsert_message(self, message: Message) -> None:
345361
self._messages.append(message)
346362

347-
async def store_message(self, message: MessagePayload, channel: MessageableChannel) -> Message:
363+
async def store_message(self, message: MessagePayload, channel: "MessageableChannel") -> Message:
348364
msg = await Message._from_data(state=self._state, channel=channel, data=message)
349365
self._messages.append(msg)
350366
return msg

discord/app/event_emitter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
from abc import ABC, abstractmethod
2727
from asyncio import Future
2828
from collections import defaultdict
29-
from typing import Any, Callable, Self, TypeVar
29+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
3030

31-
from .state import ConnectionState
31+
from typing_extensions import Self
32+
33+
if TYPE_CHECKING:
34+
from .state import ConnectionState
3235

3336
T = TypeVar("T", bound="Event")
3437

@@ -38,11 +41,11 @@ class Event(ABC):
3841

3942
@classmethod
4043
@abstractmethod
41-
async def __load__(cls, data: Any, state: ConnectionState) -> Self | None: ...
44+
async def __load__(cls, data: Any, state: "ConnectionState") -> Self | None: ...
4245

4346

4447
class EventEmitter:
45-
def __init__(self, state: ConnectionState) -> None:
48+
def __init__(self, state: "ConnectionState") -> None:
4649
self._listeners: dict[type[Event], list[Callable]] = {}
4750
self._events: dict[str, list[type[Event]]]
4851
self._wait_fors: dict[type[Event], list[Future]] = defaultdict(list)

discord/app/state.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(
238238
self._activity: ActivityPayload | None = activity
239239
self._status: str | None = status
240240
self._intents: Intents = intents
241+
self._voice_clients: dict[int, VoiceClient] = {}
241242

242243
if not intents.members or cache_flags._empty:
243244
self.store_user = self.create_user # type: ignore
@@ -247,12 +248,13 @@ def __init__(
247248

248249
self.emitter = EventEmitter(self)
249250

250-
self.cache: Cache = self.cache
251+
self.cache: Cache = cache
252+
self.cache._state = self
251253

252254
async def clear(self, *, views: bool = True) -> None:
253255
self.user: ClientUser | None = None
254256
await self.cache.clear()
255-
self._voice_clients: dict[int, VoiceClient] = {}
257+
self._voice_clients = {}
256258

257259
async def process_chunk_requests(
258260
self, guild_id: int, nonce: str | None, members: list[Member], complete: bool

discord/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from . import utils
4141
from .activity import ActivityTypes, BaseActivity, create_activity
42+
from .app.cache import Cache, MemoryCache
4243
from .app.state import ConnectionState
4344
from .appinfo import AppInfo, PartialAppInfo
4445
from .application_role_connection import ApplicationRoleConnectionMetadata
@@ -309,6 +310,7 @@ def _get_state(self, **options: Any) -> ConnectionState:
309310
hooks=self._hooks,
310311
http=self.http,
311312
loop=self.loop,
313+
cache=MemoryCache(),
312314
**options,
313315
)
314316

@@ -1022,7 +1024,7 @@ async def get_stage_instance(self, id: int, /) -> StageInstance | None:
10221024
Optional[:class:`.StageInstance`]
10231025
The stage instance or ``None`` if not found.
10241026
"""
1025-
from .channel import StageChannel # noqa: PLC0415
1027+
from .channel import StageChannel
10261028

10271029
channel = await self._connection.get_channel(id)
10281030

discord/events/guild.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import asyncio
2626
import copy
2727
import logging
28-
from typing import TYPE_CHECKING, Any, Self
28+
from typing import TYPE_CHECKING, Any
29+
30+
from typing_extensions import Self
2931

3032
from discord import Role
3133
from discord.app.event_emitter import Event

discord/guild.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@
3636
List,
3737
NamedTuple,
3838
Optional,
39-
Self,
4039
Sequence,
4140
Tuple,
4241
Union,
4342
cast,
4443
overload,
4544
)
4645

46+
from typing_extensions import Self
47+
4748
from . import abc, utils
4849
from .asset import Asset
4950
from .automod import AutoModAction, AutoModRule, AutoModTriggerMetadata
@@ -2565,7 +2566,7 @@ async def templates(self) -> list[Template]:
25652566
Forbidden
25662567
You don't have permissions to get the templates.
25672568
"""
2568-
from .template import Template # noqa: PLC0415
2569+
from .template import Template
25692570

25702571
data = await self._state.http.guild_templates(self.id)
25712572
return [await Template.from_data(data=d, state=self._state) for d in data]
@@ -2588,7 +2589,7 @@ async def webhooks(self) -> list[Webhook]:
25882589
You don't have permissions to get the webhooks.
25892590
"""
25902591

2591-
from .webhook import Webhook # noqa: PLC0415
2592+
from .webhook import Webhook # circular import
25922593

25932594
data = await self._state.http.guild_webhooks(self.id)
25942595
return [Webhook.from_state(d, state=self._state) for d in data]
@@ -2678,7 +2679,7 @@ async def create_template(self, *, name: str, description: str | utils.Undefined
26782679
description: :class:`str`
26792680
The description of the template.
26802681
"""
2681-
from .template import Template # noqa: PLC0415
2682+
from .template import Template # circular import
26822683

26832684
payload = {"name": name}
26842685

discord/member.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
import itertools
3131
import sys
3232
from operator import attrgetter
33-
from typing import TYPE_CHECKING, Any, Self, TypeVar, Union
33+
from typing import TYPE_CHECKING, Any, TypeVar, Union
34+
35+
from typing_extensions import Self
3436

3537
import discord.abc
3638

discord/message.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
Any,
3636
Callable,
3737
ClassVar,
38-
Self,
3938
Sequence,
4039
TypeVar,
4140
Union,
4241
overload,
4342
)
4443
from urllib.parse import parse_qs, urlparse
4544

45+
from typing_extensions import Self
46+
4647
from . import utils
4748
from .channel import PartialMessageable
4849
from .components import _component_factory

discord/raw_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
from __future__ import annotations
2727

2828
import datetime
29-
from typing import TYPE_CHECKING, Self
29+
from typing import TYPE_CHECKING
30+
31+
from typing_extensions import Self
3032

3133
from .automod import AutoModAction, AutoModTriggerType
3234
from .enums import (

discord/shard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from .app.state import AutoShardedConnectionState
3535
from .backoff import ExponentialBackoff
36+
from .client import Client
3637
from .enums import Status
3738
from .errors import (
3839
ClientException,
@@ -48,7 +49,6 @@
4849
from .gateway import DiscordWebSocket
4950

5051
EI = TypeVar("EI", bound="EventItem")
51-
5252
__all__ = (
5353
"AutoShardedClient",
5454
"ShardInfo",

0 commit comments

Comments
 (0)