Skip to content

Commit b8eafbc

Browse files
authored
Merge pull request #31 from taskbadger/sk/record-args
record task args
2 parents d93ae1d + dc0ca58 commit b8eafbc

File tree

4 files changed

+198
-2
lines changed

4 files changed

+198
-2
lines changed

taskbadger/celery.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import functools
3+
import json
34
import logging
45

56
import celery
@@ -10,6 +11,7 @@
1011
task_retry,
1112
task_success,
1213
)
14+
from kombu import serialization
1315

1416
from .internal.models import StatusEnum
1517
from .mug import Badger
@@ -18,7 +20,7 @@
1820

1921
KWARG_PREFIX = "taskbadger_"
2022
TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs"
21-
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id"}
23+
IGNORE_ARGS = {TB_KWARGS_ARG, f"{KWARG_PREFIX}task", f"{KWARG_PREFIX}task_id", f"{KWARG_PREFIX}record_task_args"}
2224
TB_TASK_ID = f"{KWARG_PREFIX}task_id"
2325

2426
TERMINAL_STATES = {
@@ -124,6 +126,8 @@ def apply_async(self, *args, **kwargs):
124126
if Badger.is_configured():
125127
headers["taskbadger_track"] = True
126128
headers[TB_KWARGS_ARG] = tb_kwargs
129+
if "record_task_args" in tb_kwargs:
130+
headers["taskbadger_record_task_args"] = tb_kwargs.pop("record_task_args")
127131

128132
result = super().apply_async(*args, **kwargs)
129133

@@ -187,6 +191,20 @@ def task_publish_handler(sender=None, headers=None, body=None, **kwargs):
187191
kwargs["status"] = StatusEnum.PENDING
188192
name = kwargs.pop("name", headers["task"])
189193

194+
global_record_task_args = celery_system and celery_system.record_task_args
195+
if headers.get("taskbadger_record_task_args", global_record_task_args):
196+
data = {
197+
"celery_task_args": body[0],
198+
"celery_task_kwargs": body[1],
199+
}
200+
try:
201+
_, _, value = serialization.dumps(data, serializer="json")
202+
data = json.loads(value)
203+
except Exception:
204+
log.error("Error serializing task arguments for task '%s'", name)
205+
else:
206+
kwargs.setdefault("data", {}).update(data)
207+
190208
task = create_task_safe(name, **kwargs)
191209
if task:
192210
meta = {TB_TASK_ID: task.id}

taskbadger/systems/celery.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class CelerySystemIntegration(System):
77
identifier = "celery"
88

9-
def __init__(self, auto_track_tasks=True, includes=None, excludes=None):
9+
def __init__(self, auto_track_tasks=True, includes=None, excludes=None, record_task_args=False):
1010
"""
1111
Args:
1212
auto_track_tasks: Automatically track all Celery tasks regardless of whether they are using the
@@ -16,10 +16,12 @@ def __init__(self, auto_track_tasks=True, includes=None, excludes=None):
1616
matches both an include and an exclude, it will be excluded.
1717
excludes: A list of task names to exclude from tracking. As with `includes`, these can be either
1818
the full task name or a regular expression. Exclusions take precedence over inclusions.
19+
record_task_args: Record the arguments passed to each task.
1920
"""
2021
self.auto_track_tasks = auto_track_tasks
2122
self.includes = includes
2223
self.excludes = excludes
24+
self.record_task_args = record_task_args
2325

2426
if auto_track_tasks:
2527
# Importing this here ensures that the Celery signal handlers are registered

tests/test_celery.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import celery
1515
import pytest
16+
from kombu.utils.json import register_type
1617

1718
from taskbadger import Action, EmailIntegration, StatusEnum
1819
from taskbadger.celery import Task
@@ -111,6 +112,109 @@ def add_with_task_args(self, a, b):
111112
create.assert_called_once_with("new_name", value_max=10, actions=actions, status=StatusEnum.PENDING)
112113

113114

115+
def test_celery_record_args(celery_session_app, celery_session_worker, bind_settings):
116+
@celery_session_app.task(bind=True, base=Task)
117+
def add_with_task_args(self, a, b):
118+
assert self.taskbadger_task is not None
119+
return a + b
120+
121+
celery_session_worker.reload()
122+
123+
with (
124+
mock.patch("taskbadger.celery.create_task_safe") as create,
125+
mock.patch("taskbadger.celery.update_task_safe"),
126+
mock.patch("taskbadger.celery.get_task"),
127+
):
128+
create.return_value = task_for_test()
129+
130+
result = add_with_task_args.apply_async(
131+
(2, 2),
132+
taskbadger_name="new_name",
133+
taskbadger_value_max=10,
134+
taskbadger_kwargs={"data": {"foo": "bar"}},
135+
taskbadger_record_task_args=True,
136+
)
137+
assert result.get(timeout=10, propagate=True) == 4
138+
139+
create.assert_called_once_with(
140+
"new_name",
141+
value_max=10,
142+
data={"foo": "bar", "celery_task_args": [2, 2], "celery_task_kwargs": {}},
143+
status=StatusEnum.PENDING,
144+
)
145+
146+
147+
def test_celery_record_task_kwargs(celery_session_app, celery_session_worker, bind_settings):
148+
@celery_session_app.task(bind=True, base=Task)
149+
def add_with_task_kwargs(self, a, b, c=0):
150+
assert self.taskbadger_task is not None
151+
return a + b + c
152+
153+
celery_session_worker.reload()
154+
155+
with (
156+
mock.patch("taskbadger.celery.create_task_safe") as create,
157+
mock.patch("taskbadger.celery.update_task_safe"),
158+
mock.patch("taskbadger.celery.get_task"),
159+
):
160+
create.return_value = task_for_test()
161+
162+
actions = [Action("stale", integration=EmailIntegration(to="test@test.com"))]
163+
result = add_with_task_kwargs.delay(
164+
2,
165+
2,
166+
c=3,
167+
taskbadger_name="new_name",
168+
taskbadger_value_max=10,
169+
taskbadger_kwargs={"actions": actions},
170+
taskbadger_record_task_args=True,
171+
)
172+
assert result.get(timeout=10, propagate=True) == 7
173+
174+
create.assert_called_once_with(
175+
"new_name",
176+
value_max=10,
177+
data={"celery_task_args": [2, 2], "celery_task_kwargs": {"c": 3}},
178+
actions=actions,
179+
status=StatusEnum.PENDING,
180+
)
181+
182+
183+
def test_celery_record_task_args_custom_serialization(celery_session_app, celery_session_worker, bind_settings):
184+
class A:
185+
def __init__(self, a, b):
186+
self.a = a
187+
self.b = b
188+
189+
register_type(A, "A", lambda o: [o.a, o.b], lambda o: A(*o))
190+
191+
@celery_session_app.task(bind=True, base=Task)
192+
def add_task_custom_serialization(self, a):
193+
assert self.taskbadger_task is not None
194+
return a.a + a.b
195+
196+
celery_session_worker.reload()
197+
198+
with (
199+
mock.patch("taskbadger.celery.create_task_safe") as create,
200+
mock.patch("taskbadger.celery.update_task_safe"),
201+
mock.patch("taskbadger.celery.get_task"),
202+
):
203+
create.return_value = task_for_test()
204+
205+
result = add_task_custom_serialization.delay(
206+
A(2, 2),
207+
taskbadger_record_task_args=True,
208+
)
209+
assert result.get(timeout=10, propagate=True) == 4
210+
211+
create.assert_called_once_with(
212+
"tests.test_celery.add_task_custom_serialization",
213+
data={"celery_task_args": [{"__type__": "A", "__value__": [2, 2]}], "celery_task_kwargs": {}},
214+
status=StatusEnum.PENDING,
215+
)
216+
217+
114218
def test_celery_task_with_args_in_decorator(celery_session_app, celery_session_worker, bind_settings):
115219
@celery_session_app.task(
116220
bind=True,

tests/test_celery_system_integration.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pytest
1717
from celery.signals import task_prerun
1818

19+
from taskbadger import StatusEnum
20+
from taskbadger.celery import Task
1921
from taskbadger.mug import Badger, Settings
2022
from taskbadger.systems.celery import CelerySystemIntegration
2123
from tests.utils import task_for_test
@@ -73,6 +75,76 @@ def add_normal(self, a, b):
7375
assert Badger.current.session().client is None
7476

7577

78+
@pytest.mark.usefixtures("_bind_settings_with_system")
79+
def test_celery_record_task_args(celery_session_app, celery_session_worker):
80+
@celery_session_app.task(bind=True)
81+
def add_normal(self, a, b):
82+
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
83+
assert not hasattr(self, "taskbadger_task")
84+
assert Badger.current.session().client is not None, "missing client"
85+
return a + b
86+
87+
celery_session_worker.reload()
88+
89+
celery_system = Badger.current.settings.get_system_by_id("celery")
90+
celery_system.record_task_args = True
91+
92+
with (
93+
mock.patch("taskbadger.celery.create_task_safe") as create,
94+
mock.patch("taskbadger.celery.update_task_safe") as update,
95+
mock.patch("taskbadger.celery.get_task") as get_task,
96+
):
97+
tb_task = task_for_test()
98+
create.return_value = tb_task
99+
result = add_normal.delay(2, 2)
100+
assert result.info.get("taskbadger_task_id") == tb_task.id
101+
assert result.get(timeout=10, propagate=True) == 4
102+
103+
create.assert_called_once_with(
104+
"tests.test_celery_system_integration.add_normal",
105+
status=StatusEnum.PENDING,
106+
data={"celery_task_args": [2, 2], "celery_task_kwargs": {}},
107+
)
108+
assert get_task.call_count == 1
109+
assert update.call_count == 2
110+
assert Badger.current.session().client is None
111+
112+
113+
@pytest.mark.usefixtures("_bind_settings_with_system")
114+
def test_celery_record_task_args_local_override(celery_session_app, celery_session_worker):
115+
"""Test that passing `taskbadger_record_task_args` overrides the integration value"""
116+
117+
@celery_session_app.task(bind=True, base=Task)
118+
def add_normal_with_override(self, a, b):
119+
assert self.request.get("taskbadger_task_id") is not None, "missing task in request"
120+
assert hasattr(self, "taskbadger_task")
121+
assert Badger.current.session().client is not None, "missing client"
122+
return a + b
123+
124+
celery_session_worker.reload()
125+
126+
celery_system = Badger.current.settings.get_system_by_id("celery")
127+
celery_system.record_task_args = True
128+
129+
with (
130+
mock.patch("taskbadger.celery.create_task_safe") as create,
131+
mock.patch("taskbadger.celery.update_task_safe") as update,
132+
mock.patch("taskbadger.celery.get_task") as get_task,
133+
):
134+
tb_task = task_for_test()
135+
create.return_value = tb_task
136+
result = add_normal_with_override.delay(2, 2, taskbadger_record_task_args=False)
137+
assert result.info.get("taskbadger_task_id") == tb_task.id
138+
assert result.get(timeout=10, propagate=True) == 4
139+
140+
create.assert_called_once_with(
141+
"tests.test_celery_system_integration.add_normal_with_override", status=StatusEnum.PENDING
142+
)
143+
assert get_task.call_count == 1
144+
assert update.call_count == 2
145+
assert Badger.current.session().client is None
146+
147+
76148
@pytest.mark.parametrize(
77149
("include", "exclude", "expected"),
78150
[

0 commit comments

Comments
 (0)