Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ jobs:
- name: Set up PostgreSQL
uses: ikalnytskyi/action-setup-postgres@v8
with:
username: postgres
password: postgres
database: taskiqpsqlpy
username: taskiq_psqlpy
password: look_in_vault
database: taskiq_psqlpy
id: postgres
- name: Set up uv and enable cache
id: setup-uv
Expand Down
10 changes: 10 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
services:
postgres:
container_name: taskiq_psqlpy
image: postgres:18
environment:
POSTGRES_DB: taskiq_psqlpy
POSTGRES_USER: taskiq_psqlpy
POSTGRES_PASSWORD: look_in_vault
ports:
- "5432:5432"
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ dev = [
{include-group = "lint"},
{include-group = "test"},
"pre-commit>=4.5.0",
"anyio>=4.12.0",
]
test = [
"pytest>=9.0.1",
"pytest-cov>=7.0.0",
"pytest-env>=1.2.0",
"pytest-xdist>=3.8.0",
"pytest-asyncio>=1.3.0",
"polyfactory>=3.1.0",
"sqlalchemy-utils>=0.42.1",
]
lint = [
"black>=25.11.0",
Expand Down Expand Up @@ -80,7 +82,7 @@ module-root = ""
module-name = "taskiq_psqlpy"

[tool.ruff]
line-length = 88
line-length = 120

[tool.ruff.lint]
# List of enabled rulsets.
Expand Down Expand Up @@ -147,3 +149,14 @@ allow-magic-value-types = ["int", "str", "float"]

[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = ["taskiq_dependencies.Depends", "taskiq.TaskiqDepends"]

[tool.pytest.ini_options]
pythonpath = [
"."
]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
markers = [
"unit: marks unit tests",
"integration: marks tests with real infrastructure env",
]
4 changes: 4 additions & 0 deletions taskiq_psqlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from taskiq_psqlpy.broker import PSQLPyBroker
from taskiq_psqlpy.result_backend import PSQLPyResultBackend
from taskiq_psqlpy.schedule_source import PSQLPyScheduleSource

__all__ = [
"PSQLPyBroker",
"PSQLPyResultBackend",
"PSQLPyScheduleSource",
]
225 changes: 225 additions & 0 deletions taskiq_psqlpy/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import asyncio
import logging
import typing as tp
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from datetime import datetime

import psqlpy
from psqlpy.exceptions import ConnectionExecuteError
from psqlpy.extra_types import JSONB
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage

from taskiq_psqlpy.queries import (
CLAIM_MESSAGE_QUERY,
CREATE_MESSAGE_TABLE_QUERY,
DELETE_MESSAGE_QUERY,
INSERT_MESSAGE_QUERY,
)

logger = logging.getLogger("taskiq.psqlpy_broker")
_T = tp.TypeVar("_T")


@dataclass
class MessageRow:
"""Message in db table."""

id: int
task_id: str
task_name: str
message: str
labels: JSONB
status: str
created_at: datetime


class PSQLPyBroker(AsyncBroker):
"""Broker that uses PostgreSQL and PSQLPy with LISTEN/NOTIFY."""

_read_conn: psqlpy.Connection
_write_pool: psqlpy.ConnectionPool
_listener: psqlpy.Listener

def __init__(
self,
dsn: (
str | tp.Callable[[], str]
) = "postgresql://taskiq_psqlpy:look_in_vault@localhost:5432/taskiq_psqlpy",
result_backend: AsyncResultBackend[_T] | None = None,
task_id_generator: tp.Callable[[], str] | None = None,
channel_name: str = "taskiq",
table_name: str = "taskiq_messages",
max_retry_attempts: int = 5,
read_kwargs: dict[str, tp.Any] | None = None,
write_kwargs: dict[str, tp.Any] | None = None,
) -> None:
"""
Construct a new broker.
Args:
dsn: connection string to PostgreSQL, or callable returning one.
result_backend: Custom result backend.
task_id_generator: Custom task_id generator.
channel_name: Name of the channel to listen on.
table_name: Name of the table to store messages.
max_retry_attempts: Maximum number of message processing attempts.
read_kwargs: Additional arguments for read connection creation.
write_kwargs: Additional arguments for write pool creation.
"""
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)
self._dsn: str | tp.Callable[[], str] = dsn
self.channel_name: str = channel_name
self.table_name: str = table_name
self.read_kwargs: dict[str, tp.Any] = read_kwargs or {}
self.write_kwargs: dict[str, tp.Any] = write_kwargs or {}
self.max_retry_attempts: int = max_retry_attempts
self._queue: asyncio.Queue[str] | None = None

@property
def dsn(self) -> str:
"""
Get the DSN string.
Returns:
A string with dsn or None if dsn isn't set yet.
"""
if callable(self._dsn):
return self._dsn()
return self._dsn

async def startup(self) -> None:
"""Initialize the broker."""
await super().startup()
self._read_conn = await psqlpy.connect(
dsn=self.dsn,
**self.read_kwargs,
)
self._write_pool = psqlpy.ConnectionPool(
dsn=self.dsn,
**self.write_kwargs,
)

# create messages table if it doesn't exist
async with self._write_pool.acquire() as conn:
await conn.execute(CREATE_MESSAGE_TABLE_QUERY.format(self.table_name))

# listen to notification channel
self._listener = self._write_pool.listener()
await self._listener.add_callback(self.channel_name, self._notification_handler)
await self._listener.startup()
self._listener.listen()

self._queue = asyncio.Queue()

async def shutdown(self) -> None:
"""Close all connections on shutdown."""
await super().shutdown()
if self._read_conn is not None:
self._read_conn.close()
if self._write_pool is not None:
self._write_pool.close()
if self._listener is not None:
self._listener.abort_listen()
await self._listener.shutdown()

async def _notification_handler(
self,
connection: psqlpy.Connection,
payload: str,
channel: str,
process_id: int,
) -> None:
"""
Handle NOTIFY messages.
https://psqlpy-python.github.io/components/listener.html#usage
"""
logger.debug("Received notification on channel %s: %s", channel, payload)
if self._queue is not None:
self._queue.put_nowait(payload)

async def kick(self, message: BrokerMessage) -> None:
"""
Send message to the channel.
Inserts the message into the database and sends a NOTIFY.
:param message: Message to send.
"""
async with self._write_pool.acquire() as conn:
# insert message into db table
message_inserted_id = tp.cast(
"int",
await conn.fetch_val(
INSERT_MESSAGE_QUERY.format(self.table_name),
[
message.task_id,
message.task_name,
message.message.decode(),
JSONB(message.labels),
],
),
)

delay_value = tp.cast("str | None", message.labels.get("delay"))
if delay_value is not None:
delay_seconds = int(delay_value)
asyncio.create_task( # noqa: RUF006
self._schedule_notification(message_inserted_id, delay_seconds),
)
else:
# Send NOTIFY with message ID as payload
_ = await conn.execute(
f"NOTIFY {self.channel_name}, '{message_inserted_id}'",
)

async def _schedule_notification(self, message_id: int, delay_seconds: int) -> None:
"""Schedule a notification to be sent after a delay."""
await asyncio.sleep(delay_seconds)
async with self._write_pool.acquire() as conn:
# Send NOTIFY with message ID as payload
_ = await conn.execute(f"NOTIFY {self.channel_name}, '{message_id}'")

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
"""
Listen to the channel.
Yields messages as they are received.
:yields: AckableMessage instances.
"""
while True:
try:
payload = await self._queue.get() # type: ignore[union-attr]
message_id = int(payload) # payload is the message id
try:
async with self._write_pool.acquire() as conn:
claimed_message = await conn.fetch_row(
CLAIM_MESSAGE_QUERY.format(self.table_name),
[message_id],
)
except ConnectionExecuteError: # message was claimed by another worker
continue
message_row_result = tp.cast(
"MessageRow",
tp.cast("object", claimed_message.as_class(MessageRow)),
)
message_data = message_row_result.message.encode()

async def ack(*, _message_id: int = message_id) -> None:
async with self._write_pool.acquire() as conn:
_ = await conn.execute(
DELETE_MESSAGE_QUERY.format(self.table_name),
[_message_id],
)

yield AckableMessage(data=message_data, ack=ack)
except Exception:
logger.exception("Error processing message")
continue
54 changes: 54 additions & 0 deletions taskiq_psqlpy/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,57 @@
DELETE_RESULT_QUERY = """
DELETE FROM {} WHERE task_id = $1
"""

CREATE_SCHEDULES_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS {} (
id UUID PRIMARY KEY,
task_name VARCHAR(100) NOT NULL,
schedule JSONB NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""

INSERT_SCHEDULE_QUERY = """
INSERT INTO {} (id, task_name, schedule)
VALUES ($1, $2, $3)
ON CONFLICT (id) DO UPDATE
SET task_name = EXCLUDED.task_name,
schedule = EXCLUDED.schedule,
updated_at = NOW();
"""

SELECT_SCHEDULES_QUERY = """
SELECT id, task_name, schedule
FROM {};
"""

DELETE_ALL_SCHEDULES_QUERY = """
DELETE FROM {};
"""

DELETE_SCHEDULE_QUERY = """
DELETE FROM {} WHERE id = $1;
"""

CREATE_MESSAGE_TABLE_QUERY = """
CREATE TABLE IF NOT EXISTS {} (
id SERIAL PRIMARY KEY,
task_id VARCHAR NOT NULL,
task_name VARCHAR NOT NULL,
message TEXT NOT NULL,
labels JSONB NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""

INSERT_MESSAGE_QUERY = """
INSERT INTO {} (task_id, task_name, message, labels)
VALUES ($1, $2, $3, $4)
RETURNING id
"""

CLAIM_MESSAGE_QUERY = "UPDATE {} SET status = 'processing' WHERE id = $1 AND status = 'pending' RETURNING *"

DELETE_MESSAGE_QUERY = "DELETE FROM {} WHERE id = $1"
4 changes: 3 additions & 1 deletion taskiq_psqlpy/result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class PSQLPyResultBackend(AsyncResultBackend[_ReturnType]):

def __init__(
self,
dsn: str | None = "postgres://postgres:postgres@localhost:5432/postgres",
dsn: (
str | None
) = "postgresql://taskiq_psqlpy:look_in_vault@localhost:5432/taskiq_psqlpy",
keep_results: bool = True,
table_name: str = "taskiq_results",
field_for_task_id: Literal["VarChar", "Text"] = "VarChar",
Expand Down
Loading