From ce85696f6d1636aeea35ad18131138a756a6cea1 Mon Sep 17 00:00:00 2001 From: Greg Pstrucha <875316+gricha@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:23:17 -0700 Subject: [PATCH 1/4] Add headers and hooks to taskbroker client --- clients/python/src/taskbroker_client/app.py | 5 ++++- .../python/src/taskbroker_client/registry.py | 7 ++++++- clients/python/src/taskbroker_client/task.py | 3 +++ clients/python/src/taskbroker_client/types.py | 15 +++++++++++++++ .../src/taskbroker_client/worker/workerchild.py | 17 ++++++++++++----- 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/clients/python/src/taskbroker_client/app.py b/clients/python/src/taskbroker_client/app.py index 4c2a6207..5f90add6 100644 --- a/clients/python/src/taskbroker_client/app.py +++ b/clients/python/src/taskbroker_client/app.py @@ -12,7 +12,7 @@ from taskbroker_client.retry import Retry from taskbroker_client.router import TaskRouter from taskbroker_client.task import Task -from taskbroker_client.types import AtMostOnceStore, ProducerFactory +from taskbroker_client.types import AtMostOnceStore, ContextHook, ProducerFactory class TaskbrokerApp: @@ -27,9 +27,11 @@ def __init__( router_class: str | TaskRouter = "taskbroker_client.router.DefaultRouter", metrics_class: str | MetricsBackend = "taskbroker_client.metrics.NoOpMetricsBackend", at_most_once_store: AtMostOnceStore | None = None, + context_hooks: list[ContextHook] | None = None, ) -> None: self.name = name self.metrics = self._build_metrics(metrics_class) + self.context_hooks: list[ContextHook] = context_hooks or [] self._config = { "rpc_secret": None, "grpc_config": None, @@ -41,6 +43,7 @@ def __init__( producer_factory=producer_factory, router=self._build_router(router_class), metrics=self.metrics, + context_hooks=self.context_hooks, ) self.at_most_once_store(at_most_once_store) diff --git a/clients/python/src/taskbroker_client/registry.py b/clients/python/src/taskbroker_client/registry.py index 6f505cf2..70f50e06 100644 --- a/clients/python/src/taskbroker_client/registry.py +++ b/clients/python/src/taskbroker_client/registry.py @@ -17,7 +17,7 @@ from taskbroker_client.retry import Retry from taskbroker_client.router import TaskRouter from taskbroker_client.task import ExternalTask, P, R, Task -from taskbroker_client.types import ProducerFactory, ProducerProtocol +from taskbroker_client.types import ContextHook, ProducerFactory, ProducerProtocol logger = logging.getLogger(__name__) @@ -42,6 +42,7 @@ def __init__( expires: int | datetime.timedelta | None = None, processing_deadline_duration: int = DEFAULT_PROCESSING_DEADLINE, app_feature: str | None = None, + context_hooks: list[ContextHook] | None = None, ): self.name = name self.application = application @@ -50,6 +51,7 @@ def __init__( self.default_expires = expires # seconds self.default_processing_deadline_duration = processing_deadline_duration # seconds self.app_feature = app_feature or name + self.context_hooks: list[ContextHook] = context_hooks or [] self._registered_tasks: dict[str, Task[Any, Any]] = {} self._producers: dict[str, ProducerProtocol] = {} self._producer_factory = producer_factory @@ -289,6 +291,7 @@ def __init__( producer_factory: ProducerFactory, router: TaskRouter, metrics: MetricsBackend, + context_hooks: list[ContextHook] | None = None, ) -> None: self._application = application self._namespaces: dict[str, TaskNamespace] = {} @@ -296,6 +299,7 @@ def __init__( self._producer_factory = producer_factory self._router = router self._metrics = metrics + self._context_hooks: list[ContextHook] = context_hooks or [] def contains(self, name: str) -> bool: return name in self._namespaces @@ -339,6 +343,7 @@ def create_namespace( expires=expires, processing_deadline_duration=processing_deadline_duration, app_feature=app_feature, + context_hooks=self._context_hooks, ) self._namespaces[name] = namespace diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index 5ec58e43..5dfa9ead 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -179,6 +179,9 @@ def create_activation( **headers, } + for hook in self._namespace.context_hooks: + hook.on_dispatch(headers) + # Monitor config is patched in by the sentry_sdk # however, taskworkers do not support the nested object, # nor do they use it when creating checkins. diff --git a/clients/python/src/taskbroker_client/types.py b/clients/python/src/taskbroker_client/types.py index 6ab5058e..ce00f5fd 100644 --- a/clients/python/src/taskbroker_client/types.py +++ b/clients/python/src/taskbroker_client/types.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses from typing import Callable, Protocol @@ -7,6 +8,20 @@ from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation, TaskActivationStatus +class ContextHook(Protocol): + """ + Hook for propagating application context through task headers. + + on_dispatch: called at task creation time to inject context into headers. + on_execute: called at task execution time, returns a context manager + that restores context from headers for the duration of the task. + """ + + def on_dispatch(self, headers: dict[str, str]) -> None: ... + + def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: ... + + class AtMostOnceStore(Protocol): """ Interface for the at_most_once store used for idempotent task execution. diff --git a/clients/python/src/taskbroker_client/worker/workerchild.py b/clients/python/src/taskbroker_client/worker/workerchild.py index b35a346c..2bb5560e 100644 --- a/clients/python/src/taskbroker_client/worker/workerchild.py +++ b/clients/python/src/taskbroker_client/worker/workerchild.py @@ -6,7 +6,7 @@ import queue import signal import time -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Sequence from multiprocessing.synchronize import Event from types import FrameType from typing import Any @@ -30,7 +30,7 @@ from taskbroker_client.retry import NoRetriesRemainingError from taskbroker_client.state import clear_current_task, current_task, set_current_task from taskbroker_client.task import Task -from taskbroker_client.types import InflightTaskActivation, ProcessingResult +from taskbroker_client.types import ContextHook, InflightTaskActivation, ProcessingResult logger = logging.getLogger(__name__) @@ -226,7 +226,7 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None: execution_start_time = time.time() try: with timeout_alarm(inflight.activation.processing_deadline_duration, handle_alarm): - _execute_activation(task_func, inflight.activation) + _execute_activation(task_func, inflight.activation, app.context_hooks) next_state = TASK_ACTIVATION_STATUS_COMPLETE except ProcessingDeadlineExceeded as err: with sentry_sdk.isolation_scope() as scope: @@ -308,7 +308,11 @@ def handle_alarm(signum: int, frame: FrameType | None) -> None: inflight.host, ) - def _execute_activation(task_func: Task[Any, Any], activation: TaskActivation) -> None: + def _execute_activation( + task_func: Task[Any, Any], + activation: TaskActivation, + context_hooks: Sequence[ContextHook] = (), + ) -> None: """Invoke a task function with the activation parameters.""" headers = {k: v for k, v in activation.headers.items()} parameters = load_parameters(activation.parameters, headers) @@ -363,7 +367,10 @@ def _execute_activation(task_func: Task[Any, Any], activation: TaskActivation) - kwargs.pop("__start_time") try: - task_func(*args, **kwargs) + with contextlib.ExitStack() as stack: + for hook in context_hooks: + stack.enter_context(hook.on_execute(headers)) + task_func(*args, **kwargs) transaction.set_status(SPANSTATUS.OK) except Exception: transaction.set_status(SPANSTATUS.INTERNAL_ERROR) From 98839137378493bc4e9400114a5e25ebd193c77e Mon Sep 17 00:00:00 2001 From: Greg Pstrucha <875316+gricha@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:36:57 -0700 Subject: [PATCH 2/4] fix: Pass context_hooks to ExternalNamespace and fix formatting ExternalNamespace was missing context_hooks, so on_dispatch hooks would not fire for tasks dispatched to external applications. Also fixes black formatting issues. Co-Authored-By: Claude Opus 4.6 --- clients/python/src/taskbroker_client/registry.py | 3 ++- clients/python/src/taskbroker_client/task.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/clients/python/src/taskbroker_client/registry.py b/clients/python/src/taskbroker_client/registry.py index 70f50e06..686b1c5a 100644 --- a/clients/python/src/taskbroker_client/registry.py +++ b/clients/python/src/taskbroker_client/registry.py @@ -177,7 +177,7 @@ def send_task(self, activation: TaskActivation, wait_for_delivery: bool = False) ) # We know this type is futures.Future, but cannot assert so, # because it is also mock.Mock in tests. - produce_future.add_done_callback( # type:ignore[union-attr] + produce_future.add_done_callback( # type: ignore[union-attr] lambda future: self._handle_produce_future( future=future, tags={ @@ -376,6 +376,7 @@ def create_external_namespace( retry=retry, expires=expires, processing_deadline_duration=processing_deadline_duration, + context_hooks=self._context_hooks, ) self._external_namespaces[key] = namespace return namespace diff --git a/clients/python/src/taskbroker_client/task.py b/clients/python/src/taskbroker_client/task.py index 5dfa9ead..8688bf1f 100644 --- a/clients/python/src/taskbroker_client/task.py +++ b/clients/python/src/taskbroker_client/task.py @@ -60,13 +60,11 @@ def __init__( processing_deadline_duration or DEFAULT_PROCESSING_DEADLINE ) if at_most_once and retry: - raise AssertionError( - """ + raise AssertionError(""" You cannot enable at_most_once and have retries defined. Having retries enabled means that a task supports being executed multiple times and thus cannot be idempotent. - """ - ) + """) self._retry = retry self.at_most_once = at_most_once self.wait_for_delivery = wait_for_delivery From d9d90acded785fcb7e5a8a45fe061c6f639fdeef Mon Sep 17 00:00:00 2001 From: Greg Pstrucha <875316+gricha@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:42:57 -0700 Subject: [PATCH 3/4] fix: Widen on_dispatch header type to match MutableMapping[str, Any] create_activation passes headers as MutableMapping[str, Any], not dict[str, str]. Values are stringified after hooks run. Co-Authored-By: Claude Opus 4.6 --- clients/python/src/taskbroker_client/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/clients/python/src/taskbroker_client/types.py b/clients/python/src/taskbroker_client/types.py index ce00f5fd..c4a3783f 100644 --- a/clients/python/src/taskbroker_client/types.py +++ b/clients/python/src/taskbroker_client/types.py @@ -1,6 +1,7 @@ import contextlib import dataclasses -from typing import Callable, Protocol +from collections.abc import MutableMapping +from typing import Any, Callable, Protocol from arroyo.backends.abstract import ProducerFuture from arroyo.backends.kafka import KafkaPayload @@ -17,7 +18,7 @@ class ContextHook(Protocol): that restores context from headers for the duration of the task. """ - def on_dispatch(self, headers: dict[str, str]) -> None: ... + def on_dispatch(self, headers: MutableMapping[str, Any]) -> None: ... def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: ... From 32fcf31e2366853b471c3e52436225ed983f5e96 Mon Sep 17 00:00:00 2001 From: Greg Pstrucha <875316+gricha@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:46:36 -0700 Subject: [PATCH 4/4] test: Add tests for context hook dispatch and execution Tests verify: - on_dispatch injects headers during create_activation - No headers injected when no hooks registered - Multiple hooks all get called - on_execute receives activation headers during child_process execution Co-Authored-By: Claude Opus 4.6 --- clients/python/tests/test_task.py | 84 ++++++++++++++++++++++ clients/python/tests/worker/test_worker.py | 56 +++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/clients/python/tests/test_task.py b/clients/python/tests/test_task.py index 264f03fc..c74c44a8 100644 --- a/clients/python/tests/test_task.py +++ b/clients/python/tests/test_task.py @@ -1,4 +1,6 @@ +import contextlib import datetime +from collections.abc import MutableMapping from typing import Any from unittest.mock import patch @@ -375,3 +377,85 @@ def with_parameters(one: str, two: int, org_id: int) -> None: assert "sentry-monitor-config" not in result assert "sentry-monitor-slug" in result assert "sentry-monitor-check-in-id" in result + + +class StubContextHook: + """Test hook that writes/reads a simple key.""" + + def on_dispatch(self, headers: MutableMapping[str, Any]) -> None: + headers["x-test-context"] = "dispatched" + + def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: + if "x-test-context" not in headers: + return contextlib.nullcontext() + # Store the value so the test can verify it was called + StubContextHook.last_executed = headers["x-test-context"] + return contextlib.nullcontext() + + +def test_context_hook_on_dispatch() -> None: + """Context hooks inject headers during create_activation.""" + ns = TaskNamespace( + name="tests", + application="acme", + producer_factory=producer_factory, + router=DefaultRouter(), + metrics=NoOpMetricsBackend(), + retry=None, + context_hooks=[StubContextHook()], + ) + + @ns.register(name="test.hooked") + def hooked_task() -> None: + pass + + activation = hooked_task.create_activation([], {}) + assert activation.headers["x-test-context"] == "dispatched" + + +def test_context_hook_not_present_without_hooks() -> None: + """Without hooks, no extra headers are injected.""" + ns = TaskNamespace( + name="tests", + application="acme", + producer_factory=producer_factory, + router=DefaultRouter(), + metrics=NoOpMetricsBackend(), + retry=None, + ) + + @ns.register(name="test.no_hooks") + def no_hooks_task() -> None: + pass + + activation = no_hooks_task.create_activation([], {}) + assert "x-test-context" not in activation.headers + + +def test_context_hook_multiple_hooks() -> None: + """Multiple hooks all get called.""" + + class AnotherHook: + def on_dispatch(self, headers: MutableMapping[str, Any]) -> None: + headers["x-another"] = "also-here" + + def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: + return contextlib.nullcontext() + + ns = TaskNamespace( + name="tests", + application="acme", + producer_factory=producer_factory, + router=DefaultRouter(), + metrics=NoOpMetricsBackend(), + retry=None, + context_hooks=[StubContextHook(), AnotherHook()], + ) + + @ns.register(name="test.multi_hooks") + def multi_task() -> None: + pass + + activation = multi_task.create_activation([], {}) + assert activation.headers["x-test-context"] == "dispatched" + assert activation.headers["x-another"] == "also-here" diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 5e41361e..f5dbc5d6 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -1,4 +1,5 @@ import base64 +import contextlib import queue import time from multiprocessing import Event @@ -761,3 +762,58 @@ def test_child_process_decompression(mock_capture_checkin: mock.MagicMock) -> No assert result.task_id == COMPRESSED_TASK.activation.id assert result.status == TASK_ACTIVATION_STATUS_COMPLETE assert mock_capture_checkin.call_count == 0 + + +def test_child_process_context_hooks() -> None: + """Context hooks' on_execute is called with activation headers during task execution.""" + executed_headers: list[dict[str, str]] = [] + + class RecordingHook: + def on_dispatch(self, headers: dict[str, Any]) -> None: + pass + + def on_execute(self, headers: dict[str, str]) -> contextlib.AbstractContextManager[None]: + executed_headers.append(dict(headers)) + return contextlib.nullcontext() + + from examples.app import app + + hook = RecordingHook() + app.context_hooks.append(hook) + + try: + activation_with_headers = InflightTaskActivation( + host="localhost:50051", + receive_timestamp=0, + activation=TaskActivation( + id="hook-test", + taskname="examples.simple_task", + namespace="examples", + parameters='{"args": [], "kwargs": {}}', + headers={"x-viewer-org": "42", "x-viewer-user": "7"}, + processing_deadline_duration=5, + ), + ) + + todo: queue.Queue[InflightTaskActivation] = queue.Queue() + processed: queue.Queue[ProcessingResult] = queue.Queue() + shutdown = Event() + + todo.put(activation_with_headers) + child_process( + "examples.app:app", + todo, + processed, + shutdown, + max_task_count=1, + processing_pool_name="test", + process_type="fork", + ) + + result = processed.get() + assert result.status == TASK_ACTIVATION_STATUS_COMPLETE + assert len(executed_headers) == 1 + assert executed_headers[0]["x-viewer-org"] == "42" + assert executed_headers[0]["x-viewer-user"] == "7" + finally: + app.context_hooks.remove(hook)