From f70bfda92f7380aa5d147511950593c5a4d20b57 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Tue, 4 Nov 2025 15:59:33 -0500 Subject: [PATCH 01/17] feat(v2): Postgres storage adapter This adds a postgres storage adapter for the taskbroker, as well as providing a way to choose between the adapters in the configuration. This adapter will also work with AlloyDB. In postgres, the keyword `offset` is reserved, so that column is called `kafka_offset` in the PG tables and converted to `offset`. The tests were updated to run with both the SQLite and Postgres adapter using the rstest crate. The `create_test_store` function was updated to be the standard for all tests, and to allow choosing between a SQLite and Postgres DB. A `remove_db` function was added to the trait and the existing adapters, since the tests create a unique PG database on every run that should be cleaned up. The `create_test_store` function was updated to be the standard for all tests, and to allow choosing between an SQLite and Postgres DB. --- Cargo.lock | 49 ++ Cargo.toml | 3 +- Dockerfile | 3 +- default_migrations/0001_create_database.sql | 1 + .../0001_create_inflight_activations.sql | 20 + src/config.rs | 11 + src/grpc/server_tests.rs | 73 +- src/kafka/deserialize_activation.rs | 2 +- src/kafka/inflight_activation_writer.rs | 157 ++-- src/main.rs | 26 +- src/store/inflight_activation.rs | 97 ++- src/store/inflight_activation_tests.rs | 375 ++++++--- src/store/mod.rs | 1 + src/store/postgres_activation_store.rs | 732 ++++++++++++++++++ src/test_utils.rs | 71 +- src/upkeep.rs | 127 +-- 16 files changed, 1431 insertions(+), 317 deletions(-) create mode 100644 default_migrations/0001_create_database.sql create mode 100644 pg_migrations/0001_create_inflight_activations.sql create mode 100644 src/store/postgres_activation_store.rs diff --git a/Cargo.lock b/Cargo.lock index b933e361..c93a4d1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -833,6 +833,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -890,6 +896,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.12" @@ -2133,6 +2145,12 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "reqwest" version = "0.12.23" @@ -2191,6 +2209,36 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.26" @@ -2884,6 +2932,7 @@ dependencies = [ "prost-types", "rand 0.8.5", "rdkafka", + "rstest", "sentry", "sentry_protos", "serde", diff --git a/Cargo.toml b/Cargo.toml index d785ec99..5d928197 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ sentry_protos = "0.4.11" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" -sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono"] } +sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "postgres"] } tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" @@ -61,6 +61,7 @@ uuid = { version = "1.11.0", features = ["v4"] } [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } +rstest = "0.23" [[bench]] name = "store_bench" diff --git a/Dockerfile b/Dockerfile index 125447c4..700da096 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,7 +3,7 @@ # recent enough version of protobuf-compiler FROM rust:1-bookworm AS build -RUN apt-get update && apt-get upgrade -y +RUN apt-get update && apt-get upgrade -y RUN apt-get install -y cmake pkg-config libssl-dev librdkafka-dev protobuf-compiler RUN USER=root cargo new --bin taskbroker @@ -17,6 +17,7 @@ ENV TASKBROKER_VERSION=$TASKBROKER_GIT_REVISION COPY ./Cargo.lock ./Cargo.lock COPY ./Cargo.toml ./Cargo.toml COPY ./migrations ./migrations +COPY ./pg_migrations ./pg_migrations COPY ./benches ./benches # Build dependencies in a way they can be cached diff --git a/default_migrations/0001_create_database.sql b/default_migrations/0001_create_database.sql new file mode 100644 index 00000000..00d61748 --- /dev/null +++ b/default_migrations/0001_create_database.sql @@ -0,0 +1 @@ +CREATE DATABASE taskbroker; diff --git a/pg_migrations/0001_create_inflight_activations.sql b/pg_migrations/0001_create_inflight_activations.sql new file mode 100644 index 00000000..80b552db --- /dev/null +++ b/pg_migrations/0001_create_inflight_activations.sql @@ -0,0 +1,20 @@ +-- PostgreSQL equivalent of the inflight_taskactivations table +CREATE TABLE IF NOT EXISTS inflight_taskactivations ( + id TEXT NOT NULL PRIMARY KEY, + activation BYTEA NOT NULL, + partition INTEGER NOT NULL, + kafka_offset BIGINT NOT NULL, + added_at TIMESTAMPTZ NOT NULL, + received_at TIMESTAMPTZ NOT NULL, + processing_attempts INTEGER NOT NULL, + expires_at TIMESTAMPTZ, + delay_until TIMESTAMPTZ, + processing_deadline_duration INTEGER NOT NULL, + processing_deadline TIMESTAMPTZ, + status TEXT NOT NULL, + at_most_once BOOLEAN NOT NULL DEFAULT FALSE, + application TEXT NOT NULL DEFAULT '', + namespace TEXT NOT NULL, + taskname TEXT NOT NULL, + on_attempts_exceeded INTEGER NOT NULL DEFAULT 1 +); diff --git a/src/config.rs b/src/config.rs index d4bbaf2b..af1d2fc1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -121,6 +121,14 @@ pub struct Config { /// The number of ms for timeouts when publishing messages to kafka. pub kafka_send_timeout_ms: u64, + pub database_adapter: &'static str, + + /// The url of the postgres database to use for the inflight activation store. + pub pg_url: String, + + /// The name of the postgres database to use for the inflight activation store. + pub pg_database_name: String, + /// The path to the sqlite database pub db_path: String, @@ -256,6 +264,9 @@ impl Default for Config { kafka_auto_offset_reset: "latest".to_owned(), kafka_send_timeout_ms: 500, db_path: "./taskbroker-inflight.sqlite".to_owned(), + database_adapter: "sqlite", + pg_url: "postgres://postgres:password@sentry-postgres-1:5432/".to_owned(), + pg_database_name: "taskbroker".to_owned(), db_write_failure_backoff_ms: 4000, db_insert_batch_max_len: 256, db_insert_batch_max_size: 16_000_000, diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 1c9c9279..c72b1b50 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -1,6 +1,6 @@ use crate::grpc::server::TaskbrokerServer; -use crate::store::inflight_activation::InflightActivationStore; use prost::Message; +use rstest::rstest; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; use sentry_protos::taskbroker::v1::{ FetchNextTask, GetTaskRequest, SetTaskStatusRequest, TaskActivation, @@ -10,8 +10,11 @@ use tonic::{Code, Request}; use crate::test_utils::{create_test_store, make_activations}; #[tokio::test] -async fn test_get_task() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_task(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = GetTaskRequest { namespace: None, @@ -25,9 +28,12 @@ async fn test_get_task() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status() { - let store = create_test_store().await; +async fn test_set_task_status(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = SetTaskStatusRequest { id: "test_task".to_string(), @@ -41,9 +47,12 @@ async fn test_set_task_status() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_invalid() { - let store = create_test_store().await; +async fn test_set_task_status_invalid(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let service = TaskbrokerServer { store }; let request = SetTaskStatusRequest { id: "test_task".to_string(), @@ -61,9 +70,12 @@ async fn test_set_task_status_invalid() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_success() { - let store = create_test_store().await; +async fn test_get_task_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(1); store.store(activations).await.unwrap(); @@ -81,9 +93,12 @@ async fn test_get_task_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_with_application_success() { - let store = create_test_store().await; +async fn test_get_task_with_application_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -108,9 +123,12 @@ async fn test_get_task_with_application_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_get_task_with_namespace_requires_application() { - let store = create_test_store().await; +async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); let namespace = activations[0].namespace.clone(); @@ -129,9 +147,12 @@ async fn test_get_task_with_namespace_requires_application() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_success() { - let store = create_test_store().await; +async fn test_set_task_status_success(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); store.store(activations).await.unwrap(); @@ -157,6 +178,7 @@ async fn test_set_task_status_success() { }), }; let response = service.set_task_status(Request::new(request)).await; + println!("response: {:?}", response); assert!(response.is_ok()); let resp = response.unwrap(); assert!(resp.get_ref().task.is_some()); @@ -165,9 +187,12 @@ async fn test_set_task_status_success() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_application() { - let store = create_test_store().await; +async fn test_set_task_status_with_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -199,9 +224,12 @@ async fn test_set_task_status_with_application() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_application_no_match() { - let store = create_test_store().await; +async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut activations = make_activations(2); let mut payload = TaskActivation::decode(&activations[1].activation as &[u8]).unwrap(); @@ -228,9 +256,12 @@ async fn test_set_task_status_with_application_no_match() { } #[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] #[allow(deprecated)] -async fn test_set_task_status_with_namespace_requires_application() { - let store = create_test_store().await; +async fn test_set_task_status_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let activations = make_activations(2); let namespace = activations[0].namespace.clone(); diff --git a/src/kafka/deserialize_activation.rs b/src/kafka/deserialize_activation.rs index 891bc4f8..d0ea263e 100644 --- a/src/kafka/deserialize_activation.rs +++ b/src/kafka/deserialize_activation.rs @@ -87,7 +87,7 @@ pub fn new( added_at: Utc::now(), received_at: activation_time, processing_deadline: None, - processing_deadline_duration: activation.processing_deadline_duration as u32, + processing_deadline_duration: activation.processing_deadline_duration as i32, processing_attempts: 0, expires_at, delay_until, diff --git a/src/kafka/inflight_activation_writer.rs b/src/kafka/inflight_activation_writer.rs index ff48d5c6..d89bdf24 100644 --- a/src/kafka/inflight_activation_writer.rs +++ b/src/kafka/inflight_activation_writer.rs @@ -80,7 +80,6 @@ impl Reducer for InflightActivationWriter { self.batch.take(); return Ok(Some(())); } - // Check if writing the batch would exceed the limits let exceeded_pending_limit = self .store @@ -145,7 +144,6 @@ impl Reducer for InflightActivationWriter { "reason" => reason, ) .increment(1); - return Ok(None); } @@ -206,22 +204,23 @@ mod tests { use chrono::{DateTime, Utc}; use prost::Message; use prost_types::Timestamp; + use rstest::rstest; use std::collections::HashMap; + use crate::test_utils::create_test_store; use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use sentry_protos::taskbroker::v1::TaskActivation; - use std::sync::Arc; - use crate::store::inflight_activation::{ - InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, - SqliteActivationStore, - }; + use crate::store::inflight_activation::InflightActivationStatus; use crate::test_utils::generate_unique_namespace; use crate::test_utils::make_activations; - use crate::test_utils::{create_integration_config, generate_temp_filename}; #[tokio::test] - async fn test_writer_flush_batch() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -230,17 +229,7 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -333,29 +322,24 @@ mod tests { .await .unwrap(); assert_eq!(count_pending + count_delay, 2); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_only_pending() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_only_pending(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, max_pending_activations: 10, max_processing_activations: 10, - max_delay_activations: 0, + max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -402,10 +386,15 @@ mod tests { writer.flush().await.unwrap(); let count_pending = writer.store.count_pending_activations().await.unwrap(); assert_eq!(count_pending, 1); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_only_delay() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_only_delay(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -414,17 +403,7 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -475,10 +454,15 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 1); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_pending_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_pending_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -487,17 +471,7 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -591,10 +565,17 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 0); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_only_delay_limit_reached_and_entire_batch_is_pending() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_only_delay_limit_reached_and_entire_batch_is_pending( + #[case] adapter: &str, + ) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -603,17 +584,7 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let mut writer = InflightActivationWriter::new( - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ), - writer_config, - ); + let mut writer = InflightActivationWriter::new(store, writer_config); let received_at = Timestamp { seconds: 0, @@ -707,10 +678,15 @@ mod tests { .await .unwrap(); assert_eq!(count_delay, 0); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_processing_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_processing_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -719,14 +695,6 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let received_at = Timestamp { seconds: 0, @@ -866,10 +834,17 @@ mod tests { .unwrap(); // Only the existing processing activation should remain, new ones should be blocked assert_eq!(count_processing, 1); + // TODO: Because the store and the writer both access the DB, both need to be cleaned up. + // Uncomment this when we figure out how to do that cleanly. + // writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_backpressure_db_size_limit_reached() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_backpressure_db_size_limit_reached(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { // 200 rows is ~50KB db_max_size: Some(50_000), @@ -879,14 +854,6 @@ mod tests { max_delay_activations: 0, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let first_round = make_activations(200); store.store(first_round).await.unwrap(); assert!(store.db_size().await.unwrap() > 50_000); @@ -901,10 +868,15 @@ mod tests { let count_pending = writer.store.count_pending_activations().await.unwrap(); assert_eq!(count_pending, 200); + writer.store.remove_db().await.unwrap(); } #[tokio::test] - async fn test_writer_flush_empty_batch() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_writer_flush_empty_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let writer_config = ActivationWriterConfig { db_max_size: None, max_buf_len: 100, @@ -913,17 +885,10 @@ mod tests { max_delay_activations: 10, write_failure_backoff_ms: 4000, }; - let store = Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ); let mut writer = InflightActivationWriter::new(store.clone(), writer_config); writer.reduce(vec![]).await.unwrap(); let flush_result = writer.flush().await.unwrap(); assert!(flush_result.is_some()); + writer.store.remove_db().await.unwrap(); } } diff --git a/src/main.rs b/src/main.rs index 5f0eee52..c1221d1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,9 @@ use taskbroker::runtime_config::RuntimeConfigManager; use taskbroker::store::inflight_activation::{ InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, }; +use taskbroker::store::postgres_activation_store::{ + PostgresActivationStore, PostgresActivationStoreConfig, +}; use taskbroker::{Args, get_version}; use tonic_health::ServingStatus; @@ -62,13 +65,21 @@ async fn main() -> Result<(), Error> { logging::init(logging::LoggingConfig::from_config(&config)); metrics::init(metrics::MetricsConfig::from_config(&config)); - let store: Arc = Arc::new( - SqliteActivationStore::new( - &config.db_path, - InflightActivationStoreConfig::from_config(&config), - ) - .await?, - ); + + let store: Arc = match config.database_adapter { + "sqlite" => Arc::new( + SqliteActivationStore::new( + &config.db_path, + InflightActivationStoreConfig::from_config(&config), + ) + .await?, + ), + "postgres" => Arc::new( + PostgresActivationStore::new(PostgresActivationStoreConfig::from_config(&config)) + .await?, + ), + _ => panic!("Invalid database adapter: {}", config.database_adapter), + }; // If this is an environment where the topics might not exist, check and create them. if config.create_missing_topics { @@ -80,6 +91,7 @@ async fn main() -> Result<(), Error> { ) .await?; } + if config.full_vacuum_on_start { info!("Running full vacuum on database"); match store.full_vacuum_db().await { diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 3668bc40..7d0034ab 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -1,6 +1,10 @@ +use anyhow::{Error, anyhow}; +use sqlx::postgres::PgQueryResult; +use std::fmt::Result as FmtResult; +use std::fmt::{Display, Formatter}; use std::{str::FromStr, time::Instant}; -use anyhow::{Error, anyhow}; +use crate::config::Config; use async_trait::async_trait; use chrono::{DateTime, Utc}; use libsqlite3_sys::{ @@ -23,8 +27,6 @@ use sqlx::{ }; use tracing::{instrument, warn}; -use crate::config::Config; - /// The members of this enum should be synced with the members /// of InflightActivationStatus in sentry_protos #[derive(Clone, Copy, Debug, PartialEq, Eq, Type)] @@ -39,6 +41,36 @@ pub enum InflightActivationStatus { Delay, } +impl Display for InflightActivationStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{:?}", self) + } +} + +impl FromStr for InflightActivationStatus { + type Err = String; + + fn from_str(s: &str) -> Result { + if s == "Unspecified" { + Ok(InflightActivationStatus::Unspecified) + } else if s == "Pending" { + Ok(InflightActivationStatus::Pending) + } else if s == "Processing" { + Ok(InflightActivationStatus::Processing) + } else if s == "Failure" { + Ok(InflightActivationStatus::Failure) + } else if s == "Retry" { + Ok(InflightActivationStatus::Retry) + } else if s == "Complete" { + Ok(InflightActivationStatus::Complete) + } else if s == "Delay" { + Ok(InflightActivationStatus::Delay) + } else { + Err(format!("Unknown inflight activation status string: {}", s)) + } + } +} + impl InflightActivationStatus { /// Is the current value a 'conclusion' status that can be supplied over GRPC. pub fn is_conclusion(&self) -> bool { @@ -93,7 +125,7 @@ pub struct InflightActivation { /// The duration in seconds that a worker has to complete task execution. /// When an activation is moved from pending -> processing a result is expected /// in this many seconds. - pub processing_deadline_duration: u32, + pub processing_deadline_duration: i32, /// If the task has specified an expiry, this is the timestamp after which the task should be removed from inflight store pub expires_at: Option>, @@ -145,31 +177,39 @@ impl From for QueryResult { } } +impl From for QueryResult { + fn from(value: PgQueryResult) -> Self { + Self { + rows_affected: value.rows_affected(), + } + } +} + pub struct FailedTasksForwarder { pub to_discard: Vec<(String, Vec)>, pub to_deadletter: Vec<(String, Vec)>, } #[derive(Debug, FromRow)] -struct TableRow { - id: String, - activation: Vec, - partition: i32, - offset: i64, - added_at: DateTime, - received_at: DateTime, - processing_attempts: i32, - expires_at: Option>, - delay_until: Option>, - processing_deadline_duration: u32, - processing_deadline: Option>, - status: InflightActivationStatus, - at_most_once: bool, - application: String, - namespace: String, - taskname: String, +pub struct TableRow { + pub id: String, + pub activation: Vec, + pub partition: i32, + pub offset: i64, + pub added_at: DateTime, + pub received_at: DateTime, + pub processing_attempts: i32, + pub expires_at: Option>, + pub delay_until: Option>, + pub processing_deadline_duration: i32, + pub processing_deadline: Option>, + pub status: String, + pub at_most_once: bool, + pub application: String, + pub namespace: String, + pub taskname: String, #[sqlx(try_from = "i32")] - on_attempts_exceeded: OnAttemptsExceeded, + pub on_attempts_exceeded: OnAttemptsExceeded, } impl TryFrom for TableRow { @@ -188,7 +228,7 @@ impl TryFrom for TableRow { delay_until: value.delay_until, processing_deadline_duration: value.processing_deadline_duration, processing_deadline: value.processing_deadline, - status: value.status, + status: value.status.to_string(), at_most_once: value.at_most_once, application: value.application, namespace: value.namespace, @@ -203,7 +243,7 @@ impl From for InflightActivation { Self { id: value.id, activation: value.activation, - status: value.status, + status: InflightActivationStatus::from_str(&value.status).unwrap(), partition: value.partition, offset: value.offset, added_at: value.added_at, @@ -360,6 +400,9 @@ pub trait InflightActivationStore: Send + Sync { /// Remove killswitched tasks async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result; + + /// Remove the database, used only in tests + async fn remove_db(&self) -> Result<(), Error>; } pub struct SqliteActivationStore { @@ -656,6 +699,7 @@ impl InflightActivationStore for SqliteActivationStore { .into_iter() .map(TableRow::try_from) .collect::, _>>()?; + let query = query_builder .push_values(rows, |mut b, row| { b.push_bind(row.id); @@ -1180,4 +1224,9 @@ impl InflightActivationStore for SqliteActivationStore { Ok(query.rows_affected()) } + + // Used in tests + async fn remove_db(&self) -> Result<(), Error> { + Ok(()) + } } diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index 83d71c65..657bb3e2 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -1,5 +1,8 @@ use prost::Message; +use rstest::rstest; +use sqlx::{QueryBuilder, Sqlite}; use std::collections::{HashMap, HashSet}; +use std::fs; use std::io::Error; use std::path::Path; use std::sync::Arc; @@ -19,8 +22,6 @@ use chrono::{DateTime, SubsecRound, TimeZone, Utc}; use sentry_protos::taskbroker::v1::{ OnAttemptsExceeded, RetryState, TaskActivation, TaskActivationStatus, }; -use sqlx::{QueryBuilder, Sqlite}; -use std::fs; use tokio::sync::broadcast; use tokio::task::JoinSet; @@ -64,7 +65,7 @@ fn test_inflightactivation_status_from() { } #[tokio::test] -async fn test_create_db() { +async fn test_sqlite_create_db() { assert!( SqliteActivationStore::new( &generate_temp_filename(), @@ -76,34 +77,50 @@ async fn test_create_db() { } #[tokio::test] -async fn test_store() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.count().await; assert_eq!(result.unwrap(), 2); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_store_duplicate_id_in_batch() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store_duplicate_id_in_batch(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); // Coerce a conflict batch[0].id = "id_0".into(); batch[1].id = "id_0".into(); - assert!(store.store(batch).await.is_ok()); + let first_result = store.store(batch).await; + assert!( + first_result.is_ok(), + "{}", + first_result.err().unwrap().to_string() + ); let result = store.count().await; assert_eq!(result.unwrap(), 1); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_store_duplicate_id_between_batches() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_store_duplicate_id_between_batches(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -118,11 +135,15 @@ async fn test_store_duplicate_id_between_batches() { let second_count = store.count().await; assert_eq!(second_count.unwrap(), 2); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -149,11 +170,15 @@ async fn test_get_pending_activation() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test(flavor = "multi_thread", worker_threads = 32)] -async fn test_get_pending_activation_with_race() { - let store = Arc::new(create_test_store().await); +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_race(#[case] adapter: &str) { + let store = Arc::new(create_test_store(adapter).await); let namespace = generate_unique_namespace(); const NUM_CONCURRENT_WRITES: u32 = 2000; @@ -192,11 +217,15 @@ async fn test_get_pending_activation_with_race() { .collect(); assert_eq!(res.len(), NUM_CONCURRENT_WRITES as usize); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_namespace() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_namespace(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].namespace = "other_namespace".into(); @@ -212,11 +241,15 @@ async fn test_get_pending_activation_with_namespace() { assert_eq!(result.status, InflightActivationStatus::Processing); assert!(result.processing_deadline.unwrap() > Utc::now()); assert_eq!(result.namespace, "other_namespace"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_from_multiple_namespaces() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_from_multiple_namespaces(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(4); batch[0].namespace = "ns1".into(); @@ -233,17 +266,21 @@ async fn test_get_pending_activation_from_multiple_namespaces() { .unwrap(); assert_eq!(result.len(), 2); - assert_eq!(result[0].id, "id_1"); - assert_eq!(result[0].namespace, "ns2"); - assert_eq!(result[0].status, InflightActivationStatus::Processing); assert_eq!(result[1].id, "id_2"); assert_eq!(result[1].namespace, "ns3"); assert_eq!(result[1].status, InflightActivationStatus::Processing); + assert_eq!(result[0].id, "id_1"); + assert_eq!(result[0].namespace, "ns2"); + assert_eq!(result[0].status, InflightActivationStatus::Processing); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_namespace_requires_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_namespace_requires_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].namespace = "other_namespace".into(); @@ -268,11 +305,24 @@ async fn test_get_pending_activation_with_namespace_requires_application() { activations.len(), "should find 1 activation with a matching namespace" ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_skip_expires() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_skip_expires(#[case] adapter: &str) { + let store = create_test_store(adapter).await; + + assert_counts( + StatusCount { + pending: 0, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; let mut batch = make_activations(1); batch[0].expires_at = Some(Utc::now() - Duration::from_secs(100)); @@ -291,16 +341,21 @@ async fn test_get_pending_activation_skip_expires() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_earliest() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_earliest(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[0].added_at = Utc.with_ymd_and_hms(2024, 6, 24, 0, 0, 0).unwrap(); batch[1].added_at = Utc.with_ymd_and_hms(1998, 6, 24, 0, 0, 0).unwrap(); - assert!(store.store(batch.clone()).await.is_ok()); + let ret = store.store(batch.clone()).await; + assert!(ret.is_ok(), "{}", ret.err().unwrap().to_string()); let result = store .get_pending_activation(None, None) @@ -311,11 +366,15 @@ async fn test_get_pending_activation_earliest() { result.added_at, Utc.with_ymd_and_hms(1998, 6, 24, 0, 0, 0).unwrap() ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_fetches_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_fetches_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(1); batch[0].application = "hammers".into(); @@ -332,11 +391,15 @@ async fn test_get_pending_activation_fetches_application() { assert_eq!(result.status, InflightActivationStatus::Processing); assert!(result.processing_deadline.unwrap() > Utc::now()); assert_eq!(result.application, "hammers"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_application() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_application(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].application = "hammers".into(); @@ -364,11 +427,15 @@ async fn test_get_pending_activation_with_application() { let result_opt = store.get_pending_activation(None, None).await.unwrap(); assert!(result_opt.is_some(), "one pending activation in '' left"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_pending_activation_with_application_and_namespace() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_pending_activation_with_application_and_namespace(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].namespace = "target".into(); @@ -400,11 +467,15 @@ async fn test_get_pending_activation_with_application_and_namespace() { assert_eq!(result.id, "id_2"); assert_eq!(result.application, "hammers"); assert_eq!(result.namespace, "not-target"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_count_pending_activations() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_count_pending_activations(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].status = InflightActivationStatus::Processing; @@ -420,11 +491,15 @@ async fn test_count_pending_activations() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn set_activation_status() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_set_activation_status(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); @@ -514,37 +589,37 @@ async fn set_activation_status() { let inflight = result_opt.unwrap(); assert_eq!(inflight.id, "id_0"); assert_eq!(inflight.status, InflightActivationStatus::Complete); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_set_processing_deadline() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_set_processing_deadline(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(1); assert!(store.store(batch.clone()).await.is_ok()); - let deadline = Utc::now(); - assert!( - store - .set_processing_deadline("id_0", Some(deadline)) - .await - .is_ok() - ); + let deadline = Utc::now().round_subsecs(0); + let result = store.set_processing_deadline("id_0", Some(deadline)).await; + assert!(result.is_ok(), "query error: {:?}", result.err().unwrap()); let result = store.get_by_id("id_0").await.unwrap().unwrap(); assert_eq!( - result - .processing_deadline - .unwrap() - .round_subsecs(0) - .timestamp(), + result.processing_deadline.unwrap().timestamp(), deadline.timestamp() - ) + ); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_delete_activation() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_delete_activation(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); @@ -563,11 +638,15 @@ async fn test_delete_activation() { assert!(store.delete_activation("id_1").await.is_ok()); let result = store.count().await; assert_eq!(result.unwrap(), 0); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_get_retry_activations() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_get_retry_activations(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch.clone()).await.is_ok()); @@ -608,11 +687,15 @@ async fn test_get_retry_activations() { for record in retries.iter() { assert_eq!(record.status, InflightActivationStatus::Retry); } + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -648,11 +731,15 @@ async fn test_handle_processing_deadline() { let count = store.handle_processing_deadline().await; assert!(count.is_ok()); assert_eq!(count.unwrap(), 0); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_multiple_tasks() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_multiple_tasks(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[0].status = InflightActivationStatus::Processing; @@ -681,11 +768,15 @@ async fn test_handle_processing_deadline_multiple_tasks() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_at_most_once() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_at_most_once(#[case] adapter: &str) { + let store = create_test_store(adapter).await; // Both records are past processing deadlines let mut batch = make_activations(2); @@ -731,11 +822,15 @@ async fn test_handle_processing_at_most_once() { let task = store.get_by_id(&batch[1].id).await.unwrap().unwrap(); assert_eq!(task.status, InflightActivationStatus::Failure); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_discard_after() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_discard_after(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -773,11 +868,15 @@ async fn test_handle_processing_deadline_discard_after() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_deadletter_after() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_deadletter_after(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -815,11 +914,15 @@ async fn test_handle_processing_deadline_deadletter_after() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_processing_deadline_no_retries_remaining() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_processing_deadline_no_retries_remaining(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(2); batch[1].status = InflightActivationStatus::Processing; @@ -857,12 +960,16 @@ async fn test_handle_processing_deadline_no_retries_remaining() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_processing_attempts_exceeded() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_processing_attempts_exceeded(#[case] adapter: &str) { let config = create_integration_config(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut batch = make_activations(3); batch[0].status = InflightActivationStatus::Pending; @@ -899,11 +1006,15 @@ async fn test_processing_attempts_exceeded() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_completed() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_completed(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(3); records[0].status = InflightActivationStatus::Complete; @@ -956,11 +1067,15 @@ async fn test_remove_completed() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_completed_multiple_gaps() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_completed_multiple_gaps(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(4); // only record 1 can be removed @@ -1027,11 +1142,15 @@ async fn test_remove_completed_multiple_gaps() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_failed_tasks() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_failed_tasks(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut records = make_activations(4); // deadletter @@ -1113,11 +1232,15 @@ async fn test_handle_failed_tasks() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_mark_completed() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_mark_completed(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let records = make_activations(3); assert!(store.store(records.clone()).await.is_ok()); @@ -1145,11 +1268,15 @@ async fn test_mark_completed() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_handle_expires_at() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_handle_expires_at(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(3); // All expired tasks should be removed, regardless of order or other tasks. @@ -1168,7 +1295,11 @@ async fn test_handle_expires_at() { .await; let result = store.handle_expires_at().await; - assert!(result.is_ok()); + assert!( + result.is_ok(), + "handle_expires_at should be ok {:?}", + result + ); assert_eq!(result.unwrap(), 2); assert_counts( StatusCount { @@ -1178,11 +1309,15 @@ async fn test_handle_expires_at() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_remove_killswitched() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_remove_killswitched(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let mut batch = make_activations(6); batch[0].taskname = "task_to_be_killswitched_one".to_string(); @@ -1216,11 +1351,15 @@ async fn test_remove_killswitched() { store.as_ref(), ) .await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_clear() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_clear(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let namespace = generate_unique_namespace(); #[allow(deprecated)] @@ -1275,28 +1414,37 @@ async fn test_clear() { assert!(store.clear().await.is_ok()); assert_eq!(store.count().await.unwrap(), 0); assert_counts(StatusCount::default(), store.as_ref()).await; + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_full_vacuum() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_full_vacuum(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.full_vacuum_db().await; assert!(result.is_ok()); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_vacuum_db_no_limit() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_vacuum_db_no_limit(#[case] adapter: &str) { + let store = create_test_store(adapter).await; let batch = make_activations(2); assert!(store.store(batch).await.is_ok()); let result = store.vacuum_db().await; assert!(result.is_ok()); + store.remove_db().await.unwrap(); } #[tokio::test] @@ -1320,8 +1468,11 @@ async fn test_vacuum_db_incremental() { } #[tokio::test] -async fn test_db_size() { - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_db_size(#[case] adapter: &str) { + let store = create_test_store(adapter).await; assert!(store.db_size().await.is_ok()); let first_size = store.db_size().await.unwrap(); @@ -1333,12 +1484,16 @@ async fn test_db_size() { let second_size = store.db_size().await.unwrap(); assert!(second_size > first_size, "should have more bytes now"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_no_pending() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_no_pending(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; // No activations, max lag is 0 assert_eq!(0.0, store.pending_activation_max_lag(&now).await); @@ -1348,12 +1503,16 @@ async fn test_pending_activation_max_lag_no_pending() { // No pending activations, max lag is 0 assert_eq!(0.0, store.pending_activation_max_lag(&now).await); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_use_oldest() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_use_oldest(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut pending = make_activations(2); pending[0].received_at = now - Duration::from_secs(10); @@ -1363,12 +1522,16 @@ async fn test_pending_activation_max_lag_use_oldest() { let result = store.pending_activation_max_lag(&now).await; assert!(11.0 < result, "Should not get the small record"); assert!(result < 501.0, "Should not get an inflated value"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_ignore_processing_attempts() { - let now = Utc::now(); - let store = create_test_store().await; +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_ignore_processing_attempts(#[case] adapter: &str) { + let now = Utc::now().round_subsecs(0); + let store = create_test_store(adapter).await; let mut pending = make_activations(2); pending[0].received_at = now - Duration::from_secs(10); @@ -1377,14 +1540,17 @@ async fn test_pending_activation_max_lag_ignore_processing_attempts() { assert!(store.store(pending).await.is_ok()); let result = store.pending_activation_max_lag(&now).await; - assert!(10.00 < result); - assert!(result < 11.00); + assert_eq!(result, 10.0, "max lag: {result:?}"); + store.remove_db().await.unwrap(); } #[tokio::test] -async fn test_pending_activation_max_lag_account_for_delayed() { +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +async fn test_pending_activation_max_lag_account_for_delayed(#[case] adapter: &str) { let now = Utc::now(); - let store = create_test_store().await; + let store = create_test_store(adapter).await; let mut pending = make_activations(2); // delayed tasks are received well before they become pending @@ -1395,7 +1561,8 @@ async fn test_pending_activation_max_lag_account_for_delayed() { let result = store.pending_activation_max_lag(&now).await; assert!(22.00 < result, "result: {result}"); - assert!(result < 23.00, "result: {result}"); + assert!(result < 24.00, "result: {result}"); + store.remove_db().await.unwrap(); } #[tokio::test] diff --git a/src/store/mod.rs b/src/store/mod.rs index dcc0f255..deb05655 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -1,3 +1,4 @@ pub mod inflight_activation; #[cfg(test)] pub mod inflight_activation_tests; +pub mod postgres_activation_store; diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs new file mode 100644 index 00000000..363bbda1 --- /dev/null +++ b/src/store/postgres_activation_store.rs @@ -0,0 +1,732 @@ +use crate::store::inflight_activation::{ + FailedTasksForwarder, InflightActivation, InflightActivationStatus, InflightActivationStore, + QueryResult, TableRow, +}; +use anyhow::{Error, anyhow}; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use sentry_protos::taskbroker::v1::OnAttemptsExceeded; +use sqlx::{ + Pool, Postgres, QueryBuilder, Row, + pool::PoolConnection, + postgres::{PgConnectOptions, PgPool, PgPoolOptions, PgRow}, +}; +use std::{str::FromStr, time::Instant}; +use tracing::{instrument, warn}; + +use crate::config::Config; + +pub async fn create_postgres_pool( + url: &str, + database_name: &str, +) -> Result<(Pool, Pool), Error> { + let conn_str = url.to_owned() + "/" + database_name; + let read_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + + let write_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + Ok((read_pool, write_pool)) +} + +pub async fn create_default_postgres_pool(url: &str) -> Result, Error> { + let conn_str = url.to_owned() + "/postgres"; + let read_pool = PgPoolOptions::new() + .max_connections(64) + .connect_with(PgConnectOptions::from_str(&conn_str)?) + .await?; + Ok(read_pool) +} + +pub struct PostgresActivationStoreConfig { + pub pg_url: String, + pub pg_database_name: String, + pub max_processing_attempts: usize, + pub processing_deadline_grace_sec: u64, + pub vacuum_page_count: Option, + pub enable_sqlite_status_metrics: bool, +} + +impl PostgresActivationStoreConfig { + pub fn from_config(config: &Config) -> Self { + Self { + pg_url: config.pg_url.clone(), + pg_database_name: config.pg_database_name.clone(), + max_processing_attempts: config.max_processing_attempts, + vacuum_page_count: config.vacuum_page_count, + processing_deadline_grace_sec: config.processing_deadline_grace_sec, + enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + } + } +} + +pub struct PostgresActivationStore { + read_pool: PgPool, + write_pool: PgPool, + config: PostgresActivationStoreConfig, +} + +impl PostgresActivationStore { + async fn acquire_write_conn_metric( + &self, + caller: &'static str, + ) -> Result, Error> { + let start = Instant::now(); + let conn = self.write_pool.acquire().await?; + metrics::histogram!("postgres.write.acquire_conn", "fn" => caller).record(start.elapsed()); + Ok(conn) + } + + pub async fn new(config: PostgresActivationStoreConfig) -> Result { + let default_pool = create_default_postgres_pool(&config.pg_url).await?; + + // Create the database if it doesn't exist + let row: (bool,) = sqlx::query_as( + "SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )", + ) + .bind(&config.pg_database_name) + .fetch_one(&default_pool) + .await?; + + if !row.0 { + println!("Creating database {}", &config.pg_database_name); + sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) + .bind(&config.pg_database_name) + .execute(&default_pool) + .await?; + } + // Close the default pool + default_pool.close().await; + + let (read_pool, write_pool) = + create_postgres_pool(&config.pg_url, &config.pg_database_name).await?; + sqlx::migrate!("./pg_migrations").run(&write_pool).await?; + + Ok(Self { + read_pool, + write_pool, + config, + }) + } +} + +#[async_trait] +impl InflightActivationStore for PostgresActivationStore { + /// Trigger incremental vacuum to reclaim free pages in the database. + /// Depending on config data, will either vacuum a set number of + /// pages or attempt to reclaim all free pages. + #[instrument(skip_all)] + async fn vacuum_db(&self) -> Result<(), Error> { + // TODO: Remove + Ok(()) + } + + /// Perform a full vacuum on the database. + async fn full_vacuum_db(&self) -> Result<(), Error> { + // TODO: Remove + Ok(()) + } + + /// Get the size of the database in bytes based on SQLite metadata queries. + async fn db_size(&self) -> Result { + let row_result: (i64,) = sqlx::query_as("SELECT pg_database_size($1) as size") + .bind(&self.config.pg_database_name) + .fetch_one(&self.read_pool) + .await?; + if row_result.0 < 0 { + return Ok(0); + } + Ok(row_result.0 as u64) + } + + /// Get an activation by id. Primarily used for testing + async fn get_by_id(&self, id: &str) -> Result, Error> { + let row_result: Option = sqlx::query_as( + " + SELECT id, + activation, + partition, + kafka_offset AS offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + FROM inflight_taskactivations + WHERE id = $1 + ", + ) + .bind(id) + .fetch_optional(&self.read_pool) + .await?; + + let Some(row) = row_result else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[instrument(skip_all)] + async fn store(&self, batch: Vec) -> Result { + if batch.is_empty() { + return Ok(QueryResult { rows_affected: 0 }); + } + let mut query_builder = QueryBuilder::::new( + " + INSERT INTO inflight_taskactivations + ( + id, + activation, + partition, + kafka_offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + ) + ", + ); + let rows = batch + .into_iter() + .map(TableRow::try_from) + .collect::, _>>()?; + let query = query_builder + .push_values(rows, |mut b, row| { + b.push_bind(row.id); + b.push_bind(row.activation); + b.push_bind(row.partition); + b.push_bind(row.offset); + b.push_bind(row.added_at); + b.push_bind(row.received_at); + b.push_bind(row.processing_attempts); + b.push_bind(Some(row.expires_at)); + b.push_bind(Some(row.delay_until)); + b.push_bind(row.processing_deadline_duration); + if let Some(deadline) = row.processing_deadline { + b.push_bind(deadline); + } else { + // Add a literal null + b.push("null"); + } + b.push_bind(row.status); + b.push_bind(row.at_most_once); + b.push_bind(row.application); + b.push_bind(row.namespace); + b.push_bind(row.taskname); + b.push_bind(row.on_attempts_exceeded as i32); + }) + .push(" ON CONFLICT(id) DO NOTHING") + .build(); + let mut conn = self.acquire_write_conn_metric("store").await?; + Ok(query.execute(&mut *conn).await?.into()) + } + + #[instrument(skip_all)] + async fn get_pending_activation( + &self, + application: Option<&str>, + namespace: Option<&str>, + ) -> Result, Error> { + // Convert single namespace to vector for internal use + let namespaces = namespace.map(|ns| vec![ns.to_string()]); + + // If a namespace filter is used, an application must also be used. + if namespaces.is_some() && application.is_none() { + warn!( + "Received request for namespaced task without application. namespaces = {namespaces:?}" + ); + return Ok(None); + } + let result = self + .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) + .await?; + if result.is_empty() { + return Ok(None); + } + Ok(Some(result[0].clone())) + } + + /// Get a pending activation from specified namespaces + /// If namespaces is None, gets from any namespace + /// If namespaces is Some(&[...]), gets from those namespaces + #[instrument(skip_all)] + async fn get_pending_activations_from_namespaces( + &self, + application: Option<&str>, + namespaces: Option<&[String]>, + limit: Option, + ) -> Result, Error> { + let now = Utc::now(); + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new( + "WITH selected_activations AS ( + SELECT id + FROM inflight_taskactivations + WHERE status = ", + ); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(" AND (expires_at IS NULL OR expires_at > "); + query_builder.push_bind(now); + query_builder.push(")"); + + // Handle application & namespace filtering + if let Some(value) = application { + query_builder.push(" AND application ="); + query_builder.push_bind(value); + } + if let Some(namespaces) = namespaces + && !namespaces.is_empty() + { + query_builder.push(" AND namespace IN ("); + let mut separated = query_builder.separated(", "); + for namespace in namespaces.iter() { + separated.push_bind(namespace); + } + query_builder.push(")"); + } + query_builder.push(" ORDER BY added_at"); + if let Some(limit) = limit { + query_builder.push(" LIMIT "); + query_builder.push_bind(limit); + } + query_builder.push(" FOR UPDATE SKIP LOCKED)"); + query_builder.push(format!( + "UPDATE inflight_taskactivations + SET + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + status = " + )); + query_builder.push_bind(InflightActivationStatus::Processing.to_string()); + query_builder.push(" FROM selected_activations "); + query_builder.push(" WHERE inflight_taskactivations.id = selected_activations.id"); + query_builder.push(" RETURNING *, kafka_offset AS offset"); + + let mut conn = self + .acquire_write_conn_metric("get_pending_activation") + .await?; + let rows: Vec = query_builder + .build_query_as::() + .fetch_all(&mut *conn) + .await?; + + Ok(rows.into_iter().map(|row| row.into()).collect()) + } + + /// Get the age of the oldest pending activation in seconds. + /// Only activations with status=pending and processing_attempts=0 are considered + /// as we are interested in latency to the *first* attempt. + /// Tasks with delay_until set, will have their age adjusted based on their + /// delay time. No tasks = 0 lag + async fn pending_activation_max_lag(&self, now: &DateTime) -> f64 { + let result = sqlx::query( + "SELECT received_at, delay_until + FROM inflight_taskactivations + WHERE status = $1 + AND processing_attempts = 0 + ORDER BY received_at ASC + LIMIT 1 + ", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .fetch_one(&self.read_pool) + .await; + if let Ok(row) = result { + let received_at: DateTime = row.get("received_at"); + let delay_until: Option> = row.get("delay_until"); + let millis = now.signed_duration_since(received_at).num_milliseconds() + - delay_until.map_or(0, |delay_time| { + delay_time + .signed_duration_since(received_at) + .num_milliseconds() + }); + millis as f64 / 1000.0 + } else { + // If we couldn't find a row, there is no latency. + 0.0 + } + } + + #[instrument(skip_all)] + async fn count_pending_activations(&self) -> Result { + self.count_by_status(InflightActivationStatus::Pending) + .await + } + + #[instrument(skip_all)] + async fn count_by_status(&self, status: InflightActivationStatus) -> Result { + let result = + sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations WHERE status = $1") + .bind(status.to_string()) + .fetch_one(&self.read_pool) + .await?; + Ok(result.get::("count") as usize) + } + + async fn count(&self) -> Result { + let result = sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations") + .fetch_one(&self.read_pool) + .await?; + Ok(result.get::("count") as usize) + } + + /// Update the status of a specific activation + #[instrument(skip_all)] + async fn set_status( + &self, + id: &str, + status: InflightActivationStatus, + ) -> Result, Error> { + let mut conn = self.acquire_write_conn_metric("set_status").await?; + let result: Option = sqlx::query_as( + "UPDATE inflight_taskactivations SET status = $1 WHERE id = $2 RETURNING *, kafka_offset AS offset", + ) + .bind(status.to_string()) + .bind(id) + .fetch_optional(&mut *conn) + .await?; + println!("result: {:?}", result); + let Some(row) = result else { + return Ok(None); + }; + + Ok(Some(row.into())) + } + + #[instrument(skip_all)] + async fn set_processing_deadline( + &self, + id: &str, + deadline: Option>, + ) -> Result<(), Error> { + let mut conn = self + .acquire_write_conn_metric("set_processing_deadline") + .await?; + sqlx::query("UPDATE inflight_taskactivations SET processing_deadline = $1 WHERE id = $2") + .bind(deadline.unwrap()) + .bind(id) + .execute(&mut *conn) + .await?; + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_activation(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("delete_activation").await?; + sqlx::query("DELETE FROM inflight_taskactivations WHERE id = $1") + .bind(id) + .execute(&mut *conn) + .await?; + Ok(()) + } + + #[instrument(skip_all)] + async fn get_retry_activations(&self) -> Result, Error> { + Ok(sqlx::query_as( + " + SELECT id, + activation, + partition, + kafka_offset AS offset, + added_at, + received_at, + processing_attempts, + expires_at, + delay_until, + processing_deadline_duration, + processing_deadline, + status, + at_most_once, + application, + namespace, + taskname, + on_attempts_exceeded + FROM inflight_taskactivations + WHERE status = $1 + ", + ) + .bind(InflightActivationStatus::Retry.to_string()) + .fetch_all(&self.read_pool) + .await? + .into_iter() + .map(|row: TableRow| row.into()) + .collect()) + } + + // Used in tests + async fn clear(&self) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("clear").await?; + sqlx::query("TRUNCATE TABLE inflight_taskactivations") + .execute(&mut *conn) + .await?; + + Ok(()) + } + + /// Update tasks that are in processing and have exceeded their processing deadline + /// Exceeding a processing deadline does not consume a retry as we don't know + /// if a worker took the task and was killed, or failed. + #[instrument(skip_all)] + async fn handle_processing_deadline(&self) -> Result { + let now = Utc::now(); + let mut atomic = self.write_pool.begin().await?; + + // Idempotent tasks that fail their processing deadlines go directly to failure + // there are no retries, as the worker will reject the task due to idempotency keys. + let most_once_result = sqlx::query( + "UPDATE inflight_taskactivations + SET processing_deadline = null, status = $1 + WHERE processing_deadline < $2 AND at_most_once = TRUE AND status = $3", + ) + .bind(InflightActivationStatus::Failure.to_string()) + .bind(now) + .bind(InflightActivationStatus::Processing.to_string()) + .execute(&mut *atomic) + .await; + + let mut processing_deadline_modified_rows = 0; + if let Ok(query_res) = most_once_result { + processing_deadline_modified_rows = query_res.rows_affected(); + } + + // Update non-idempotent tasks. + // Increment processing_attempts by 1 and reset processing_deadline to null. + let result = sqlx::query( + "UPDATE inflight_taskactivations + SET processing_deadline = null, status = $1, processing_attempts = processing_attempts + 1 + WHERE processing_deadline < $2 AND status = $3", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .bind(InflightActivationStatus::Processing.to_string()) + .execute(&mut *atomic) + .await; + + atomic.commit().await?; + + if let Ok(query_res) = result { + processing_deadline_modified_rows += query_res.rows_affected(); + return Ok(processing_deadline_modified_rows); + } + + Err(anyhow!("Could not update tasks past processing_deadline")) + } + + /// Update tasks that have exceeded their max processing attempts. + /// These tasks are set to status=failure and will be handled by handle_failed_tasks accordingly. + #[instrument(skip_all)] + async fn handle_processing_attempts(&self) -> Result { + let mut conn = self + .acquire_write_conn_metric("handle_processing_attempts") + .await?; + let processing_attempts_result = sqlx::query( + "UPDATE inflight_taskactivations + SET status = $1 + WHERE processing_attempts >= $2 AND status = $3", + ) + .bind(InflightActivationStatus::Failure.to_string()) + .bind(self.config.max_processing_attempts as i32) + .bind(InflightActivationStatus::Pending.to_string()) + .execute(&mut *conn) + .await; + + if let Ok(query_res) = processing_attempts_result { + return Ok(query_res.rows_affected()); + } + + Err(anyhow!("Could not update tasks past processing_deadline")) + } + + /// Perform upkeep work for tasks that are past expires_at deadlines + /// + /// Tasks that are pending and past their expires_at deadline are updated + /// to have status=failure so that they can be discarded/deadlettered by handle_failed_tasks + /// + /// The number of impacted records is returned in a Result. + #[instrument(skip_all)] + async fn handle_expires_at(&self) -> Result { + let now = Utc::now(); + let mut conn = self.acquire_write_conn_metric("handle_expires_at").await?; + let query = sqlx::query( + "DELETE FROM inflight_taskactivations WHERE status = $1 AND expires_at IS NOT NULL AND expires_at < $2", + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .execute(&mut *conn) + .await?; + + Ok(query.rows_affected()) + } + + /// Perform upkeep work for tasks that are past delay_until deadlines + /// + /// Tasks that are delayed and past their delay_until deadline are updated + /// to have status=pending so that they can be executed by workers + /// + /// The number of impacted records is returned in a Result. + #[instrument(skip_all)] + async fn handle_delay_until(&self) -> Result { + let now = Utc::now(); + let mut conn = self.acquire_write_conn_metric("handle_delay_until").await?; + let update_result = sqlx::query( + r#"UPDATE inflight_taskactivations + SET status = $1 + WHERE delay_until IS NOT NULL AND delay_until < $2 AND status = $3 + "#, + ) + .bind(InflightActivationStatus::Pending.to_string()) + .bind(now) + .bind(InflightActivationStatus::Delay.to_string()) + .execute(&mut *conn) + .await?; + + Ok(update_result.rows_affected()) + } + + /// Perform upkeep work related to status=failure + /// + /// Activations that are status=failure need to either be discarded by setting status=complete + /// or need to be moved to deadletter and are returned in the Result. + /// Once dead-lettered tasks have been added to Kafka those tasks can have their status set to + /// complete. + #[instrument(skip_all)] + async fn handle_failed_tasks(&self) -> Result { + let mut atomic = self.write_pool.begin().await?; + + let failed_tasks: Vec = + sqlx::query("SELECT id, activation, on_attempts_exceeded FROM inflight_taskactivations WHERE status = $1") + .bind(InflightActivationStatus::Failure.to_string()) + .fetch_all(&mut *atomic) + .await? + .into_iter() + .collect(); + + let mut forwarder = FailedTasksForwarder { + to_discard: vec![], + to_deadletter: vec![], + }; + + for record in failed_tasks.iter() { + let activation_data: &[u8] = record.get("activation"); + let id: String = record.get("id"); + // We could be deadlettering because of activation.expires + // when a task expires we still deadletter if configured. + let on_attempts_exceeded_val: i32 = record.get("on_attempts_exceeded"); + let on_attempts_exceeded: OnAttemptsExceeded = + on_attempts_exceeded_val.try_into().unwrap(); + if on_attempts_exceeded == OnAttemptsExceeded::Discard + || on_attempts_exceeded == OnAttemptsExceeded::Unspecified + { + forwarder.to_discard.push((id, activation_data.to_vec())) + } else if on_attempts_exceeded == OnAttemptsExceeded::Deadletter { + forwarder.to_deadletter.push((id, activation_data.to_vec())) + } + } + + if !forwarder.to_discard.is_empty() { + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations "); + query_builder + .push("SET status = ") + .push_bind(InflightActivationStatus::Complete.to_string()) + .push(" WHERE id IN ("); + + let mut separated = query_builder.separated(", "); + for (id, _body) in forwarder.to_discard.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + query_builder.build().execute(&mut *atomic).await?; + } + + atomic.commit().await?; + + Ok(forwarder) + } + + /// Mark a collection of tasks as complete by id + #[instrument(skip_all)] + async fn mark_completed(&self, ids: Vec) -> Result { + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations "); + query_builder + .push("SET status = ") + .push_bind(InflightActivationStatus::Complete.to_string()) + .push(" WHERE id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + let mut conn = self.acquire_write_conn_metric("mark_completed").await?; + let result = query_builder.build().execute(&mut *conn).await?; + + Ok(result.rows_affected()) + } + + /// Remove completed tasks. + /// This method is a garbage collector for the inflight task store. + #[instrument(skip_all)] + async fn remove_completed(&self) -> Result { + let mut conn = self.acquire_write_conn_metric("remove_completed").await?; + let query = sqlx::query("DELETE FROM inflight_taskactivations WHERE status = $1") + .bind(InflightActivationStatus::Complete.to_string()) + .execute(&mut *conn) + .await?; + + Ok(query.rows_affected()) + } + + /// Remove killswitched tasks. + #[instrument(skip_all)] + async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result { + let mut query_builder = + QueryBuilder::new("DELETE FROM inflight_taskactivations WHERE taskname IN ("); + let mut separated = query_builder.separated(", "); + for taskname in killswitched_tasks.iter() { + separated.push_bind(taskname); + } + separated.push_unseparated(")"); + let mut conn = self + .acquire_write_conn_metric("remove_killswitched") + .await?; + let query = query_builder.build().execute(&mut *conn).await?; + + Ok(query.rows_affected()) + } + + // Used in tests + async fn remove_db(&self) -> Result<(), Error> { + self.read_pool.close().await; + self.write_pool.close().await; + let default_pool = create_default_postgres_pool(&self.config.pg_url).await?; + let _ = sqlx::query(format!("DROP DATABASE {}", &self.config.pg_database_name).as_str()) + .bind(&self.config.pg_database_name) + .execute(&default_pool) + .await; + let _ = default_pool.close().await; + Ok(()) + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs index 9df8ba05..eefc7540 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,13 +1,12 @@ use futures::StreamExt; use prost::Message as ProstMessage; -use rand::Rng; use rdkafka::{ Message, admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}, - consumer::{Consumer, StreamConsumer}, + consumer::{CommitMode, Consumer, StreamConsumer}, producer::FutureProducer, }; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, env::var, sync::Arc}; use uuid::Uuid; use crate::{ @@ -16,14 +15,25 @@ use crate::{ InflightActivation, InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, }, + store::postgres_activation_store::{PostgresActivationStore, PostgresActivationStoreConfig}, }; use chrono::{Timelike, Utc}; use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivation}; -/// Generate a unique filename for isolated SQLite databases. +pub fn get_pg_url() -> String { + var("TASKBROKER_PG_URL").unwrap_or("postgres://postgres:password@localhost:5432/".to_string()) +} + +pub fn get_pg_database_name() -> String { + let random_name = format!("a{}", Uuid::new_v4().to_string().replace("-", "")); + var("TASKBROKER_PG_DATABASE_NAME").unwrap_or(random_name) +} + pub fn generate_temp_filename() -> String { - let mut rng = rand::thread_rng(); - format!("/var/tmp/{}-{}.sqlite", Utc::now(), rng.r#gen::()) + format!( + "/tmp/taskbroker-test-{}", + Uuid::new_v4().to_string().replace("-", "") + ) } /// Generate a unique alphanumeric string for namespaces (and possibly other purposes). @@ -90,15 +100,25 @@ pub fn create_config() -> Arc { } /// Create an InflightActivationStore instance -pub async fn create_test_store() -> Arc { - Arc::new( - SqliteActivationStore::new( - &generate_temp_filename(), - InflightActivationStoreConfig::from_config(&create_integration_config()), - ) - .await - .unwrap(), - ) +pub async fn create_test_store(adapter: &str) -> Arc { + match adapter { + "sqlite" => Arc::new( + SqliteActivationStore::new( + &generate_temp_filename(), + InflightActivationStoreConfig::from_config(&create_integration_config()), + ) + .await + .unwrap(), + ) as Arc, + "postgres" => Arc::new( + PostgresActivationStore::new(PostgresActivationStoreConfig::from_config( + &create_integration_config(), + )) + .await + .unwrap(), + ) as Arc, + _ => panic!("Invalid adapter: {}", adapter), + } } /// Create a Config instance that uses a testing topic @@ -106,6 +126,8 @@ pub async fn create_test_store() -> Arc { /// with [`reset_topic`] pub fn create_integration_config() -> Arc { let config = Config { + pg_url: get_pg_url(), + pg_database_name: get_pg_database_name(), kafka_topic: "taskbroker-test".into(), kafka_auto_offset_reset: "earliest".into(), ..Config::default() @@ -114,6 +136,18 @@ pub fn create_integration_config() -> Arc { Arc::new(config) } +pub fn create_integration_config_with_topic(topic: String) -> Arc { + let config = Config { + pg_url: get_pg_url(), + pg_database_name: get_pg_database_name(), + kafka_topic: topic, + kafka_auto_offset_reset: "earliest".into(), + ..Config::default() + }; + + Arc::new(config) +} + /// Create a kafka producer for a given config pub fn create_producer(config: Arc) -> Arc { let producer: FutureProducer = config @@ -166,6 +200,7 @@ pub async fn consume_topic( let mut stream = consumer.stream(); let mut results: Vec = vec![]; + let mut last_message = None; let start = Utc::now(); loop { let current = Utc::now(); @@ -187,8 +222,12 @@ pub async fn consume_topic( let payload = message.payload().expect("Could not fetch message payload"); let activation = TaskActivation::decode(payload).unwrap(); results.push(activation); + last_message = Some(message); + } + // Commit the last message's offset so subsequent calls start from the next message + if let Some(msg) = last_message { + consumer.commit_message(&msg, CommitMode::Sync).unwrap(); } - results } diff --git a/src/upkeep.rs b/src/upkeep.rs index dbcfeb1e..fc485e13 100644 --- a/src/upkeep.rs +++ b/src/upkeep.rs @@ -510,6 +510,7 @@ mod tests { use chrono::{DateTime, TimeDelta, TimeZone, Utc}; use prost::Message; use prost_types::Timestamp; + use rstest::rstest; use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivation}; use std::sync::Arc; use std::time::Duration; @@ -520,29 +521,15 @@ mod tests { use crate::{ config::Config, runtime_config::RuntimeConfigManager, - store::inflight_activation::{ - InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, - SqliteActivationStore, - }, + store::inflight_activation::InflightActivationStatus, test_utils::{ StatusCount, assert_counts, consume_topic, create_config, create_integration_config, - create_producer, generate_temp_filename, make_activations, replace_retry_state, - reset_topic, + create_integration_config_with_topic, create_producer, create_test_store, + make_activations, replace_retry_state, reset_topic, }, upkeep::{create_retry_activation, do_upkeep}, }; - async fn create_inflight_store() -> Arc { - let url = generate_temp_filename(); - let config = create_integration_config(); - - Arc::new( - SqliteActivationStore::new(&url, InflightActivationStoreConfig::from_config(&config)) - .await - .unwrap(), - ) - } - #[tokio::test] async fn test_retry_activation_sets_delay_with_delay_on_retry() { let inflight = make_activations(1).remove(0); @@ -625,14 +612,17 @@ mod tests { } #[tokio::test] - async fn test_retry_activation_is_appended_to_kafka() { - let config = create_integration_config(); + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_retry_activation_is_appended_to_kafka(#[case] adapter: &str) { + let config = create_integration_config_with_topic(format!("taskbroker-test-{}", adapter)); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); reset_topic(config.clone()).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let mut records = make_activations(2); @@ -706,10 +696,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_retains_future_deadline() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_retains_future_deadline(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -741,10 +734,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_skip_past_deadline_after_startup() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_skip_past_deadline_after_startup(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let mut batch = make_activations(2); @@ -792,10 +788,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_updates_past_deadline() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_updates_past_deadline(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -805,6 +804,7 @@ mod tests { batch[1].status = InflightActivationStatus::Processing; batch[1].processing_deadline = Some(Utc.with_ymd_and_hms(2024, 11, 14, 21, 22, 23).unwrap()); + batch[1].processing_attempts = 0; assert!(store.store(batch.clone()).await.is_ok()); // Should start off with one in processing @@ -815,6 +815,13 @@ mod tests { .unwrap(), 1 ); + assert_eq!( + store + .count_by_status(InflightActivationStatus::Pending) + .await + .unwrap(), + 1 + ); let result_context = do_upkeep( config, @@ -826,6 +833,7 @@ mod tests { ) .await; + println!("result_context: {:?}", result_context); // 0 processing, 2 pending now assert_eq!(result_context.processing_deadline_reset, 1); assert_counts( @@ -840,10 +848,13 @@ mod tests { } #[tokio::test] - async fn test_processing_deadline_discard_at_most_once() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_deadline_discard_at_most_once(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now(); @@ -890,10 +901,13 @@ mod tests { } #[tokio::test] - async fn test_processing_attempts_exceeded_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_processing_attempts_exceeded_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -941,12 +955,15 @@ mod tests { } #[tokio::test] - async fn test_remove_at_remove_failed_publish_to_kafka() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_at_remove_failed_publish_to_kafka(#[case] adapter: &str) { let config = create_integration_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); reset_topic(config.clone()).await; - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -992,10 +1009,13 @@ mod tests { } #[tokio::test] - async fn test_remove_failed_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_failed_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1033,10 +1053,13 @@ mod tests { } #[tokio::test] - async fn test_expired_discard() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_expired_discard(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1100,10 +1123,13 @@ mod tests { } #[tokio::test] - async fn test_delay_elapsed() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_delay_elapsed(#[case] adapter: &str) { let config = create_config(); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1195,7 +1221,10 @@ mod tests { } #[tokio::test] - async fn test_forward_demoted_namespaces() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_forward_demoted_namespaces(#[case] adapter: &str) { // Create runtime config with demoted namespaces let config = create_config(); let test_yaml = r#" @@ -1209,7 +1238,7 @@ demoted_namespaces: fs::write(test_path, test_yaml).await.unwrap(); let runtime_config = Arc::new(RuntimeConfigManager::new(Some(test_path.to_string())).await); let producer = create_producer(config.clone()); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1246,11 +1275,14 @@ demoted_namespaces: 2, "two tasks should be marked as complete" ); - fs::remove_file(test_path).await.unwrap(); + let _ = fs::remove_file(test_path).await; } #[tokio::test] - async fn test_remove_killswitched() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_remove_killswitched(#[case] adapter: &str) { let config = create_config(); let test_yaml = r#" drop_task_killswitch: @@ -1263,7 +1295,7 @@ demoted_namespaces: let runtime_config = Arc::new(RuntimeConfigManager::new(Some(test_path.to_string())).await); let producer = create_producer(config.clone()); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let start_time = Utc::now(); let mut last_vacuum = Instant::now(); @@ -1294,11 +1326,14 @@ demoted_namespaces: 3 ); - fs::remove_file(test_path).await.unwrap(); + let _ = fs::remove_file(test_path).await; } #[tokio::test] - async fn test_full_vacuum_on_upkeep() { + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_full_vacuum_on_upkeep(#[case] adapter: &str) { let raw_config = Config { full_vacuum_on_start: true, ..Default::default() @@ -1306,7 +1341,7 @@ demoted_namespaces: let config = Arc::new(raw_config); let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let store = create_inflight_store().await; + let store = create_test_store(adapter).await; let producer = create_producer(config.clone()); let start_time = Utc::now() - Duration::from_secs(90); let mut last_vacuum = Instant::now() - Duration::from_secs(60); From 04b13bd36825a32579e2b5f1e608b029f2158662 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Tue, 13 Jan 2026 16:45:17 -0500 Subject: [PATCH 02/17] remove unnecessary Option --- src/store/postgres_activation_store.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index 363bbda1..dd86697b 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -221,8 +221,8 @@ impl InflightActivationStore for PostgresActivationStore { b.push_bind(row.added_at); b.push_bind(row.received_at); b.push_bind(row.processing_attempts); - b.push_bind(Some(row.expires_at)); - b.push_bind(Some(row.delay_until)); + b.push_bind(row.expires_at); + b.push_bind(row.delay_until); b.push_bind(row.processing_deadline_duration); if let Some(deadline) = row.processing_deadline { b.push_bind(deadline); From 10196557c84bb25b5849e5c05a7ec1116f016ec4 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Wed, 14 Jan 2026 16:59:49 -0500 Subject: [PATCH 03/17] some fixes --- .../0001_create_inflight_activations.sql | 2 +- src/config.rs | 29 ++++++++++-- src/grpc/server_tests.rs | 1 - src/main.rs | 8 ++-- src/store/postgres_activation_store.rs | 46 +++++++++++-------- src/test_utils.rs | 2 + src/upkeep.rs | 1 - 7 files changed, 59 insertions(+), 30 deletions(-) diff --git a/pg_migrations/0001_create_inflight_activations.sql b/pg_migrations/0001_create_inflight_activations.sql index 80b552db..ee8b26a4 100644 --- a/pg_migrations/0001_create_inflight_activations.sql +++ b/pg_migrations/0001_create_inflight_activations.sql @@ -13,7 +13,7 @@ CREATE TABLE IF NOT EXISTS inflight_taskactivations ( processing_deadline TIMESTAMPTZ, status TEXT NOT NULL, at_most_once BOOLEAN NOT NULL DEFAULT FALSE, - application TEXT NOT NULL DEFAULT '', + application TEXT NOT NULL, namespace TEXT NOT NULL, taskname TEXT NOT NULL, on_attempts_exceeded INTEGER NOT NULL DEFAULT 1 diff --git a/src/config.rs b/src/config.rs index af1d2fc1..8c163c57 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,6 +8,16 @@ use std::{borrow::Cow, collections::BTreeMap}; use crate::{Args, logging::LogFormat}; +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum DatabaseAdapter { + /// SQLite database adapter + Sqlite, + + /// PostgreSQL database adapter + Postgres, +} + #[derive(PartialEq, Debug, Deserialize, Serialize)] pub struct Config { /// The sentry DSN to use for error reporting. @@ -121,7 +131,13 @@ pub struct Config { /// The number of ms for timeouts when publishing messages to kafka. pub kafka_send_timeout_ms: u64, - pub database_adapter: &'static str, + /// The database adapter to use for the inflight activation store. + pub database_adapter: DatabaseAdapter, + + /// Whether to run the migrations on the database. + /// This is only used by the postgres database adapter, since + /// in production the migrations shouldn't be run by the taskbroker. + pub run_migrations: bool, /// The url of the postgres database to use for the inflight activation store. pub pg_url: String, @@ -264,7 +280,8 @@ impl Default for Config { kafka_auto_offset_reset: "latest".to_owned(), kafka_send_timeout_ms: 500, db_path: "./taskbroker-inflight.sqlite".to_owned(), - database_adapter: "sqlite", + database_adapter: DatabaseAdapter::Sqlite, + run_migrations: false, pg_url: "postgres://postgres:password@sentry-postgres-1:5432/".to_owned(), pg_database_name: "taskbroker".to_owned(), db_write_failure_backoff_ms: 4000, @@ -404,7 +421,7 @@ impl Provider for Config { mod tests { use std::{borrow::Cow, collections::BTreeMap}; - use super::Config; + use super::{Config, DatabaseAdapter}; use crate::{Args, logging::LogFormat}; use figment::Jail; @@ -437,11 +454,12 @@ mod tests { log_format: json statsd_addr: 127.0.0.1:8126 default_metrics_tags: - key_1: value_1 + key_1: value_1 kafka_cluster: 10.0.0.1:9092,10.0.0.2:9092 kafka_topic: error-tasks kafka_deadletter_topic: error-tasks-dlq kafka_auto_offset_reset: earliest + database_adapter: postgres db_path: ./taskbroker-error.sqlite db_max_size: 3000000000 max_pending_count: 512 @@ -475,6 +493,7 @@ mod tests { assert_eq!(config.kafka_session_timeout_ms, 6000.to_owned()); assert_eq!(config.kafka_topic, "error-tasks".to_owned()); assert_eq!(config.kafka_deadletter_topic, "error-tasks-dlq".to_owned()); + assert_eq!(config.database_adapter, DatabaseAdapter::Postgres); assert_eq!(config.db_path, "./taskbroker-error.sqlite".to_owned()); assert_eq!(config.max_pending_count, 512); assert_eq!(config.max_processing_count, 512); @@ -491,11 +510,13 @@ mod tests { fn test_from_args_env_and_args() { Jail::expect_with(|jail| { jail.set_env("TASKBROKER_LOG_FILTER", "error"); + jail.set_env("TASKBROKER_DATABASE_ADAPTER", "postgres"); jail.set_env("TASKBROKER_MAX_PROCESSING_ATTEMPTS", "5"); let args = Args { config: None }; let config = Config::from_args(&args).unwrap(); assert_eq!(config.log_filter, "error"); + assert_eq!(config.database_adapter, DatabaseAdapter::Postgres); assert_eq!(config.max_processing_attempts, 5); Ok(()) diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index c72b1b50..b99911a2 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -178,7 +178,6 @@ async fn test_set_task_status_success(#[case] adapter: &str) { }), }; let response = service.set_task_status(Request::new(request)).await; - println!("response: {:?}", response); assert!(response.is_ok()); let resp = response.unwrap(); assert!(resp.get_ref().task.is_some()); diff --git a/src/main.rs b/src/main.rs index c1221d1c..9a58ef39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,7 @@ use tracing::{debug, error, info, warn}; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerServiceServer; use taskbroker::SERVICE_NAME; -use taskbroker::config::Config; +use taskbroker::config::{Config, DatabaseAdapter}; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; use taskbroker::grpc::server::TaskbrokerServer; @@ -67,18 +67,18 @@ async fn main() -> Result<(), Error> { metrics::init(metrics::MetricsConfig::from_config(&config)); let store: Arc = match config.database_adapter { - "sqlite" => Arc::new( + DatabaseAdapter::Sqlite => Arc::new( SqliteActivationStore::new( &config.db_path, InflightActivationStoreConfig::from_config(&config), ) .await?, ), - "postgres" => Arc::new( + DatabaseAdapter::Postgres => Arc::new( PostgresActivationStore::new(PostgresActivationStoreConfig::from_config(&config)) .await?, ), - _ => panic!("Invalid database adapter: {}", config.database_adapter), + _ => panic!("Invalid database adapter: {:?}", config.database_adapter), }; // If this is an environment where the topics might not exist, check and create them. diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index dd86697b..3af96e15 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -45,6 +45,7 @@ pub async fn create_default_postgres_pool(url: &str) -> Result, E pub struct PostgresActivationStoreConfig { pub pg_url: String, pub pg_database_name: String, + pub run_migrations: bool, pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, pub vacuum_page_count: Option, @@ -56,6 +57,7 @@ impl PostgresActivationStoreConfig { Self { pg_url: config.pg_url.clone(), pg_database_name: config.pg_database_name.clone(), + run_migrations: config.run_migrations, max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, processing_deadline_grace_sec: config.processing_deadline_grace_sec, @@ -82,29 +84,35 @@ impl PostgresActivationStore { } pub async fn new(config: PostgresActivationStoreConfig) -> Result { - let default_pool = create_default_postgres_pool(&config.pg_url).await?; - - // Create the database if it doesn't exist - let row: (bool,) = sqlx::query_as( - "SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )", - ) - .bind(&config.pg_database_name) - .fetch_one(&default_pool) - .await?; + if config.run_migrations { + let default_pool = create_default_postgres_pool(&config.pg_url).await?; + + // Create the database if it doesn't exist + let row: (bool,) = sqlx::query_as( + "SELECT EXISTS ( SELECT 1 FROM pg_catalog.pg_database WHERE datname = $1 )", + ) + .bind(&config.pg_database_name) + .fetch_one(&default_pool) + .await?; - if !row.0 { - println!("Creating database {}", &config.pg_database_name); - sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) - .bind(&config.pg_database_name) - .execute(&default_pool) - .await?; + if !row.0 { + println!("Creating database {}", &config.pg_database_name); + sqlx::query(format!("CREATE DATABASE {}", &config.pg_database_name).as_str()) + .bind(&config.pg_database_name) + .execute(&default_pool) + .await?; + } + // Close the default pool + default_pool.close().await; } - // Close the default pool - default_pool.close().await; let (read_pool, write_pool) = create_postgres_pool(&config.pg_url, &config.pg_database_name).await?; - sqlx::migrate!("./pg_migrations").run(&write_pool).await?; + + if config.run_migrations { + println!("Running migrations on database"); + sqlx::migrate!("./pg_migrations").run(&write_pool).await?; + } Ok(Self { read_pool, @@ -407,7 +415,7 @@ impl InflightActivationStore for PostgresActivationStore { .bind(id) .fetch_optional(&mut *conn) .await?; - println!("result: {:?}", result); + let Some(row) = result else { return Ok(None); }; diff --git a/src/test_utils.rs b/src/test_utils.rs index eefc7540..16a6d99a 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -128,6 +128,7 @@ pub fn create_integration_config() -> Arc { let config = Config { pg_url: get_pg_url(), pg_database_name: get_pg_database_name(), + run_migrations: true, kafka_topic: "taskbroker-test".into(), kafka_auto_offset_reset: "earliest".into(), ..Config::default() @@ -140,6 +141,7 @@ pub fn create_integration_config_with_topic(topic: String) -> Arc { let config = Config { pg_url: get_pg_url(), pg_database_name: get_pg_database_name(), + run_migrations: true, kafka_topic: topic, kafka_auto_offset_reset: "earliest".into(), ..Config::default() diff --git a/src/upkeep.rs b/src/upkeep.rs index fc485e13..19c452d1 100644 --- a/src/upkeep.rs +++ b/src/upkeep.rs @@ -833,7 +833,6 @@ mod tests { ) .await; - println!("result_context: {:?}", result_context); // 0 processing, 2 pending now assert_eq!(result_context.processing_deadline_reset, 1); assert_counts( From 001ad1ad1e86d4a714f4288bfb4be5077a889eb2 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 11:40:12 -0500 Subject: [PATCH 04/17] updates --- default_migrations/0001_create_database.sql | 1 - src/config.rs | 2 +- src/store/postgres_activation_store.rs | 6 +++--- test_forward_task_due_to_demoted_namespace.yaml | 7 +++++++ 4 files changed, 11 insertions(+), 5 deletions(-) delete mode 100644 default_migrations/0001_create_database.sql create mode 100644 test_forward_task_due_to_demoted_namespace.yaml diff --git a/default_migrations/0001_create_database.sql b/default_migrations/0001_create_database.sql deleted file mode 100644 index 00d61748..00000000 --- a/default_migrations/0001_create_database.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE DATABASE taskbroker; diff --git a/src/config.rs b/src/config.rs index 8c163c57..8ee0436e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -283,7 +283,7 @@ impl Default for Config { database_adapter: DatabaseAdapter::Sqlite, run_migrations: false, pg_url: "postgres://postgres:password@sentry-postgres-1:5432/".to_owned(), - pg_database_name: "taskbroker".to_owned(), + pg_database_name: "default".to_owned(), db_write_failure_backoff_ms: 4000, db_insert_batch_max_len: 256, db_insert_batch_max_size: 16_000_000, diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index 3af96e15..d87100a0 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -501,8 +501,8 @@ impl InflightActivationStore for PostgresActivationStore { let now = Utc::now(); let mut atomic = self.write_pool.begin().await?; - // Idempotent tasks that fail their processing deadlines go directly to failure - // there are no retries, as the worker will reject the task due to idempotency keys. + // At-most-once tasks that fail their processing deadlines go directly to failure + // there are no retries, as the worker will reject the task due to at_most_once keys. let most_once_result = sqlx::query( "UPDATE inflight_taskactivations SET processing_deadline = null, status = $1 @@ -519,7 +519,7 @@ impl InflightActivationStore for PostgresActivationStore { processing_deadline_modified_rows = query_res.rows_affected(); } - // Update non-idempotent tasks. + // Update regular tasks. // Increment processing_attempts by 1 and reset processing_deadline to null. let result = sqlx::query( "UPDATE inflight_taskactivations diff --git a/test_forward_task_due_to_demoted_namespace.yaml b/test_forward_task_due_to_demoted_namespace.yaml new file mode 100644 index 00000000..b3dcb5c3 --- /dev/null +++ b/test_forward_task_due_to_demoted_namespace.yaml @@ -0,0 +1,7 @@ + +drop_task_killswitch: + - +demoted_namespaces: + - bad_namespace +demoted_topic_cluster: 0.0.0.0:9092 +demoted_topic: taskworker-demoted From 010bab31b5f725ee0677fa0edbf6d4942763ca26 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 13:01:20 -0500 Subject: [PATCH 05/17] nit --- src/store/postgres_activation_store.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index d87100a0..ad1f67da 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -35,11 +35,11 @@ pub async fn create_postgres_pool( pub async fn create_default_postgres_pool(url: &str) -> Result, Error> { let conn_str = url.to_owned() + "/postgres"; - let read_pool = PgPoolOptions::new() + let default_pool = PgPoolOptions::new() .max_connections(64) .connect_with(PgConnectOptions::from_str(&conn_str)?) .await?; - Ok(read_pool) + Ok(default_pool) } pub struct PostgresActivationStoreConfig { From 51e7649e0515f0477818204ca8e10d3610e224e4 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 14:27:08 -0500 Subject: [PATCH 06/17] lint --- src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 9a58ef39..7970939d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,7 +78,6 @@ async fn main() -> Result<(), Error> { PostgresActivationStore::new(PostgresActivationStoreConfig::from_config(&config)) .await?, ), - _ => panic!("Invalid database adapter: {:?}", config.database_adapter), }; // If this is an environment where the topics might not exist, check and create them. From a7eaa6fb62ec6d0a177425f148a138e688046dbd Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 14:48:53 -0500 Subject: [PATCH 07/17] add postgres to devservices dependencies --- devservices/config.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/devservices/config.yml b/devservices/config.yml index 839aa20c..d1d701de 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -15,12 +15,14 @@ x-sentry-service-config: repo_name: sentry-shared-redis branch: main repo_link: https://github.com/getsentry/sentry-shared-redis.git + postgres: + description: Shared instance of postgres used by sentry services taskbroker: description: Taskbroker service modes: default: [kafka] client: [kafka, redis] - containerized: [kafka, redis, taskbroker] + containerized: [kafka, redis, postgres, taskbroker] x-programs: devserver: From 234e722d67f4cbedd7903d2f135cb0b3fd035147 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 15:23:30 -0500 Subject: [PATCH 08/17] try to get postgres working --- devservices/config.yml | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/devservices/config.yml b/devservices/config.yml index d1d701de..b40ce7f6 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -20,8 +20,8 @@ x-sentry-service-config: taskbroker: description: Taskbroker service modes: - default: [kafka] - client: [kafka, redis] + default: [kafka, postgres] + client: [kafka, redis, postgres] containerized: [kafka, redis, postgres, taskbroker] x-programs: @@ -42,6 +42,38 @@ services: - orchestrator=devservices restart: unless-stopped platform: linux/amd64 + services: + postgres: + image: ghcr.io/getsentry/image-mirror-library-postgres:14-alpine + environment: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: sentry + command: + [ + postgres, + -c, + wal_level=logical, + -c, + max_replication_slots=1, + -c, + max_wal_senders=1, + ] + healthcheck: + test: pg_isready -U postgres + interval: 5s + timeout: 5s + retries: 3 + networks: + - devservices + volumes: + - postgres-data:/var/lib/postgresql/data + ports: + - 127.0.0.1:5432:5432 + extra_hosts: + - host.docker.internal:host-gateway + labels: + - orchestrator=devservices + restart: unless-stopped networks: devservices: From 7ed5a7af76c96b9ba89c56c99d45e8b176c7389e Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 15:29:34 -0500 Subject: [PATCH 09/17] fix --- devservices/config.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/devservices/config.yml b/devservices/config.yml index b40ce7f6..9c831372 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -42,7 +42,6 @@ services: - orchestrator=devservices restart: unless-stopped platform: linux/amd64 - services: postgres: image: ghcr.io/getsentry/image-mirror-library-postgres:14-alpine environment: From d08af4fac4a1b164f1bbb4cb2b39e2fc409fb738 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 15:33:17 -0500 Subject: [PATCH 10/17] volume --- devservices/config.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/devservices/config.yml b/devservices/config.yml index 9c831372..a44fd43f 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -78,3 +78,7 @@ networks: devservices: name: devservices external: true + + +volumes: + postgres-data: From 4c29aee5d226e3f35d6853cf088b01708201f4f3 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 15:41:27 -0500 Subject: [PATCH 11/17] fix test --- src/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/config.rs b/src/config.rs index 8ee0436e..046de4f7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -454,7 +454,7 @@ mod tests { log_format: json statsd_addr: 127.0.0.1:8126 default_metrics_tags: - key_1: value_1 + key_1: value_1 kafka_cluster: 10.0.0.1:9092,10.0.0.2:9092 kafka_topic: error-tasks kafka_deadletter_topic: error-tasks-dlq From c8322ad72374b2c0d5388c1f4fcd51cd364f75a3 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 15:42:59 -0500 Subject: [PATCH 12/17] use repo --- devservices/config.yml | 39 ++++----------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/devservices/config.yml b/devservices/config.yml index a44fd43f..adcda65a 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -17,6 +17,10 @@ x-sentry-service-config: repo_link: https://github.com/getsentry/sentry-shared-redis.git postgres: description: Shared instance of postgres used by sentry services + remote: + repo_name: sentry-shared-postgres + branch: main + repo_link: https://github.com/getsentry/sentry-shared-postgres.git taskbroker: description: Taskbroker service modes: @@ -42,43 +46,8 @@ services: - orchestrator=devservices restart: unless-stopped platform: linux/amd64 - postgres: - image: ghcr.io/getsentry/image-mirror-library-postgres:14-alpine - environment: - POSTGRES_HOST_AUTH_METHOD: trust - POSTGRES_DB: sentry - command: - [ - postgres, - -c, - wal_level=logical, - -c, - max_replication_slots=1, - -c, - max_wal_senders=1, - ] - healthcheck: - test: pg_isready -U postgres - interval: 5s - timeout: 5s - retries: 3 - networks: - - devservices - volumes: - - postgres-data:/var/lib/postgresql/data - ports: - - 127.0.0.1:5432:5432 - extra_hosts: - - host.docker.internal:host-gateway - labels: - - orchestrator=devservices - restart: unless-stopped networks: devservices: name: devservices external: true - - -volumes: - postgres-data: From 086f6b54cf900f2038b8af3cd30643a2c083ae6e Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 16:34:02 -0500 Subject: [PATCH 13/17] remove duplicate code --- src/store/inflight_activation.rs | 65 ++++++++----------- src/store/postgres_activation_store.rs | 33 +--------- ...forward_task_due_to_demoted_namespace.yaml | 7 -- 3 files changed, 27 insertions(+), 78 deletions(-) delete mode 100644 test_forward_task_due_to_demoted_namespace.yaml diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 7d0034ab..8ec95b8a 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -332,7 +332,25 @@ pub trait InflightActivationStore: Send + Sync { &self, application: Option<&str>, namespace: Option<&str>, - ) -> Result, Error>; + ) -> Result, Error> { + // Convert single namespace to vector for internal use + let namespaces = namespace.map(|ns| vec![ns.to_string()]); + + // If a namespace filter is used, an application must also be used. + if namespaces.is_some() && application.is_none() { + warn!( + "Received request for namespaced task without application. namespaces = {namespaces:?}" + ); + return Ok(None); + } + let result = self + .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) + .await?; + if result.is_empty() { + return Ok(None); + } + Ok(Some(result[0].clone())) + } /// Get pending activations from specified namespaces async fn get_pending_activations_from_namespaces( @@ -346,7 +364,10 @@ pub trait InflightActivationStore: Send + Sync { async fn pending_activation_max_lag(&self, now: &DateTime) -> f64; /// Count activations with Pending status - async fn count_pending_activations(&self) -> Result; + async fn count_pending_activations(&self) -> Result { + self.count_by_status(InflightActivationStatus::Pending) + .await + } /// Count activations by status async fn count_by_status(&self, status: InflightActivationStatus) -> Result; @@ -402,7 +423,9 @@ pub trait InflightActivationStore: Send + Sync { async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result; /// Remove the database, used only in tests - async fn remove_db(&self) -> Result<(), Error>; + async fn remove_db(&self) -> Result<(), Error> { + Ok(()) + } } pub struct SqliteActivationStore { @@ -752,31 +775,6 @@ impl InflightActivationStore for SqliteActivationStore { meta_result } - #[instrument(skip_all)] - async fn get_pending_activation( - &self, - application: Option<&str>, - namespace: Option<&str>, - ) -> Result, Error> { - // Convert single namespace to vector for internal use - let namespaces = namespace.map(|ns| vec![ns.to_string()]); - - // If a namespace filter is used, an application must also be used. - if namespaces.is_some() && application.is_none() { - warn!( - "Received request for namespaced task without application. namespaces = {namespaces:?}" - ); - return Ok(None); - } - let result = self - .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) - .await?; - if result.is_empty() { - return Ok(None); - } - Ok(Some(result[0].clone())) - } - /// Get a pending activation from specified namespaces /// If namespaces is None, gets from any namespace /// If namespaces is Some(&[...]), gets from those namespaces @@ -879,12 +877,6 @@ impl InflightActivationStore for SqliteActivationStore { } } - #[instrument(skip_all)] - async fn count_pending_activations(&self) -> Result { - self.count_by_status(InflightActivationStatus::Pending) - .await - } - #[instrument(skip_all)] async fn count_by_status(&self, status: InflightActivationStatus) -> Result { let result = @@ -1224,9 +1216,4 @@ impl InflightActivationStore for SqliteActivationStore { Ok(query.rows_affected()) } - - // Used in tests - async fn remove_db(&self) -> Result<(), Error> { - Ok(()) - } } diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index ad1f67da..fdfc493b 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -12,7 +12,7 @@ use sqlx::{ postgres::{PgConnectOptions, PgPool, PgPoolOptions, PgRow}, }; use std::{str::FromStr, time::Instant}; -use tracing::{instrument, warn}; +use tracing::instrument; use crate::config::Config; @@ -251,31 +251,6 @@ impl InflightActivationStore for PostgresActivationStore { Ok(query.execute(&mut *conn).await?.into()) } - #[instrument(skip_all)] - async fn get_pending_activation( - &self, - application: Option<&str>, - namespace: Option<&str>, - ) -> Result, Error> { - // Convert single namespace to vector for internal use - let namespaces = namespace.map(|ns| vec![ns.to_string()]); - - // If a namespace filter is used, an application must also be used. - if namespaces.is_some() && application.is_none() { - warn!( - "Received request for namespaced task without application. namespaces = {namespaces:?}" - ); - return Ok(None); - } - let result = self - .get_pending_activations_from_namespaces(application, namespaces.as_deref(), Some(1)) - .await?; - if result.is_empty() { - return Ok(None); - } - Ok(Some(result[0].clone())) - } - /// Get a pending activation from specified namespaces /// If namespaces is None, gets from any namespace /// If namespaces is Some(&[...]), gets from those namespaces @@ -377,12 +352,6 @@ impl InflightActivationStore for PostgresActivationStore { } } - #[instrument(skip_all)] - async fn count_pending_activations(&self) -> Result { - self.count_by_status(InflightActivationStatus::Pending) - .await - } - #[instrument(skip_all)] async fn count_by_status(&self, status: InflightActivationStatus) -> Result { let result = diff --git a/test_forward_task_due_to_demoted_namespace.yaml b/test_forward_task_due_to_demoted_namespace.yaml deleted file mode 100644 index b3dcb5c3..00000000 --- a/test_forward_task_due_to_demoted_namespace.yaml +++ /dev/null @@ -1,7 +0,0 @@ - -drop_task_killswitch: - - -demoted_namespaces: - - bad_namespace -demoted_topic_cluster: 0.0.0.0:9092 -demoted_topic: taskworker-demoted From c1b9f8ee5ac1d7e584696d190f869da927dc63f0 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 15 Jan 2026 16:43:33 -0500 Subject: [PATCH 14/17] imports --- src/kafka/inflight_activation_writer.rs | 27 ++++----------- src/store/inflight_activation.rs | 10 +++--- src/store/inflight_activation_tests.rs | 44 +++++++++---------------- src/test_utils.rs | 32 +++++------------- 4 files changed, 35 insertions(+), 78 deletions(-) diff --git a/src/kafka/inflight_activation_writer.rs b/src/kafka/inflight_activation_writer.rs index b9eace94..cf953727 100644 --- a/src/kafka/inflight_activation_writer.rs +++ b/src/kafka/inflight_activation_writer.rs @@ -200,31 +200,16 @@ impl Reducer for InflightActivationWriter { #[cfg(test)] mod tests { - use super::{ActivationWriterConfig, InflightActivation, InflightActivationWriter, Reducer}; - use chrono::{DateTime, Utc}; - use prost::Message; - use prost_types::Timestamp; + use chrono::DateTime; use rstest::rstest; - use std::collections::HashMap; - - use crate::test_utils::create_test_store; - use sentry_protos::taskbroker::v1::OnAttemptsExceeded; - use sentry_protos::taskbroker::v1::TaskActivation; use super::{ActivationWriterConfig, InflightActivationWriter, Reducer}; - use crate::store::inflight_activation::InflightActivationStatus; - - use chrono::DateTime; - use std::sync::Arc; - - use crate::store::inflight_activation::InflightActivationBuilder; - use crate::store::inflight_activation::{ - InflightActivationStatus, InflightActivationStore, InflightActivationStoreConfig, - SqliteActivationStore, + use crate::{ + store::inflight_activation::{InflightActivationBuilder, InflightActivationStatus}, + test_utils::{ + TaskActivationBuilder, create_test_store, generate_unique_namespace, make_activations, + }, }; - use crate::test_utils::TaskActivationBuilder; - use crate::test_utils::generate_unique_namespace; - use crate::test_utils::make_activations; #[tokio::test] #[rstest] diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 2cad357f..9de6bf2b 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -1,10 +1,4 @@ use anyhow::{Error, anyhow}; -use sqlx::postgres::PgQueryResult; -use std::fmt::Result as FmtResult; -use std::fmt::{Display, Formatter}; -use std::{str::FromStr, time::Instant}; - -use crate::config::Config; use async_trait::async_trait; use chrono::{DateTime, Utc}; use derive_builder::Builder; @@ -21,14 +15,18 @@ use sqlx::{ ConnectOptions, FromRow, Pool, QueryBuilder, Row, Sqlite, Type, migrate::MigrateDatabase, pool::{PoolConnection, PoolOptions}, + postgres::PgQueryResult, sqlite::{ SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqliteQueryResult, SqliteRow, SqliteSynchronous, }, }; +use std::fmt::{Display, Formatter, Result as FmtResult}; use std::{str::FromStr, time::Instant}; use tracing::{instrument, warn}; +use crate::config::Config; + /// The members of this enum should be synced with the members /// of InflightActivationStatus in sentry_protos #[derive(Clone, Copy, Debug, PartialEq, Eq, Type)] diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index b332ff9d..e14ba1e4 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -1,34 +1,22 @@ -use prost::Message; -use rstest::rstest; -use sqlx::{QueryBuilder, Sqlite}; -use std::collections::HashSet; -use std::collections::{HashMap, HashSet}; -use std::fs; -use std::io::Error; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; - -use crate::config::Config; -use crate::store::inflight_activation::{ - InflightActivationBuilder, InflightActivationStatus, InflightActivationStore, - InflightActivationStoreConfig, QueryResult, SqliteActivationStore, create_sqlite_pool, -}; -use crate::test_utils::{StatusCount, TaskActivationBuilder}; -use crate::test_utils::{ - assert_counts, create_integration_config, create_test_store, generate_temp_filename, - generate_unique_namespace, make_activations, make_activations_with_namespace, - replace_retry_state, -}; use chrono::{DateTime, SubsecRound, TimeZone, Utc}; -use sentry_protos::taskbroker::v1::{ - OnAttemptsExceeded, RetryState, TaskActivation, TaskActivationStatus, -}; +use rstest::rstest; use sentry_protos::taskbroker::v1::{OnAttemptsExceeded, RetryState, TaskActivationStatus}; use sqlx::{QueryBuilder, Sqlite}; -use std::fs; -use tokio::sync::broadcast; -use tokio::task::JoinSet; +use std::{collections::HashSet, fs, io::Error, path::Path, sync::Arc, time::Duration}; +use tokio::{sync::broadcast, task::JoinSet}; + +use crate::{ + config::Config, + store::inflight_activation::{ + InflightActivationBuilder, InflightActivationStatus, InflightActivationStore, + InflightActivationStoreConfig, QueryResult, SqliteActivationStore, create_sqlite_pool, + }, + test_utils::{ + StatusCount, TaskActivationBuilder, assert_counts, create_integration_config, + create_test_store, generate_temp_filename, generate_unique_namespace, make_activations, + make_activations_with_namespace, replace_retry_state, + }, +}; #[test] fn test_inflightactivation_status_is_completion() { diff --git a/src/test_utils.rs b/src/test_utils.rs index 8a9c34d5..392b48c1 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,42 +1,28 @@ use chrono::Utc; use futures::StreamExt; use prost::Message as ProstMessage; +use prost_types::Timestamp; use rdkafka::{ Message, admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}, consumer::{CommitMode, Consumer, StreamConsumer}, producer::FutureProducer, }; -use std::{collections::HashMap, env::var, sync::Arc}; +use sentry_protos::taskbroker::v1::{self, OnAttemptsExceeded, RetryState, TaskActivation}; +use std::{collections::HashMap, env::var, sync::Arc, time::SystemTime}; use uuid::Uuid; use crate::{ config::Config, - store::inflight_activation::{ - InflightActivation, InflightActivationStatus, InflightActivationStore, - InflightActivationStoreConfig, SqliteActivationStore, + store::{ + inflight_activation::{ + InflightActivation, InflightActivationBuilder, InflightActivationStatus, + InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, + }, + postgres_activation_store::{PostgresActivationStore, PostgresActivationStoreConfig}, }, - store::postgres_activation_store::{PostgresActivationStore, PostgresActivationStoreConfig}, -}; -use prost_types::Timestamp; -use rand::Rng; -use rdkafka::Message; -use rdkafka::admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}; -use rdkafka::consumer::{Consumer, StreamConsumer}; -use rdkafka::producer::FutureProducer; -use sentry_protos::taskbroker::v1::{self, OnAttemptsExceeded, RetryState, TaskActivation}; -use uuid::Uuid; - -use crate::config::Config; -use crate::store::inflight_activation::{ - InflightActivation, InflightActivationBuilder, InflightActivationStatus, - InflightActivationStore, InflightActivationStoreConfig, SqliteActivationStore, }; -use std::collections::HashMap; -use std::sync::Arc; -use std::time::SystemTime; - /// Builder for `TaskActivation`. We cannot generate a builder automatically because `TaskActivation` is defined in `sentry-protos`. pub struct TaskActivationBuilder { pub id: Option, From 05125fe263525da6860762ab873dc85fdfa05e08 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Fri, 16 Jan 2026 15:40:07 -0500 Subject: [PATCH 15/17] feat(postgres): Change the postgres adapter to be partition aware Have the postgres adapter only fetch and do upkeep on activations that are part of the partition that the consumer is assigned. The broker can still update tasks outside its partitions, in case a worker is connected to a broker that is then rebalanced. Change the consumer to pass the partitions to the store whenever partitions are assigned. This was originally tested with PARTITION BY, but that requires manually keeping track of the partition tables which isn't a desired behaviour. --- .../0001_create_inflight_activations.sql | 2 + src/kafka/consumer.rs | 10 + src/main.rs | 1 + src/store/inflight_activation.rs | 55 ++-- src/store/inflight_activation_tests.rs | 99 +++++++ src/store/postgres_activation_store.rs | 247 +++++++++++------- src/test_utils.rs | 18 +- 7 files changed, 309 insertions(+), 123 deletions(-) diff --git a/pg_migrations/0001_create_inflight_activations.sql b/pg_migrations/0001_create_inflight_activations.sql index ee8b26a4..d5cacb22 100644 --- a/pg_migrations/0001_create_inflight_activations.sql +++ b/pg_migrations/0001_create_inflight_activations.sql @@ -18,3 +18,5 @@ CREATE TABLE IF NOT EXISTS inflight_taskactivations ( taskname TEXT NOT NULL, on_attempts_exceeded INTEGER NOT NULL DEFAULT 1 ); + +CREATE INDEX idx_activation_partition ON inflight_taskactivations (partition); diff --git a/src/kafka/consumer.rs b/src/kafka/consumer.rs index 16d9b1e8..24ee265e 100644 --- a/src/kafka/consumer.rs +++ b/src/kafka/consumer.rs @@ -1,3 +1,4 @@ +use crate::store::inflight_activation::InflightActivationStore; use anyhow::{Error, anyhow}; use futures::{ Stream, StreamExt, @@ -44,6 +45,7 @@ use tracing::{debug, error, info, instrument, warn}; pub async fn start_consumer( topics: &[&str], kafka_client_config: &ClientConfig, + activation_store: Arc, spawn_actors: impl FnMut( Arc>, &BTreeSet<(String, i32)>, @@ -68,6 +70,7 @@ pub async fn start_consumer( handle_events( consumer, event_receiver, + activation_store, client_shutdown_sender, spawn_actors, ) @@ -340,6 +343,7 @@ enum ConsumerState { pub async fn handle_events( consumer: Arc>, events: UnboundedReceiver<(Event, SyncSender<()>)>, + activation_store: Arc, shutdown_client: oneshot::Sender<()>, mut spawn_actors: impl FnMut( Arc>, @@ -372,6 +376,12 @@ pub async fn handle_events( state = match (state, event) { (ConsumerState::Ready, Event::Assign(tpl)) => { metrics::gauge!("arroyo.consumer.current_partitions").set(tpl.len() as f64); + // Note: This assumes we only process one topic per consumer. + let mut partitions = Vec::::new(); + for (_, partition) in tpl.iter() { + partitions.push(*partition); + } + activation_store.assign_partitions(partitions).unwrap(); ConsumerState::Consuming(spawn_actors(consumer.clone(), &tpl), tpl) } (ConsumerState::Ready, Event::Revoke(_)) => { diff --git a/src/main.rs b/src/main.rs index 7970939d..744de735 100644 --- a/src/main.rs +++ b/src/main.rs @@ -162,6 +162,7 @@ async fn main() -> Result<(), Error> { start_consumer( &[&consumer_config.kafka_topic], &consumer_config.kafka_consumer_config(), + consumer_store.clone(), processing_strategy!({ err: OsStreamWriter::new( diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 9de6bf2b..dc81f699 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -337,21 +337,13 @@ impl InflightActivationStoreConfig { #[async_trait] pub trait InflightActivationStore: Send + Sync { - /// Trigger incremental vacuum to reclaim free pages in the database - async fn vacuum_db(&self) -> Result<(), Error>; - - /// Perform a full vacuum on the database - async fn full_vacuum_db(&self) -> Result<(), Error>; - - /// Get the size of the database in bytes - async fn db_size(&self) -> Result; - - /// Get an activation by id - async fn get_by_id(&self, id: &str) -> Result, Error>; - + /// CONSUMER OPERATIONS /// Store a batch of activations async fn store(&self, batch: Vec) -> Result; + fn assign_partitions(&self, partitions: Vec) -> Result<(), Error>; + + /// SERVER OPERATIONS /// Get a single pending activation, optionally filtered by namespace async fn get_pending_activation( &self, @@ -385,6 +377,14 @@ pub trait InflightActivationStore: Send + Sync { limit: Option, ) -> Result, Error>; + /// Update the status of a specific activation + async fn set_status( + &self, + id: &str, + status: InflightActivationStatus, + ) -> Result, Error>; + + /// COUNT OPERATIONS /// Get the age of the oldest pending activation in seconds async fn pending_activation_max_lag(&self, now: &DateTime) -> f64; @@ -400,12 +400,9 @@ pub trait InflightActivationStore: Send + Sync { /// Count all activations async fn count(&self) -> Result; - /// Update the status of a specific activation - async fn set_status( - &self, - id: &str, - status: InflightActivationStatus, - ) -> Result, Error>; + /// ACTIVATION OPERATIONS + /// Get an activation by id + async fn get_by_id(&self, id: &str) -> Result, Error>; /// Set the processing deadline for a specific activation async fn set_processing_deadline( @@ -417,12 +414,20 @@ pub trait InflightActivationStore: Send + Sync { /// Delete an activation by id async fn delete_activation(&self, id: &str) -> Result<(), Error>; + /// DATABASE OPERATIONS + /// Trigger incremental vacuum to reclaim free pages in the database + async fn vacuum_db(&self) -> Result<(), Error>; + + /// Perform a full vacuum on the database + async fn full_vacuum_db(&self) -> Result<(), Error>; + + /// Get the size of the database in bytes + async fn db_size(&self) -> Result; + + /// UPKEEP OPERATIONS /// Get all activations with status Retry async fn get_retry_activations(&self) -> Result, Error>; - /// Clear all activations from the store - async fn clear(&self) -> Result<(), Error>; - /// Update tasks that exceeded their processing deadline async fn handle_processing_deadline(&self) -> Result; @@ -447,6 +452,10 @@ pub trait InflightActivationStore: Send + Sync { /// Remove killswitched tasks async fn remove_killswitched(&self, killswitched_tasks: Vec) -> Result; + /// TEST OPERATIONS + /// Clear all activations from the store + async fn clear(&self) -> Result<(), Error>; + /// Remove the database, used only in tests async fn remove_db(&self) -> Result<(), Error> { Ok(()) @@ -714,6 +723,10 @@ impl InflightActivationStore for SqliteActivationStore { Ok(Some(row.into())) } + fn assign_partitions(&self, partitions: Vec) -> Result<(), Error> { + Ok(()) + } + #[instrument(skip_all)] async fn store(&self, batch: Vec) -> Result { if batch.is_empty() { diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index e14ba1e4..151b7f44 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -585,6 +585,105 @@ async fn test_set_activation_status(#[case] adapter: &str) { store.remove_db().await.unwrap(); } +#[tokio::test] +#[rstest] +#[case::postgres("postgres")] +async fn test_set_activation_status_with_partitions(#[case] adapter: &str) { + let store = create_test_store(adapter).await; + + let mut batch = make_activations(2); + batch[1].partition = 1; + assert!(store.store(batch).await.is_ok()); + assert_counts( + StatusCount { + pending: 1, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; + + assert!( + store + .set_status("id_0", InflightActivationStatus::Failure) + .await + .is_ok() + ); + assert_counts( + StatusCount { + failure: 1, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; + + assert!( + store + .set_status("id_0", InflightActivationStatus::Pending) + .await + .is_ok() + ); + assert_counts( + StatusCount { + pending: 1, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; + assert!( + store + .set_status("id_0", InflightActivationStatus::Failure) + .await + .is_ok() + ); + assert!( + store + .set_status("id_1", InflightActivationStatus::Failure) + .await + .is_ok() + ); + // The broker can update the status of an activation in a different partition, but + // it still should not be counted in its upkeep. + assert_counts( + StatusCount { + pending: 0, + failure: 1, + ..StatusCount::default() + }, + store.as_ref(), + ) + .await; + assert!( + store + .get_pending_activation(None, None) + .await + .unwrap() + .is_none() + ); + + let result = store + .set_status("not_there", InflightActivationStatus::Complete) + .await; + assert!(result.is_ok(), "no query error"); + + let activation = result.unwrap(); + assert!(activation.is_none(), "no activation found"); + + let result = store + .set_status("id_0", InflightActivationStatus::Complete) + .await; + assert!(result.is_ok(), "no query error"); + + let result_opt = result.unwrap(); + assert!(result_opt.is_some(), "activation should be returned"); + let inflight = result_opt.unwrap(); + assert_eq!(inflight.id, "id_0"); + assert_eq!(inflight.status, InflightActivationStatus::Complete); + store.remove_db().await.unwrap(); +} + #[tokio::test] #[rstest] #[case::sqlite("sqlite")] diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index fdfc493b..09f710d4 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -11,6 +11,8 @@ use sqlx::{ pool::PoolConnection, postgres::{PgConnectOptions, PgPool, PgPoolOptions, PgRow}, }; +use std::collections::HashMap; +use std::sync::RwLock; use std::{str::FromStr, time::Instant}; use tracing::instrument; @@ -70,6 +72,7 @@ pub struct PostgresActivationStore { read_pool: PgPool, write_pool: PgPool, config: PostgresActivationStoreConfig, + partitions: RwLock>, } impl PostgresActivationStore { @@ -118,8 +121,29 @@ impl PostgresActivationStore { read_pool, write_pool, config, + partitions: RwLock::new(vec![]), }) } + + /// Add the partition condition to the query builder in a thread-safe manner + fn add_partition_condition( + &self, + query_builder: &mut QueryBuilder, + first_condition: bool, + ) { + let partitions = self.partitions.read().unwrap(); + let condition = if first_condition { "WHERE" } else { "AND" }; + if !partitions.is_empty() { + query_builder.push(" "); + query_builder.push(condition); + query_builder.push(" partition IN ("); + let mut separated = query_builder.separated(", "); + for partition in partitions.iter() { + separated.push_bind(*partition); + } + query_builder.push(")"); + } + } } #[async_trait] @@ -187,6 +211,13 @@ impl InflightActivationStore for PostgresActivationStore { Ok(Some(row.into())) } + fn assign_partitions(&self, partitions: Vec) -> Result<(), Error> { + let mut write_guard = self.partitions.write().unwrap(); + write_guard.clear(); + write_guard.extend(partitions); + Ok(()) + } + #[instrument(skip_all)] async fn store(&self, batch: Vec) -> Result { if batch.is_empty() { @@ -275,6 +306,8 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push_bind(now); query_builder.push(")"); + self.add_partition_condition(&mut query_builder, false); + // Handle application & namespace filtering if let Some(value) = application { query_builder.push(" AND application ="); @@ -324,21 +357,25 @@ impl InflightActivationStore for PostgresActivationStore { /// Tasks with delay_until set, will have their age adjusted based on their /// delay time. No tasks = 0 lag async fn pending_activation_max_lag(&self, now: &DateTime) -> f64 { - let result = sqlx::query( + let mut query_builder = QueryBuilder::new( "SELECT received_at, delay_until FROM inflight_taskactivations - WHERE status = $1 - AND processing_attempts = 0 - ORDER BY received_at ASC - LIMIT 1 - ", - ) - .bind(InflightActivationStatus::Pending.to_string()) - .fetch_one(&self.read_pool) - .await; + WHERE status = ", + ); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(" AND processing_attempts = 0"); + + self.add_partition_condition(&mut query_builder, false); + + query_builder.push(" ORDER BY received_at ASC LIMIT 1"); + + let result = query_builder + .build_query_as::<(DateTime, Option>)>() + .fetch_one(&self.read_pool) + .await; if let Ok(row) = result { - let received_at: DateTime = row.get("received_at"); - let delay_until: Option> = row.get("delay_until"); + let received_at: DateTime = row.0; + let delay_until: Option> = row.1; let millis = now.signed_duration_since(received_at).num_milliseconds() - delay_until.map_or(0, |delay_time| { delay_time @@ -354,19 +391,27 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] async fn count_by_status(&self, status: InflightActivationStatus) -> Result { - let result = - sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations WHERE status = $1") - .bind(status.to_string()) - .fetch_one(&self.read_pool) - .await?; - Ok(result.get::("count") as usize) + let mut query_builder = QueryBuilder::new( + "SELECT COUNT(*) as count FROM inflight_taskactivations WHERE status = ", + ); + query_builder.push_bind(status.to_string()); + self.add_partition_condition(&mut query_builder, false); + let result = query_builder + .build_query_as::<(i64,)>() + .fetch_one(&self.read_pool) + .await?; + Ok(result.0 as usize) } async fn count(&self) -> Result { - let result = sqlx::query("SELECT COUNT(*) as count FROM inflight_taskactivations") + let mut query_builder = + QueryBuilder::new("SELECT COUNT(*) as count FROM inflight_taskactivations"); + self.add_partition_condition(&mut query_builder, true); + let result = query_builder + .build_query_as::<(i64,)>() .fetch_one(&self.read_pool) .await?; - Ok(result.get::("count") as usize) + Ok(result.0 as usize) } /// Update the status of a specific activation @@ -421,9 +466,8 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] async fn get_retry_activations(&self) -> Result, Error> { - Ok(sqlx::query_as( - " - SELECT id, + let mut query_builder = QueryBuilder::new( + "SELECT id, activation, partition, kafka_offset AS offset, @@ -441,15 +485,18 @@ impl InflightActivationStore for PostgresActivationStore { taskname, on_attempts_exceeded FROM inflight_taskactivations - WHERE status = $1 - ", - ) - .bind(InflightActivationStatus::Retry.to_string()) - .fetch_all(&self.read_pool) - .await? - .into_iter() - .map(|row: TableRow| row.into()) - .collect()) + WHERE status = ", + ); + query_builder.push_bind(InflightActivationStatus::Retry.to_string()); + self.add_partition_condition(&mut query_builder, false); + + Ok(query_builder + .build_query_as::() + .fetch_all(&self.read_pool) + .await? + .into_iter() + .map(|row| row.into()) + .collect()) } // Used in tests @@ -472,16 +519,19 @@ impl InflightActivationStore for PostgresActivationStore { // At-most-once tasks that fail their processing deadlines go directly to failure // there are no retries, as the worker will reject the task due to at_most_once keys. - let most_once_result = sqlx::query( + let mut query_builder = QueryBuilder::new( "UPDATE inflight_taskactivations - SET processing_deadline = null, status = $1 - WHERE processing_deadline < $2 AND at_most_once = TRUE AND status = $3", - ) - .bind(InflightActivationStatus::Failure.to_string()) - .bind(now) - .bind(InflightActivationStatus::Processing.to_string()) - .execute(&mut *atomic) - .await; + SET processing_deadline = null, status = ", + ); + query_builder.push_bind(InflightActivationStatus::Failure.to_string()); + query_builder.push("WHERE processing_deadline < "); + query_builder.push_bind(now); + query_builder.push(" AND at_most_once = TRUE AND status = "); + query_builder.push_bind(InflightActivationStatus::Processing.to_string()); + + self.add_partition_condition(&mut query_builder, false); + + let most_once_result = query_builder.build().execute(&mut *atomic).await; let mut processing_deadline_modified_rows = 0; if let Ok(query_res) = most_once_result { @@ -490,16 +540,19 @@ impl InflightActivationStore for PostgresActivationStore { // Update regular tasks. // Increment processing_attempts by 1 and reset processing_deadline to null. - let result = sqlx::query( + let mut query_builder = QueryBuilder::new( "UPDATE inflight_taskactivations - SET processing_deadline = null, status = $1, processing_attempts = processing_attempts + 1 - WHERE processing_deadline < $2 AND status = $3", - ) - .bind(InflightActivationStatus::Pending.to_string()) - .bind(now) - .bind(InflightActivationStatus::Processing.to_string()) - .execute(&mut *atomic) - .await; + SET processing_deadline = null, status = ", + ); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(", processing_attempts = processing_attempts + 1"); + query_builder.push(" WHERE processing_deadline < "); + query_builder.push_bind(now); + query_builder.push(" AND status = "); + query_builder.push_bind(InflightActivationStatus::Processing.to_string()); + self.add_partition_condition(&mut query_builder, false); + + let result = query_builder.build().execute(&mut *atomic).await; atomic.commit().await?; @@ -518,16 +571,17 @@ impl InflightActivationStore for PostgresActivationStore { let mut conn = self .acquire_write_conn_metric("handle_processing_attempts") .await?; - let processing_attempts_result = sqlx::query( + let mut query_builder = QueryBuilder::new( "UPDATE inflight_taskactivations - SET status = $1 - WHERE processing_attempts >= $2 AND status = $3", - ) - .bind(InflightActivationStatus::Failure.to_string()) - .bind(self.config.max_processing_attempts as i32) - .bind(InflightActivationStatus::Pending.to_string()) - .execute(&mut *conn) - .await; + SET status = ", + ); + query_builder.push_bind(InflightActivationStatus::Failure.to_string()); + query_builder.push(" WHERE processing_attempts >= "); + query_builder.push_bind(self.config.max_processing_attempts as i32); + query_builder.push(" AND status = "); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + self.add_partition_condition(&mut query_builder, false); + let processing_attempts_result = query_builder.build().execute(&mut *conn).await; if let Ok(query_res) = processing_attempts_result { return Ok(query_res.rows_affected()); @@ -546,15 +600,15 @@ impl InflightActivationStore for PostgresActivationStore { async fn handle_expires_at(&self) -> Result { let now = Utc::now(); let mut conn = self.acquire_write_conn_metric("handle_expires_at").await?; - let query = sqlx::query( - "DELETE FROM inflight_taskactivations WHERE status = $1 AND expires_at IS NOT NULL AND expires_at < $2", - ) - .bind(InflightActivationStatus::Pending.to_string()) - .bind(now) - .execute(&mut *conn) - .await?; + let mut query_builder = + QueryBuilder::new("DELETE FROM inflight_taskactivations WHERE status = "); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(" AND expires_at IS NOT NULL AND expires_at < "); + query_builder.push_bind(now); + self.add_partition_condition(&mut query_builder, false); + let result = query_builder.build().execute(&mut *conn).await?; - Ok(query.rows_affected()) + Ok(result.rows_affected()) } /// Perform upkeep work for tasks that are past delay_until deadlines @@ -567,17 +621,18 @@ impl InflightActivationStore for PostgresActivationStore { async fn handle_delay_until(&self) -> Result { let now = Utc::now(); let mut conn = self.acquire_write_conn_metric("handle_delay_until").await?; - let update_result = sqlx::query( - r#"UPDATE inflight_taskactivations - SET status = $1 - WHERE delay_until IS NOT NULL AND delay_until < $2 AND status = $3 - "#, - ) - .bind(InflightActivationStatus::Pending.to_string()) - .bind(now) - .bind(InflightActivationStatus::Delay.to_string()) - .execute(&mut *conn) - .await?; + + let mut query_builder = QueryBuilder::new( + "UPDATE inflight_taskactivations + SET status = ", + ); + query_builder.push_bind(InflightActivationStatus::Pending.to_string()); + query_builder.push(" WHERE delay_until IS NOT NULL AND delay_until < "); + query_builder.push_bind(now); + query_builder.push(" AND status = "); + query_builder.push_bind(InflightActivationStatus::Delay.to_string()); + self.add_partition_condition(&mut query_builder, false); + let update_result = query_builder.build().execute(&mut *conn).await?; Ok(update_result.rows_affected()) } @@ -592,13 +647,15 @@ impl InflightActivationStore for PostgresActivationStore { async fn handle_failed_tasks(&self) -> Result { let mut atomic = self.write_pool.begin().await?; - let failed_tasks: Vec = - sqlx::query("SELECT id, activation, on_attempts_exceeded FROM inflight_taskactivations WHERE status = $1") - .bind(InflightActivationStatus::Failure.to_string()) - .fetch_all(&mut *atomic) - .await? - .into_iter() - .collect(); + let mut query_builder = QueryBuilder::new( + "SELECT id, activation, on_attempts_exceeded FROM inflight_taskactivations WHERE status = ", + ); + query_builder.push_bind(InflightActivationStatus::Failure.to_string()); + self.add_partition_condition(&mut query_builder, false); + let failed_tasks = query_builder + .build_query_as::<(String, Vec, i32)>() + .fetch_all(&mut *atomic) + .await?; let mut forwarder = FailedTasksForwarder { to_discard: vec![], @@ -606,13 +663,11 @@ impl InflightActivationStore for PostgresActivationStore { }; for record in failed_tasks.iter() { - let activation_data: &[u8] = record.get("activation"); - let id: String = record.get("id"); + let activation_data: &[u8] = record.1.as_slice(); + let id: String = record.0.clone(); // We could be deadlettering because of activation.expires // when a task expires we still deadletter if configured. - let on_attempts_exceeded_val: i32 = record.get("on_attempts_exceeded"); - let on_attempts_exceeded: OnAttemptsExceeded = - on_attempts_exceeded_val.try_into().unwrap(); + let on_attempts_exceeded: OnAttemptsExceeded = record.2.try_into().unwrap(); if on_attempts_exceeded == OnAttemptsExceeded::Discard || on_attempts_exceeded == OnAttemptsExceeded::Unspecified { @@ -668,12 +723,13 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] async fn remove_completed(&self) -> Result { let mut conn = self.acquire_write_conn_metric("remove_completed").await?; - let query = sqlx::query("DELETE FROM inflight_taskactivations WHERE status = $1") - .bind(InflightActivationStatus::Complete.to_string()) - .execute(&mut *conn) - .await?; + let mut query_builder = + QueryBuilder::new("DELETE FROM inflight_taskactivations WHERE status = "); + query_builder.push_bind(InflightActivationStatus::Complete.to_string()); + self.add_partition_condition(&mut query_builder, false); + let result = query_builder.build().execute(&mut *conn).await?; - Ok(query.rows_affected()) + Ok(result.rows_affected()) } /// Remove killswitched tasks. @@ -686,6 +742,7 @@ impl InflightActivationStore for PostgresActivationStore { separated.push_bind(taskname); } separated.push_unseparated(")"); + self.add_partition_condition(&mut query_builder, false); let mut conn = self .acquire_write_conn_metric("remove_killswitched") .await?; diff --git a/src/test_utils.rs b/src/test_utils.rs index 392b48c1..d5bd3310 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -257,13 +257,17 @@ pub async fn create_test_store(adapter: &str) -> Arc, - "postgres" => Arc::new( - PostgresActivationStore::new(PostgresActivationStoreConfig::from_config( - &create_integration_config(), - )) - .await - .unwrap(), - ) as Arc, + "postgres" => { + let store = Arc::new( + PostgresActivationStore::new(PostgresActivationStoreConfig::from_config( + &create_integration_config(), + )) + .await + .unwrap(), + ) as Arc; + store.assign_partitions(vec![0]).unwrap(); + store + } _ => panic!("Invalid adapter: {}", adapter), } } From d2e86f6235b5b14094d43bd3cf989fcbf79e3fc6 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 29 Jan 2026 16:51:29 -0500 Subject: [PATCH 16/17] cleanup --- src/store/postgres_activation_store.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index 09f710d4..0214761b 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -7,11 +7,10 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use sentry_protos::taskbroker::v1::OnAttemptsExceeded; use sqlx::{ - Pool, Postgres, QueryBuilder, Row, + Pool, Postgres, QueryBuilder, pool::PoolConnection, - postgres::{PgConnectOptions, PgPool, PgPoolOptions, PgRow}, + postgres::{PgConnectOptions, PgPool, PgPoolOptions}, }; -use std::collections::HashMap; use std::sync::RwLock; use std::{str::FromStr, time::Instant}; use tracing::instrument; From 94b41c23fc6c7f7cd5f43232c4497302efc30bdb Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Thu, 29 Jan 2026 16:54:21 -0500 Subject: [PATCH 17/17] warn --- src/store/inflight_activation.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index dc81f699..b76392c4 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -724,6 +724,7 @@ impl InflightActivationStore for SqliteActivationStore { } fn assign_partitions(&self, partitions: Vec) -> Result<(), Error> { + warn!("assign_partitions: {:?}", partitions); Ok(()) }