Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clients/python/src/taskbroker_client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from taskbroker_client.retry import Retry
from taskbroker_client.router import TaskRouter
from taskbroker_client.task import Task
from taskbroker_client.types import AtMostOnceStore, ProducerFactory
from taskbroker_client.types import AtMostOnceStore, ContextHook, ProducerFactory


class TaskbrokerApp:
Expand All @@ -27,9 +27,11 @@ def __init__(
router_class: str | TaskRouter = "taskbroker_client.router.DefaultRouter",
metrics_class: str | MetricsBackend = "taskbroker_client.metrics.NoOpMetricsBackend",
at_most_once_store: AtMostOnceStore | None = None,
context_hooks: list[ContextHook] | None = None,
) -> None:
self.name = name
self.metrics = self._build_metrics(metrics_class)
self.context_hooks: list[ContextHook] = context_hooks or []
self._config = {
"rpc_secret": None,
"grpc_config": None,
Expand All @@ -41,6 +43,7 @@ def __init__(
producer_factory=producer_factory,
router=self._build_router(router_class),
metrics=self.metrics,
context_hooks=self.context_hooks,
)
self.at_most_once_store(at_most_once_store)

Expand Down
10 changes: 8 additions & 2 deletions clients/python/src/taskbroker_client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from taskbroker_client.retry import Retry
from taskbroker_client.router import TaskRouter
from taskbroker_client.task import ExternalTask, P, R, Task
from taskbroker_client.types import ProducerFactory, ProducerProtocol
from taskbroker_client.types import ContextHook, ProducerFactory, ProducerProtocol

logger = logging.getLogger(__name__)

Expand All @@ -42,6 +42,7 @@ def __init__(
expires: int | datetime.timedelta | None = None,
processing_deadline_duration: int = DEFAULT_PROCESSING_DEADLINE,
app_feature: str | None = None,
context_hooks: list[ContextHook] | None = None,
):
self.name = name
self.application = application
Expand All @@ -50,6 +51,7 @@ def __init__(
self.default_expires = expires # seconds
self.default_processing_deadline_duration = processing_deadline_duration # seconds
self.app_feature = app_feature or name
self.context_hooks: list[ContextHook] = context_hooks or []
self._registered_tasks: dict[str, Task[Any, Any]] = {}
self._producers: dict[str, ProducerProtocol] = {}
self._producer_factory = producer_factory
Expand Down Expand Up @@ -175,7 +177,7 @@ def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False)
)
# We know this type is futures.Future, but cannot assert so,
# because it is also mock.Mock in tests.
produce_future.add_done_callback( # type:ignore[union-attr]
produce_future.add_done_callback( # type: ignore[union-attr]
lambda future: self._handle_produce_future(
future=future,
tags={
Expand Down Expand Up @@ -289,13 +291,15 @@ def __init__(
producer_factory: ProducerFactory,
router: TaskRouter,
metrics: MetricsBackend,
context_hooks: list[ContextHook] | None = None,
) -> None:
self._application = application
self._namespaces: dict[str, TaskNamespace] = {}
self._external_namespaces: dict[str, ExternalNamespace] = {}
self._producer_factory = producer_factory
self._router = router
self._metrics = metrics
self._context_hooks: list[ContextHook] = context_hooks or []
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Falsy or [] breaks hook list reference chain

Medium Severity

The context_hooks or [] pattern in TaskRegistry.__init__ and TaskNamespace.__init__ creates new disconnected empty lists when the incoming list is empty (since [] is falsy in Python). This breaks the reference chain from app.context_hooks through to namespace.context_hooks. The dispatch side reads hooks from self._namespace.context_hooks while the execution side reads from app.context_hooks — so any hooks appended to app.context_hooks after construction will be visible on execution but silently ignored on dispatch. Using if context_hooks is not None instead of the or pattern would preserve empty list references.

Additional Locations (1)
Fix in Cursor Fix in Web


def contains(self, name: str) -> bool:
return name in self._namespaces
Expand Down Expand Up @@ -339,6 +343,7 @@ def create_namespace(
expires=expires,
processing_deadline_duration=processing_deadline_duration,
app_feature=app_feature,
context_hooks=self._context_hooks,
)
self._namespaces[name] = namespace

Expand Down Expand Up @@ -371,6 +376,7 @@ def create_external_namespace(
retry=retry,
expires=expires,
processing_deadline_duration=processing_deadline_duration,
context_hooks=self._context_hooks,
)
self._external_namespaces[key] = namespace
return namespace
Expand Down
9 changes: 5 additions & 4 deletions clients/python/src/taskbroker_client/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,11 @@ def __init__(
processing_deadline_duration or DEFAULT_PROCESSING_DEADLINE
)
if at_most_once and retry:
raise AssertionError(
"""
raise AssertionError("""
You cannot enable at_most_once and have retries defined.
Having retries enabled means that a task supports being executed
multiple times and thus cannot be idempotent.
"""
)
""")
self._retry = retry
self.at_most_once = at_most_once
self.wait_for_delivery = wait_for_delivery
Expand Down Expand Up @@ -179,6 +177,9 @@ def create_activation(
**headers,
}

for hook in self._namespace.context_hooks:
hook.on_dispatch(headers)

# Monitor config is patched in by the sentry_sdk
# however, taskworkers do not support the nested object,
# nor do they use it when creating checkins.
Expand Down
18 changes: 17 additions & 1 deletion clients/python/src/taskbroker_client/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import contextlib
import dataclasses
from typing import Callable, Protocol
from collections.abc import MutableMapping
from typing import Any, Callable, Protocol

from arroyo.backends.abstract import ProducerFuture
from arroyo.backends.kafka import KafkaPayload
from arroyo.types import BrokerValue, Topic
from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation, TaskActivationStatus


class ContextHook(Protocol):
"""
Hook for propagating application context through task headers.

on_dispatch: called at task creation time to inject context into headers.
on_execute: called at task execution time, returns a context manager
that restores context from headers for the duration of the task.
"""

def on_dispatch(self, headers: MutableMapping[str, Any]) -> None: ...

def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: ...


class AtMostOnceStore(Protocol):
"""
Interface for the at_most_once store used for idempotent task execution.
Expand Down
17 changes: 12 additions & 5 deletions clients/python/src/taskbroker_client/worker/workerchild.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue
import signal
import time
from collections.abc import Callable, Generator
from collections.abc import Callable, Generator, Sequence
from multiprocessing.synchronize import Event
from types import FrameType
from typing import Any
Expand All @@ -30,7 +30,7 @@
from taskbroker_client.retry import NoRetriesRemainingError
from taskbroker_client.state import clear_current_task, current_task, set_current_task
from taskbroker_client.task import Task
from taskbroker_client.types import InflightTaskActivation, ProcessingResult
from taskbroker_client.types import ContextHook, InflightTaskActivation, ProcessingResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,7 +226,7 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None:
execution_start_time = time.time()
try:
with timeout_alarm(inflight.activation.processing_deadline_duration, handle_alarm):
_execute_activation(task_func, inflight.activation)
_execute_activation(task_func, inflight.activation, app.context_hooks)
next_state = TASK_ACTIVATION_STATUS_COMPLETE
except ProcessingDeadlineExceeded as err:
with sentry_sdk.isolation_scope() as scope:
Expand Down Expand Up @@ -308,7 +308,11 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None:
inflight.host,
)

def _execute_activation(task_func: Task[Any, Any], activation: TaskActivation) -> None:
def _execute_activation(
task_func: Task[Any, Any],
activation: TaskActivation,
context_hooks: Sequence[ContextHook] = (),
) -> None:
"""Invoke a task function with the activation parameters."""
headers = {k: v for k, v in activation.headers.items()}
parameters = load_parameters(activation.parameters, headers)
Expand Down Expand Up @@ -363,7 +367,10 @@ def _execute_activation(task_func: Task[Any, Any], activation: TaskActivation) -
kwargs.pop("__start_time")

try:
task_func(*args, **kwargs)
with contextlib.ExitStack() as stack:
for hook in context_hooks:
stack.enter_context(hook.on_execute(headers))
task_func(*args, **kwargs)
transaction.set_status(SPANSTATUS.OK)
except Exception:
transaction.set_status(SPANSTATUS.INTERNAL_ERROR)
Expand Down
84 changes: 84 additions & 0 deletions clients/python/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import datetime
from collections.abc import MutableMapping
from typing import Any
from unittest.mock import patch

Expand Down Expand Up @@ -375,3 +377,85 @@ def with_parameters(one: str, two: int, org_id: int) -> None:
assert "sentry-monitor-config" not in result
assert "sentry-monitor-slug" in result
assert "sentry-monitor-check-in-id" in result


class StubContextHook:
"""Test hook that writes/reads a simple key."""

def on_dispatch(self, headers: MutableMapping[str, Any]) -> None:
headers["x-test-context"] = "dispatched"

def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]:
if "x-test-context" not in headers:
return contextlib.nullcontext()
# Store the value so the test can verify it was called
StubContextHook.last_executed = headers["x-test-context"]
return contextlib.nullcontext()


def test_context_hook_on_dispatch() -> None:
"""Context hooks inject headers during create_activation."""
ns = TaskNamespace(
name="tests",
application="acme",
producer_factory=producer_factory,
router=DefaultRouter(),
metrics=NoOpMetricsBackend(),
retry=None,
context_hooks=[StubContextHook()],
)

@ns.register(name="test.hooked")
def hooked_task() -> None:
pass

activation = hooked_task.create_activation([], {})
assert activation.headers["x-test-context"] == "dispatched"


def test_context_hook_not_present_without_hooks() -> None:
"""Without hooks, no extra headers are injected."""
ns = TaskNamespace(
name="tests",
application="acme",
producer_factory=producer_factory,
router=DefaultRouter(),
metrics=NoOpMetricsBackend(),
retry=None,
)

@ns.register(name="test.no_hooks")
def no_hooks_task() -> None:
pass

activation = no_hooks_task.create_activation([], {})
assert "x-test-context" not in activation.headers


def test_context_hook_multiple_hooks() -> None:
"""Multiple hooks all get called."""

class AnotherHook:
def on_dispatch(self, headers: MutableMapping[str, Any]) -> None:
headers["x-another"] = "also-here"

def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]:
return contextlib.nullcontext()

ns = TaskNamespace(
name="tests",
application="acme",
producer_factory=producer_factory,
router=DefaultRouter(),
metrics=NoOpMetricsBackend(),
retry=None,
context_hooks=[StubContextHook(), AnotherHook()],
)

@ns.register(name="test.multi_hooks")
def multi_task() -> None:
pass

activation = multi_task.create_activation([], {})
assert activation.headers["x-test-context"] == "dispatched"
assert activation.headers["x-another"] == "also-here"
56 changes: 56 additions & 0 deletions clients/python/tests/worker/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import contextlib
import queue
import time
from multiprocessing import Event
Expand Down Expand Up @@ -761,3 +762,58 @@ def test_child_process_decompression(mock_capture_checkin: mock.MagicMock) -> No
assert result.task_id == COMPRESSED_TASK.activation.id
assert result.status == TASK_ACTIVATION_STATUS_COMPLETE
assert mock_capture_checkin.call_count == 0


def test_child_process_context_hooks() -> None:
"""Context hooks' on_execute is called with activation headers during task execution."""
executed_headers: list[dict[str, str]] = []

class RecordingHook:
def on_dispatch(self, headers: dict[str, Any]) -> None:
pass

def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]:
executed_headers.append(dict(headers))
return contextlib.nullcontext()

from examples.app import app

hook = RecordingHook()
app.context_hooks.append(hook)

try:
activation_with_headers = InflightTaskActivation(
host="localhost:50051",
receive_timestamp=0,
activation=TaskActivation(
id="hook-test",
taskname="examples.simple_task",
namespace="examples",
parameters='{"args": [], "kwargs": {}}',
headers={"x-viewer-org": "42", "x-viewer-user": "7"},
processing_deadline_duration=5,
),
)

todo: queue.Queue[InflightTaskActivation] = queue.Queue()
processed: queue.Queue[ProcessingResult] = queue.Queue()
shutdown = Event()

todo.put(activation_with_headers)
child_process(
"examples.app:app",
todo,
processed,
shutdown,
max_task_count=1,
processing_pool_name="test",
process_type="fork",
)

result = processed.get()
assert result.status == TASK_ACTIVATION_STATUS_COMPLETE
assert len(executed_headers) == 1
assert executed_headers[0]["x-viewer-org"] == "42"
assert executed_headers[0]["x-viewer-user"] == "7"
finally:
app.context_hooks.remove(hook)
Loading