diff --git a/src/push/mod.rs b/src/push/mod.rs index 60408130..55bae746 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,20 +1,38 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use anyhow::Result; +use anyhow::{Context, Result}; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; +use hmac::{Hmac, Mac}; use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; +use sha2::Sha256; use tokio::task::JoinSet; use tonic::async_trait; +use tonic::metadata::MetadataValue; use tonic::transport::Channel; use tracing::{debug, error, info}; use crate::config::Config; use crate::store::inflight_activation::InflightActivation; +type HmacSha256 = Hmac; + +/// gRPC path for `WorkerService::PushTask` — keep in sync with `sentry_protos` generated client. +const WORKER_PUSH_TASK_PATH: &str = "/sentry_protos.taskbroker.v1.WorkerService/PushTask"; + +/// HMAC-SHA256(secret, grpc_path + ":" + message), hex-encoded. Matches Python `RequestSignatureInterceptor` and broker [`crate::grpc::auth_middleware`]. +fn sentry_signature_hex(secret: &str, grpc_path: &str, message: &[u8]) -> String { + let mut mac = + HmacSha256::new_from_slice(secret.as_bytes()).expect("HMAC accepts keys of any length"); + mac.update(grpc_path.as_bytes()); + mac.update(b":"); + mac.update(message); + hex::encode(mac.finalize().into_bytes()) +} + /// Error returned when enqueueing an activation for the push workers fails. #[derive(Debug)] #[allow(clippy::large_enum_variant)] @@ -30,13 +48,32 @@ pub enum PushError { #[async_trait] trait WorkerClient { /// Send a single `PushTaskRequest` to the worker service. - async fn send(&mut self, request: PushTaskRequest) -> Result<()>; + /// + /// When `grpc_shared_secret` is non-empty, signs with `grpc_shared_secret[0]` and sets `sentry-signature` metadata (same scheme as Python pull client and broker `AuthLayer`). + async fn send(&mut self, request: PushTaskRequest, grpc_shared_secret: &[String]) + -> Result<()>; } #[async_trait] impl WorkerClient for WorkerServiceClient { - async fn send(&mut self, request: PushTaskRequest) -> Result<()> { - self.push_task(request).await?; + async fn send( + &mut self, + request: PushTaskRequest, + grpc_shared_secret: &[String], + ) -> Result<()> { + let mut req = tonic::Request::new(request); + + if let Some(secret) = grpc_shared_secret.first() { + let body = req.get_ref().encode_to_vec(); + let signature = sentry_signature_hex(secret, WORKER_PUSH_TASK_PATH, &body); + let value = MetadataValue::try_from(signature.as_str()) + .context("sentry-signature metadata value must be valid ASCII")?; + req.metadata_mut().insert("sentry-signature", value); + } + + self.push_task(req) + .await + .map_err(|status| anyhow::anyhow!(status))?; Ok(()) } } @@ -67,9 +104,8 @@ impl PushPool { /// Spawn `config.push_threads` asynchronous tasks, each of which repeatedly moves pending activations from the channel to the worker service until the shutdown signal is received. pub async fn start(&self) -> Result<()> { - let mut push_pool: JoinSet> = crate::tokio::spawn_pool( - self.config.push_threads, - |_| { + let mut push_pool: JoinSet> = + crate::tokio::spawn_pool(self.config.push_threads, |_| { let endpoint = self.config.worker_endpoint.clone(); let receiver = self.receiver.clone(); @@ -81,6 +117,7 @@ impl PushPool { ); let timeout = Duration::from_millis(self.config.push_timeout_ms); + let grpc_shared_secret = self.config.grpc_shared_secret.clone(); async move { let mut worker = match WorkerServiceClient::connect(endpoint).await { @@ -112,7 +149,15 @@ impl PushPool { let id = activation.id.clone(); let callback_url = callback_url.clone(); - match push_task(&mut worker, activation, callback_url, timeout).await { + match push_task( + &mut worker, + activation, + callback_url, + timeout, + grpc_shared_secret.as_slice(), + ) + .await + { Ok(_) => debug!(task_id = %id, "Activation sent to worker"), // Once processing deadline expires, status will be set back to pending @@ -131,7 +176,15 @@ impl PushPool { let id = activation.id.clone(); let callback_url = callback_url.clone(); - match push_task(&mut worker, activation, callback_url, timeout).await { + match push_task( + &mut worker, + activation, + callback_url, + timeout, + grpc_shared_secret.as_slice(), + ) + .await + { Ok(_) => debug!(task_id = %id, "Activation sent to worker"), // Once processing deadline expires, status will be set back to pending @@ -145,8 +198,7 @@ impl PushPool { Ok(()) } - }, - ); + }); while let Some(result) = push_pool.join_next().await { match result { @@ -185,6 +237,7 @@ async fn push_task( activation: InflightActivation, callback_url: String, timeout: Duration, + grpc_shared_secret: &[String], ) -> Result<()> { let start = Instant::now(); @@ -196,7 +249,8 @@ async fn push_task( callback_url, }; - let result = match tokio::time::timeout(timeout, worker.send(request)).await { + let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await + { Ok(r) => r, Err(e) => Err(e.into()), }; diff --git a/src/push/tests.rs b/src/push/tests.rs index b2db0ba1..5b79b95b 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -31,7 +31,11 @@ impl MockWorkerClient { #[async_trait] impl WorkerClient for MockWorkerClient { - async fn send(&mut self, request: PushTaskRequest) -> Result<()> { + async fn send( + &mut self, + request: PushTaskRequest, + _grpc_shared_secret: &[String], + ) -> Result<()> { self.captured_requests.push(request); if self.should_fail { @@ -53,6 +57,7 @@ async fn push_task_returns_ok_on_client_success() { activation.clone(), callback_url.clone(), Duration::from_secs(5), + &[], ) .await; assert!(result.is_ok(), "push_task should succeed"); @@ -77,6 +82,7 @@ async fn push_task_returns_err_on_invalid_payload() { activation, "taskbroker:50051".to_string(), Duration::from_secs(5), + &[], ) .await; @@ -97,6 +103,7 @@ async fn push_task_propagates_client_error() { activation, "taskbroker:50051".to_string(), Duration::from_secs(5), + &[], ) .await; assert!(result.is_err(), "worker send errors should propagate"); @@ -139,3 +146,12 @@ async fn push_pool_submit_backpressures_when_queue_full() { "second submit should block when queue is full" ); } + +#[test] +fn sentry_signature_hex_matches_hmac_contract() { + let digest = sentry_signature_hex("super secret", "/test/path", b"hello"); + assert_eq!( + digest, + "6408482d9e6d4975ada4c0302fda813c5718e571e6f9a2d6e2803cb48528044e" + ); +}