Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clients/python/src/taskbroker_client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
metrics_class: str | MetricsBackend = "taskbroker_client.metrics.NoOpMetricsBackend",
at_most_once_store: AtMostOnceStore | None = None,
) -> None:
self.name = name
self.metrics = self._build_metrics(metrics_class)
self._config = {
"rpc_secret": None,
Expand Down
6 changes: 5 additions & 1 deletion clients/python/src/taskbroker_client/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TaskbrokerClient:
def __init__(
self,
hosts: list[str],
application: str,
metrics: MetricsBackend,
max_tasks_before_rebalance: int = DEFAULT_REBALANCE_AFTER,
max_consecutive_unavailable_errors: int = DEFAULT_CONSECUTIVE_UNAVAILABLE_ERRORS,
Expand All @@ -140,6 +141,7 @@ def __init__(
grpc_config: str | None = None,
) -> None:
assert len(hosts) > 0, "You must provide at least one RPC host to connect to"
self._application = application
self._hosts = hosts
self._rpc_secret = rpc_secret
self._metrics = metrics
Expand Down Expand Up @@ -259,7 +261,7 @@ def get_task(self, namespace: str | None = None) -> InflightTaskActivation | Non
"""
self._emit_health_check()

request = GetTaskRequest(namespace=namespace)
request = GetTaskRequest(application=self._application, namespace=namespace)
try:
host, stub = self._get_cur_stub()
with self._metrics.timer("taskworker.get_task.rpc", tags={"host": host}):
Expand Down Expand Up @@ -299,6 +301,8 @@ def update_task(
The return value is the next task that should be executed.
"""
self._emit_health_check()
if fetch_next_task is not None:
fetch_next_task.application = self._application

self._metrics.incr(
"taskworker.client.fetch_next", tags={"next": fetch_next_task is not None}
Expand Down
1 change: 1 addition & 0 deletions clients/python/src/taskbroker_client/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(

self.client = TaskbrokerClient(
hosts=broker_hosts,
application=app.name,
metrics=app.metrics,
max_tasks_before_rebalance=rebalance_after,
health_check_settings=(
Expand Down
99 changes: 74 additions & 25 deletions clients/python/tests/worker/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
TASK_ACTIVATION_STATUS_COMPLETE,
TASK_ACTIVATION_STATUS_RETRY,
FetchNextTask,
GetTaskRequest,
GetTaskResponse,
SetTaskStatusRequest,
SetTaskStatusResponse,
TaskActivation,
)
Expand Down Expand Up @@ -61,6 +63,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

if isinstance(res.response, Exception):
raise res.response
elif callable(res.response):
return res.response(*args, **kwargs)
return res.response

def with_call(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -91,7 +95,7 @@ def unary_unary(
def add_response(
self,
path: str,
resp: Message | Exception,
resp: Callable[[Any], Any] | Message | Exception,
metadata: tuple[tuple[str, str | bytes], ...] | None = None,
) -> None:
self._responses[path].append(MockServiceCall(response=resp, metadata=metadata))
Expand Down Expand Up @@ -130,7 +134,7 @@ def test_make_broker_hosts() -> None:

def test_init_no_hosts() -> None:
with pytest.raises(AssertionError) as err:
TaskbrokerClient(hosts=[], metrics=NoOpMetricsBackend())
TaskbrokerClient(hosts=[], application="sentry", metrics=NoOpMetricsBackend())
assert "You must provide at least one RPC host" in str(err)


Expand Down Expand Up @@ -167,6 +171,7 @@ def test_health_check_is_debounced() -> None:
health_check_path = Path(f"/tmp/{''.join(random.choices(string.ascii_letters, k=16))}")
client = TaskbrokerClient(
hosts=["localhost-0:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
health_check_settings=HealthCheckSettings(health_check_path, 1),
)
Expand All @@ -183,10 +188,11 @@ def test_health_check_is_debounced() -> None:


def test_get_task_ok() -> None:
channel = MockChannel()
channel.add_response(
"/sentry_protos.taskbroker.v1.ConsumerService/GetTask",
GetTaskResponse(
def get_task_response(request: GetTaskRequest) -> GetTaskResponse:
assert request.application == "sentry"
assert request.namespace == ""

return GetTaskResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
Expand All @@ -195,11 +201,18 @@ def test_get_task_ok() -> None:
headers={},
processing_deadline_duration=10,
)
),
)

channel = MockChannel()
channel.add_response(
"/sentry_protos.taskbroker.v1.ConsumerService/GetTask",
get_task_response,
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(["localhost-0:50051"], metrics=NoOpMetricsBackend())
client = TaskbrokerClient(
hosts=["localhost-0:50051"], application="sentry", metrics=NoOpMetricsBackend()
)
result = client.get_task()

assert result
Expand Down Expand Up @@ -228,7 +241,8 @@ def test_get_task_writes_to_health_check_file() -> None:
mock_channel.return_value = channel
health_check_path = Path(f"/tmp/{''.join(random.choices(string.ascii_letters, k=16))}")
client = TaskbrokerClient(
["localhost-0:50051"],
hosts=["localhost-0:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
health_check_settings=HealthCheckSettings(health_check_path, 3),
)
Expand All @@ -253,15 +267,18 @@ def test_get_task_with_interceptor() -> None:
metadata=(
(
"sentry-signature",
"3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f",
"556b2e74f2d5a1d0134b1f803c9bfaa8467bbd8e4cb510a9856c5e2ef2b66a21",
),
),
)
secret = '["a long secret value","notused"]'
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
["localhost-0:50051"], metrics=NoOpMetricsBackend(), rpc_secret=secret
hosts=["localhost-0:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
rpc_secret=secret,
)
result = client.get_task()

Expand Down Expand Up @@ -289,7 +306,9 @@ def test_get_task_with_namespace() -> None:
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
hosts=make_broker_hosts("localhost:50051", num_brokers=1), metrics=NoOpMetricsBackend()
hosts=make_broker_hosts("localhost:50051", num_brokers=1),
application="sentry",
metrics=NoOpMetricsBackend(),
)
result = client.get_task(namespace="testing")

Expand All @@ -307,7 +326,9 @@ def test_get_task_not_found() -> None:
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(["localhost:50051"], metrics=NoOpMetricsBackend())
client = TaskbrokerClient(
hosts=["localhost:50051"], application="sentry", metrics=NoOpMetricsBackend()
)
result = client.get_task()

assert result is None
Expand All @@ -321,7 +342,9 @@ def test_get_task_failure() -> None:
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(["localhost:50051"], metrics=NoOpMetricsBackend())
client = TaskbrokerClient(
hosts=["localhost:50051"], application="name", metrics=NoOpMetricsBackend()
)
with pytest.raises(grpc.RpcError):
client.get_task()

Expand All @@ -346,6 +369,7 @@ def test_update_task_writes_to_health_check_file() -> None:
health_check_path = Path(f"/tmp/{''.join(random.choices(string.ascii_letters, k=16))}")
client = TaskbrokerClient(
make_broker_hosts("localhost:50051", num_brokers=1),
application="sentry",
metrics=NoOpMetricsBackend(),
health_check_settings=HealthCheckSettings(
health_check_path, DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH
Expand All @@ -359,10 +383,12 @@ def test_update_task_writes_to_health_check_file() -> None:


def test_update_task_ok_with_next() -> None:
channel = MockChannel()
channel.add_response(
"/sentry_protos.taskbroker.v1.ConsumerService/SetTaskStatus",
SetTaskStatusResponse(
def update_task_response(request: SetTaskStatusRequest) -> SetTaskStatusResponse:
assert request.fetch_next_task
assert request.fetch_next_task.application == "sentry"
assert request.fetch_next_task.namespace == ""

return SetTaskStatusResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
Expand All @@ -371,12 +397,19 @@ def test_update_task_ok_with_next() -> None:
headers={},
processing_deadline_duration=10,
)
),
)

channel = MockChannel()
channel.add_response(
"/sentry_protos.taskbroker.v1.ConsumerService/SetTaskStatus",
update_task_response,
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
make_broker_hosts("localhost:50051", num_brokers=1), metrics=NoOpMetricsBackend()
make_broker_hosts("localhost:50051", num_brokers=1),
application="sentry",
metrics=NoOpMetricsBackend(),
)
assert set(client._host_to_stubs.keys()) == {"localhost-0:50051"}
result = client.update_task(
Expand Down Expand Up @@ -407,7 +440,9 @@ def test_update_task_ok_with_next_namespace() -> None:
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
make_broker_hosts("localhost:50051", num_brokers=1), metrics=NoOpMetricsBackend()
make_broker_hosts("localhost:50051", num_brokers=1),
application="sentry",
metrics=NoOpMetricsBackend(),
)
result = client.update_task(
ProcessingResult(
Expand All @@ -431,7 +466,9 @@ def test_update_task_ok_no_next() -> None:
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
make_broker_hosts("localhost:50051", num_brokers=1), metrics=NoOpMetricsBackend()
make_broker_hosts("localhost:50051", num_brokers=1),
application="sentry",
metrics=NoOpMetricsBackend(),
)
result = client.update_task(
ProcessingResult(
Expand All @@ -453,7 +490,9 @@ def test_update_task_not_found() -> None:
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(["localhost-0:50051"], metrics=NoOpMetricsBackend())
client = TaskbrokerClient(
["localhost-0:50051"], application="sentry", metrics=NoOpMetricsBackend()
)
result = client.update_task(
ProcessingResult(
task_id="abc123",
Expand All @@ -474,7 +513,9 @@ def test_update_task_unavailable_retain_task_to_host() -> None:
)
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(["localhost-0:50051"], metrics=NoOpMetricsBackend())
client = TaskbrokerClient(
["localhost-0:50051"], application="sentry", metrics=NoOpMetricsBackend()
)
with pytest.raises(MockGrpcError) as err:
client.update_task(
ProcessingResult(
Expand Down Expand Up @@ -572,6 +613,7 @@ def test_client_loadbalance() -> None:
]
client = TaskbrokerClient(
hosts=make_broker_hosts(host_prefix="localhost:50051", num_brokers=4),
application="sentry",
metrics=NoOpMetricsBackend(),
max_tasks_before_rebalance=1,
)
Expand Down Expand Up @@ -662,6 +704,7 @@ def test_client_loadbalance_on_notfound() -> None:
]
client = TaskbrokerClient(
hosts=make_broker_hosts(host_prefix="localhost:50051", num_brokers=3),
application="sentry",
metrics=NoOpMetricsBackend(),
max_tasks_before_rebalance=30,
)
Expand Down Expand Up @@ -727,6 +770,7 @@ def test_client_loadbalance_on_unavailable() -> None:
]
client = TaskbrokerClient(
hosts=make_broker_hosts(host_prefix="localhost:50051", num_brokers=2),
application="sentry",
metrics=NoOpMetricsBackend(),
max_consecutive_unavailable_errors=3,
)
Expand Down Expand Up @@ -784,6 +828,7 @@ def test_client_single_host_unavailable() -> None:
mock_channel.return_value = channel
client = TaskbrokerClient(
hosts=["localhost-0:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
max_consecutive_unavailable_errors=3,
temporary_unavailable_host_timeout=2,
Expand Down Expand Up @@ -829,7 +874,10 @@ def test_client_reset_errors_after_success() -> None:
with patch("taskbroker_client.worker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskbrokerClient(
["localhost:50051"], metrics=NoOpMetricsBackend(), max_consecutive_unavailable_errors=3
["localhost:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
max_consecutive_unavailable_errors=3,
)

with pytest.raises(grpc.RpcError, match="host is unavailable"):
Expand Down Expand Up @@ -885,6 +933,7 @@ def mock_time() -> float:
mock_channel.return_value = channel
client = TaskbrokerClient(
["localhost:50051"],
application="sentry",
metrics=NoOpMetricsBackend(),
max_consecutive_unavailable_errors=3,
temporary_unavailable_host_timeout=10,
Expand Down
Loading