diff --git a/Cargo.lock b/Cargo.lock index b6736b11b082d..ad6e0b733ced3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5509,7 +5509,7 @@ dependencies = [ "mz-frontegg-auth", "mz-frontegg-client", "mz-ore", - "mz-sql-parser", + "mz-postgres-util", "open", "openssl-probe", "reqwest", @@ -6303,6 +6303,7 @@ dependencies = [ "mz-build-info", "mz-cloud-resources", "mz-ore", + "mz-postgres-util", "mz-server-core", "mz-sql-parser", "mz-tls-util", @@ -6476,6 +6477,7 @@ dependencies = [ "mz-pgtest", "mz-pgwire", "mz-pgwire-common", + "mz-postgres-util", "mz-prof-http", "mz-repr", "mz-secrets", @@ -6667,6 +6669,7 @@ dependencies = [ "mz-build-tools", "mz-ore", "mz-pgrepr", + "mz-postgres-util", "mz-sql-parser", "openssl", "postgres-openssl", @@ -8492,6 +8495,7 @@ dependencies = [ "mz-persist-client", "mz-persist-types", "mz-pgrepr", + "mz-postgres-util", "mz-repr", "mz-sql", "mz-sql-parser", diff --git a/clippy.toml b/clippy.toml index f0adb6ac4d12a..7dae30aad2d45 100644 --- a/clippy.toml +++ b/clippy.toml @@ -53,6 +53,20 @@ disallowed-methods = [ { path = "aws_config::from_env", reason = "use the `mz_aws_config::defaults` function instead" }, { path = "aws_config::load_from_env", reason = "use the `mz_aws_config::defaults` function instead" }, { path = "aws_sdk_s3::Client::new", reason = "use the `mz_aws_util::s3::new_client` function instead" }, + { path = "tokio_postgres::Client::simple_query", reason = "use `mz_postgres_util::simple_query` with `Sql` instead" }, + { path = "tokio_postgres::Client::batch_execute", reason = "use `mz_postgres_util::batch_execute` with `Sql` instead" }, + { path = "tokio_postgres::Client::query", reason = "use `mz_postgres_util::query`/`query_prepared` wrappers instead" }, + { path = "tokio_postgres::Client::query_one", reason = "use `mz_postgres_util::query_one`/`query_one_prepared` wrappers instead" }, + { path = "tokio_postgres::Client::query_opt", reason = "use `mz_postgres_util::query_opt`/`query_opt_prepared` wrappers instead" }, + { path = "tokio_postgres::Client::execute", reason = "use `mz_postgres_util::execute`/`execute_prepared` wrappers instead" }, + { path = "tokio_postgres::Transaction::simple_query", reason = "use `Sql` wrappers from `mz_postgres_util` instead" }, + { path = "tokio_postgres::Transaction::batch_execute", reason = "use `Sql` wrappers from `mz_postgres_util` instead" }, + { path = "tokio_postgres::Transaction::query", reason = "use `mz_postgres_util::query`/`query_prepared` wrappers instead" }, + { path = "tokio_postgres::Transaction::query_one", reason = "use `mz_postgres_util::query_one`/`query_one_prepared` wrappers instead" }, + { path = "tokio_postgres::Transaction::query_opt", reason = "use `mz_postgres_util::query_opt`/`query_opt_prepared` wrappers instead" }, + { path = "tokio_postgres::Transaction::execute", reason = "use `mz_postgres_util::execute`/`execute_prepared` wrappers instead" }, + { path = "postgres::Client::simple_query", reason = "avoid direct string SQL; use `Sql` composition and wrappers where possible" }, + { path = "postgres::Client::batch_execute", reason = "avoid direct string SQL; use `Sql` composition and wrappers where possible" }, # Prevent access to Differential APIs that want to use the default trace or use a default name, or where we offer # our own wrapper diff --git a/src/adapter/src/catalog.rs b/src/adapter/src/catalog.rs index c2939dd27f6cd..ff3a8612dfeb4 100644 --- a/src/adapter/src/catalog.rs +++ b/src/adapter/src/catalog.rs @@ -2328,6 +2328,7 @@ mod tests { use itertools::Itertools; use mz_catalog::memory::objects::CatalogItem; + use mz_postgres_util::{query, sql}; use tokio_postgres::NoTls; use tokio_postgres::types::Type; use uuid::Uuid; @@ -2877,8 +2878,9 @@ mod tests { name: String, } - let pg_proc: BTreeMap<_, _> = client - .query( + let pg_proc: BTreeMap<_, _> = query( + &client, + sql!( "SELECT p.oid, proname, @@ -2886,30 +2888,33 @@ mod tests { prorettype, proretset FROM pg_proc p - JOIN pg_namespace n ON p.pronamespace = n.oid", - &[], - ) - .await - .expect("pg query failed") - .into_iter() - .map(|row| { - let oid: u32 = row.get("oid"); - let pg_proc = PgProc { - name: row.get("proname"), - arg_oids: row.get("proargtypes"), - ret_oid: row.get("prorettype"), - ret_set: row.get("proretset"), - }; - (oid, pg_proc) - }) - .collect(); + JOIN pg_namespace n ON p.pronamespace = n.oid" + ), + &[], + ) + .await + .expect("pg query failed") + .into_iter() + .map(|row| { + let oid: u32 = row.get("oid"); + let pg_proc = PgProc { + name: row.get("proname"), + arg_oids: row.get("proargtypes"), + ret_oid: row.get("prorettype"), + ret_set: row.get("proretset"), + }; + (oid, pg_proc) + }) + .collect(); - let pg_type: BTreeMap<_, _> = client - .query( - "SELECT oid, typname, typtype::text, typelem, typarray, typinput::oid, typreceive::oid as typreceive FROM pg_type", - &[], - ) - .await + let pg_type: BTreeMap<_, _> = query( + &client, + sql!( + "SELECT oid, typname, typtype::text, typelem, typarray, typinput::oid, typreceive::oid as typreceive FROM pg_type" + ), + &[], + ) + .await .expect("pg query failed") .into_iter() .map(|row| { @@ -2926,20 +2931,23 @@ mod tests { }) .collect(); - let pg_oper: BTreeMap<_, _> = client - .query("SELECT oid, oprname, oprresult FROM pg_operator", &[]) - .await - .expect("pg query failed") - .into_iter() - .map(|row| { - let oid: u32 = row.get("oid"); - let pg_oper = PgOper { - name: row.get("oprname"), - oprresult: row.get("oprresult"), - }; - (oid, pg_oper) - }) - .collect(); + let pg_oper: BTreeMap<_, _> = query( + &client, + sql!("SELECT oid, oprname, oprresult FROM pg_operator"), + &[], + ) + .await + .expect("pg query failed") + .into_iter() + .map(|row| { + let oid: u32 = row.get("oid"); + let pg_oper = PgOper { + name: row.get("oprname"), + oprresult: row.get("oprresult"), + }; + (oid, pg_oper) + }) + .collect(); let conn_catalog = catalog.for_system_session(); let resolve_type_oid = |item: &str| { diff --git a/src/balancerd/tests/server.rs b/src/balancerd/tests/server.rs index de32c702d46a2..929f43b5017b0 100644 --- a/src/balancerd/tests/server.rs +++ b/src/balancerd/tests/server.rs @@ -47,6 +47,7 @@ use uuid::Uuid; #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_balancer() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca diff --git a/src/environmentd/Cargo.toml b/src/environmentd/Cargo.toml index 89b5fdd24fdad..2079de39216bc 100644 --- a/src/environmentd/Cargo.toml +++ b/src/environmentd/Cargo.toml @@ -69,6 +69,7 @@ mz-persist-client = { path = "../persist-client" } mz-pgrepr = { path = "../pgrepr" } mz-pgwire = { path = "../pgwire" } mz-pgwire-common = { path = "../pgwire-common" } +mz-postgres-util = { path = "../postgres-util" } mz-prof-http = { path = "../prof-http" } mz-repr = { path = "../repr" } mz-secrets = { path = "../secrets" } diff --git a/src/environmentd/src/test_util.rs b/src/environmentd/src/test_util.rs index a418ab783711a..a1b221ad8c6d5 100644 --- a/src/environmentd/src/test_util.rs +++ b/src/environmentd/src/test_util.rs @@ -51,6 +51,9 @@ use mz_persist_client::PersistLocation; use mz_persist_client::cache::PersistClientCache; use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig}; use mz_persist_client::rpc::PersistGrpcPubSubServer; +use mz_postgres_util::{ + Sql, batch_execute as pg_batch_execute, execute as pg_execute, query_one as pg_query_one, sql, +}; use mz_secrets::SecretsController; use mz_server_core::listeners::{ AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled, @@ -668,12 +671,18 @@ impl Listeners { panic!("connection error: {}", err); }; }); - client - .batch_execute(&format!( - "CREATE SCHEMA IF NOT EXISTS consensus_{seed}; - CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};" - )) - .await?; + let consensus_schema = sql!("consensus_{}", seed); + let tsoracle_schema = sql!("tsoracle_{}", seed); + pg_batch_execute( + &client, + sql!( + "CREATE SCHEMA IF NOT EXISTS {}; + CREATE SCHEMA IF NOT EXISTS {};", + consensus_schema, + tsoracle_schema, + ), + ) + .await?; ( format!("{cockroach_url}?options=--search_path=consensus_{seed}") .parse() @@ -888,10 +897,8 @@ impl TestServer { let internal_client = self.connect().internal().await.unwrap(); for flag in flags { - internal_client - .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag)) - .await - .unwrap(); + let query = sql!("ALTER SYSTEM SET {} = true;", Sql::ident(flag)); + pg_batch_execute(&internal_client, query).await.unwrap(); } } @@ -899,10 +906,8 @@ impl TestServer { let internal_client = self.connect().internal().await.unwrap(); for flag in flags { - internal_client - .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag)) - .await - .unwrap(); + let query = sql!("ALTER SYSTEM SET {} = false;", Sql::ident(flag)); + pg_batch_execute(&internal_client, query).await.unwrap(); } } @@ -1231,9 +1236,11 @@ impl TestServerWithRuntime { let mut internal_client = self.connect_internal(postgres::NoTls).unwrap(); for flag in flags { - internal_client - .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag)) - .unwrap(); + let query = sql!("ALTER SYSTEM SET {} = true;", Sql::ident(flag)); + // This uses the synchronous `postgres::Client`; wrappers are async + // and currently only defined for tokio-postgres clients. + #[allow(clippy::disallowed_methods)] + internal_client.batch_execute(query.as_str()).unwrap(); } } @@ -1242,9 +1249,11 @@ impl TestServerWithRuntime { let mut internal_client = self.connect_internal(postgres::NoTls).unwrap(); for flag in flags { - internal_client - .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag)) - .unwrap(); + let query = sql!("ALTER SYSTEM SET {} = false;", Sql::ident(flag)); + // This uses the synchronous `postgres::Client`; wrappers are async + // and currently only defined for tokio-postgres clients. + #[allow(clippy::disallowed_methods)] + internal_client.batch_execute(query.as_str()).unwrap(); } } @@ -1357,8 +1366,11 @@ pub async fn insert_with_deterministic_timestamps( let mut current_ts = get_explain_timestamp(table, &client_read).await; - let insert_query = format!("INSERT INTO {table} VALUES {values}"); + let insert_query = format!("INSERT INTO {} VALUES {values}", Sql::ident(table)); + // The `values` fragment is raw SQL text in test code and cannot currently + // be represented as a composable `Sql` fragment. + #[allow(clippy::disallowed_methods)] let write_future = client_write.execute(&insert_query, &[]); let timestamp_interval = tokio::time::interval(Duration::from_millis(1)); @@ -1397,6 +1409,9 @@ pub async fn get_explain_timestamp_determination( from_suffix: &str, client: &Client, ) -> Result, anyhow::Error> { + // `from_suffix` is a raw SQL suffix used by this test helper and cannot + // currently be represented as a composable `Sql` fragment. + #[allow(clippy::disallowed_methods)] let row = client .query_one( &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"), @@ -1459,66 +1474,97 @@ pub async fn create_postgres_source_with_table<'a>( }); // Create table in Postgres with publication. - let _ = pg_client - .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[]) - .await - .unwrap(); - let _ = pg_client - .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[]) - .await - .unwrap(); - let _ = pg_client - .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[]) - .await - .unwrap(); - let _ = pg_client - .execute( - &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"), - &[], - ) - .await - .unwrap(); + let _ = pg_execute( + &pg_client, + sql!("DROP TABLE IF EXISTS {};", Sql::ident(table_name)), + &[], + ) + .await + .unwrap(); + let _ = pg_execute( + &pg_client, + sql!("DROP PUBLICATION IF EXISTS {};", Sql::ident(source_name)), + &[], + ) + .await + .unwrap(); + // `table_schema` is a raw schema fragment in this test helper and cannot + // currently be represented as a composable `Sql` fragment. + #[allow(clippy::disallowed_methods)] let _ = pg_client .execute( - &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"), + format!("CREATE TABLE {} {table_schema};", Sql::ident(table_name)).as_str(), &[], ) .await .unwrap(); + let _ = pg_execute( + &pg_client, + sql!( + "ALTER TABLE {} REPLICA IDENTITY FULL;", + Sql::ident(table_name) + ), + &[], + ) + .await + .unwrap(); + let _ = pg_execute( + &pg_client, + sql!( + "CREATE PUBLICATION {} FOR TABLE {};", + Sql::ident(source_name), + Sql::ident(table_name) + ), + &[], + ) + .await + .unwrap(); // Create postgres source in Materialize. let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}"); if let Some(password) = password { let password = std::str::from_utf8(password).unwrap(); - mz_client - .batch_execute(&format!("CREATE SECRET s AS '{password}'")) - .await - .unwrap(); - connection_str = format!("{connection_str}, PASSWORD SECRET s"); - } - mz_client - .batch_execute(&format!( - "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})" - )) - .await - .unwrap(); - mz_client - .batch_execute(&format!( - "CREATE SOURCE {source_name} - FROM POSTGRES - CONNECTION pgconn - (PUBLICATION '{source_name}')" - )) + pg_batch_execute( + mz_client, + sql!("CREATE SECRET s AS {}", Sql::literal(password)), + ) .await .unwrap(); + connection_str = format!("{connection_str}, PASSWORD SECRET s"); + } + // `connection_str` is a raw connection-option fragment generated for tests + // and cannot currently be represented as a composable `Sql` fragment. + #[allow(clippy::disallowed_methods)] mz_client - .batch_execute(&format!( - "CREATE TABLE {table_name} - FROM SOURCE {source_name} - (REFERENCE {table_name});" - )) + .batch_execute(format!("CREATE CONNECTION pgconn TO POSTGRES ({connection_str})").as_str()) .await .unwrap(); + pg_batch_execute( + mz_client, + sql!( + "CREATE SOURCE {} \ + FROM POSTGRES \ + CONNECTION pgconn \ + (PUBLICATION {})", + Sql::ident(source_name), + Sql::literal(source_name), + ), + ) + .await + .unwrap(); + pg_batch_execute( + mz_client, + sql!( + "CREATE TABLE {} \ + FROM SOURCE {} \ + (REFERENCE {});", + Sql::ident(table_name), + Sql::ident(source_name), + Sql::ident(table_name), + ), + ) + .await + .unwrap(); let table_name = table_name.to_string(); let source_name = source_name.to_string(); @@ -1526,23 +1572,30 @@ pub async fn create_postgres_source_with_table<'a>( pg_client, move |mz_client: &'a Client, pg_client: &'a Client| { let f: Pin + 'a>> = Box::pin(async move { - mz_client - .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;")) - .await - .unwrap(); - mz_client - .batch_execute("DROP CONNECTION pgconn;") + pg_batch_execute( + mz_client, + sql!("DROP SOURCE {} CASCADE;", Sql::ident(&source_name)), + ) + .await + .unwrap(); + pg_batch_execute(mz_client, sql!("DROP CONNECTION pgconn;")) .await .unwrap(); - let _ = pg_client - .execute(&format!("DROP PUBLICATION {source_name};"), &[]) - .await - .unwrap(); - let _ = pg_client - .execute(&format!("DROP TABLE {table_name};"), &[]) - .await - .unwrap(); + let _ = pg_execute( + pg_client, + sql!("DROP PUBLICATION {};", Sql::ident(&source_name)), + &[], + ) + .await + .unwrap(); + let _ = pg_execute( + pg_client, + sql!("DROP TABLE {};", Sql::ident(&table_name)), + &[], + ) + .await + .unwrap(); }); f }, @@ -1550,22 +1603,23 @@ pub async fn create_postgres_source_with_table<'a>( } pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) { - let current_isolation = mz_client - .query_one("SHOW transaction_isolation", &[]) + let current_isolation = pg_query_one(mz_client, sql!("SHOW transaction_isolation"), &[]) .await .unwrap() .get::<_, String>(0); - mz_client - .batch_execute("SET transaction_isolation = SERIALIZABLE") + pg_batch_execute(mz_client, sql!("SET transaction_isolation = SERIALIZABLE")) .await .unwrap(); Retry::default() .retry_async(|_| async move { - let rows = mz_client - .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[]) - .await - .unwrap() - .get::<_, i64>(0); + let rows = pg_query_one( + mz_client, + sql!("SELECT COUNT(*) FROM {};", Sql::ident(view_name)), + &[], + ) + .await + .unwrap() + .get::<_, i64>(0); if rows == source_rows { Ok(()) } else { @@ -1576,12 +1630,15 @@ pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, s }) .await .unwrap(); - mz_client - .batch_execute(&format!( - "SET transaction_isolation = '{current_isolation}'" - )) - .await - .unwrap(); + pg_batch_execute( + mz_client, + sql!( + "SET transaction_isolation = {}", + Sql::literal(¤t_isolation), + ), + ) + .await + .unwrap(); } // Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery. diff --git a/src/environmentd/tests/auth.rs b/src/environmentd/tests/auth.rs index 8c5d49f0abe4e..640a1efb40828 100644 --- a/src/environmentd/tests/auth.rs +++ b/src/environmentd/tests/auth.rs @@ -159,6 +159,7 @@ fn assert_http_rejected() -> Assert, String)>> { })) } +#[allow(clippy::disallowed_methods)] async fn run_tests<'a>(header: &str, server: &test_util::TestServer, tests: &[TestCase<'a>]) { println!("==> {}", header); for test in tests { @@ -438,6 +439,7 @@ async fn run_tests<'a>(header: &str, server: &test_util::TestServer, tests: &[Te #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_auth_expiry() { // This function verifies that the background expiry refresh task runs. This // is done by starting a web server that awaits the refresh request, which the @@ -1701,6 +1703,7 @@ async fn test_auth_oidc_audience_optional() { /// the OIDC authenticator falls back to password authentication. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] +#[allow(clippy::disallowed_methods)] async fn test_auth_oidc_password_fallback() { let ca = Ca::new_root("test ca").unwrap(); let encoding_key = String::from_utf8(ca.pkey.private_key_to_pem_pkcs8().unwrap()).unwrap(); @@ -1786,6 +1789,7 @@ async fn test_auth_oidc_password_fallback() { /// runtime changes to the oidc_issuer and oidc_audience system parameters. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] +#[allow(clippy::disallowed_methods)] async fn test_auth_oidc_config_switch() { let ca1 = Ca::new_root("test ca 1").unwrap(); let (server_cert, server_key) = ca1 @@ -2685,6 +2689,7 @@ async fn test_auth_admin_superuser() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_auth_admin_superuser_revoked() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -2845,6 +2850,7 @@ async fn test_auth_admin_superuser_revoked() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_auth_deduplication() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3017,6 +3023,7 @@ async fn test_auth_deduplication() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_refresh_task_metrics() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3153,6 +3160,7 @@ async fn test_refresh_task_metrics() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_superuser_can_alter_cluster() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3304,6 +3312,7 @@ async fn test_superuser_can_alter_cluster() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_refresh_dropped_session() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3471,6 +3480,7 @@ async fn test_refresh_dropped_session() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_refresh_dropped_session_lru() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3672,6 +3682,7 @@ async fn test_refresh_dropped_session_lru() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_transient_auth_failures() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3796,6 +3807,7 @@ async fn test_transient_auth_failures() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_transient_auth_failure_on_refresh() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -3931,6 +3943,7 @@ async fn test_transient_auth_failure_on_refresh() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_password_auth() { let metrics_registry = MetricsRegistry::new(); @@ -4021,6 +4034,7 @@ async fn test_password_auth() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_sasl_auth() { let metrics_registry = MetricsRegistry::new(); @@ -4076,6 +4090,7 @@ async fn test_sasl_auth() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_sasl_auth_failure() { let metrics_registry = MetricsRegistry::new(); @@ -4122,6 +4137,7 @@ async fn test_sasl_auth_failure() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_password_auth_superuser() { let metrics_registry = MetricsRegistry::new(); @@ -4177,6 +4193,7 @@ async fn test_password_auth_superuser() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_password_auth_alter_role() { let metrics_registry = MetricsRegistry::new(); @@ -4462,6 +4479,7 @@ async fn test_password_auth_http() { /// because internal_user_metadata was hardcoded to None. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `OPENSSL_init_ssl` on OS `linux` +#[allow(clippy::disallowed_methods)] async fn test_password_auth_http_superuser() { let metrics_registry = MetricsRegistry::new(); diff --git a/src/environmentd/tests/pgwire.rs b/src/environmentd/tests/pgwire.rs index bbe61eeaf4092..ff7f647a641c6 100644 --- a/src/environmentd/tests/pgwire.rs +++ b/src/environmentd/tests/pgwire.rs @@ -33,6 +33,7 @@ use postgres_array::{Array, Dimension}; use tokio::sync::mpsc; #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_bind_params() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -262,6 +263,7 @@ fn test_read_many_rows() { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_conn_startup() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -432,6 +434,7 @@ async fn test_conn_startup() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_conn_user() { let server = test_util::TestHarness::default().start_blocking(); @@ -466,6 +469,7 @@ fn test_conn_user() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_simple_query_no_hang() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -540,6 +544,7 @@ fn test_copy() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_arrays() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -593,6 +598,7 @@ fn test_arrays() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_record_types() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); diff --git a/src/environmentd/tests/server.rs b/src/environmentd/tests/server.rs index bf6e92e085e7a..20b47e019d292 100644 --- a/src/environmentd/tests/server.rs +++ b/src/environmentd/tests/server.rs @@ -72,8 +72,8 @@ use tungstenite::{Error, Message, Utf8Bytes}; use uuid::Uuid; // Allow the use of banned rdkafka methods, because we are just in tests. -#[allow(clippy::disallowed_methods)] #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_persistence() { let data_dir = tempfile::tempdir().unwrap(); let harness = test_util::TestHarness::default() @@ -196,6 +196,7 @@ impl TestServerWithStatementLoggingChecks { /// Helper to get statement logging record counts from the metrics registry. /// Returns (sampled_true_count, sampled_false_count). +#[allow(clippy::disallowed_methods)] fn get_statement_logging_record_counts( server: &TestServerWithStatementLoggingChecks, ) -> (u64, u64) { @@ -229,6 +230,7 @@ fn get_statement_logging_record_counts( } impl Drop for TestServerWithStatementLoggingChecks { + #[allow(clippy::disallowed_methods)] fn drop(&mut self) { // Don't run checks if we're already panicking, as this could mask the original error. if std::thread::panicking() { @@ -324,6 +326,7 @@ fn setup_statement_logging( // Test that we log various kinds of statement whose execution terminates in the coordinator. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_statement_logging_immediate() { let (server, mut client) = setup_statement_logging(1.0, 1.0, ""); @@ -468,6 +471,7 @@ fn test_statement_logging_immediate() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_statement_logging_basic() { let (server, mut client) = setup_statement_logging(1.0, 1.0, ""); client.execute("SELECT 1", &[]).unwrap(); @@ -666,6 +670,7 @@ ORDER BY mseh.began_at", ); } +#[allow(clippy::disallowed_methods)] fn run_throttling_test(use_prepared_statement: bool) { // The `target_data_rate` should be // - high enough so that the `SELECT 1` queries get throttled (even with high CPU load due to @@ -742,6 +747,7 @@ fn test_statement_logging_prepared_statement_throttling() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_statement_logging_subscribes() { let (server, mut client) = setup_statement_logging(1.0, 1.0, ""); let cancel_token = client.cancel_token(); @@ -834,6 +840,7 @@ fn test_statement_logging_subscribes() { /// Relies on two assumptions: /// (1) that the effective sampling rate for the session is 50%, /// (2) that we are using the deterministic testing RNG. +#[allow(clippy::disallowed_methods)] fn test_statement_logging_sampling_inner( server: TestServerWithStatementLoggingChecks, mut client: postgres::Client, @@ -913,6 +920,7 @@ fn test_statement_logging_sampling_constrained() { /// We set `sample_rate=0.0` so no statements are actually sampled/logged, but the /// unsampled_bytes metric still gets incremented for every executed statement. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_statement_logging_unsampled_metrics() { // Use sample_rate=0.0 so statements are not sampled, but unsampled_bytes metric is still tracked. let (server, mut client) = setup_statement_logging(1.0, 0.0, ""); @@ -985,6 +993,7 @@ fn test_statement_logging_unsampled_metrics() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_enable_internal_statement_logging() { let (server, mut client) = setup_statement_logging_core( 1.0, @@ -1027,6 +1036,7 @@ WHERE authenticated_user='mz_system'", // Test the POST and WS server endpoints. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_http_sql() { // Datadriven directives for WebSocket are "ws-text" and "ws-binary" to send // text or binary websocket messages that are the input. Output is @@ -1165,6 +1175,7 @@ fn test_http_sql() { // Test that the server properly handles cancellation requests. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_cancel_long_running_query() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -1206,6 +1217,7 @@ fn test_cancel_long_running_query() { .expect("simple query succeeds after cancellation"); } +#[allow(clippy::disallowed_methods)] fn test_cancellation_cancels_dataflows(query: &str) { // Query that returns how many dataflows are currently installed. // Accounts for the presence of introspection subscribe dataflows by ignoring those. @@ -1285,6 +1297,7 @@ fn test_cancel_insert_select() { ); } +#[allow(clippy::disallowed_methods)] fn test_closing_connection_cancels_dataflows(query: String) { // Query that returns how many dataflows are currently installed. // Accounts for the presence of introspection subscribe dataflows by ignoring those. @@ -1378,6 +1391,7 @@ fn test_closing_connection_for_insert_select() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_storage_usage_collection_interval() { /// Waits for the next storage collection to occur, then returns the /// timestamp at which the collection occurred. The timestamp of the last @@ -1519,6 +1533,7 @@ fn test_storage_usage_collection_interval() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_storage_usage_updates_between_restarts() { let data_dir = tempfile::tempdir().unwrap(); let storage_usage_collection_interval = Duration::from_secs(3); @@ -1569,6 +1584,7 @@ fn test_storage_usage_updates_between_restarts() { // Test that all rows for a single collection use the same timestamp. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_storage_usage_collection_interval_timestamps() { let storage_interval_s = 2; let server = test_util::TestHarness::default() @@ -1626,6 +1642,7 @@ fn test_storage_usage_collection_interval_timestamps() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_old_storage_usage_records_are_reaped_on_restart() { let now = Arc::new(Mutex::new(0)); let now_fn = { @@ -1726,6 +1743,7 @@ fn test_old_storage_usage_records_are_reaped_on_restart() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_storage_usage_records_are_not_cleared_on_restart() { let data_dir = tempfile::tempdir().unwrap(); let collection_interval = Duration::from_secs(1); @@ -1811,6 +1829,7 @@ fn test_storage_usage_records_are_not_cleared_on_restart() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_default_cluster_sizes() { let server = test_util::TestHarness::default() .with_builtin_system_cluster_replica_size("scale=1,workers=1".to_string()) @@ -1844,6 +1863,7 @@ fn test_default_cluster_sizes() { #[mz_ore::test] #[ignore] // TODO: Reenable when https://github.com/MaterializeInc/database-issues/issues/6931 is fixed +#[allow(clippy::disallowed_methods)] fn test_max_request_size() { let statement = "SELECT $1::text"; let statement_size = statement.bytes().count(); @@ -1894,6 +1914,7 @@ fn test_max_request_size() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_max_statement_batch_size() { let statement = "SELECT 1;"; let statement_size = statement.bytes().count(); @@ -1980,6 +2001,7 @@ fn test_max_statement_batch_size() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_mz_system_user_admin() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server @@ -2198,6 +2220,7 @@ fn test_http_options_param() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_max_connections_on_all_interfaces() { let query = "SELECT 1"; let server = test_util::TestHarness::default() @@ -2343,6 +2366,7 @@ fn test_max_connections_on_all_interfaces() { // Test max_connections and superuser_reserved_connections. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_max_connections_limits() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -2552,6 +2576,7 @@ async fn test_max_connections_limits() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_concurrent_id_reuse() { let server = test_util::TestHarness::default().start().await; @@ -2798,6 +2823,7 @@ fn test_internal_ws_auth() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_leader_promotion_always_using_deploy_generation() { let tmpdir = TempDir::new().unwrap(); let harness = test_util::TestHarness::default() @@ -2843,6 +2869,7 @@ fn test_leader_promotion_always_using_deploy_generation() { #[mz_ore::test(tokio::test(flavor = "multi_thread"))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_leader_promotion_mixed_code_version() { let tmpdir = TempDir::new().unwrap(); let this_version = mz_environmentd::BUILD_INFO.semver_version(); @@ -2904,6 +2931,7 @@ async fn test_leader_promotion_mixed_code_version() { // Test that websockets observe cancellation. #[mz_ore::test] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `epoll_wait` on OS `linux` +#[allow(clippy::disallowed_methods)] fn test_cancel_ws() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -2948,6 +2976,7 @@ fn test_cancel_ws() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn smoketest_webhook_source() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3067,6 +3096,7 @@ async fn smoketest_webhook_source() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_invalid_webhook_body() { let server = test_util::TestHarness::default().start_blocking(); @@ -3150,6 +3180,7 @@ fn test_invalid_webhook_body() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_webhook_duplicate_headers() { let server = test_util::TestHarness::default().start_blocking(); @@ -3190,6 +3221,7 @@ fn test_webhook_duplicate_headers() { // Test that websockets observe cancellation and leave the transaction in an idle state. #[mz_ore::test] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `epoll_wait` on OS `linux` +#[allow(clippy::disallowed_methods)] fn test_github_20262() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -3257,6 +3289,7 @@ fn test_github_20262() { // See database-issues#6134. #[mz_ore::test] #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `epoll_wait` on OS `linux` +#[allow(clippy::disallowed_methods)] fn test_cancel_read_then_write() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -3404,6 +3437,7 @@ async fn test_http_metrics() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 2))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn webhook_concurrent_actions() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3558,6 +3592,7 @@ async fn webhook_concurrent_actions() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn webhook_concurrency_limit() { let concurrency_limit = 15; let server = test_util::TestHarness::default() @@ -3648,6 +3683,7 @@ fn webhook_concurrency_limit() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn webhook_too_large_request() { let metrics_registry = MetricsRegistry::new(); let server = test_util::TestHarness::default() @@ -3719,6 +3755,7 @@ fn webhook_too_large_request() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_webhook_url_notice() { let server = test_util::TestHarness::default().start_blocking(); let (tx, mut rx) = futures::channel::mpsc::unbounded(); @@ -3782,6 +3819,7 @@ fn test_webhook_url_notice() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 2))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn webhook_concurrent_swap() { let server = test_util::TestHarness::default().start().await; let mut client = server.connect().await.unwrap(); @@ -3951,6 +3989,7 @@ async fn webhook_concurrent_swap() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn copy_from() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -3976,6 +4015,7 @@ fn copy_from() { // Test that a cluster dropped mid transaction results in an error. #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn concurrent_cluster_drop() { let server = test_util::TestHarness::default().start_blocking(); let mut txn_client = server.connect(postgres::NoTls).unwrap(); @@ -4012,6 +4052,7 @@ fn concurrent_cluster_drop() { // Test connection ID properties. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_connection_id() { let harness = test_util::TestHarness::default(); let envid = harness.environment_id.organization_id().as_u128(); @@ -4033,6 +4074,7 @@ async fn test_connection_id() { // Test connection ID properties. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_github_25388() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -4134,6 +4176,7 @@ async fn test_github_25388() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_webhook_source_batch_interval() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -4287,6 +4330,7 @@ async fn test_startup_cluster_notice_with_http_options() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_startup_cluster_notice() { let server = test_util::TestHarness::default().start().await; @@ -4403,6 +4447,7 @@ async fn test_startup_cluster_notice() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_durable_oids() { let data_dir = tempfile::tempdir().unwrap(); let harness = test_util::TestHarness::default().data_directory(data_dir.path()); @@ -4442,6 +4487,7 @@ fn test_durable_oids() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_double_encoded_json() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.expect("success"); @@ -4548,6 +4594,7 @@ async fn test_double_encoded_json() { // Tests cert reloading of environmentd. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_cert_reloading() { let ca = Ca::new_root("test ca").unwrap(); let (server_cert, server_key) = ca @@ -4753,6 +4800,7 @@ async fn test_cert_reloading() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_builtin_connection_alterations_are_preserved_across_restarts() { let data_dir = tempfile::tempdir().unwrap(); let harness = test_util::TestHarness::default() @@ -4852,6 +4900,7 @@ fn test_builtin_connection_alterations_are_preserved_across_restarts() { #[mz_ore::test] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] fn test_webhook_request_compression() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); diff --git a/src/environmentd/tests/sql.rs b/src/environmentd/tests/sql.rs index 0d466e1596750..fa406be6b9eb3 100644 --- a/src/environmentd/tests/sql.rs +++ b/src/environmentd/tests/sql.rs @@ -109,6 +109,7 @@ impl MockHttpServer { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_no_block() { // We manually time out the test because it's better than relying on CI to time out, because // an actual failure (as opposed to a CI timeout) causes `services.log` to be uploaded. @@ -211,6 +212,7 @@ async fn test_no_block() { /// Test that dropping a connection while a source is undergoing purification /// does not crash the server. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_drop_connection_race() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -311,6 +313,7 @@ async fn test_drop_connection_race() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_time() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -360,6 +363,7 @@ fn test_time() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_consolidation() { let server = test_util::TestHarness::default().start_blocking(); let mut client_writes = server.connect(postgres::NoTls).unwrap(); @@ -389,6 +393,7 @@ fn test_subscribe_consolidation() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_negative_diffs() { let server = test_util::TestHarness::default().start_blocking(); let mut client_writes = server.connect(postgres::NoTls).unwrap(); @@ -440,6 +445,7 @@ fn test_subscribe_negative_diffs() { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_empty_subscribe_notice() { let server = test_util::TestHarness::default() .with_now(NOW_ZERO.clone()) @@ -479,6 +485,7 @@ async fn test_empty_subscribe_notice() { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_empty_subscribe_error() { let server = test_util::TestHarness::default() .with_now(NOW_ZERO.clone()) @@ -500,6 +507,7 @@ async fn test_empty_subscribe_error() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_basic() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -663,6 +671,7 @@ fn test_subscribe_basic() { /// batches and we won't yet insert a second row, we know that if we've seen a /// data row we will also see one progressed message. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_progress() { let server = test_util::TestHarness::default().start_blocking(); @@ -780,6 +789,7 @@ fn test_subscribe_progress() { // Verifies that subscribing to non-nullable columns with progress information // turns them into nullable columns. See database-issues#1946. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_progress_non_nullable_columns() { let server = test_util::TestHarness::default().start_blocking(); let mut client_writes = server.connect(postgres::NoTls).unwrap(); @@ -830,6 +840,7 @@ fn test_subscribe_progress_non_nullable_columns() { /// Verifies that we get continuous progress messages, regardless of if we /// receive data or not. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subcribe_continuous_progress() { let server = test_util::TestHarness::default().start_blocking(); let mut client_writes = server.connect(postgres::NoTls).unwrap(); @@ -914,6 +925,7 @@ fn test_subcribe_continuous_progress() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_fetch_timeout() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -1010,6 +1022,7 @@ fn test_subscribe_fetch_timeout() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_fetch_wait() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -1072,6 +1085,7 @@ fn test_subscribe_fetch_wait() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_empty_upper_frontier() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -1092,6 +1106,7 @@ fn test_subscribe_empty_upper_frontier() { // Tests that a client that launches a non-terminating SUBSCRIBE and disconnects // does not keep the server alive forever. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 2))] +#[allow(clippy::disallowed_methods)] async fn test_subscribe_shutdown() { let server = test_util::TestHarness::default().start().await; @@ -1120,6 +1135,7 @@ async fn test_subscribe_shutdown() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_table_rw_timestamps() { let server = test_util::TestHarness::default().start_blocking(); let mut client_interactive = server.connect(postgres::NoTls).unwrap(); @@ -1204,6 +1220,7 @@ fn test_subscribe_table_rw_timestamps() { // Tests that temporary views created by one connection cannot be viewed // by another connection. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_temporary_views() { let server = test_util::TestHarness::default().start_blocking(); let mut client_a = server.connect(postgres::NoTls).unwrap(); @@ -1234,6 +1251,7 @@ fn test_temporary_views() { // Test EXPLAIN TIMESTAMP with tables. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_explain_timestamp_table() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -1269,6 +1287,7 @@ lower: // Test `EXPLAIN TIMESTAMP AS JSON` #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_explain_timestamp_json() { let server = test_util::TestHarness::default().start_blocking(); let mut client = server.connect(postgres::NoTls).unwrap(); @@ -1291,6 +1310,7 @@ fn test_explain_timestamp_json() { // // GitHub issue # 18950 #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_transactional_explain_timestamps() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -1414,6 +1434,7 @@ fn test_transactional_explain_timestamps() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(coverage, ignore)] // https://github.com/MaterializeInc/database-issues/issues/5600 #[ignore] // TODO: Reenable when https://github.com/MaterializeInc/database-issues/issues/8491 is fixed +#[allow(clippy::disallowed_methods)] async fn test_utilization_hold() { const THIRTY_DAYS_MS: u64 = 30 * 24 * 60 * 60 * 1000; // `mz_catalog_server` tests indexes, `quickstart` tests tables. @@ -1551,6 +1572,7 @@ async fn test_utilization_hold() { // of cancelled (sends a pgwire cancel request on a new connection). #[ignore] // TODO(necaris): Re-enable this as soon as possible #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 2))] +#[allow(clippy::disallowed_methods)] async fn test_github_12546() { let server = test_util::TestHarness::default() .with_propagate_crashes(false) @@ -1602,6 +1624,7 @@ async fn test_github_12546() { /// Regression test for database-issues#3721. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_github_3721() { let server = test_util::TestHarness::default().start_blocking(); @@ -1659,6 +1682,7 @@ fn test_github_3721() { #[mz_ore::test] // Tests github issue database-issues#3761 +#[allow(clippy::disallowed_methods)] fn test_subscribe_outlive_cluster() { let server = test_util::TestHarness::default().start_blocking(); @@ -1693,6 +1717,7 @@ fn test_subscribe_outlive_cluster() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_read_then_write_serializability() { let server = test_util::TestHarness::default().start_blocking(); @@ -1743,6 +1768,7 @@ fn test_read_then_write_serializability() { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_timestamp_recovery() { let now = Arc::new(Mutex::new(1)); let now_fn = { @@ -1788,6 +1814,7 @@ async fn test_strong_session_serializability() { test_session_linearizability("strong session serializable").await; } +#[allow(clippy::disallowed_methods)] async fn test_session_linearizability(isolation_level: &str) { // Set the timestamp to zero for deterministic initial timestamps. let now = Arc::new(Mutex::new(0)); @@ -1930,6 +1957,7 @@ fn test_internal_users() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_internal_users_cluster() { let server = test_util::TestHarness::default().start_blocking(); @@ -1951,6 +1979,7 @@ fn test_internal_users_cluster() { // Tests that you can have simultaneous connections on the internal and external ports without // crashing #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_internal_ports() { let server = test_util::TestHarness::default().start_blocking(); @@ -2008,6 +2037,7 @@ fn test_internal_ports() { // doesn't allow you to specify a connection and expect a failure which is // needed for this test. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_alter_system_invalid_param() { let server = test_util::TestHarness::default().start_blocking(); @@ -2037,6 +2067,7 @@ fn test_alter_system_invalid_param() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_concurrent_writes() { let server = test_util::TestHarness::default().start_blocking(); @@ -2094,6 +2125,7 @@ fn test_concurrent_writes() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_load_generator() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -2140,6 +2172,7 @@ fn test_load_generator() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_support_user_permissions() { let server = test_util::TestHarness::default().start_blocking(); @@ -2183,6 +2216,7 @@ fn test_support_user_permissions() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_idle_in_transaction_session_timeout() { #[track_caller] fn assert_db_error(error: tokio_postgres::Error, message: &str) { @@ -2246,6 +2280,7 @@ fn test_idle_in_transaction_session_timeout() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_peek_on_dropped_cluster() { let server = test_util::TestHarness::default() .unsafe_mode() @@ -2303,6 +2338,7 @@ fn test_peek_on_dropped_cluster() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_emit_timestamp_notice() { let server = test_util::TestHarness::default().start_blocking(); @@ -2369,6 +2405,7 @@ fn test_emit_timestamp_notice() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_isolation_level_notice() { let server = test_util::TestHarness::default().start_blocking(); @@ -2404,6 +2441,7 @@ fn test_isolation_level_notice() { } #[test] // allow(test-attribute) +#[allow(clippy::disallowed_methods)] fn test_emit_tracing_notice() { let server = test_util::TestHarness::default() .with_enable_tracing(true) @@ -2443,6 +2481,7 @@ fn test_emit_tracing_notice() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_subscribe_on_dropped_source() { fn test_subscribe_on_dropped_source_inner( server: &TestServerWithRuntime, @@ -2536,6 +2575,7 @@ fn test_subscribe_on_dropped_source() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_dont_drop_sinks_twice() { let server = test_util::TestHarness::default().start_blocking(); @@ -2581,6 +2621,7 @@ fn test_dont_drop_sinks_twice() { // This can almost be tested with SLT using the simple directive, but // we have no way to disconnect sessions using SLT. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_mz_sessions() { let server = test_util::TestHarness::default().start_blocking(); @@ -2715,6 +2756,7 @@ fn test_mz_sessions() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_auto_run_on_introspection_feature_enabled() { // unsafe_mode enables the feature as a whole let server = test_util::TestHarness::default() @@ -2809,6 +2851,7 @@ fn test_auto_run_on_introspection_feature_enabled() { const INTROSPECTION_NOTICE: &str = "results from querying these objects depend on the current values of the `cluster` and `cluster_replica` session variables"; #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_auto_run_on_introspection_feature_disabled() { // unsafe_mode enables the feature as a whole let server = test_util::TestHarness::default() @@ -2885,6 +2928,7 @@ fn test_auto_run_on_introspection_feature_disabled() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_auto_run_on_introspection_per_replica_relations() { // unsafe_mode enables the feature as a whole let server = test_util::TestHarness::default() @@ -2964,6 +3008,7 @@ fn test_auto_run_on_introspection_per_replica_relations() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_pg_cancel_backend() { mz_ore::test::init_logging(); let server = test_util::TestHarness::default().start_blocking(); @@ -3087,6 +3132,7 @@ fn test_pg_cancel_backend() { // Test params in interesting places. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_params() { mz_ore::test::init_logging(); let server = test_util::TestHarness::default().start_blocking(); @@ -3141,6 +3187,7 @@ fn test_params() { // Test pg_cancel_backend after the authenticated role is dropped. #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_pg_cancel_dropped_role() { let server = test_util::TestHarness::default().start_blocking(); let dropped_role = "r1"; @@ -3191,6 +3238,7 @@ fn test_pg_cancel_dropped_role() { } #[mz_ore::test] +#[allow(clippy::disallowed_methods)] fn test_peek_on_dropped_indexed_view() { let server = test_util::TestHarness::default().start_blocking(); @@ -3273,6 +3321,7 @@ fn test_peek_on_dropped_indexed_view() { /// Test AS OF in EXPLAIN. This output will only differ from the non-ASOF versions with RETAIN /// HISTORY, where the object and its indexes have differing compaction policies. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_explain_as_of() { // TODO: This would be better in testdrive, but we'd need to support negative intervals in AS // OF first. @@ -3310,6 +3359,7 @@ async fn test_explain_as_of() { // Test that RETAIN HISTORY results in the since and upper being separated by the specified amount. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[ignore] // TODO: Reenable when database-issues#7450 is fixed +#[allow(clippy::disallowed_methods)] async fn test_retain_history() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3445,6 +3495,7 @@ async fn test_retain_history() { } #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_temporal_static_queries() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3545,6 +3596,7 @@ async fn test_temporal_static_queries() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_constant_materialized_view() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3581,6 +3633,7 @@ async fn test_constant_materialized_view() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_explain_timestamp_blocking() { let server = test_util::TestHarness::default().start().await; server @@ -3638,6 +3691,7 @@ async fn test_explain_timestamp_on_const_with_temporal() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_cancel_linearize_reads() { let server = test_util::TestHarness::default().start().await; server @@ -3685,6 +3739,7 @@ async fn test_cancel_linearize_reads() { #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_cancel_linearize_read_then_writes() { let server = test_util::TestHarness::default().start().await; server @@ -3750,6 +3805,7 @@ async fn test_cancel_linearize_read_then_writes() { // Test that builtin objects are created in the schemas they advertise in builtin.rs. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_builtin_schemas() { let server = test_util::TestHarness::default().start().await; let client = server.connect().await.unwrap(); @@ -3795,6 +3851,7 @@ async fn test_builtin_schemas() { // others to fail. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_serialized_ddl_serial() { let server = test_util::TestHarness::default().start().await; let mut handles = Vec::new(); @@ -3833,6 +3890,7 @@ async fn test_serialized_ddl_serial() { // Test that serial DDLs are cancellable. #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_serialized_ddl_cancel() { let server = test_util::TestHarness::default() .unsafe_mode() diff --git a/src/environmentd/tests/timezones.rs b/src/environmentd/tests/timezones.rs index ede5bd7420894..1e9d32f223808 100644 --- a/src/environmentd/tests/timezones.rs +++ b/src/environmentd/tests/timezones.rs @@ -37,6 +37,7 @@ use mz_ore::assert_none; use mz_pgrepr::Interval; #[mz_ore::test(tokio::test(flavor = "multi_thread", worker_threads = 1))] +#[allow(clippy::disallowed_methods)] async fn test_pg_timezone_abbrevs() { let server = test_util::TestHarness::default() .unsafe_mode() diff --git a/src/environmentd/tests/tracing.rs b/src/environmentd/tests/tracing.rs index e190dd7b244ca..f04e92c96fc4b 100644 --- a/src/environmentd/tests/tracing.rs +++ b/src/environmentd/tests/tracing.rs @@ -16,6 +16,7 @@ use tracing_capture::SharedStorage; // Test that expected spans are generated for various queries. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] // allow(test-attribute) #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_expected_spans() { let storage = SharedStorage::default(); // Sets of expected span names and SQL statements that should produce that span. @@ -88,6 +89,7 @@ async fn test_expected_spans() { // of the most common code paths. This is not guaranteed to exhaustively test all of our tracing. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] // allow(test-attribute) #[cfg_attr(miri, ignore)] // too slow +#[allow(clippy::disallowed_methods)] async fn test_secrets() { let storage = SharedStorage::default(); diff --git a/src/fivetran-destination/Cargo.toml b/src/fivetran-destination/Cargo.toml index 356fdf971907e..26407121e3e8e 100644 --- a/src/fivetran-destination/Cargo.toml +++ b/src/fivetran-destination/Cargo.toml @@ -18,6 +18,7 @@ futures = "0.3.31" itertools = "0.14.0" mz-ore = { path = "../ore", features = ["cli", "id_gen"], default-features = false } mz-pgrepr = { path = "../pgrepr", default-features = false } +mz-postgres-util = { path = "../postgres-util" } mz-sql-parser = { path = "../sql-parser", default-features = false } openssl = { version = "0.10.75", features = ["vendored"] } postgres-openssl = "0.5.2" diff --git a/src/fivetran-destination/src/destination.rs b/src/fivetran-destination/src/destination.rs index 50b022f63038a..d5e6634fdfb04 100644 --- a/src/fivetran-destination/src/destination.rs +++ b/src/fivetran-destination/src/destination.rs @@ -7,11 +7,9 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use std::borrow::Cow; - use futures::Future; use mz_ore::retry::{Retry, RetryResult}; -use postgres_protocol::escape; +use mz_postgres_util::{Sql, sql}; use tonic::{Request, Response, Status}; use crate::error::{Context, OpError, OpErrorKind}; @@ -249,24 +247,22 @@ fn to_grpc(response: Result) -> Result, Status> { /// Metadata about a column that is relevant to operations peformed by the destination. #[derive(Debug)] struct ColumnMetadata { - /// Name of the column in the destination table, with necessary characters escaped. - escaped_name: String, + /// Name of the column in the destination table, escaped as an identifier. + ident: Sql, /// Type of the column in the destination table. - ty: Cow<'static, str>, + ty: Sql, /// Is this column a primary key. is_primary: bool, } impl ColumnMetadata { - /// Returns a [`String`] that is suitable for use when creating a table. - fn to_column_def(&self) -> String { - let mut def = format!("{} {}", self.escaped_name, self.ty); - + /// Returns SQL suitable for use in a `CREATE TABLE` column definition. + fn to_column_def(&self) -> Sql { if self.is_primary { - def.push_str(" NOT NULL"); + sql!("{} {} NOT NULL", self.ident.clone(), self.ty.clone()) + } else { + sql!("{} {}", self.ident.clone(), self.ty.clone()) } - - def } } @@ -274,28 +270,27 @@ impl TryFrom<&crate::fivetran_sdk::Column> for ColumnMetadata { type Error = OpError; fn try_from(value: &crate::fivetran_sdk::Column) -> Result { - let escaped_name = escape::escape_identifier(&value.name); - - let mut ty: Cow<'static, str> = utils::to_materialize_type(value.r#type())?.into(); + let ident = Sql::ident(&value.name); + let base_ty = Sql::new(utils::to_materialize_type(value.r#type())?); let params = value.params.as_ref().and_then(|p| p.params.as_ref()); - match (value.r#type(), params) { + let ty = match (value.r#type(), params) { (DataType::Decimal, Some(Params::Decimal(DecimalParams { precision, scale }))) => { - ty.to_mut().push_str(&format!("({}, {})", precision, scale)); + sql!("{}({}, {})", base_ty, *precision, *scale) } (other, Some(Params::Decimal(DecimalParams { .. }))) => { let msg = format!("found decimal params for {other:?} data type!"); return Err(OpError::new(OpErrorKind::InvariantViolated(msg))); } // TODO(parkmycar): Also handle the string datatype params here. - _ => (), + _ => base_ty, }; let is_primary = value.primary_key || value.name.to_lowercase() == FIVETRAN_SYSTEM_COLUMN_ID; Ok(ColumnMetadata { - escaped_name, + ident, ty, is_primary, }) diff --git a/src/fivetran-destination/src/destination/config.rs b/src/fivetran-destination/src/destination/config.rs index 56a09cc032d3b..498723514f7db 100644 --- a/src/fivetran-destination/src/destination/config.rs +++ b/src/fivetran-destination/src/destination/config.rs @@ -7,6 +7,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use mz_postgres_util::{query_one, sql}; use openssl::ssl::{SslConnector, SslMethod}; use openssl::x509::X509; use openssl::x509::store::X509StoreBuilder; @@ -111,13 +112,13 @@ async fn test_connect(config: BTreeMap) -> Result<(), OpError> { async fn test_permissions(config: BTreeMap) -> Result<(), OpError> { let (dbname, client) = connect(config).await?; - let row = client - .query_one( - "SELECT has_database_privilege($1, 'CREATE') OR mz_is_superuser() AS has_create", - &[&dbname], - ) - .await - .context("querying privileges")?; + let row = query_one( + &client, + sql!("SELECT has_database_privilege($1, 'CREATE') OR mz_is_superuser() AS has_create"), + &[&dbname], + ) + .await + .context("querying privileges")?; let has_create: bool = row.get("has_create"); if !has_create { diff --git a/src/fivetran-destination/src/destination/ddl.rs b/src/fivetran-destination/src/destination/ddl.rs index 5572ca369ec59..4feebe8d73bb1 100644 --- a/src/fivetran-destination/src/destination/ddl.rs +++ b/src/fivetran-destination/src/destination/ddl.rs @@ -7,10 +7,8 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use itertools::Itertools; use mz_pgrepr::Type; -use mz_sql_parser::ast::{Ident, UnresolvedItemName}; -use postgres_protocol::escape; +use mz_postgres_util::{Sql, batch_execute, execute, query, sql}; use tokio_postgres::Client; use crate::destination::{ColumnMetadata, FIVETRAN_SYSTEM_COLUMN_DELETE, config}; @@ -37,18 +35,20 @@ pub async fn describe_table( table: &str, ) -> Result, OpError> { let table_id = { - let rows = client - .query( + let rows = query( + client, + sql!( r#"SELECT t.id FROM mz_tables t JOIN mz_schemas s ON s.id = t.schema_id JOIN mz_databases d ON d.id = s.database_id WHERE d.name = $1 AND s.name = $2 AND t.name = $3 - "#, - &[&database, &schema, &table], - ) - .await - .context("fetching table ID")?; + "# + ), + &[&database, &schema, &table], + ) + .await + .context("fetching table ID")?; match &*rows { [] => return Ok(None), @@ -63,7 +63,10 @@ pub async fn describe_table( }; let columns = { - let stmt = r#"SELECT + let rows = query( + client, + sql!( + r#"SELECT name, type_oid, type_mod, @@ -72,12 +75,12 @@ pub async fn describe_table( LEFT JOIN mz_internal.mz_comments AS coms ON cols.id = coms.id AND cols.position = coms.object_sub_id WHERE cols.id = $2 - ORDER BY cols.position ASC"#; - - let rows = client - .query(stmt, &[&PRIMARY_KEY_MAGIC_STRING, &table_id]) - .await - .context("fetching table columns")?; + ORDER BY cols.position ASC"# + ), + &[&PRIMARY_KEY_MAGIC_STRING, &table_id], + ) + .await + .context("fetching table columns")?; let mut columns = vec![]; for row in rows { @@ -108,10 +111,8 @@ pub async fn describe_table( pub async fn handle_create_table(request: CreateTableRequest) -> Result<(), OpError> { let table = request.table.ok_or(OpErrorKind::FieldMissing("table"))?; - - let schema = Ident::new(&request.schema_name)?; - let qualified_table_name = - UnresolvedItemName::qualified(&[schema.clone(), Ident::new(&table.name)?]); + let schema_name = Sql::ident(&request.schema_name); + let qualified_table_name = sql!("{}.{}", schema_name.clone(), Sql::ident(&table.name)); let mut total_columns = table.columns; // We want to make sure the deleted system column is always provided. @@ -138,14 +139,17 @@ pub async fn handle_create_table(request: CreateTableRequest) -> Result<(), OpEr .map(ColumnMetadata::try_from) .collect::, OpError>>()?; - let defs = columns.iter().map(|col| col.to_column_def()).join(","); - let sql = format!( - r#"BEGIN; CREATE SCHEMA IF NOT EXISTS {schema}; COMMIT; - BEGIN; CREATE TABLE {qualified_table_name} ({defs}); COMMIT;"#, + let defs = Sql::join(columns.iter().map(|col| col.to_column_def()), ","); + let create_table_sql = sql!( + "BEGIN; CREATE SCHEMA IF NOT EXISTS {}; COMMIT; \ + BEGIN; CREATE TABLE {} ({}); COMMIT;", + schema_name, + qualified_table_name.clone(), + defs ); let (_dbname, client) = config::connect(request.configuration).await?; - client.batch_execute(&sql).await?; + batch_execute(&client, create_table_sql).await?; // TODO(parkmycar): This is an ugly hack! // @@ -153,13 +157,13 @@ pub async fn handle_create_table(request: CreateTableRequest) -> Result<(), OpEr // those columns as primary keys. But Materialize doesn't support primary keys, so we need to // store this metadata somewhere else. For now we do it in a COMMENT. for column in columns.iter().filter(|col| col.is_primary) { - let stmt = format!( - "COMMENT ON COLUMN {qualified_table_name}.{column_name} IS {magic_comment}", - column_name = column.escaped_name, - magic_comment = escape::escape_literal(PRIMARY_KEY_MAGIC_STRING), + let stmt = sql!( + "COMMENT ON COLUMN {}.{} IS {}", + qualified_table_name.clone(), + column.ident.clone(), + Sql::literal(PRIMARY_KEY_MAGIC_STRING), ); - client - .execute(&stmt, &[]) + execute(&client, stmt, &[]) .await .context("setting magic primary key comment")?; } diff --git a/src/fivetran-destination/src/destination/dml.rs b/src/fivetran-destination/src/destination/dml.rs index 67703b8a53bd7..22daa2a189261 100644 --- a/src/fivetran-destination/src/destination/dml.rs +++ b/src/fivetran-destination/src/destination/dml.rs @@ -14,9 +14,7 @@ use std::time::{Duration, SystemTime}; use async_compression::tokio::bufread::{GzipDecoder, ZstdDecoder}; use futures::{StreamExt, TryStreamExt}; -use itertools::Itertools; -use mz_sql_parser::ast::{Ident, UnresolvedItemName}; -use postgres_protocol::escape; +use mz_postgres_util::{Sql, execute, query_one, sql}; use prost::bytes::{BufMut, BytesMut}; use sha2::{Digest, Sha256}; use tokio::fs::File; @@ -50,19 +48,21 @@ pub async fn handle_truncate_table(request: TruncateRequest) -> Result<(), OpErr let (_dbname, client) = config::connect(request.configuration).await?; - let exists_stmt = r#" - SELECT EXISTS( - SELECT 1 FROM mz_tables t - LEFT JOIN mz_schemas s - ON t.schema_id = s.id - WHERE s.name = $1 AND t.name = $2 - )"# - .to_string(); - let exists: bool = client - .query_one(&exists_stmt, &[&request.schema_name, &request.table_name]) - .await - .map(|row| row.get(0)) - .context("checking existence")?; + let exists: bool = query_one( + &client, + sql!( + r#"SELECT EXISTS( + SELECT 1 FROM mz_tables t + LEFT JOIN mz_schemas s + ON t.schema_id = s.id + WHERE s.name = $1 AND t.name = $2 + )"# + ), + &[&request.schema_name, &request.table_name], + ) + .await + .map(|row| row.get(0)) + .context("checking existence")?; // Truncates can happen at any point in time, even if the table hasn't been created yet. We // want to no-op in this case. @@ -70,23 +70,22 @@ pub async fn handle_truncate_table(request: TruncateRequest) -> Result<(), OpErr return Ok(()); } - let sql = match request.soft { - None => format!( + let query = match request.soft { + None => sql!( "DELETE FROM {}.{} WHERE {} < $1", - escape::escape_identifier(&request.schema_name), - escape::escape_identifier(&request.table_name), - escape::escape_identifier(&request.synced_column), + Sql::ident(&request.schema_name), + Sql::ident(&request.table_name), + Sql::ident(&request.synced_column), ), - Some(soft) => format!( + Some(soft) => sql!( "UPDATE {}.{} SET {} = true WHERE {} < $1", - escape::escape_identifier(&request.schema_name), - escape::escape_identifier(&request.table_name), - escape::escape_identifier(&soft.deleted_column), - escape::escape_identifier(&request.synced_column), + Sql::ident(&request.schema_name), + Sql::ident(&request.table_name), + Sql::ident(&soft.deleted_column), + Sql::ident(&request.synced_column), ), }; - client - .execute(&sql, &[&delete_before]) + execute(&client, query, &[&delete_before]) .await .context("truncating")?; @@ -248,35 +247,36 @@ async fn replace_files( return Ok(()); } - let qualified_table_name = format!( - "{}.{}", - escape::escape_identifier(schema), - escape::escape_identifier(&table.name) - ); + let qualified_table_name = sql!("{}.{}", Sql::ident(schema), Sql::ident(&table.name)); // First delete all of the matching rows. - let matching_cols = columns.iter().filter(|col| col.is_primary); - let delete_stmt = format!( - r#" - DELETE FROM {qualified_table_name} - WHERE ({cols}) IN ( - SELECT {cols} - FROM {qualified_temp_table_name} - )"#, - cols = matching_cols.map(|col| &col.escaped_name).join(","), + let matching_cols = Sql::join( + columns + .iter() + .filter(|col| col.is_primary) + .map(|col| col.ident.clone()), + ",", ); - let rows_changed = client.execute(&delete_stmt, &[]).await?; + let delete_stmt = sql!( + "DELETE FROM {} WHERE ({}) IN (SELECT {} FROM {})", + qualified_table_name.clone(), + matching_cols.clone(), + matching_cols, + qualified_temp_table_name.clone(), + ); + let rows_changed = execute(client, delete_stmt, &[]).await?; tracing::info!(rows_changed, "deleted rows from {qualified_table_name}"); // Then re-insert rows. - let insert_stmt = format!( - r#" - INSERT INTO {qualified_table_name} ({cols}) - SELECT {cols} FROM {qualified_temp_table_name} - "#, - cols = columns.iter().map(|col| &col.escaped_name).join(","), + let all_cols = Sql::join(columns.iter().map(|col| col.ident.clone()), ","); + let insert_stmt = sql!( + "INSERT INTO {} ({}) SELECT {} FROM {}", + qualified_table_name.clone(), + all_cols.clone(), + all_cols, + qualified_temp_table_name.clone(), ); - let rows_changed = client.execute(&insert_stmt, &[]).await?; + let rows_changed = execute(client, insert_stmt, &[]).await?; tracing::info!(rows_changed, "inserted rows to {qualified_table_name}"); // Clear out our scratch table. @@ -297,37 +297,37 @@ async fn update_files( ) -> Result<(), OpError> { // TODO(benesch): this is hideously inefficient. - let mut assignments = vec![]; - let mut filters = vec![]; + let mut assignments: Vec = vec![]; + let mut filters: Vec = vec![]; for (i, column) in table.columns.iter().enumerate() { + let name = Sql::ident(&column.name); + let param = Sql::param(i + 1); if column.primary_key { - filters.push(format!( - "{} = ${}", - escape::escape_identifier(&column.name), - i + 1 - )); + filters.push(sql!("{} = {}", name, param)); } else { - assignments.push(format!( - "{name} = CASE ${p}::text WHEN {unmodified_string} THEN {name} ELSE ${p}::{ty} END", - name = escape::escape_identifier(&column.name), - p = i + 1, - unmodified_string = escape::escape_literal(&file_config.unmodified_string), - ty = utils::to_materialize_type(column.r#type())?, + assignments.push(sql!( + "{} = CASE {}::text WHEN {} THEN {} ELSE {}::{} END", + name.clone(), + param.clone(), + Sql::literal(&file_config.unmodified_string), + name, + param, + Sql::new(utils::to_materialize_type(column.r#type())?), )); } } - let update_stmt = format!( + let update_stmt = sql!( "UPDATE {}.{} SET {} WHERE {}", - escape::escape_identifier(schema), - escape::escape_identifier(&table.name), - assignments.join(","), - filters.join(" AND "), + Sql::ident(schema), + Sql::ident(&table.name), + Sql::join(assignments, ","), + Sql::join(filters, " AND "), ); let update_stmt = client - .prepare(&update_stmt) + .prepare(update_stmt.as_str()) .await .context("preparing update statement")?; @@ -406,33 +406,34 @@ async fn delete_files( // HACKY: We want to update the "_fivetran_synced" column for all of the rows we marked as // deleted, but don't have a way to read from the temp table that would allow this in an // `UPDATE` statement. - let synced_time_stmt = format!( - "SELECT MAX({synced_col}) FROM {qualified_temp_table_name}", - synced_col = escape::escape_identifier(FIVETRAN_SYSTEM_COLUMN_SYNCED) + let synced_time_stmt = sql!( + "SELECT MAX({}) FROM {}", + Sql::ident(FIVETRAN_SYSTEM_COLUMN_SYNCED), + qualified_temp_table_name.clone(), ); - let synced_time: SystemTime = client - .query_one(&synced_time_stmt, &[]) + let synced_time_row = query_one(client, synced_time_stmt, &[]) .await - .and_then(|row| row.try_get(0)) .context("get MAX _fivetran_synced")?; - - let qualified_table_name = - UnresolvedItemName::qualified(&[Ident::new(schema)?, Ident::new(&table.name)?]); - let matching_cols = columns.iter().filter(|col| col.is_primary); - let merge_stmt = format!( - r#" - UPDATE {qualified_table_name} - SET {deleted_col} = true, {synced_col} = $1 - WHERE ({cols}) IN ( - SELECT {cols} - FROM {qualified_temp_table_name} - )"#, - deleted_col = escape::escape_identifier(FIVETRAN_SYSTEM_COLUMN_DELETE), - synced_col = escape::escape_identifier(FIVETRAN_SYSTEM_COLUMN_SYNCED), - cols = matching_cols.map(|col| &col.escaped_name).join(","), + let synced_time: SystemTime = synced_time_row.try_get(0)?; + + let qualified_table_name = sql!("{}.{}", Sql::ident(schema), Sql::ident(&table.name)); + let matching_cols = Sql::join( + columns + .iter() + .filter(|col| col.is_primary) + .map(|col| col.ident.clone()), + ",", + ); + let merge_stmt = sql!( + "UPDATE {} SET {} = true, {} = $1 WHERE ({}) IN (SELECT {} FROM {})", + qualified_table_name.clone(), + Sql::ident(FIVETRAN_SYSTEM_COLUMN_DELETE), + Sql::ident(FIVETRAN_SYSTEM_COLUMN_SYNCED), + matching_cols.clone(), + matching_cols, + qualified_temp_table_name.clone(), ); - let total_count = client - .execute(&merge_stmt, &[&synced_time]) + let total_count = execute(client, merge_stmt, &[&synced_time]) .await .context("update deletes")?; tracing::info!(?total_count, "altered rows in {qualified_table_name}"); @@ -446,19 +447,20 @@ async fn delete_files( #[must_use = "Need to clear the scratch table once you're done using it."] struct ScratchTableGuard<'a> { client: &'a tokio_postgres::Client, - qualified_name: UnresolvedItemName, + qualified_name: Sql, } impl<'a> ScratchTableGuard<'a> { /// Deletes all the rows from the associated scratch table. async fn clear(self) -> Result<(), OpError> { - let clear_table_stmt = format!("DELETE FROM {}", self.qualified_name); - let rows_cleared = self - .client - .execute(&clear_table_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("scratch table guard")?; + let rows_cleared = execute( + self.client, + sql!("DELETE FROM {}", self.qualified_name.clone()), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("scratch table guard")?; tracing::info!(?rows_cleared, table_name = %self.qualified_name, "guard cleared table"); Ok(()) @@ -472,22 +474,20 @@ async fn get_scratch_table<'a>( schema: &str, table: &Table, client: &'a tokio_postgres::Client, -) -> Result< - ( - UnresolvedItemName, - Vec, - ScratchTableGuard<'a>, - ), - OpError, -> { +) -> Result<(Sql, Vec, ScratchTableGuard<'a>), OpError> { static SCRATCH_TABLE_SCHEMA: &str = "_mz_fivetran_scratch"; - let create_schema_stmt = format!("CREATE SCHEMA IF NOT EXISTS {SCRATCH_TABLE_SCHEMA}"); - client - .execute(&create_schema_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("creating scratch schema")?; + execute( + client, + sql!( + "CREATE SCHEMA IF NOT EXISTS {}", + Sql::ident(SCRATCH_TABLE_SCHEMA) + ), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("creating scratch schema")?; // To make sure the table name is unique, and under the Materialize identifier limits, we name // the scratch table with a hash. @@ -495,10 +495,11 @@ async fn get_scratch_table<'a>( hasher.update(&format!("{database}.{schema}.{}", table.name)); let scratch_table_name = format!("{:x}", hasher.finalize()); - let qualified_scratch_table_name = UnresolvedItemName::qualified(&[ - Ident::new(SCRATCH_TABLE_SCHEMA).context("scratch schema")?, - Ident::new(&scratch_table_name).context("scratch table_name")?, - ]); + let qualified_scratch_table_name = sql!( + "{}.{}", + Sql::ident(SCRATCH_TABLE_SCHEMA), + Sql::ident(&scratch_table_name) + ); let columns = table .columns @@ -507,28 +508,37 @@ async fn get_scratch_table<'a>( .collect::, OpError>>()?; let create_scratch_table = || async { - let defs = columns.iter().map(|col| col.to_column_def()).join(","); - let create_table_stmt = format!("CREATE TABLE {qualified_scratch_table_name} ({defs})"); - client - .execute(&create_table_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("creating scratch table")?; + let defs = Sql::join(columns.iter().map(|col| col.to_column_def()), ","); + execute( + client, + sql!( + "CREATE TABLE {} ({})", + qualified_scratch_table_name.clone(), + defs + ), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("creating scratch table")?; // Leave a COMMENT on the scratch table for debug-ability. let comment = format!( "Fivetran scratch table for {database}.{schema}.{}", table.name ); - let comment_stmt = format!( - "COMMENT ON TABLE {qualified_scratch_table_name} IS {comment}", - comment = escape::escape_literal(&comment) - ); - client - .execute(&comment_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("comment scratch table")?; + execute( + client, + sql!( + "COMMENT ON TABLE {} IS {}", + qualified_scratch_table_name.clone(), + Sql::literal(&comment) + ), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("comment scratch table")?; Ok::<_, OpError>(()) }; @@ -573,12 +583,14 @@ async fn get_scratch_table<'a>( "recreate scratch table", ); - let drop_table_stmt = format!("DROP TABLE {qualified_scratch_table_name}"); - client - .execute(&drop_table_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("dropping scratch table")?; + execute( + client, + sql!("DROP TABLE {}", qualified_scratch_table_name.clone()), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("dropping scratch table")?; create_scratch_table().await.context("recreate table")?; } else { @@ -589,25 +601,32 @@ async fn get_scratch_table<'a>( "clear and reuse scratch table", ); - let clear_table_stmt = format!("DELETE FROM {qualified_scratch_table_name}"); - let rows_cleared = client - .execute(&clear_table_stmt, &[]) - .await - .map_err(OpErrorKind::TemporaryResource) - .context("clearing scratch table")?; + let rows_cleared = execute( + client, + sql!("DELETE FROM {}", qualified_scratch_table_name.clone()), + &[], + ) + .await + .map_err(OpErrorKind::from) + .context("clearing scratch table")?; tracing::info!(?rows_cleared, %qualified_scratch_table_name, "cleared table"); } } } // Verify that our table is empty. - let count_stmt = format!("SELECT COUNT(*) FROM {qualified_scratch_table_name}"); - let rows: i64 = client - .query_one(&count_stmt, &[]) - .await - .map(|row| row.get(0)) - .map_err(OpErrorKind::TemporaryResource) - .context("validate scratch table")?; + let rows: i64 = query_one( + client, + sql!( + "SELECT COUNT(*) FROM {}", + qualified_scratch_table_name.clone() + ), + &[], + ) + .await + .map(|row| row.get(0)) + .map_err(OpErrorKind::from) + .context("validate scratch table")?; if rows != 0 { return Err(OpErrorKind::InvariantViolated(format!( "scratch table had non-zero number of rows: {rows}" @@ -632,7 +651,7 @@ async fn copy_files( files: &[String], client: &tokio_postgres::Client, table: &Table, - temporary_table: &UnresolvedItemName, + temporary_table: &Sql, ) -> Result { let mut total_row_count = 0; @@ -641,11 +660,12 @@ async fn copy_files( tracing::info!(?path, "starting copy"); // Create a Sink which we can stream the CSV files into. - let copy_in_stmt = format!( - "COPY {temporary_table} FROM STDIN WITH (FORMAT CSV, HEADER false, NULL {null_value})", - null_value = escape::escape_literal(&file_config.null_string), + let copy_in_stmt = sql!( + "COPY {} FROM STDIN WITH (FORMAT CSV, HEADER false, NULL {})", + temporary_table.clone(), + Sql::literal(&file_config.null_string), ); - let sink = client.copy_in(©_in_stmt).await?; + let sink = client.copy_in(copy_in_stmt.as_str()).await?; let mut sink = std::pin::pin!(sink); { diff --git a/src/fivetran-destination/src/error.rs b/src/fivetran-destination/src/error.rs index 62db3b1eed534..00ffe90419265 100644 --- a/src/fivetran-destination/src/error.rs +++ b/src/fivetran-destination/src/error.rs @@ -159,6 +159,17 @@ impl OpErrorKind { } } +impl From for OpErrorKind { + fn from(err: mz_postgres_util::PostgresError) -> Self { + match err { + mz_postgres_util::PostgresError::Postgres(err) => OpErrorKind::TemporaryResource(err), + mz_postgres_util::PostgresError::Io(err) => OpErrorKind::Filesystem(err), + mz_postgres_util::PostgresError::PostgresSsl(err) => OpErrorKind::Crypto(err), + other => OpErrorKind::InvariantViolated(other.to_string()), + } + } +} + pub trait Context { type TransformType; diff --git a/src/mz-debug/Cargo.toml b/src/mz-debug/Cargo.toml index 50e4001877339..c58cbc85931e5 100644 --- a/src/mz-debug/Cargo.toml +++ b/src/mz-debug/Cargo.toml @@ -20,6 +20,7 @@ kube = { version = "3.0.1", default-features = false, features = ["client", "run mz-build-info = { path = "../build-info" } mz-cloud-resources = { path = "../cloud-resources"} mz-ore = { path = "../ore", features = ["cli", "test"] } +mz-postgres-util = { path = "../postgres-util" } mz-server-core = { path = "../server-core"} mz-sql-parser = { path = "../sql-parser" } mz-tls-util = { path = "../tls-util" } diff --git a/src/mz-debug/src/system_catalog_dumper.rs b/src/mz-debug/src/system_catalog_dumper.rs index 4a1e163c957e1..096c990baed89 100644 --- a/src/mz-debug/src/system_catalog_dumper.rs +++ b/src/mz-debug/src/system_catalog_dumper.rs @@ -21,7 +21,7 @@ use anyhow::{Context as _, Result}; use csv_async::AsyncSerializer; use futures::TryStreamExt; -use mz_sql_parser::ast::display::escaped_string_literal; +use mz_postgres_util::{Sql, execute, query, sql}; use mz_tls_util::make_tls; use std::fmt; use std::path::PathBuf; @@ -41,6 +41,23 @@ use mz_ore::task::{self, JoinHandle}; use postgres_openssl::{MakeTlsConnector, TlsStream}; use tracing::{info, warn}; +async fn execute_sql( + transaction: &Transaction<'_>, + query: Sql, +) -> Result { + execute(transaction, query, &[]).await +} + +async fn simple_query_sql( + transaction: &Transaction<'_>, + query: Sql, +) -> Result, tokio_postgres::Error> { + // `simple_query` is only available on concrete driver types and is required + // here to avoid statement preparation for cursor `FETCH` in subscribe flows. + #[allow(clippy::disallowed_methods)] + transaction.simple_query(query.as_str()).await +} + #[derive(Debug, Clone)] pub enum RelationCategory { /// For relations that belong in the `mz_introspection` schema. @@ -551,9 +568,6 @@ static PG_QUERY_TIMEOUT: Duration = Duration::from_secs(20); /// If a cluster replica has more than this many errors, we skip it. static MAX_CLUSTER_REPLICA_ERROR_COUNT: usize = 3; -static SET_SEARCH_PATH_QUERY: &str = "SET search_path = mz_internal, mz_catalog, mz_introspection"; -static SELECT_CLUSTER_REPLICAS_QUERY: &str = "SELECT c.name as cluster_name, cr.name as replica_name FROM mz_clusters AS c JOIN mz_cluster_replicas AS cr ON c.id = cr.cluster_id;"; - #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClusterReplica { pub cluster_name: String, @@ -631,12 +645,12 @@ pub async fn create_postgres_connection( pub async fn write_copy_stream( transaction: &Transaction<'_>, - copy_query: &str, + copy_query: &Sql, file: &mut tokio::fs::File, relation_name: &str, ) -> Result<(), anyhow::Error> { let copy_stream = transaction - .copy_out(copy_query) + .copy_out(copy_query.as_str()) .await .context(format!("Failed to COPY TO for {}", relation_name))? .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)); @@ -662,23 +676,25 @@ pub async fn copy_relation_to_csv( let mut writer = AsyncSerializer::from_writer(file); writer.serialize(column_names).await?; - transaction - .execute( - &format!("DECLARE c CURSOR FOR SUBSCRIBE TO {}", relation.name), - &[], - ) + let declare_query = sql!( + "DECLARE c CURSOR FOR SUBSCRIBE TO {}", + Sql::ident(relation.name) + ); + execute_sql(transaction, declare_query) .await .context("Failed to declare cursor")?; // We need to use simple_query, otherwise tokio-postgres will run an introspection SELECT query to figure out the types since it'll // try to prepare the query. This causes an error since SUBSCRIBEs and SELECT queries are not allowed to be executed in the same transaction. // Thus we use simple_query to avoid the introspection query. - let rows = transaction + let rows = simple_query_sql( + transaction, // We use a timeout of '1' to receive the snapshot of the current state. A timeout of '0' will return no results. // We also don't care if we get more than just the snapshot. - .simple_query("FETCH ALL FROM c WITH (TIMEOUT '1')") - .await - .context("Failed to fetch all from cursor")?; + sql!("FETCH ALL FROM c WITH (TIMEOUT '1')"), + ) + .await + .context("Failed to fetch all from cursor")?; for row in rows { if let SimpleQueryMessage::Row(row) = row { @@ -693,9 +709,9 @@ pub async fn copy_relation_to_csv( // TODO (SangJunBak): Use `WITH (HEADER TRUE)` once database-issues#2846 is implemented. file.write_all((column_names.join(",") + "\n").as_bytes()) .await?; - let copy_query = format!( + let copy_query = sql!( "COPY (SELECT * FROM {}) TO STDOUT WITH (FORMAT CSV)", - relation.name + Sql::ident(relation.name) ); write_copy_stream(transaction, ©_query, &mut file, relation.name).await?; } @@ -711,8 +727,8 @@ pub async fn query_column_names( ) -> Result, anyhow::Error> { let relation_name = relation.name; // We query the column names to write the header row of the CSV file. - let mut column_names = pg_client - .query(&format!("SHOW COLUMNS FROM {}", &relation_name), &[]) + let show_columns_query = sql!("SHOW COLUMNS FROM {}", Sql::ident(relation_name)); + let mut column_names = query(pg_client, show_columns_query, &[]) .await .context(format!("Failed to get column names for {}", relation_name))? .into_iter() @@ -745,27 +761,21 @@ pub async fn query_relation( // Some queries (i.e. mz_introspection relations) require the cluster and replica to be set. if let Some(cluster_replica) = &cluster_replica { - transaction - .execute( - &format!( - "SET LOCAL CLUSTER = {}", - escaped_string_literal(&cluster_replica.cluster_name) - ), - &[], - ) + let cluster_query = sql!( + "SET LOCAL CLUSTER = {}", + Sql::literal(&cluster_replica.cluster_name) + ); + execute_sql(transaction, cluster_query) .await .context(format!( "Failed to set cluster to {}", cluster_replica.cluster_name ))?; - transaction - .execute( - &format!( - "SET LOCAL CLUSTER_REPLICA = {}", - escaped_string_literal(&cluster_replica.replica_name) - ), - &[], - ) + let replica_query = sql!( + "SET LOCAL CLUSTER_REPLICA = {}", + Sql::literal(&cluster_replica.replica_name) + ); + execute_sql(transaction, replica_query) .await .context(format!( "Failed to set cluster replica to {}", @@ -810,13 +820,26 @@ impl SystemCatalogDumper { let handle = task::spawn(|| "postgres-connection", pg_conn); // Set search path to system catalog tables - pg_client - .execute(SET_SEARCH_PATH_QUERY, &[]) - .await - .context("Failed to set search path")?; + execute( + &pg_client, + sql!("SET search_path = mz_internal, mz_catalog, mz_introspection"), + &[], + ) + .await + .context("Failed to set search path")?; // We need to get all cluster replicas to dump introspection relations. - let cluster_replicas = match pg_client.query(SELECT_CLUSTER_REPLICAS_QUERY, &[]).await { + let cluster_replicas = match query( + &pg_client, + sql!( + "SELECT c.name as cluster_name, cr.name as replica_name \ + FROM mz_clusters AS c \ + JOIN mz_cluster_replicas AS cr ON c.id = cr.cluster_id;" + ), + &[], + ) + .await + { Ok(rows) => rows .into_iter() .map(|row| { @@ -871,21 +894,22 @@ impl SystemCatalogDumper { // For custom queries, create a temporary view so the retry loop // can treat them identically to basic relations. - if let RelationCategory::Custom { sql } = &relation.category { + if let RelationCategory::Custom { sql: custom_sql } = &relation.category { let pg_client_lock = pg_client.lock().await; - pg_client_lock - .execute( - &format!( - "CREATE OR REPLACE TEMPORARY VIEW {} AS {}", - relation.name, sql - ), - &[], - ) - .await - .context(format!( - "Failed to create temporary view for {}", - relation.name - ))?; + execute( + &*pg_client_lock, + sql!( + "CREATE OR REPLACE TEMPORARY VIEW {} AS {}", + Sql::ident(relation.name), + Sql::new(*custom_sql) + ), + &[], + ) + .await + .context(format!( + "Failed to create temporary view for {}", + relation.name + ))?; } if let Err(err) = retry::Retry::default() diff --git a/src/mz/Cargo.toml b/src/mz/Cargo.toml index a5352ae008b18..1ffeb6cd17077 100644 --- a/src/mz/Cargo.toml +++ b/src/mz/Cargo.toml @@ -22,7 +22,7 @@ mz-frontegg-client = { path = "../frontegg-client" } mz-frontegg-auth = { path = "../frontegg-auth" } mz-build-info = { path = "../build-info" } mz-ore = { path = "../ore", features = ["async", "cli", "test"] } -mz-sql-parser = { path = "../sql-parser" } +mz-postgres-util = { path = "../postgres-util" } open = "5.3.3" openssl-probe = "0.1.6" hyper = "1.4.1" diff --git a/src/mz/src/command/secret.rs b/src/mz/src/command/secret.rs index e91469bc78d82..e055aecc09be1 100644 --- a/src/mz/src/command/secret.rs +++ b/src/mz/src/command/secret.rs @@ -19,10 +19,7 @@ use std::io::{self, Write}; -use mz_sql_parser::ast::{ - Ident, - display::{AstDisplay, escaped_string_literal}, -}; +use mz_postgres_util::{Sql, sql}; use crate::{context::RegionContext, error::Error}; @@ -72,19 +69,19 @@ pub async fn create( let mut client = cx.sql_client().shell(®ion_info, user, None); // Build the queries to create the secret. - let mut commands: Vec = vec![]; + let mut commands: Vec = vec![]; + let name = Sql::ident(name); if let Some(database) = database { client.args(vec!["-d", database]); } if let Some(schema) = schema { - let schema = Ident::new_unchecked(schema).to_ast_string_simple(); - commands.push(format!("SET search_path TO {schema};")); + commands.push(sql!("SET search_path TO {}", Sql::ident(schema))); } - let buffer = escaped_string_literal(&buffer).to_string(); - let name = Ident::new_unchecked(name).to_ast_string_simple(); + // Treat stdin as a literal value to avoid command injection through `psql -c`. + let value = Sql::literal(&buffer); if force { // Rather than checking if the SECRET exists, do an upsert. @@ -93,18 +90,19 @@ pub async fn create( // The alternative is passing two `-c` commands to psql. // Otherwise if the SECRET exists `psql` will display a NOTICE message. - commands.push("SET client_min_messages TO WARNING;".to_string()); - commands.push(format!( + commands.push(sql!("SET client_min_messages TO WARNING;")); + commands.push(sql!( "CREATE SECRET IF NOT EXISTS {} AS {};", - name, buffer + name.clone(), + value.clone() )); - commands.push(format!("ALTER SECRET {} AS {};", name, buffer)); + commands.push(sql!("ALTER SECRET {} AS {};", name.clone(), value.clone())); } else { - commands.push(format!("CREATE SECRET {} AS {};", name, buffer)); + commands.push(sql!("CREATE SECRET {} AS {};", name, value)); } commands.iter().for_each(|c| { - client.args(vec!["-c", c]); + client.args(vec!["-c", c.as_str()]); }); let output = client diff --git a/src/persist/src/postgres.rs b/src/persist/src/postgres.rs index 909d1d23f7348..e448aa34b1ccd 100644 --- a/src/persist/src/postgres.rs +++ b/src/persist/src/postgres.rs @@ -30,6 +30,7 @@ use mz_postgres_client::metrics::PostgresClientMetrics; use mz_postgres_client::{PostgresClient, PostgresClientConfig, PostgresClientKnobs}; use postgres_protocol::escape::escape_identifier; use tokio_postgres::error::SqlState; +use tokio_postgres::{Row, Statement}; use tracing::{info, warn}; use crate::error::Error; @@ -68,6 +69,40 @@ const CRDB_SCHEMA_OPTIONS: &str = "WITH (sql_stats_automatic_collection_enabled // See: https://www.cockroachlabs.com/docs/stable/configure-zone.html#variables const CRDB_CONFIGURE_ZONE: &str = "ALTER TABLE consensus CONFIGURE ZONE USING gc.ttlseconds = 600"; +/// NOTE: `mz-persist` intentionally does not depend on `mz-postgres-util`. +/// These helpers are the only direct driver-call boundary in this module. +async fn pg_batch_execute(client: &Object, query: &str) -> Result<(), tokio_postgres::Error> { + #[allow(clippy::disallowed_methods)] + client.batch_execute(query).await +} + +async fn pg_query_prepared( + client: &Object, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result, tokio_postgres::Error> { + #[allow(clippy::disallowed_methods)] + client.query(statement, params).await +} + +async fn pg_query_opt_prepared( + client: &Object, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result, tokio_postgres::Error> { + #[allow(clippy::disallowed_methods)] + client.query_opt(statement, params).await +} + +async fn pg_execute_prepared( + client: &Object, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + #[allow(clippy::disallowed_methods)] + client.execute(statement, params).await +} + impl ToSql for SeqNo { fn to_sql( &self, @@ -249,12 +284,14 @@ impl PostgresConsensus { let client = postgres_client.get_connection().await?; - let mode = match client - .batch_execute(&format!( + let mode = match pg_batch_execute( + &client, + &format!( "{}; {}{}; {};", create_schema, SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE, - )) - .await + ), + ) + .await { Ok(()) => PostgresMode::CockroachDB, Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => { @@ -279,9 +316,7 @@ impl PostgresConsensus { }; if mode != PostgresMode::CockroachDB { - client - .batch_execute(&format!("{}; {};", create_schema, SCHEMA)) - .await?; + pg_batch_execute(&client, &format!("{}; {};", create_schema, SCHEMA)).await?; } Ok(PostgresConsensus { @@ -297,13 +332,12 @@ impl PostgresConsensus { pub async fn drop_and_recreate(&self) -> Result<(), ExternalError> { // this could be a TRUNCATE if we're confident the db won't reuse any state let client = self.get_connection().await?; - client.execute("DROP TABLE consensus", &[]).await?; - let crdb_mode = match client - .batch_execute(&format!( - "{}{}; {}", - SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE, - )) - .await + pg_batch_execute(&client, "DROP TABLE consensus").await?; + let crdb_mode = match pg_batch_execute( + &client, + &format!("{}{}; {}", SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE,), + ) + .await { Ok(()) => true, Err(e) if e.code() == Some(&SqlState::INSUFFICIENT_PRIVILEGE) => { @@ -326,7 +360,7 @@ impl PostgresConsensus { }; if !crdb_mode { - client.execute(SCHEMA, &[]).await?; + pg_batch_execute(&client, SCHEMA).await?; } Ok(()) } @@ -361,7 +395,7 @@ impl Consensus for PostgresConsensus { let row = { let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - client.query_opt(&statement, &[&key]).await? + pg_query_opt_prepared(&client, &statement, &[&key]).await? }; let row = match row { None => return Ok(None), @@ -433,12 +467,12 @@ impl Consensus for PostgresConsensus { }; let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - client - .execute( - &statement, - &[&key, &new.seqno, &new.data.as_ref(), &expected], - ) - .await? + pg_execute_prepared( + &client, + &statement, + &[&key, &new.seqno, &new.data.as_ref(), &expected], + ) + .await? } else { // Insert the new row as long as no other row exists for the same shard. let q = "INSERT INTO consensus SELECT $1, $2, $3 WHERE @@ -448,8 +482,7 @@ impl Consensus for PostgresConsensus { ON CONFLICT DO NOTHING"; let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - client - .execute(&statement, &[&key, &new.seqno, &new.data.as_ref()]) + pg_execute_prepared(&client, &statement, &[&key, &new.seqno, &new.data.as_ref()]) .await? }; @@ -478,7 +511,7 @@ impl Consensus for PostgresConsensus { let rows = { let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - client.query(&statement, &[&key, &from, &limit]).await? + pg_query_prepared(&client, &statement, &[&key, &from, &limit]).await? }; let mut results = Vec::with_capacity(rows.len()); @@ -545,7 +578,7 @@ impl Consensus for PostgresConsensus { let result = { let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - client.execute(&statement, &[&key, &seqno]).await? + pg_execute_prepared(&client, &statement, &[&key, &seqno]).await? }; if result == 0 { // We weren't able to successfully truncate any rows inspect head to diff --git a/src/postgres-client/src/lib.rs b/src/postgres-client/src/lib.rs index 1d6933cd76c58..a801e083d6ba2 100644 --- a/src/postgres-client/src/lib.rs +++ b/src/postgres-client/src/lib.rs @@ -139,9 +139,15 @@ impl PostgresClient { connections_created.inc(); Box::pin(async move { debug!("opened new consensus postgres connection"); - client.batch_execute( + // This hook must return `tokio_postgres::Error`; using + // `mz_postgres_util` wrappers would change the error type. + #[allow(clippy::disallowed_methods)] + client + .batch_execute( "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL SERIALIZABLE", - ).await.map_err(|e| HookError::Abort(HookErrorCause::Backend(e))) + ) + .await + .map_err(|e| HookError::Abort(HookErrorCause::Backend(e))) }) })) .pre_recycle(Hook::sync_fn(move |_client, conn_metrics| { diff --git a/src/postgres-util/src/lib.rs b/src/postgres-util/src/lib.rs index dcb5847f1811d..66317a7831231 100644 --- a/src/postgres-util/src/lib.rs +++ b/src/postgres-util/src/lib.rs @@ -28,7 +28,11 @@ pub mod tunnel; pub use tunnel::{Client, Config, DEFAULT_SNAPSHOT_STATEMENT_TIMEOUT, TunnelConfig}; pub mod query; -pub use query::simple_query_opt; +pub use query::{ + Sql, SqlFormatError, batch_execute, execute, execute_prepared, query, query_one, + query_one_prepared, query_opt, query_opt_prepared, query_prepared, simple_query, + simple_query_opt, +}; /// An error representing pg, ssh, ssl, and other failures. #[derive(Debug, thiserror::Error)] diff --git a/src/postgres-util/src/query.rs b/src/postgres-util/src/query.rs index c215f09cf8dc4..838ecb4376051 100644 --- a/src/postgres-util/src/query.rs +++ b/src/postgres-util/src/query.rs @@ -7,16 +7,286 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use tokio_postgres::{Client, SimpleQueryMessage, SimpleQueryRow}; +#![allow(clippy::disallowed_methods)] + +use std::borrow::Cow; +use std::fmt::{self, Write}; +use std::ops::{Add, AddAssign}; + +use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal}; +use tokio_postgres::types::ToSql; +use tokio_postgres::{Client, GenericClient, Row, SimpleQueryMessage, SimpleQueryRow, Statement}; use crate::PostgresError; +/// A composable SQL query string. +/// +/// Use [`crate::sql!`] for static SQL fragments and [`Sql::ident`]/[`Sql::literal`] for +/// dynamic values. This mirrors psycopg's split between trusted SQL text and escaped +/// values. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Sql(Cow<'static, str>); + +#[doc(hidden)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum SqlTemplateError { + InvalidOpenBrace, + InvalidCloseBrace, +} + +#[doc(hidden)] +pub const fn sql_template_placeholder_count(template: &str) -> Result { + let bytes = template.as_bytes(); + let mut i = 0; + let mut count = 0; + + while i < bytes.len() { + match bytes[i] { + b'{' => { + if i + 1 >= bytes.len() { + return Err(SqlTemplateError::InvalidOpenBrace); + } + match bytes[i + 1] { + b'{' => i += 2, + b'}' => { + count += 1; + i += 2; + } + _ => return Err(SqlTemplateError::InvalidOpenBrace), + } + } + b'}' => { + if i + 1 >= bytes.len() { + return Err(SqlTemplateError::InvalidCloseBrace); + } + match bytes[i + 1] { + b'}' => i += 2, + _ => return Err(SqlTemplateError::InvalidCloseBrace), + } + } + _ => i += 1, + } + } + + Ok(count) +} + +#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)] +pub enum SqlFormatError { + #[error("SQL format string contains an invalid '{{' sequence")] + InvalidOpenBrace, + #[error("SQL format string contains an invalid '}}' sequence")] + InvalidCloseBrace, + #[error("SQL format string expected more arguments")] + MissingArgument, + #[error("SQL format string received too many arguments")] + ExtraArgument, +} + +impl Sql { + /// Creates a SQL fragment from a static SQL string. + pub fn new(sql: &'static str) -> Self { + Self(Cow::Borrowed(sql)) + } + + /// Creates a SQL fragment by escaping a SQL identifier. + pub fn ident(ident: &str) -> Self { + // PostgreSQL identifiers are escaped by surrounding with double quotes + // and doubling any embedded double quotes. + let mut out = String::with_capacity(ident.len() + 2); + out.push('"'); + for ch in ident.chars() { + if ch == '"' { + out.push('"'); + } + out.push(ch); + } + out.push('"'); + Self(Cow::Owned(out)) + } + + /// Creates a SQL fragment by escaping a SQL literal. + pub fn literal(literal: &str) -> Self { + Self(Cow::Owned( + escaped_string_literal(literal).to_ast_string_simple(), + )) + } + + /// Creates a SQL fragment for a PostgreSQL positional parameter (e.g. `$1`). + pub fn param(index: usize) -> Self { + let mut out = String::new(); + out.push('$'); + let _ = write!(out, "{index}"); + Self(Cow::Owned(out)) + } + + /// Joins SQL fragments with a static separator. + pub fn join(parts: impl IntoIterator, separator: &'static str) -> Self { + let mut iter = parts.into_iter(); + let Some(first) = iter.next() else { + return Self(Cow::Borrowed("")); + }; + + let mut out = first.0; + for part in iter { + out.to_mut().push_str(separator); + out.to_mut().push_str(part.as_str()); + } + Self(out) + } + + /// Formats this SQL fragment by replacing each `{}` with the next SQL argument. + /// + /// Use `{{` and `}}` to escape literal braces. + pub fn format(self, args: impl IntoIterator) -> Result { + let mut args = args.into_iter(); + let mut out = String::with_capacity(self.0.len()); + let mut chars = self.0.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '{' => match chars.peek() { + Some('{') => { + chars.next(); + out.push('{'); + } + Some('}') => { + chars.next(); + let arg = args.next().ok_or(SqlFormatError::MissingArgument)?; + out.push_str(arg.as_str()); + } + _ => return Err(SqlFormatError::InvalidOpenBrace), + }, + '}' => match chars.peek() { + Some('}') => { + chars.next(); + out.push('}'); + } + _ => return Err(SqlFormatError::InvalidCloseBrace), + }, + _ => out.push(ch), + } + } + + if args.next().is_some() { + return Err(SqlFormatError::ExtraArgument); + } + Ok(Sql(Cow::Owned(out))) + } + + #[doc(hidden)] + pub fn format_unchecked(self, args: impl IntoIterator) -> Self { + let mut args = args.into_iter(); + let mut out = String::with_capacity(self.0.len()); + let mut chars = self.0.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '{' => match chars.next().expect("validated in sql! macro") { + '{' => out.push('{'), + '}' => { + let arg = args.next().expect("validated in sql! macro"); + out.push_str(arg.as_str()); + } + _ => unreachable!("validated in sql! macro"), + }, + '}' => match chars.next().expect("validated in sql! macro") { + '}' => out.push('}'), + _ => unreachable!("validated in sql! macro"), + }, + _ => out.push(ch), + } + } + debug_assert!(args.next().is_none(), "validated in sql! macro"); + Sql(Cow::Owned(out)) + } + + /// Returns the underlying SQL string. + pub fn as_str(&self) -> &str { + self.0.as_ref() + } + + /// Consumes this value and returns the SQL string. + pub fn into_string(self) -> String { + self.0.into_owned() + } +} + +impl fmt::Display for Sql { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl Add for Sql { + type Output = Sql; + + fn add(mut self, rhs: Sql) -> Self::Output { + self += rhs; + self + } +} + +impl AddAssign for Sql { + fn add_assign(&mut self, rhs: Sql) { + self.0.to_mut().push_str(rhs.as_str()); + } +} + +macro_rules! impl_from_integer_for_sql { + ($($t:ty),+ $(,)?) => { + $( + impl From<$t> for Sql { + fn from(value: $t) -> Self { + Sql(Cow::Owned(value.to_string())) + } + } + )+ + }; +} + +impl_from_integer_for_sql!(i16, i32, i64, isize, u16, u32, u64, usize); + +impl From for Sql { + fn from(value: tokio_postgres::types::PgLsn) -> Self { + Sql(Cow::Owned(value.to_string())) + } +} + +#[macro_export] +macro_rules! sql { + ($template:literal $(,)?) => { + $crate::Sql::new($template) + }; + ($template:literal, $($arg:expr),+ $(,)?) => {{ + const __SQL_FORMAT_ARG_COUNT: usize = <[()]>::len(&[$($crate::sql!(@unit $arg)),*]); + const __SQL_FORMAT_PLACEHOLDER_COUNT: usize = + match $crate::query::sql_template_placeholder_count($template) { + Ok(n) => n, + Err($crate::query::SqlTemplateError::InvalidOpenBrace) => { + panic!("sql!: invalid '{{' in SQL template") + } + Err($crate::query::SqlTemplateError::InvalidCloseBrace) => { + panic!("sql!: invalid '}}' in SQL template") + } + }; + const _: () = { + if __SQL_FORMAT_ARG_COUNT != __SQL_FORMAT_PLACEHOLDER_COUNT { + panic!("sql!: placeholder count does not match arguments"); + } + }; + + $crate::Sql::new($template).format_unchecked([$($crate::Sql::from($arg)),*]) + }}; + (@unit $_arg:expr) => { () }; +} + /// Runs the given query using the client and expects at most a single row to be returned. pub async fn simple_query_opt( client: &Client, - query: &str, + query: Sql, ) -> Result, PostgresError> { - let result = client.simple_query(query).await?; + let result = simple_query(client, query).await?; let mut rows = result.into_iter().filter_map(|msg| match msg { SimpleQueryMessage::Row(row) => Some(row), _ => None, @@ -27,3 +297,146 @@ pub async fn simple_query_opt( _ => Err(PostgresError::UnexpectedRow), } } + +/// Runs a simple query and returns all protocol messages. +pub async fn simple_query( + client: &Client, + query: Sql, +) -> Result, PostgresError> { + Ok(client.simple_query(query.as_str()).await?) +} + +/// Runs a query and returns all resulting rows. +pub async fn query( + client: &C, + query: Sql, + params: &[&(dyn ToSql + Sync)], +) -> Result, PostgresError> { + Ok(client.query(query.as_str(), params).await?) +} + +/// Runs a prepared query and returns all resulting rows. +pub async fn query_prepared( + client: &C, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result, PostgresError> { + Ok(client.query(statement, params).await?) +} + +/// Runs a query and returns exactly one row. +pub async fn query_one( + client: &C, + query: Sql, + params: &[&(dyn ToSql + Sync)], +) -> Result { + Ok(client.query_one(query.as_str(), params).await?) +} + +/// Runs a prepared query and returns exactly one row. +pub async fn query_one_prepared( + client: &C, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + Ok(client.query_one(statement, params).await?) +} + +/// Runs a query and returns at most one row. +pub async fn query_opt( + client: &C, + query: Sql, + params: &[&(dyn ToSql + Sync)], +) -> Result, PostgresError> { + Ok(client.query_opt(query.as_str(), params).await?) +} + +/// Runs a prepared query and returns at most one row. +pub async fn query_opt_prepared( + client: &C, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result, PostgresError> { + Ok(client.query_opt(statement, params).await?) +} + +/// Runs a query and returns the number of affected rows. +pub async fn execute( + client: &C, + query: Sql, + params: &[&(dyn ToSql + Sync)], +) -> Result { + Ok(client.execute(query.as_str(), params).await?) +} + +/// Runs a prepared query and returns the number of affected rows. +pub async fn execute_prepared( + client: &C, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + Ok(client.execute(statement, params).await?) +} + +/// Runs one or more SQL statements with no returned rows. +pub async fn batch_execute( + client: &C, + query: Sql, +) -> Result<(), PostgresError> { + Ok(client.batch_execute(query.as_str()).await?) +} + +#[cfg(test)] +mod tests { + use super::{Sql, SqlFormatError}; + + #[mz_ore::test] + fn sql_identifier_escaping() { + assert_eq!(Sql::ident("a").as_str(), "\"a\""); + assert_eq!(Sql::ident("a\"b").as_str(), "\"a\"\"b\""); + } + + #[mz_ore::test] + fn sql_literal_escaping() { + assert_eq!(Sql::literal("a").as_str(), "'a'"); + assert_eq!(Sql::literal("a'b").as_str(), "'a''b'"); + } + + #[mz_ore::test] + fn sql_format_composes_fragments() { + let query = Sql::new("SELECT * FROM {} WHERE col = {}") + .format([Sql::ident("my_table"), Sql::literal("v")]) + .expect("valid template"); + assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'"); + } + + #[mz_ore::test] + fn sql_format_errors_on_invalid_placeholders() { + let err = Sql::new("SELECT {x}") + .format([Sql::ident("t")]) + .expect_err("invalid format"); + assert_eq!(err, SqlFormatError::InvalidOpenBrace); + } + + #[mz_ore::test] + fn sql_macro_composes_fragments() { + let query = crate::sql!( + "SELECT * FROM {} WHERE col = {}", + Sql::ident("my_table"), + Sql::literal("v") + ); + assert_eq!(query.as_str(), "SELECT * FROM \"my_table\" WHERE col = 'v'"); + } + + #[mz_ore::test] + fn sql_macro_escaped_braces() { + let query = crate::sql!("SELECT '{{}}' AS braces, {} AS t", Sql::ident("col")); + assert_eq!(query.as_str(), "SELECT '{}' AS braces, \"col\" AS t"); + } + + #[mz_ore::test] + fn sql_macro_static_literal() { + let query = crate::sql!("SELECT '{not_a_placeholder}'"); + assert_eq!(query.as_str(), "SELECT '{not_a_placeholder}'"); + } +} diff --git a/src/postgres-util/src/replication.rs b/src/postgres-util/src/replication.rs index 9f40761063aa7..442134e95a870 100644 --- a/src/postgres-util/src/replication.rs +++ b/src/postgres-util/src/replication.rs @@ -9,7 +9,6 @@ use std::str::FromStr; -use mz_sql_parser::ast::{Ident, display::AstDisplay}; use tokio_postgres::{ Client, types::{Oid, PgLsn}, @@ -17,7 +16,7 @@ use tokio_postgres::{ use mz_ssh_util::tunnel_manager::SshTunnelManager; -use crate::{Config, PostgresError, simple_query_opt}; +use crate::{Config, PostgresError, Sql, query, query_one, simple_query, simple_query_opt}; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum WalLevel { @@ -59,31 +58,33 @@ fn test_wal_level_max() { } pub async fn get_wal_level(client: &Client) -> Result { - let wal_level = client.query_one("SHOW wal_level", &[]).await?; + let wal_level = query_one(client, crate::sql!("SHOW wal_level"), &[]).await?; let wal_level: String = wal_level.get("wal_level"); Ok(WalLevel::from_str(&wal_level)?) } pub async fn get_max_wal_senders(client: &Client) -> Result { - let max_wal_senders = client - .query_one( - "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders", - &[], - ) - .await?; + let max_wal_senders = query_one( + client, + crate::sql!("SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders"), + &[], + ) + .await?; Ok(max_wal_senders.get("max_wal_senders")) } pub async fn available_replication_slots(client: &Client) -> Result { - let available_replication_slots = client - .query_one( + let available_replication_slots = query_one( + client, + crate::sql!( "SELECT CAST(current_setting('max_replication_slots') AS int8) - (SELECT count(*) FROM pg_catalog.pg_replication_slots) - AS available_replication_slots;", - &[], - ) - .await?; + AS available_replication_slots;" + ), + &[], + ) + .await?; let available_replication_slots: i64 = available_replication_slots.get("available_replication_slots"); @@ -95,12 +96,12 @@ pub async fn available_replication_slots(client: &Client) -> Result pub async fn bypass_rls_attribute(client: &Client) -> Result { - let rls_attribute = client - .query_one( - "SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;", - &[], - ) - .await?; + let rls_attribute = query_one( + client, + crate::sql!("SELECT rolbypassrls FROM pg_roles WHERE rolname = CURRENT_USER;"), + &[], + ) + .await?; Ok(rls_attribute.get("rolbypassrls")) } @@ -118,8 +119,9 @@ pub async fn validate_no_rls_policies( if table_oids.is_empty() { return Ok(()); } - let tables_with_rls_for_user = client - .query( + let tables_with_rls_for_user = query( + client, + crate::sql!( "SELECT format('%I.%I', pc.relnamespace::regnamespace, pc.relname) AS qualified_name FROM pg_policy pp @@ -127,11 +129,11 @@ pub async fn validate_no_rls_policies( WHERE polrelid = ANY($1::oid[]) AND - (0 = ANY(polroles) OR CURRENT_USER::regrole::oid = ANY(polroles));", - &[&table_oids], - ) - .await - .map_err(PostgresError::from)?; + (0 = ANY(polroles) OR CURRENT_USER::regrole::oid = ANY(polroles));" + ), + &[&table_oids], + ) + .await?; let mut tables_with_rls_for_user = tables_with_rls_for_user .into_iter() @@ -158,12 +160,12 @@ pub async fn drop_replication_slots( .await?; let replication_client = config.connect_replication(ssh_tunnel_manager).await?; for (slot, should_wait) in slots { - let rows = client - .query( - "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT", - &[&slot], - ) - .await?; + let rows = query( + &*client, + crate::sql!("SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT"), + &[&slot], + ) + .await?; match &*rows { [] => { // DROP_REPLICATION_SLOT will error if the slot does not exist @@ -181,16 +183,17 @@ pub async fn drop_replication_slots( // active backend and drop the slot. let active_pid: Option = row.get("active_pid"); if let Some(active_pid) = active_pid { - client - .simple_query(&format!("SELECT pg_terminate_backend({active_pid})")) - .await?; + let query = crate::sql!("SELECT pg_terminate_backend({})", active_pid); + simple_query(&*client, query).await?; } - let wait_str = if *should_wait { " WAIT" } else { "" }; - let slot = Ident::new_unchecked(*slot).to_ast_string_simple(); - replication_client - .simple_query(&format!("DROP_REPLICATION_SLOT {slot}{wait_str}")) - .await?; + let wait = if *should_wait { + crate::sql!(" WAIT") + } else { + crate::sql!("") + }; + let query = crate::sql!("DROP_REPLICATION_SLOT {}{}", Sql::ident(slot), wait); + simple_query(&*replication_client, query).await?; } _ => { return Err(PostgresError::Generic(anyhow::anyhow!( @@ -204,8 +207,11 @@ pub async fn drop_replication_slots( } pub async fn get_timeline_id(client: &Client) -> Result { - if let Some(r) = - simple_query_opt(client, "SELECT timeline_id FROM pg_control_checkpoint()").await? + if let Some(r) = simple_query_opt( + client, + crate::sql!("SELECT timeline_id FROM pg_control_checkpoint()"), + ) + .await? { r.get("timeline_id") .expect("Returns a row with a timeline ID") @@ -224,7 +230,7 @@ pub async fn get_timeline_id(client: &Client) -> Result { } pub async fn get_current_wal_lsn(client: &Client) -> Result { - let row = client.query_one("SELECT pg_current_wal_lsn()", &[]).await?; + let row = query_one(client, crate::sql!("SELECT pg_current_wal_lsn()"), &[]).await?; let lsn: PgLsn = row.get(0); Ok(lsn) diff --git a/src/postgres-util/src/schemas.rs b/src/postgres-util/src/schemas.rs index a36397c198944..65eeb39176831 100644 --- a/src/postgres-util/src/schemas.rs +++ b/src/postgres-util/src/schemas.rs @@ -15,20 +15,23 @@ use tokio_postgres::Client; use tokio_postgres::types::Oid; use crate::desc::{PostgresColumnDesc, PostgresKeyDesc, PostgresSchemaDesc, PostgresTableDesc}; -use crate::{PostgresError, simple_query_opt}; +use crate::{PostgresError, query, simple_query_opt}; pub async fn get_schemas(client: &Client) -> Result, PostgresError> { - Ok(client - .query("SELECT oid, nspname, nspowner FROM pg_namespace", &[]) - .await? - .into_iter() - .map(|row| { - let oid: Oid = row.get("oid"); - let name: String = row.get("nspname"); - let owner: Oid = row.get("nspowner"); - PostgresSchemaDesc { oid, name, owner } - }) - .collect::>()) + Ok(query( + client, + crate::sql!("SELECT oid, nspname, nspowner FROM pg_namespace"), + &[], + ) + .await? + .into_iter() + .map(|row| { + let oid: Oid = row.get("oid"); + let name: String = row.get("nspname"); + let owner: Oid = row.get("nspowner"); + PostgresSchemaDesc { oid, name, owner } + }) + .collect::>()) } /// Get the major version of the PostgreSQL server. @@ -36,8 +39,11 @@ pub async fn get_pg_major_version(client: &Client) -> Result // server_version_num is an integer like 140005 for version 14.5 // NOTE: We use the statement SELECT instead of SHOW because older Aurora // versions don't support SHOW via a replication channel. - let query = "SELECT current_setting('server_version_num')"; - let row = simple_query_opt(client, query).await?; + let row = simple_query_opt( + client, + crate::sql!("SELECT current_setting('server_version_num')"), + ) + .await?; let version_num: u32 = row .and_then(|r| r.get("current_setting").map(|s| s.parse().ok())) .flatten() @@ -68,19 +74,19 @@ pub async fn publication_info( ) -> Result, PostgresError> { let server_major_version = get_pg_major_version(client).await?; - client - .query( - "SELECT oid FROM pg_publication WHERE pubname = $1", - &[&publication], - ) - .await - .map_err(PostgresError::from)? - .get(0) - .ok_or_else(|| PostgresError::PublicationMissing(publication.to_string()))?; + query( + client, + crate::sql!("SELECT oid FROM pg_publication WHERE pubname = $1"), + &[&publication], + ) + .await? + .get(0) + .ok_or_else(|| PostgresError::PublicationMissing(publication.to_string()))?; let tables = if let Some(oids) = oids { - client - .query( + query( + client, + crate::sql!( "SELECT c.oid, p.schemaname, p.tablename FROM @@ -90,13 +96,15 @@ pub async fn publication_info( c.relname = p.tablename AND n.nspname = p.schemaname WHERE p.pubname = $1 - AND c.oid = ANY ($2)", - &[&publication, &oids], - ) - .await + AND c.oid = ANY ($2)" + ), + &[&publication, &oids], + ) + .await } else { - client - .query( + query( + client, + crate::sql!( "SELECT c.oid, p.schemaname, p.tablename FROM @@ -105,23 +113,41 @@ pub async fn publication_info( JOIN pg_publication_tables AS p ON c.relname = p.tablename AND n.nspname = p.schemaname WHERE - p.pubname = $1", - &[&publication], - ) - .await + p.pubname = $1" + ), + &[&publication], + ) + .await }?; // The Postgres replication protocol does not support GENERATED columns // so we exclude them from this query. But not all Postgres-like // databases have the `pg_attribute.attgenerated` column. - let attgenerated = if server_major_version >= 12 { - "a.attgenerated = ''" + let pg_columns = if server_major_version >= 12 { + crate::sql!( + " + SELECT + a.attrelid AS table_oid, + a.attname AS name, + a.atttypid AS typoid, + a.attnum AS colnum, + a.atttypmod AS typmod, + a.attnotnull AS not_null, + b.oid IS NOT NULL AS primary_key + FROM pg_catalog.pg_attribute a + LEFT JOIN pg_catalog.pg_constraint b + ON a.attrelid = b.conrelid + AND b.contype = 'p' + AND a.attnum = ANY (b.conkey) + WHERE a.attnum > 0::pg_catalog.int2 + AND NOT a.attisdropped + AND a.attgenerated = '' + AND a.attrelid = ANY ($1) + ORDER BY a.attnum" + ) } else { - "true" - }; - - let pg_columns = format!( - " + crate::sql!( + " SELECT a.attrelid AS table_oid, a.attname AS name, @@ -137,10 +163,11 @@ pub async fn publication_info( AND a.attnum = ANY (b.conkey) WHERE a.attnum > 0::pg_catalog.int2 AND NOT a.attisdropped - AND {attgenerated} + AND true AND a.attrelid = ANY ($1) ORDER BY a.attnum" - ); + ) + }; let table_oids = tables .iter() @@ -148,7 +175,7 @@ pub async fn publication_info( .collect::>(); let mut columns: BTreeMap> = BTreeMap::new(); - for row in client.query(&pg_columns, &[&table_oids]).await? { + for row in query(client, pg_columns, &[&table_oids]).await? { let table_oid: Oid = row.get("table_oid"); let name: String = row.get("name"); let type_oid = row.get("typoid"); @@ -171,20 +198,36 @@ pub async fn publication_info( // PG 15 adds UNIQUE NULLS NOT DISTINCT, which would let us use `UNIQUE` constraints over // nullable columns as keys; i.e. aligns a PG index's NULL handling with an arrangement's // keys. For more info, see https://www.postgresql.org/about/featurematrix/detail/392/ - let nulls_not_distinct = if server_major_version >= 15 { - "pg_index.indnullsnotdistinct" + let pg_keys = if server_major_version >= 15 { + crate::sql!( + " + SELECT + pg_constraint.conrelid AS table_oid, + pg_constraint.oid, + pg_constraint.conkey, + pg_constraint.conname, + pg_constraint.contype = 'p' AS is_primary, + pg_index.indnullsnotdistinct AS nulls_not_distinct + FROM + pg_constraint + JOIN + pg_index + ON pg_index.indexrelid = pg_constraint.conindid + WHERE + pg_constraint.conrelid = ANY ($1) + AND + pg_constraint.contype = ANY (ARRAY['p', 'u']);" + ) } else { - "false" - }; - let pg_keys = format!( - " + crate::sql!( + " SELECT pg_constraint.conrelid AS table_oid, pg_constraint.oid, pg_constraint.conkey, pg_constraint.conname, pg_constraint.contype = 'p' AS is_primary, - {nulls_not_distinct} AS nulls_not_distinct + false AS nulls_not_distinct FROM pg_constraint JOIN @@ -194,10 +237,11 @@ pub async fn publication_info( pg_constraint.conrelid = ANY ($1) AND pg_constraint.contype = ANY (ARRAY['p', 'u']);" - ); + ) + }; let mut keys: BTreeMap> = BTreeMap::new(); - for row in client.query(&pg_keys, &[&table_oids]).await? { + for row in query(client, pg_keys, &[&table_oids]).await? { let table_oid: Oid = row.get("table_oid"); let oid: Oid = row.get("oid"); let cols: Vec = row.get("conkey"); diff --git a/src/regexp/src/lib.rs b/src/regexp/src/lib.rs index 6378e36edd13f..7c016c4f14930 100644 --- a/src/regexp/src/lib.rs +++ b/src/regexp/src/lib.rs @@ -77,6 +77,8 @@ mod tests { for input in inputs { for re in ®exps { let regex = build_regex(re, "").unwrap(); + // This test cross-checks against the sync `postgres` crate, + // while `mz_postgres_util` wrappers target async tokio-postgres. let pg: Vec = client .query_one("select regexp_split_to_array($1, $2)", &[&input, re]) .unwrap() diff --git a/src/sql/src/pure/postgres.rs b/src/sql/src/pure/postgres.rs index 7ca235509a7fc..d06ae497e0e33 100644 --- a/src/sql/src/pure/postgres.rs +++ b/src/sql/src/pure/postgres.rs @@ -689,15 +689,16 @@ pub(crate) fn generate_column_casts( } mod privileges { - use mz_postgres_util::PostgresError; + use mz_postgres_util::{PostgresError, query, sql}; use super::*; use crate::plan::PlanError; use crate::pure::PgSourcePurificationError; async fn check_schema_privileges(client: &Client, table_oids: &[Oid]) -> Result<(), PlanError> { - let invalid_schema_privileges_rows = client - .query( + let invalid_schema_privileges_rows = query( + client, + sql!( " WITH distinct_namespace AS ( SELECT @@ -709,11 +710,11 @@ mod privileges { SELECT d.schema_name FROM distinct_namespace AS d WHERE - NOT has_schema_privilege(CURRENT_USER::TEXT, d.oid, 'usage')", - &[&table_oids], - ) - .await - .map_err(PostgresError::from)?; + NOT has_schema_privilege(CURRENT_USER::TEXT, d.oid, 'usage')" + ), + &[&table_oids], + ) + .await?; let mut invalid_schema_privileges = invalid_schema_privileges_rows .into_iter() @@ -746,8 +747,9 @@ mod privileges { ) -> Result<(), PlanError> { check_schema_privileges(client, table_oids).await?; - let invalid_table_privileges_rows = client - .query( + let invalid_table_privileges_rows = query( + client, + sql!( " SELECT format('%I.%I', n.nspname, c.relname) AS schema_qualified_table_name @@ -756,11 +758,11 @@ mod privileges { pg_class c ON c.oid = oids.oid JOIN pg_namespace n ON c.relnamespace = n.oid - WHERE NOT has_table_privilege(CURRENT_USER::text, c.oid, 'select')", - &[&table_oids], - ) - .await - .map_err(PostgresError::from)?; + WHERE NOT has_table_privilege(CURRENT_USER::text, c.oid, 'select')" + ), + &[&table_oids], + ) + .await?; let mut invalid_table_privileges = invalid_table_privileges_rows .into_iter() @@ -801,7 +803,7 @@ mod privileges { } mod replica_identity { - use mz_postgres_util::PostgresError; + use mz_postgres_util::{query, sql}; use super::*; use crate::plan::PlanError; @@ -812,8 +814,9 @@ mod replica_identity { client: &Client, table_oids: &[Oid], ) -> Result<(), PlanError> { - let invalid_replica_identity_rows = client - .query( + let invalid_replica_identity_rows = query( + client, + sql!( " SELECT format('%I.%I', n.nspname, c.relname) AS schema_qualified_table_name @@ -822,11 +825,11 @@ mod replica_identity { pg_class c ON c.oid = oids.oid JOIN pg_namespace n ON c.relnamespace = n.oid - WHERE relreplident != 'f' OR relreplident IS NULL;", - &[&table_oids], - ) - .await - .map_err(PostgresError::from)?; + WHERE relreplident != 'f' OR relreplident IS NULL;" + ), + &[&table_oids], + ) + .await?; let mut invalid_replica_identity = invalid_replica_identity_rows .into_iter() diff --git a/src/sqllogictest/src/runner.rs b/src/sqllogictest/src/runner.rs index 36eb91b6f073a..43360e1030384 100644 --- a/src/sqllogictest/src/runner.rs +++ b/src/sqllogictest/src/runner.rs @@ -805,6 +805,7 @@ impl<'a> Runner<'a> { .await } + #[allow(clippy::disallowed_methods)] async fn reset_database(&mut self) -> Result<(), anyhow::Error> { let inner = self.inner.as_mut().expect("RunnerInner missing"); @@ -998,6 +999,9 @@ impl<'a> RunnerInner<'a> { panic!("connection error: {}", e); } }); + // `prefix` is generated by sqllogictest harness configuration and + // cannot be represented with composable `Sql`. + #[allow(clippy::disallowed_methods)] client .batch_execute(&format!( "DROP SCHEMA IF EXISTS {prefix}_tsoracle CASCADE; @@ -1319,6 +1323,7 @@ impl<'a> RunnerInner<'a> { /// Set features that should be enabled regardless of whether reset-server was /// called. These features may be set conditionally depending on the run configuration. + #[allow(clippy::disallowed_methods)] async fn ensure_fixed_features(&self) -> Result<(), anyhow::Error> { // We turn on enable_reduce_mfp_fusion, as we wish // to get as much coverage of these features as we can. @@ -1334,6 +1339,7 @@ impl<'a> RunnerInner<'a> { Ok(()) } + #[allow(clippy::disallowed_methods)] async fn run_record<'r>( &mut self, record: &'r Record<'r>, @@ -1425,6 +1431,7 @@ impl<'a> RunnerInner<'a> { } } + #[allow(clippy::disallowed_methods)] async fn run_statement<'r>( &self, expected_error: Option<&'r str>, @@ -1476,6 +1483,7 @@ impl<'a> RunnerInner<'a> { } } + #[allow(clippy::disallowed_methods)] async fn prepare_query<'r>( &self, sql: &str, @@ -1549,6 +1557,7 @@ impl<'a> RunnerInner<'a> { })) } + #[allow(clippy::disallowed_methods)] async fn execute_query<'r>( &self, sql: &str, @@ -1692,6 +1701,7 @@ impl<'a> RunnerInner<'a> { Ok(Outcome::Success) } + #[allow(clippy::disallowed_methods)] async fn execute_view_inner<'r>( &self, sql: &str, @@ -1728,6 +1738,7 @@ impl<'a> RunnerInner<'a> { Ok(tentative_outcome) } + #[allow(clippy::disallowed_methods)] async fn execute_view<'r>( &self, sql: &str, @@ -1871,6 +1882,7 @@ impl<'a> RunnerInner<'a> { } } + #[allow(clippy::disallowed_methods)] async fn run_simple<'r>( &mut self, conn: Option<&'r str>, diff --git a/src/storage/src/source/postgres.rs b/src/storage/src/source/postgres.rs index 99544cda80c4f..7eb387e37c34a 100644 --- a/src/storage/src/source/postgres.rs +++ b/src/storage/src/source/postgres.rs @@ -90,10 +90,8 @@ use mz_expr::{EvalError, MirScalarExpr}; use mz_ore::cast::CastFrom; use mz_ore::error::ErrorExt; use mz_postgres_util::desc::PostgresTableDesc; -use mz_postgres_util::{Client, PostgresError, simple_query_opt}; +use mz_postgres_util::{Client, PostgresError, Sql, query_opt, simple_query_opt, sql}; use mz_repr::{Datum, Diff, GlobalId, Row}; -use mz_sql_parser::ast::Ident; -use mz_sql_parser::ast::display::AstDisplay; use mz_storage_types::errors::{DataflowError, SourceError, SourceErrorDetails}; use mz_storage_types::sources::postgres::CastType; use mz_storage_types::sources::{ @@ -387,14 +385,16 @@ impl From for DataflowError { } async fn ensure_replication_slot(client: &Client, slot: &str) -> Result<(), TransientError> { - // Note: Using unchecked here is okay because we're using it in a SQL query. - let slot = Ident::new_unchecked(slot).to_ast_string_simple(); - let query = format!("CREATE_REPLICATION_SLOT {slot} LOGICAL \"pgoutput\" NOEXPORT_SNAPSHOT"); - match simple_query_opt(client, &query).await { + let slot = Sql::ident(slot); + let query = sql!( + "CREATE_REPLICATION_SLOT {} LOGICAL \"pgoutput\" NOEXPORT_SNAPSHOT", + slot.clone() + ); + match simple_query_opt(client, query).await { Ok(_) => Ok(()), // If the slot already exists that's still ok Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => { - tracing::trace!("replication slot {slot} already existed"); + tracing::trace!(slot = %slot, "replication slot already existed"); Ok(()) } Err(err) => Err(TransientError::PostgresError(err)), @@ -418,9 +418,16 @@ async fn fetch_slot_metadata( interval: Duration, ) -> Result { loop { - let query = "SELECT active_pid, confirmed_flush_lsn - FROM pg_replication_slots WHERE slot_name = $1"; - let Some(row) = client.query_opt(query, &[&slot]).await? else { + let Some(row) = query_opt( + &**client, + sql!( + "SELECT active_pid, confirmed_flush_lsn \ + FROM pg_replication_slots WHERE slot_name = $1" + ), + &[&slot], + ) + .await? + else { return Err(TransientError::MissingReplicationSlot); }; @@ -444,8 +451,7 @@ async fn fetch_slot_metadata( /// Fetch the `pg_current_wal_lsn`, used to report metrics. async fn fetch_max_lsn(client: &Client) -> Result { - let query = "SELECT pg_current_wal_lsn()"; - let row = simple_query_opt(client, query).await?; + let row = simple_query_opt(client, sql!("SELECT pg_current_wal_lsn()")).await?; match row.and_then(|row| { row.get("pg_current_wal_lsn") diff --git a/src/storage/src/source/postgres/replication.rs b/src/storage/src/source/postgres/replication.rs index faf37c7e13acf..8b4db8e780acb 100644 --- a/src/storage/src/source/postgres/replication.rs +++ b/src/storage/src/source/postgres/replication.rs @@ -85,10 +85,8 @@ use mz_dyncfg::ConfigSet; use mz_ore::cast::CastFrom; use mz_ore::future::InTask; use mz_postgres_util::PostgresError; -use mz_postgres_util::{Client, simple_query_opt}; +use mz_postgres_util::{Client, Sql, execute, query_opt, simple_query_opt, sql}; use mz_repr::{Datum, DatumVec, Diff, Row}; -use mz_sql_parser::ast::Ident; -use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal}; use mz_storage_types::dyncfgs::PG_SCHEMA_VALIDATION_INTERVAL; use mz_storage_types::dyncfgs::PG_SOURCE_VALIDATE_TIMELINE; use mz_storage_types::errors::DataflowError; @@ -252,9 +250,12 @@ pub(crate) fn render>( "replication slot already in use; will attempt to kill existing connection", ); - match metadata_client - .execute("SELECT pg_terminate_backend($1)", &[&active_pid]) - .await + match execute( + &**metadata_client, + sql!("SELECT pg_terminate_backend($1)"), + &[&active_pid], + ) + .await { Ok(_) => { tracing::info!( @@ -652,7 +653,7 @@ async fn raw_stream<'a>( // Note: We must use the metadata client here which is NOT in replication mode. Some Aurora // Postgres versions disallow SHOW commands from within replication connection. // See: https://github.com/readysettech/readyset/discussions/28#discussioncomment-4405671 - let row = simple_query_opt(&*metadata_client, "SHOW wal_sender_timeout;") + let row = simple_query_opt(&*metadata_client, sql!("SHOW wal_sender_timeout;")) .await? .unwrap(); let wal_sender_timeout = match row.get("wal_sender_timeout") { @@ -684,13 +685,13 @@ async fn raw_stream<'a>( // Postgres will return all transactions that commit *at or after* after the provided LSN, // following the timely upper semantics. let lsn = PgLsn::from(resume_lsn.offset); - let query = format!( - r#"START_REPLICATION SLOT "{}" LOGICAL {} ("proto_version" '1', "publication_names" {})"#, - Ident::new_unchecked(slot).to_ast_string_simple(), + let query = sql!( + "START_REPLICATION SLOT {} LOGICAL {} (\"proto_version\" '1', \"publication_names\" {})", + Sql::ident(slot), lsn, - escaped_string_literal(publication), + Sql::literal(publication) ); - let copy_stream = match replication_client.copy_both_simple(&query).await { + let copy_stream = match replication_client.copy_both_simple(query.as_str()).await { Ok(copy_stream) => copy_stream, Err(err) if err.code() == Some(&SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE) => { return Ok(Err(DefiniteError::InvalidReplicationSlot)); @@ -1043,12 +1044,12 @@ async fn ensure_publication_exists( publication: &str, ) -> Result, TransientError> { // Figure out the last written LSN and then add one to convert it into an upper. - let result = client - .query_opt( - "SELECT 1 FROM pg_publication WHERE pubname = $1;", - &[&publication], - ) - .await?; + let result = query_opt( + &**client, + sql!("SELECT 1 FROM pg_publication WHERE pubname = $1;"), + &[&publication], + ) + .await?; match result { Some(_) => Ok(Ok(())), None => Ok(Err(DefiniteError::PublicationDropped( diff --git a/src/storage/src/source/postgres/snapshot.rs b/src/storage/src/source/postgres/snapshot.rs index a31e4f69ab974..313b6a565011e 100644 --- a/src/storage/src/source/postgres/snapshot.rs +++ b/src/storage/src/source/postgres/snapshot.rs @@ -168,17 +168,12 @@ use std::time::Duration; use anyhow::bail; use differential_dataflow::AsCollection; use futures::{StreamExt as _, TryStreamExt}; -use itertools::Itertools; use mz_ore::cast::CastFrom; use mz_ore::future::InTask; use mz_postgres_util::desc::PostgresTableDesc; use mz_postgres_util::schemas::get_pg_major_version; -use mz_postgres_util::{Client, Config, PostgresError, simple_query_opt}; +use mz_postgres_util::{Client, Config, PostgresError, Sql, simple_query, simple_query_opt, sql}; use mz_repr::{Datum, DatumVec, Diff, Row}; -use mz_sql_parser::ast::{ - Ident, - display::{AstDisplay, escaped_string_literal}, -}; use mz_storage_types::connections::ConnectionContext; use mz_storage_types::errors::DataflowError; use mz_storage_types::parameters::PgSourceSnapshotConfig; @@ -309,13 +304,9 @@ async fn estimate_table_block_counts( return Ok(BTreeMap::new()); } - // Query relpages for all tables at once - let oid_list = table_oids - .iter() - .map(|oid| oid.to_string()) - .collect::>() - .join(","); - let query = format!( + // Query relpages for all tables at once. + let oid_list = Sql::join(table_oids.iter().copied().map(Sql::from), ","); + let query = sql!( "SELECT oid, relpages FROM pg_class WHERE oid IN ({})", oid_list ); @@ -327,7 +318,7 @@ async fn estimate_table_block_counts( } // Execute the query and collect results - let rows = client.simple_query(&query).await?; + let rows = simple_query(client, query).await?; for msg in rows { if let tokio_postgres::SimpleQueryMessage::Row(row) = msg { let oid: u32 = row.get("oid").unwrap().parse().unwrap(); @@ -664,32 +655,30 @@ pub(crate) fn render>( ctid_range ); - // To handle quoted/keyword names, we can use `Ident`'s AST printing, which - // emulate's PG's rules for name formatting. - let namespace = Ident::new_unchecked(&info.desc.namespace) - .to_ast_string_stable(); - let table = Ident::new_unchecked(&info.desc.name) - .to_ast_string_stable(); - let column_list = info - .desc - .columns - .iter() - .map(|c| Ident::new_unchecked(&c.name).to_ast_string_stable()) - .join(","); - + let namespace = Sql::ident(&info.desc.namespace); + let table = Sql::ident(&info.desc.name); + let column_list = + Sql::join(info.desc.columns.iter().map(|c| Sql::ident(&c.name)), ","); let ctid_filter = match ctid_range.end_block { - Some(end) => format!( + Some(end) => sql!( "WHERE ctid >= '({},0)'::tid AND ctid < '({},0)'::tid", - ctid_range.start_block, end + ctid_range.start_block, + end + ), + None => sql!( + "WHERE ctid >= '({},0)'::tid", + ctid_range.start_block ), - None => format!("WHERE ctid >= '({},0)'::tid", ctid_range.start_block), }; - let query = format!( - "COPY (SELECT {column_list} FROM {namespace}.{table} {ctid_filter}) \ - TO STDOUT (FORMAT TEXT, DELIMITER '\t')" + let query = sql!( + "COPY (SELECT {} FROM {}.{} {}) TO STDOUT (FORMAT TEXT, DELIMITER '\t')", + column_list, + namespace, + table, + ctid_filter ); - let mut stream = pin!(client.copy_out_simple(&query).await?); + let mut stream = pin!(client.copy_out_simple(query.as_str()).await?); let mut snapshot_staged = 0; let mut update = @@ -746,10 +735,10 @@ pub(crate) fn render>( *snapshot_cap_set = CapabilitySet::new(); while snapshot_input.next().await.is_some() {} trace!(%id, "timely-{worker_id} (leader) comitting COPY transaction"); - client.simple_query("COMMIT").await?; + simple_query(&client, sql!("COMMIT")).await?; } else { trace!(%id, "timely-{worker_id} comitting COPY transaction"); - client.simple_query("COMMIT").await?; + simple_query(&client, sql!("COMMIT")).await?; *snapshot_cap_set = CapabilitySet::new(); } drop(client); @@ -818,7 +807,7 @@ async fn export_snapshot( Ok(ok) => Ok(ok), Err(err) => { // We don't want to leave the client inside a failed tx - client.simple_query("ROLLBACK;").await?; + simple_query(client, sql!("ROLLBACK;")).await?; Err(err) } } @@ -829,16 +818,24 @@ async fn export_snapshot_inner( slot: &str, temporary: bool, ) -> Result<(String, MzOffset), TransientError> { - client - .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;") - .await?; + simple_query( + client, + sql!("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;"), + ) + .await?; - // Note: Using unchecked here is okay because we're using it in a SQL query. - let slot = Ident::new_unchecked(slot).to_ast_string_simple(); - let temporary_str = if temporary { " TEMPORARY" } else { "" }; - let query = - format!("CREATE_REPLICATION_SLOT {slot}{temporary_str} LOGICAL \"pgoutput\" USE_SNAPSHOT"); - let row = match simple_query_opt(client, &query).await { + let query = if temporary { + sql!( + "CREATE_REPLICATION_SLOT {} TEMPORARY LOGICAL \"pgoutput\" USE_SNAPSHOT", + Sql::ident(slot) + ) + } else { + sql!( + "CREATE_REPLICATION_SLOT {} LOGICAL \"pgoutput\" USE_SNAPSHOT", + Sql::ident(slot) + ) + }; + let row = match simple_query_opt(client, query).await { Ok(row) => Ok(row.unwrap()), Err(PostgresError::Postgres(err)) if err.code() == Some(&SqlState::DUPLICATE_OBJECT) => { return Err(TransientError::ReplicationSlotAlreadyExists); @@ -856,7 +853,7 @@ async fn export_snapshot_inner( .checked_sub(1) .expect("consistent point is always non-zero"); - let row = simple_query_opt(client, "SELECT pg_export_snapshot();") + let row = simple_query_opt(client, sql!("SELECT pg_export_snapshot();")) .await? .unwrap(); let snapshot = row.get("pg_export_snapshot").unwrap().to_owned(); @@ -867,23 +864,24 @@ async fn export_snapshot_inner( /// Starts a read-only transaction on the SQL session of `client` at a the consistent LSN point of /// `snapshot`. async fn use_snapshot(client: &Client, snapshot: &str) -> Result<(), TransientError> { - client - .simple_query("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;") - .await?; - let query = format!( - "SET TRANSACTION SNAPSHOT {};", - escaped_string_literal(snapshot) - ); - client.simple_query(&query).await?; + simple_query( + client, + sql!("BEGIN READ ONLY ISOLATION LEVEL REPEATABLE READ;"), + ) + .await?; + let query = sql!("SET TRANSACTION SNAPSHOT {};", Sql::literal(snapshot)); + simple_query(client, query).await?; Ok(()) } async fn set_statement_timeout(client: &Client, timeout: Duration) -> Result<(), TransientError> { // Value is known to accept milliseconds w/o units. // https://www.postgresql.org/docs/current/runtime-config-client.html - client - .simple_query(&format!("SET statement_timeout = {}", timeout.as_millis())) - .await?; + let query = sql!( + "SET statement_timeout = {}", + Sql::literal(&timeout.as_millis().to_string()) + ); + simple_query(client, query).await?; Ok(()) } @@ -920,11 +918,12 @@ async fn report_snapshot_size( let Some((_, info)) = outputs.first_key_value() else { continue; }; - let table = format!( + let table = sql!( "{}.{}", - Ident::new_unchecked(info.desc.namespace.clone()).to_ast_string_simple(), - Ident::new_unchecked(info.desc.name.clone()).to_ast_string_simple() - ); + Sql::ident(&info.desc.namespace), + Sql::ident(&info.desc.name) + ) + .into_string(); let stats = collect_table_statistics( client, snapshot_config, @@ -951,28 +950,24 @@ struct TableStatistics { async fn collect_table_statistics( client: &Client, config: PgSourceSnapshotConfig, - namespace: &str, - table_name: &str, + schema: &str, + table: &str, oid: u32, ) -> Result { use mz_ore::metrics::MetricsFutureExt; let mut stats = TableStatistics::default(); - let table = format!( - "{}.{}", - Ident::new_unchecked(namespace).to_ast_string_simple(), - Ident::new_unchecked(table_name).to_ast_string_simple() - ); - let estimate_row = simple_query_opt( - client, - &format!("SELECT reltuples::bigint AS estimate_count FROM pg_class WHERE oid = '{oid}'"), - ) - .wall_time() - .set_at(&mut stats.count_latency) - .await?; + let estimate_query = sql!( + "SELECT reltuples::bigint AS estimate_count FROM pg_class WHERE oid = {}", + Sql::literal(&oid.to_string()) + ); + let estimate_row = simple_query_opt(client, estimate_query) + .wall_time() + .set_at(&mut stats.count_latency) + .await?; stats.count = match estimate_row { Some(row) => row.get("estimate_count").unwrap().parse().unwrap_or(0), - None => bail!("failed to get estimate count for {table}"), + None => bail!("failed to get estimate count for {schema}.{table}"), }; // If the estimate is low enough we can attempt to get an exact count. Note that not yet @@ -980,13 +975,18 @@ async fn collect_table_statistics( // large. We accept this risk and we offer the feature flag as an escape hatch if it becomes // problematic. if config.collect_strict_count && stats.count < 1_000_000 { - let count_row = simple_query_opt(client, &format!("SELECT count(*) as count from {table}")) + let count_query = sql!( + "SELECT count(*) as count from {}.{}", + Sql::ident(schema), + Sql::ident(table) + ); + let count_row = simple_query_opt(client, count_query) .wall_time() .set_at(&mut stats.count_latency) .await?; stats.count = match count_row { Some(row) => row.get("count").unwrap().parse().unwrap(), - None => bail!("failed to get count for {table}"), + None => bail!("failed to get count for {schema}.{table}"), } } diff --git a/src/testdrive/Cargo.toml b/src/testdrive/Cargo.toml index 0d36dc2510604..18a2e4e5d9ced 100644 --- a/src/testdrive/Cargo.toml +++ b/src/testdrive/Cargo.toml @@ -48,6 +48,7 @@ mz-ore = { path = "../ore", features = ["async"] } mz-persist-types = { path = "../persist-types" } mz-persist-client = { path = "../persist-client" } mz-pgrepr = { path = "../pgrepr" } +mz-postgres-util = { path = "../postgres-util" } mz-repr = { path = "../repr" } mz-sql = { path = "../sql" } mz-sql-parser = { path = "../sql-parser" } diff --git a/src/testdrive/src/action.rs b/src/testdrive/src/action.rs index 2100d65271f78..f1b537a2ee2ae 100644 --- a/src/testdrive/src/action.rs +++ b/src/testdrive/src/action.rs @@ -38,6 +38,10 @@ use mz_persist_client::cache::PersistClientCache; use mz_persist_client::cfg::PersistConfig; use mz_persist_client::rpc::PubSubClientConnection; use mz_persist_client::{PersistClient, PersistLocation}; +use mz_postgres_util::{ + Sql, batch_execute as pg_batch_execute, query as pg_query, query_one as pg_query_one, + sql as pg_sql, +}; use mz_sql::catalog::EnvironmentId; use mz_tls_util::make_tls; use rdkafka::ClientConfig; @@ -458,22 +462,23 @@ impl State { ) .await?; - let version = inner_client - .query_one("SELECT mz_version_num()", &[]) + let version = pg_query_one(&inner_client, pg_sql!("SELECT mz_version_num()"), &[]) .await .context("getting version of materialize") .map(|row| row.get::<_, i32>(0))?; - let semver = inner_client - .query_one("SELECT right(split_part(mz_version(), ' ', 1), -1)", &[]) - .await - .context("getting semver of materialize") - .map(|row| row.get::<_, String>(0))? - .parse::() - .context("parsing semver of materialize")?; + let semver = pg_query_one( + &inner_client, + pg_sql!("SELECT right(split_part(mz_version(), ' ', 1), -1)"), + &[], + ) + .await + .context("getting semver of materialize") + .map(|row| row.get::<_, String>(0))? + .parse::() + .context("parsing semver of materialize")?; - inner_client - .batch_execute("ALTER SYSTEM RESET ALL") + pg_batch_execute(&inner_client, pg_sql!("ALTER SYSTEM RESET ALL")) .await .context("resetting materialize state: ALTER SYSTEM RESET ALL")?; @@ -485,10 +490,15 @@ impl State { } else { "enable_unsafe_functions" }; - let res = inner_client - .batch_execute(&format!("ALTER SYSTEM SET {enable_unsafe_functions} = on")) - .await - .context("enabling dangerous functions"); + let res = pg_batch_execute( + &inner_client, + pg_sql!( + "ALTER SYSTEM SET {} = on", + Sql::ident(enable_unsafe_functions) + ), + ) + .await + .context("enabling dangerous functions"); if let Err(e) = res { match e.root_cause().downcast_ref::() { Some(e) if *e.code() == SqlState::CANT_CHANGE_RUNTIME_PARAM => { @@ -502,8 +512,7 @@ impl State { } } - for row in inner_client - .query("SHOW DATABASES", &[]) + for row in pg_query(&inner_client, pg_sql!("SHOW DATABASES"), &[]) .await .context("resetting materialize state: SHOW DATABASES")? { @@ -511,15 +520,14 @@ impl State { if db_name.starts_with("testdrive_no_reset_") { continue; } - let query = format!( - "DROP DATABASE {}", - postgres_protocol::escape::escape_identifier(&db_name) - ); - sql::print_query(&query, None); - inner_client.batch_execute(&query).await.context(format!( - "resetting materialize state: DROP DATABASE {}", - db_name, - ))?; + let drop_database = pg_sql!("DROP DATABASE {}", Sql::ident(&db_name)); + sql::print_query(drop_database.as_str(), None); + pg_batch_execute(&inner_client, drop_database) + .await + .context(format!( + "resetting materialize state: DROP DATABASE {}", + db_name, + ))?; } // Get all user clusters not running any objects owned by users @@ -551,8 +559,7 @@ impl State { AND owner_id LIKE 'u%';"; - let inactive_clusters = inner_client - .query(inactive_user_clusters, &[]) + let inactive_clusters = pg_query(&inner_client, Sql::new(inactive_user_clusters), &[]) .await .context("resetting materialize state: inactive_user_clusters")?; @@ -565,81 +572,91 @@ impl State { if cluster_name.starts_with("testdrive_no_reset_") { continue; } - let query = format!( - "DROP CLUSTER {}", - postgres_protocol::escape::escape_identifier(&cluster_name) - ); - sql::print_query(&query, None); - inner_client.batch_execute(&query).await.context(format!( - "resetting materialize state: DROP CLUSTER {}", - cluster_name, - ))?; + let drop_cluster = pg_sql!("DROP CLUSTER {}", Sql::ident(&cluster_name)); + sql::print_query(drop_cluster.as_str(), None); + pg_batch_execute(&inner_client, drop_cluster) + .await + .context(format!( + "resetting materialize state: DROP CLUSTER {}", + cluster_name, + ))?; } - inner_client - .batch_execute("CREATE DATABASE materialize") + pg_batch_execute(&inner_client, pg_sql!("CREATE DATABASE materialize")) .await .context("resetting materialize state: CREATE DATABASE materialize")?; // Attempt to remove all users but the current user. Old versions of // Materialize did not support roles, so this degrades gracefully if // mz_roles does not exist. - if let Ok(rows) = inner_client.query("SELECT name FROM mz_roles", &[]).await { + if let Ok(rows) = pg_query(&inner_client, pg_sql!("SELECT name FROM mz_roles"), &[]).await { for row in rows { let role_name: String = row.get(0); if role_name == self.materialize.user || role_name.starts_with("mz_") { continue; } - let query = format!( - "DROP ROLE {}", - postgres_protocol::escape::escape_identifier(&role_name) - ); - sql::print_query(&query, None); - inner_client.batch_execute(&query).await.context(format!( - "resetting materialize state: DROP ROLE {}", - role_name, - ))?; + let drop_role = pg_sql!("DROP ROLE {}", Sql::ident(&role_name)); + sql::print_query(drop_role.as_str(), None); + pg_batch_execute(&inner_client, drop_role) + .await + .context(format!( + "resetting materialize state: DROP ROLE {}", + role_name, + ))?; } } // Alter materialize user with all system privileges. - inner_client - .batch_execute(&format!( + pg_batch_execute( + &inner_client, + pg_sql!( "GRANT ALL PRIVILEGES ON SYSTEM TO {}", - self.materialize.user - )) - .await?; + Sql::ident(&self.materialize.user) + ), + ) + .await?; // Grant initial privileges. - inner_client - .batch_execute("GRANT USAGE ON DATABASE materialize TO PUBLIC") - .await?; - inner_client - .batch_execute(&format!( + pg_batch_execute( + &inner_client, + pg_sql!("GRANT USAGE ON DATABASE materialize TO PUBLIC"), + ) + .await?; + pg_batch_execute( + &inner_client, + pg_sql!( "GRANT ALL PRIVILEGES ON DATABASE materialize TO {}", - self.materialize.user - )) - .await?; - inner_client - .batch_execute(&format!( + Sql::ident(&self.materialize.user) + ), + ) + .await?; + pg_batch_execute( + &inner_client, + pg_sql!( "GRANT ALL PRIVILEGES ON SCHEMA materialize.public TO {}", - self.materialize.user - )) - .await?; + Sql::ident(&self.materialize.user) + ), + ) + .await?; let cluster = match version { ..=8199 => "default", 8200.. => "quickstart", }; - inner_client - .batch_execute(&format!("GRANT USAGE ON CLUSTER {cluster} TO PUBLIC")) - .await?; - inner_client - .batch_execute(&format!( - "GRANT ALL PRIVILEGES ON CLUSTER {cluster} TO {}", - self.materialize.user - )) - .await?; + pg_batch_execute( + &inner_client, + pg_sql!("GRANT USAGE ON CLUSTER {} TO PUBLIC", Sql::ident(cluster)), + ) + .await?; + pg_batch_execute( + &inner_client, + pg_sql!( + "GRANT ALL PRIVILEGES ON CLUSTER {} TO {}", + Sql::ident(cluster), + Sql::ident(&self.materialize.user) + ), + ) + .await?; Ok(()) } @@ -1185,6 +1202,8 @@ async fn create_materialize_state( util::postgres::config_url(&config.materialize_internal_pgconfig)?; for (key, value) in &config.materialize_params { + // Session parameter values are raw SQL fragments from testdrive config. + #[allow(clippy::disallowed_methods)] pgclient .batch_execute(&format!("SET {key} = {value}")) .await @@ -1227,8 +1246,7 @@ async fn create_materialize_state( materialize_internal_url.host_str().unwrap(), config.materialize_internal_http_port ); - let environment_id = pgclient - .query_one("SELECT mz_environment_id()", &[]) + let environment_id = pg_query_one(&pgclient, pg_sql!("SELECT mz_environment_id()"), &[]) .await? .get::<_, String>(0) .parse() diff --git a/src/testdrive/src/action/consistency.rs b/src/testdrive/src/action/consistency.rs index a0dbe1e016326..beb54479206ec 100644 --- a/src/testdrive/src/action/consistency.rs +++ b/src/testdrive/src/action/consistency.rs @@ -19,6 +19,7 @@ use crate::parser::{BuiltinCommand, LineReader, parse}; use anyhow::{Context, anyhow, bail}; use mz_ore::retry::{Retry, RetryResult}; use mz_persist_client::{PersistLocation, ShardId}; +use mz_postgres_util::{query_one, sql}; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; @@ -276,8 +277,7 @@ async fn check_statement_logging(orig_state: &State) -> Result<(), anyhow::Error .await .context("connecting as mz_system to query enable_rbac_checks")?; - let row = client - .query_one("SHOW enable_rbac_checks", &[]) + let row = query_one(&client, sql!("SHOW enable_rbac_checks"), &[]) .await .context("querying enable_rbac_checks")?; diff --git a/src/testdrive/src/action/kafka/verify_data.rs b/src/testdrive/src/action/kafka/verify_data.rs index 6c8d2c24be817..15da628ea60c8 100644 --- a/src/testdrive/src/action/kafka/verify_data.rs +++ b/src/testdrive/src/action/kafka/verify_data.rs @@ -12,6 +12,7 @@ use std::time::Duration; use std::{cmp, str}; use anyhow::{Context, bail, ensure}; +use mz_postgres_util::{Sql, query_one, sql}; use rdkafka::consumer::{Consumer, StreamConsumer}; use rdkafka::error::KafkaError; use rdkafka::message::{Headers, Message}; @@ -76,27 +77,25 @@ struct Record { } async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result { - let query = format!( + let query = sql!( "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \ - ON mz_sinks.id = mz_kafka_sinks.id \ - JOIN mz_schemas s ON s.id = mz_sinks.schema_id \ - LEFT JOIN mz_databases d ON d.id = s.database_id \ - WHERE d.name = $1 \ - AND s.name = $2 \ - AND mz_sinks.name = $3", - topic_field + ON mz_sinks.id = mz_kafka_sinks.id \ + JOIN mz_schemas s ON s.id = mz_sinks.schema_id \ + LEFT JOIN mz_databases d ON d.id = s.database_id \ + WHERE d.name = $1 \ + AND s.name = $2 \ + AND mz_sinks.name = $3", + Sql::ident(topic_field) ); let sink_fields: Vec<&str> = sink.split('.').collect(); - let result = state - .materialize - .pgclient - .query_one( - query.as_str(), - &[&sink_fields[0], &sink_fields[1], &sink_fields[2]], - ) - .await - .context("retrieving topic name")? - .get(topic_field); + let result = query_one( + &state.materialize.pgclient, + query, + &[&sink_fields[0], &sink_fields[1], &sink_fields[2]], + ) + .await + .context("retrieving topic name")? + .get(topic_field); Ok(result) } diff --git a/src/testdrive/src/action/kafka/verify_topic.rs b/src/testdrive/src/action/kafka/verify_topic.rs index 272c4bfcaab62..c5e5a33ec73c7 100644 --- a/src/testdrive/src/action/kafka/verify_topic.rs +++ b/src/testdrive/src/action/kafka/verify_topic.rs @@ -14,6 +14,7 @@ use std::time::Duration; use anyhow::{Context, bail}; use mz_ore::collections::CollectionExt; use mz_ore::retry::Retry; +use mz_postgres_util::{Sql, query_one, sql}; use rdkafka::admin::{AdminClient, AdminOptions, ResourceSpecifier}; use crate::action::{ControlFlow, State}; @@ -25,27 +26,25 @@ enum Topic { } async fn get_topic(sink: &str, topic_field: &str, state: &State) -> Result { - let query = format!( + let query = sql!( "SELECT {} FROM mz_sinks JOIN mz_kafka_sinks \ - ON mz_sinks.id = mz_kafka_sinks.id \ - JOIN mz_schemas s ON s.id = mz_sinks.schema_id \ - LEFT JOIN mz_databases d ON d.id = s.database_id \ - WHERE d.name = $1 \ - AND s.name = $2 \ - AND mz_sinks.name = $3", - topic_field + ON mz_sinks.id = mz_kafka_sinks.id \ + JOIN mz_schemas s ON s.id = mz_sinks.schema_id \ + LEFT JOIN mz_databases d ON d.id = s.database_id \ + WHERE d.name = $1 \ + AND s.name = $2 \ + AND mz_sinks.name = $3", + Sql::ident(topic_field) ); let sink_fields: Vec<&str> = sink.split('.').collect(); - let result = state - .materialize - .pgclient - .query_one( - query.as_str(), - &[&sink_fields[0], &sink_fields[1], &sink_fields[2]], - ) - .await - .context("retrieving topic name")? - .get(topic_field); + let result = query_one( + &state.materialize.pgclient, + query, + &[&sink_fields[0], &sink_fields[1], &sink_fields[2]], + ) + .await + .context("retrieving topic name")? + .get(topic_field); Ok(result) } diff --git a/src/testdrive/src/action/postgres/execute.rs b/src/testdrive/src/action/postgres/execute.rs index d96cee2b2f752..1a2891d8194ea 100644 --- a/src/testdrive/src/action/postgres/execute.rs +++ b/src/testdrive/src/action/postgres/execute.rs @@ -18,6 +18,9 @@ use crate::util::postgres::postgres_client; async fn execute_input(cmd: BuiltinCommand, client: &Client) -> Result<(), anyhow::Error> { for query in cmd.input { println!(">> {}", query); + // `query` is raw SQL from testdrive input and may contain multiple + // statements; this command intentionally forwards it verbatim. + #[allow(clippy::disallowed_methods)] client .batch_execute(&query) .await diff --git a/src/testdrive/src/action/postgres/verify_slot.rs b/src/testdrive/src/action/postgres/verify_slot.rs index cc1cd0c651007..4e3f9fdf0d769 100644 --- a/src/testdrive/src/action/postgres/verify_slot.rs +++ b/src/testdrive/src/action/postgres/verify_slot.rs @@ -12,6 +12,7 @@ use std::time::Duration; use anyhow::{Context, bail}; use mz_ore::retry::Retry; +use mz_postgres_util::{query, sql}; use crate::action::{ControlFlow, State}; use crate::parser::BuiltinCommand; @@ -33,13 +34,13 @@ pub async fn run_verify_slot( .max_duration(cmp::max(state.default_timeout, Duration::from_secs(60))) .retry_async_canceling(|_| async { println!(">> checking for postgres replication slot {}", &slot); - let rows = client - .query( - "SELECT active_pid FROM pg_replication_slots WHERE slot_name LIKE $1::TEXT", - &[&slot], - ) - .await - .context("querying postgres for replication slot")?; + let rows = query( + &client, + sql!("SELECT active_pid FROM pg_replication_slots WHERE slot_name LIKE $1::TEXT"), + &[&slot], + ) + .await + .context("querying postgres for replication slot")?; if rows.len() != 1 { bail!( diff --git a/src/testdrive/src/action/set.rs b/src/testdrive/src/action/set.rs index 12efcee296950..c15397ee8d0d6 100644 --- a/src/testdrive/src/action/set.rs +++ b/src/testdrive/src/action/set.rs @@ -10,6 +10,7 @@ use std::cmp; use anyhow::{Context, bail}; +use mz_postgres_util::query_one_prepared; use regex::Regex; use tokio::fs; @@ -106,10 +107,14 @@ pub async fn run_set_from_sql( let var = cmd.args.string("var")?; cmd.args.done()?; - let row = state + let query = cmd.input.join("\n"); + let statement = state .materialize .pgclient - .query_one(&cmd.input.join("\n"), &[]) + .prepare(&query) + .await + .context("preparing query")?; + let row = query_one_prepared(&state.materialize.pgclient, &statement, &[]) .await .context("running query")?; if row.columns().len() != 1 { diff --git a/src/testdrive/src/action/skip_if.rs b/src/testdrive/src/action/skip_if.rs index 7b24f03194348..bf2456d2e51c7 100644 --- a/src/testdrive/src/action/skip_if.rs +++ b/src/testdrive/src/action/skip_if.rs @@ -8,6 +8,7 @@ // by the Apache License, Version 2.0. use anyhow::{Context, bail}; +use mz_postgres_util::query_one_prepared; use tokio_postgres::types::Type; use crate::action::{ControlFlow, State}; @@ -26,10 +27,7 @@ pub async fn run_skip_if(cmd: BuiltinCommand, state: &State) -> Result, expected_hint: Option<&ErrorMatcher>, ) -> Result<(), anyhow::Error> { + // `query` is raw SQL from testdrive input and may include statements that + // cannot be represented as composable `Sql`. + #[allow(clippy::disallowed_methods)] match state.materialize.pgclient.query(query, &[]).await { Ok(_) => bail!("query succeeded, but expected {}", expected_error), Err(err) => match err.source().and_then(|err| err.downcast_ref::()) { diff --git a/src/testdrive/src/action/version_check.rs b/src/testdrive/src/action/version_check.rs index 42ce7ee0405f6..281f86df76572 100644 --- a/src/testdrive/src/action/version_check.rs +++ b/src/testdrive/src/action/version_check.rs @@ -8,6 +8,7 @@ // by the Apache License, Version 2.0. use anyhow::{Context, bail}; +use mz_postgres_util::query_one_prepared; use tokio_postgres::types::Type; use crate::action::State; @@ -30,10 +31,7 @@ pub async fn run_version_check( *stmt.columns()[0].type_() ); } - let actual_version: i32 = state - .materialize - .pgclient - .query_one(&stmt, &[]) + let actual_version: i32 = query_one_prepared(&state.materialize.pgclient, &stmt, &[]) .await .context("executing version-check query failed")? .get(0); diff --git a/src/timestamp-oracle/src/postgres_oracle.rs b/src/timestamp-oracle/src/postgres_oracle.rs index 4880a29a190c3..c9c4f5b4a021e 100644 --- a/src/timestamp-oracle/src/postgres_oracle.rs +++ b/src/timestamp-oracle/src/postgres_oracle.rs @@ -19,6 +19,8 @@ use std::time::{Duration, SystemTime}; use async_trait::async_trait; use deadpool_postgres::tokio_postgres::Config; use deadpool_postgres::tokio_postgres::error::SqlState; +use deadpool_postgres::tokio_postgres::types::ToSql; +use deadpool_postgres::tokio_postgres::{Row, Statement}; use deadpool_postgres::{Object, PoolError}; use dec::Decimal; use mz_adapter_types::timestamp_oracle::{ @@ -74,6 +76,52 @@ const CRDB_SCHEMA_OPTIONS: &str = "WITH (sql_stats_automatic_collection_enabled const CRDB_CONFIGURE_ZONE: &str = "ALTER TABLE timestamp_oracle CONFIGURE ZONE USING gc.ttlseconds = 600;"; +/// NOTE: `mz-timestamp-oracle` currently keeps its Postgres surface local; it +/// does not use `mz-postgres-util` wrappers. +async fn pg_batch_execute( + client: &Object, + query: &str, +) -> Result<(), deadpool_postgres::tokio_postgres::Error> { + #[allow(clippy::disallowed_methods)] + client.batch_execute(query).await +} + +async fn pg_query_one_prepared( + client: &Object, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + #[allow(clippy::disallowed_methods)] + client.query_one(statement, params).await +} + +async fn pg_execute_prepared( + client: &Object, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + #[allow(clippy::disallowed_methods)] + client.execute(statement, params).await +} + +async fn pg_txn_query_prepared( + txn: &deadpool_postgres::Transaction<'_>, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result, deadpool_postgres::tokio_postgres::Error> { + #[allow(clippy::disallowed_methods)] + txn.query(statement, params).await +} + +async fn pg_txn_query_one_prepared( + txn: &deadpool_postgres::Transaction<'_>, + statement: &Statement, + params: &[&(dyn ToSql + Sync)], +) -> Result { + #[allow(clippy::disallowed_methods)] + txn.query_one(statement, params).await +} + /// A [`TimestampOracle`] backed by "Postgres". #[derive(Debug)] pub struct PostgresTimestampOracle @@ -561,12 +609,14 @@ where let client = postgres_client.get_connection().await?; - let crdb_mode = match client - .batch_execute(&format!( + let crdb_mode = match pg_batch_execute( + &client, + &format!( "{}; {}{}; {}", create_schema, SCHEMA, CRDB_SCHEMA_OPTIONS, CRDB_CONFIGURE_ZONE, - )) - .await + ), + ) + .await { Ok(()) => true, Err(e) @@ -583,9 +633,7 @@ where }; if !crdb_mode { - client - .batch_execute(&format!("{}; {};", create_schema, SCHEMA)) - .await?; + pg_batch_execute(&client, &format!("{}; {};", create_schema, SCHEMA)).await?; } let oracle = PostgresTimestampOracle { @@ -609,12 +657,12 @@ where let statement = client.prepare_cached(q).await?; let initially_coerced = Self::ts_to_decimal(initially); - let _ = client - .execute( - &statement, - &[&oracle.timeline, &initially_coerced, &initially_coerced], - ) - .await?; + let _ = pg_execute_prepared( + &client, + &statement, + &[&oracle.timeline, &initially_coerced, &initially_coerced], + ) + .await?; // Forward timestamps to what we're given from outside. Remember, // the above query will only create the row at the initial timestamp @@ -666,7 +714,7 @@ where SELECT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'timestamp_oracle' AND table_schema = CURRENT_SCHEMA); "#; let statement = txn.prepare(q).await?; - let exists_row = txn.query_one(&statement, &[]).await?; + let exists_row = pg_txn_query_one_prepared(&txn, &statement, &[]).await?; let exists: bool = exists_row.try_get("exists").expect("missing exists column"); if !exists { return Ok(Vec::new()); @@ -676,7 +724,7 @@ where SELECT timeline, GREATEST(read_ts, write_ts) as ts FROM timestamp_oracle; "#; let statement = txn.prepare(q).await?; - let rows = txn.query(&statement, &[]).await?; + let rows = pg_txn_query_prepared(&txn, &statement, &[]).await?; txn.commit().await?; @@ -718,9 +766,9 @@ where "#; let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - let result = client - .query_one(&statement, &[&self.timeline, &proposed_next_ts]) - .await?; + let result = + pg_query_one_prepared(&client, &statement, &[&self.timeline, &proposed_next_ts]) + .await?; let write_ts: Numeric = result.try_get("write_ts").expect("missing column write_ts"); let write_ts = Self::decimal_to_ts(write_ts); @@ -747,7 +795,7 @@ where "#; let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - let result = client.query_one(&statement, &[&self.timeline]).await?; + let result = pg_query_one_prepared(&client, &statement, &[&self.timeline]).await?; let write_ts: Numeric = result.try_get("write_ts").expect("missing column write_ts"); let write_ts = Self::decimal_to_ts(write_ts); @@ -768,7 +816,7 @@ where "#; let client = self.get_connection().await?; let statement = client.prepare_cached(q).await?; - let result = client.query_one(&statement, &[&self.timeline]).await?; + let result = pg_query_one_prepared(&client, &statement, &[&self.timeline]).await?; let read_ts: Numeric = result.try_get("read_ts").expect("missing column read_ts"); let read_ts = Self::decimal_to_ts(read_ts); @@ -795,9 +843,7 @@ where let statement = client.prepare_cached(q).await?; let write_ts = Self::ts_to_decimal(write_ts); - let _ = client - .execute(&statement, &[&self.timeline, &write_ts]) - .await?; + let _ = pg_execute_prepared(&client, &statement, &[&self.timeline, &write_ts]).await?; debug!( timeline = ?self.timeline, diff --git a/test/metabase/smoketest/src/bin/metabase-smoketest.rs b/test/metabase/smoketest/src/bin/metabase-smoketest.rs index b20d9ba2d7b8e..af31ddbb1de9a 100644 --- a/test/metabase/smoketest/src/bin/metabase-smoketest.rs +++ b/test/metabase/smoketest/src/bin/metabase-smoketest.rs @@ -107,6 +107,9 @@ async fn main() -> Result<(), anyhow::Error> { mz_ore::test::init_logging(); let pgclient = connect_materialized().await?; + // Test fixture bootstrap query with fixed SQL text; this smoketest talks + // directly to the driver. + #[allow(clippy::disallowed_methods)] pgclient .batch_execute( "CREATE OR REPLACE MATERIALIZED VIEW orders (id, date, quantity, total) AS diff --git a/test/test-util/src/mz_client.rs b/test/test-util/src/mz_client.rs index 74c36efaa6df7..704853f8259c9 100644 --- a/test/test-util/src/mz_client.rs +++ b/test/test-util/src/mz_client.rs @@ -37,6 +37,8 @@ pub async fn client(host: &str, port: u16) -> Result { /// Try running PostgresSQL's `query` function, checking for a common /// Materialize error in `check_error`. +// This helper intentionally executes caller-provided test SQL. +#[allow(clippy::disallowed_methods)] pub async fn try_query(mz_client: &Client, query: &str, delay: Duration) -> Result> { loop { let timer = std::time::Instant::now(); @@ -50,6 +52,8 @@ pub async fn try_query(mz_client: &Client, query: &str, delay: Duration) -> Resu /// Try running PostgreSQL's `query_one` function, checking for a common /// Materialize error in `check_error`. +// This helper intentionally executes caller-provided test SQL. +#[allow(clippy::disallowed_methods)] pub async fn try_query_one(mz_client: &Client, query: &str, delay: Duration) -> Result { loop { let timer = std::time::Instant::now(); @@ -89,6 +93,8 @@ async fn delay_for(elapsed: Duration, delay: Duration) { } /// Run Materialize's `SHOW SOURCES` command +// Test utility uses direct driver query for a fixed introspection statement. +#[allow(clippy::disallowed_methods)] pub async fn show_sources(mz_client: &Client) -> Result> { let mut res = Vec::new(); for row in mz_client.query("SHOW SOURCES", &[]).await? { @@ -99,6 +105,8 @@ pub async fn show_sources(mz_client: &Client) -> Result> { } /// Delete a source and all dependent views, if the source exists +// Name is supplied by test code and this utility intentionally forwards SQL. +#[allow(clippy::disallowed_methods)] pub async fn drop_source(mz_client: &Client, name: &str) -> Result<()> { let q = format!("DROP SOURCE IF EXISTS {} CASCADE", name); debug!("deleting source=> {}", q); @@ -107,6 +115,8 @@ pub async fn drop_source(mz_client: &Client, name: &str) -> Result<()> { } /// Delete a table and all dependent views, if the table exists +// Name is supplied by test code and this utility intentionally forwards SQL. +#[allow(clippy::disallowed_methods)] pub async fn drop_table(mz_client: &Client, name: &str) -> Result<()> { let q = format!("DROP TABLE IF EXISTS {} CASCADE", name); debug!("deleting table=> {}", q); @@ -115,6 +125,8 @@ pub async fn drop_table(mz_client: &Client, name: &str) -> Result<()> { } /// Delete an index +// Name is supplied by test code and this utility intentionally forwards SQL. +#[allow(clippy::disallowed_methods)] pub async fn drop_index(mz_client: &Client, name: &str) -> Result<()> { let q = format!("DROP INDEX {}", name); debug!("deleting index=> {}", q); @@ -123,6 +135,8 @@ pub async fn drop_index(mz_client: &Client, name: &str) -> Result<()> { } /// Run PostgreSQL's `execute` function +// This helper intentionally executes caller-provided test SQL. +#[allow(clippy::disallowed_methods)] pub async fn execute(mz_client: &Client, query: &str) -> Result { debug!("exec=> {}", query); Ok(mz_client.execute(query, &[]).await?)