From ad9555af7cde9c229b352ea11731db69578ebb13 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 17:34:18 +0200 Subject: [PATCH 1/9] Switch sqlx dependencies to new driver crates --- Cargo.lock | 434 +++++++++++++++++++++++++++++++++-------------------- Cargo.toml | 13 +- 2 files changed, 275 insertions(+), 172 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 98292857..3bda916f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,7 +51,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "rand 0.10.1", - "sha1 0.11.0", + "sha1", "smallvec", "tokio", "tokio-util", @@ -275,19 +275,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" -[[package]] -name = "ahash" -version = "0.8.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" -dependencies = [ - "cfg-if", - "getrandom 0.3.4", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.4" @@ -1116,7 +1103,7 @@ dependencies = [ "bitflags 1.3.2", "core-foundation 0.9.4", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -1499,27 +1486,6 @@ dependencies = [ "ctutils", ] -[[package]] -name = "dirs" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" -dependencies = [ - "dirs-sys", -] - -[[package]] -name = "dirs-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" -dependencies = [ - "libc", - "option-ext", - "redox_users", - "windows-sys 0.61.2", -] - [[package]] name = "dispatch" version = "0.2.0" @@ -1622,6 +1588,9 @@ name = "either" version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" +dependencies = [ + "serde", +] [[package]] name = "elliptic-curve" @@ -1680,6 +1649,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "etcetera" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de48cc4d1c1d97a20fd819def54b890cadde72ed3ad0c614822a0a433361be96" +dependencies = [ + "cfg-if", + "windows-sys 0.61.2", +] + [[package]] name = "event-listener" version = "5.4.1" @@ -1768,6 +1747,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -1775,7 +1763,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -1789,6 +1777,12 @@ dependencies = [ "syn", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1953,7 +1947,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.1+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] @@ -2644,9 +2638,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.38.0" +version = "0.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a76001fb4daed01e5f2b518aac0b4dc592e7c734da63dbffcf0c64fa612a8d0c" +checksum = "b1f111c8c41e7c61a49cd34e44c7619462967221a6443b0ec299e0ac30cfb9b1" dependencies = [ "cc", "pkg-config", @@ -2778,7 +2772,7 @@ checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "log", - "wasi 0.11.1+wasi-snapshot-preview1", + "wasi", "windows-sys 0.61.2", ] @@ -2788,6 +2782,23 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e94e1e6445d314f972ff7395df2de295fe51b71821694f0b0e1e79c4f12c8577" +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndk" version = "0.9.0" @@ -3032,15 +3043,6 @@ dependencies = [ "objc2-foundation", ] -[[package]] -name = "objc2-core-foundation" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" -dependencies = [ - "bitflags 2.12.0", -] - [[package]] name = "objc2-core-image" version = "0.2.2" @@ -3131,15 +3133,6 @@ dependencies = [ "objc2-foundation", ] -[[package]] -name = "objc2-system-configuration" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" -dependencies = [ - "objc2-core-foundation", -] - [[package]] name = "objc2-ui-kit" version = "0.2.2" @@ -3260,12 +3253,49 @@ dependencies = [ "url", ] +[[package]] +name = "openssl" +version = "0.10.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a45fa2aa886c42762255da344f0a0d313e254066c46aad76f300c3d3da62d967" +dependencies = [ + "bitflags 2.12.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "openssl-probe" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-sys" +version = "0.9.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "opentelemetry" version = "0.32.0" @@ -3341,12 +3371,6 @@ dependencies = [ "tokio-stream", ] -[[package]] -name = "option-ext" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" - [[package]] name = "orbclient" version = "0.3.55" @@ -3837,17 +3861,6 @@ dependencies = [ "bitflags 2.12.0", ] -[[package]] -name = "redox_users" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" -dependencies = [ - "getrandom 0.2.17", - "libredox", - "thiserror 2.0.18", -] - [[package]] name = "ref-cast" version = "1.0.25" @@ -4030,6 +4043,7 @@ dependencies = [ "aws-lc-rs", "log", "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -4337,17 +4351,6 @@ dependencies = [ "syn", ] -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures 0.2.17", - "digest 0.10.7", -] - [[package]] name = "sha1" version = "0.11.0" @@ -4449,6 +4452,9 @@ name = "smallvec" version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] [[package]] name = "smol_str" @@ -4530,7 +4536,6 @@ dependencies = [ "log", "markdown", "mime_guess", - "odbc-sys", "openidconnect", "opentelemetry", "opentelemetry-http", @@ -4547,7 +4552,9 @@ dependencies = [ "serde_json", "sha2 0.11.0", "sqlparser", - "sqlx-oldapi", + "sqlx", + "sqlx-odbc", + "sqlx-sqlserver", "tokio", "tokio-stream", "tokio-util", @@ -4580,107 +4587,220 @@ dependencies = [ ] [[package]] -name = "sqlx-core-oldapi" -version = "0.6.56" +name = "sqlx" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e33eb18d1e750df8aef99361ee02e170562dff119108680bc9035e796cd2a84" +checksum = "378620ccc25c62c89d8be1c819e76a88d59bdcc3304733330788948e619bfd71" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b44e85bf579a8eeb4ceaa77a3a523baf2bf0e9bac7e40f405d537b5d2d5ccb" dependencies = [ - "ahash", - "atoi", "base64 0.22.1", "bigdecimal", - "bitflags 2.12.0", - "byteorder", "bytes", + "cfg-if", "chrono", "crc", "crossbeam-queue", - "dirs", - "dotenvy", "either", - "encoding_rs", "event-listener", - "flume", - "futures-channel", "futures-core", - "futures-executor", "futures-intrusive", + "futures-io", "futures-util", + "hashbrown 0.16.1", "hashlink", - "hex", - "hkdf 0.13.0", - "hmac 0.13.0", "indexmap 2.14.0", - "itoa", - "libc", - "libsqlite3-sys", "log", - "md-5", "memchr", - "num-bigint", - "odbc-api", - "once_cell", "percent-encoding", - "rand 0.10.1", - "regex", - "rsa", "rustls", "serde", "serde_json", - "sha1 0.10.6", - "sha1 0.11.0", - "sha2 0.11.0", + "sha2 0.10.9", "smallvec", - "sqlx-rt-oldapi", - "stringprep", "thiserror 2.0.18", + "tokio", "tokio-stream", - "tokio-util", + "tracing", "url", "uuid", "webpki-roots 1.0.7", - "whoami", ] [[package]] -name = "sqlx-macros-oldapi" -version = "0.6.56" +name = "sqlx-macros" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd2b84f2bc39a5705ef27ec785a11c934a41bbd4a24941e257927cddc26b60bf" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adbf4ebc08c19991fa51993f471e572930c4dec146d3dc915a8e54db91d624c6" +checksum = "fb8d96de5fdc85a5c4ec813432b523ec637e80ba98f046555f75f7908ddac7c3" dependencies = [ + "cfg-if", "dotenvy", "either", "heck", - "once_cell", + "hex", "proc-macro2", "quote", + "serde", "serde_json", - "sha2 0.11.0", - "sqlx-core-oldapi", - "sqlx-rt-oldapi", + "sha2 0.10.9", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", "syn", + "tokio", "url", ] [[package]] -name = "sqlx-oldapi" -version = "0.6.56" +name = "sqlx-mysql" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90b8020fe17c5f2c245bfa2505d7ef59c5604839527c740266ad2214acebea27" +dependencies = [ + "bigdecimal", + "bitflags 2.12.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest 0.11.3", + "dotenvy", + "either", + "futures-core", + "futures-util", + "generic-array", + "log", + "percent-encoding", + "serde", + "sha1", + "sha2 0.11.0", + "sqlx-core", + "thiserror 2.0.18", + "tracing", + "uuid", +] + +[[package]] +name = "sqlx-odbc" +version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c9ba4d352504ee1a0a76eb052879d68ba63536c738da5062026ac0d3dc724c7" +checksum = "cd899bb426ff319b58192029e73f28d207263bb0be20fdf769411e49a558306e" dependencies = [ - "sqlx-core-oldapi", - "sqlx-macros-oldapi", + "futures-core", + "futures-util", + "log", + "odbc-api", + "sqlx-core", + "thiserror 2.0.18", + "url", ] [[package]] -name = "sqlx-rt-oldapi" -version = "0.6.56" +name = "sqlx-postgres" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8d629fed8792460ff39bb58cb154bfe181893ab9a51d0a3634950b35672a57" +checksum = "87a2bdd6e83f6b3ea525ca9fee568030508b58355a43d0b2c1674d5f79dcd65e" dependencies = [ - "once_cell", + "atoi", + "base64 0.22.1", + "bigdecimal", + "bitflags 2.12.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf 0.13.0", + "hmac 0.13.0", + "itoa", + "log", + "md-5", + "memchr", + "num-bigint", + "rand 0.10.1", + "serde", + "serde_json", + "sha2 0.11.0", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.18", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488e99c397a62007e4229aec669a179816339afc6d2620ca6fa420dbee2e982c" +dependencies = [ + "atoi", + "chrono", + "flume", + "form_urlencoded", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "sqlx-core", + "thiserror 2.0.18", + "tracing", + "url", + "uuid", +] + +[[package]] +name = "sqlx-sqlserver" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "205de428fb54061bc058154abbff60e42876335e6d727e8aef77a70028e6386b" +dependencies = [ + "futures-core", + "futures-util", + "log", + "native-tls", + "percent-encoding", + "sqlx-core", + "thiserror 2.0.18", "tokio", - "tokio-rustls", + "tokio-native-tls", + "url", ] [[package]] @@ -4889,6 +5009,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -4918,7 +5048,6 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", - "futures-io", "futures-sink", "pin-project-lite", "tokio", @@ -5252,15 +5381,6 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" -[[package]] -name = "wasi" -version = "0.14.7+wasi-0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" -dependencies = [ - "wasip2", -] - [[package]] name = "wasip2" version = "1.0.3+wasi-0.2.9" @@ -5279,15 +5399,6 @@ dependencies = [ "wit-bindgen 0.51.0", ] -[[package]] -name = "wasite" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" -dependencies = [ - "wasi 0.14.7+wasi-0.2.4", -] - [[package]] name = "wasm-bindgen" version = "0.2.122" @@ -5420,13 +5531,6 @@ name = "whoami" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "998767ef88740d1f5b0682a9c53c24431453923962269c2db68ee43788c5a40d" -dependencies = [ - "libc", - "libredox", - "objc2-system-configuration", - "wasite", - "web-sys", -] [[package]] name = "widestring" diff --git a/Cargo.toml b/Cargo.toml index 67740f70..182f078c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,20 +18,20 @@ panic = "abort" codegen-units = 2 [dependencies] -sqlx = { package = "sqlx-oldapi", version = "0.6.56", default-features = false, features = [ - "any", - "runtime-tokio-rustls", +sqlx = { version = "0.9.0", default-features = false, features = [ + "runtime-tokio", + "tls-rustls", "migrate", "sqlite", "postgres", "mysql", - "mssql", - "odbc", "chrono", "bigdecimal", "json", "uuid", ] } +sqlx-sqlserver = { version = "0.0.2", features = ["migrate"] } +sqlx-odbc = { version = "0.0.1", features = ["runtime-tokio"] } chrono = "0.4.23" actix-web = { version = "4", features = ["rustls-0_23", "cookies"] } percent-encoding = "2.2.0" @@ -77,7 +77,6 @@ clap = { version = "4.5.17", features = ["derive"] } tokio-util = "0.7.12" openidconnect = { version = "4.0.0", default-features = false, features = ["accept-rfc3339-timestamps"] } encoding_rs = "0.8.35" -odbc-sys = { version = "0", optional = true } regex = "1" # OpenTelemetry / tracing @@ -95,7 +94,7 @@ opentelemetry-semantic-conventions = { version = "0.32", features = ["semconv_ex [features] default = [] -odbc-static = ["odbc-sys", "odbc-sys/vendored-unix-odbc"] +odbc-static = ["sqlx-odbc/vendored-unix-odbc"] lambda-web = ["dep:lambda-web", "odbc-static"] [dev-dependencies] From 225fa127ded1a724f318a6be7888b1c1cc0a4444 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 17:50:47 +0200 Subject: [PATCH 2/9] Replace sqlx Any with native backend dispatch --- Cargo.lock | 1 + Cargo.toml | 2 + src/filesystem.rs | 309 +++++-- src/telemetry_metrics.rs | 5 +- src/webserver/database/connect.rs | 416 ++++++---- src/webserver/database/csv_import.rs | 97 ++- src/webserver/database/error_highlighting.rs | 39 +- src/webserver/database/execute_queries.rs | 410 +++++++--- src/webserver/database/migrations.rs | 30 +- src/webserver/database/mod.rs | 127 ++- src/webserver/database/sql.rs | 15 +- .../database/sql/parameter_extraction.rs | 23 +- src/webserver/database/sql_to_json.rs | 753 +++--------------- .../function_definition_macro.rs | 2 +- 14 files changed, 1083 insertions(+), 1146 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3bda916f..72aca6f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4563,6 +4563,7 @@ dependencies = [ "tracing-log", "tracing-opentelemetry", "tracing-subscriber", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 182f078c..01246e5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ sqlx = { version = "0.9.0", default-features = false, features = [ "tls-rustls", "migrate", "sqlite", + "sqlite-load-extension", "postgres", "mysql", "chrono", @@ -35,6 +36,7 @@ sqlx-odbc = { version = "0.0.1", features = ["runtime-tokio"] } chrono = "0.4.23" actix-web = { version = "4", features = ["rustls-0_23", "cookies"] } percent-encoding = "2.2.0" +url = "2" handlebars = "6.2.0" log = "0.4.17" mime_guess = "2.0.4" diff --git a/src/filesystem.rs b/src/filesystem.rs index fd33b480..e16b4b63 100644 --- a/src/filesystem.rs +++ b/src/filesystem.rs @@ -1,12 +1,9 @@ use crate::webserver::ErrorWithStatus; -use crate::webserver::database::SupportedDatabase; +use crate::webserver::database::{DatabasePool, SupportedDatabase}; use crate::webserver::{Database, StatusCodeResultExt, make_placeholder}; use crate::{AppState, TEMPLATES_DIR}; use anyhow::Context; use chrono::{DateTime, Utc}; -use sqlx::any::{AnyStatement, AnyTypeInfo}; -use sqlx::postgres::types::PgTimeTz; -use sqlx::{Executor, Postgres, Statement, Type}; use std::fmt::Write; use std::io::ErrorKind; use std::path::{Component, Path, PathBuf}; @@ -241,9 +238,9 @@ async fn file_modified_since_local(path: &Path, since: DateTime) -> tokio:: } pub struct DbFsQueries { - was_modified: AnyStatement<'static>, - read_file: AnyStatement<'static>, - exists: AnyStatement<'static>, + was_modified: String, + read_file: String, + exists: String, } impl DbFsQueries { @@ -276,45 +273,37 @@ impl DbFsQueries { } async fn check_table_available(db: &Database) -> anyhow::Result<()> { - db.connection - .execute("SELECT 1 FROM sqlpage_files WHERE 1 = 0") + execute_pool(&db.connection, "SELECT 1 FROM sqlpage_files WHERE 1 = 0") .await - .map(|_| ()) .context("Unable to access sqlpage_files")?; Ok(()) } - async fn make_was_modified_query(db: &Database) -> anyhow::Result> { + async fn make_was_modified_query(db: &Database) -> anyhow::Result { let was_modified_query = format!( "SELECT 1 from sqlpage_files WHERE last_modified >= {} AND path = {}", make_placeholder(db.info.kind, 1), make_placeholder(db.info.kind, 2) ); - let param_types: &[AnyTypeInfo; 2] = &[ - PgTimeTz::type_info().into(), - >::type_info().into(), - ]; log::debug!("Preparing the database filesystem was_modified_query: {was_modified_query}"); - db.prepare_with(&was_modified_query, param_types).await + Ok(was_modified_query) } - async fn make_read_file_query(db: &Database) -> anyhow::Result> { + async fn make_read_file_query(db: &Database) -> anyhow::Result { let read_file_query = format!( "SELECT contents from sqlpage_files WHERE path = {}", make_placeholder(db.info.kind, 1), ); - let param_types: &[AnyTypeInfo; 1] = &[>::type_info().into()]; log::debug!("Preparing the database filesystem read_file_query: {read_file_query}"); - db.prepare_with(&read_file_query, param_types).await + Ok(read_file_query) } - async fn make_exists_query(db: &Database) -> anyhow::Result> { + async fn make_exists_query(db: &Database) -> anyhow::Result { let exists_query = format!( "SELECT 1 from sqlpage_files WHERE path = {}", make_placeholder(db.info.kind, 1), ); - let param_types: &[AnyTypeInfo; 1] = &[>::type_info().into()]; - db.prepare_with(&exists_query, param_types).await + Ok(exists_query) } async fn file_modified_since_in_db( @@ -323,47 +312,40 @@ impl DbFsQueries { path: &Path, since: DateTime, ) -> anyhow::Result { - let query = self - .was_modified - .query_as::<(i32,)>() - .bind(since) - .bind(path.display().to_string()); + let path = path.display().to_string(); log::trace!( "Checking if file {} was modified since {} by executing query: \n\ {}\n\ with parameters: {:?}", - path.display(), + path, since, - self.was_modified.sql(), - (since, path) + self.was_modified, + (since, &path) ); - let was_modified_i32 = query - .fetch_optional(&app_state.db.connection) - .await - .with_context(|| { - format!( - "Unable to check when {} was last modified in the database", - path.display() - ) - })?; + let was_modified_i32 = fetch_optional_i32_since_path( + &app_state.db.connection, + &self.was_modified, + since, + &path, + ) + .await + .with_context(|| format!("Unable to check when {path} was last modified in the database"))?; log::trace!( "DB File {} was modified result: {was_modified_i32:?}", - path.display() + path ); - Ok(was_modified_i32 == Some((1,))) + Ok(was_modified_i32 == Some(1)) } async fn read_file(&self, app_state: &AppState, path: &Path) -> anyhow::Result> { log::debug!("Reading file {} from the database", path.display()); - self.read_file - .query_as::<(Vec,)>() - .bind(path.display().to_string()) - .fetch_optional(&app_state.db.connection) + let path = path.display().to_string(); + fetch_optional_bytes_path(&app_state.db.connection, &self.read_file, &path) .await .map_err(anyhow::Error::from) - .and_then(|modified| { - if let Some((modified,)) = modified { - Ok(modified) + .and_then(|contents| { + if let Some(contents) = contents { + Ok(contents) } else { Err(ErrorWithStatus { status: actix_web::http::StatusCode::NOT_FOUND, @@ -371,33 +353,194 @@ impl DbFsQueries { .into()) } }) - .with_context(|| format!("Unable to read {} from the database", path.display())) + .with_context(|| format!("Unable to read {path} from the database")) } async fn file_exists(&self, app_state: &AppState, path: &Path) -> anyhow::Result { - let query = self - .exists - .query_as::<(i32,)>() - .bind(path.display().to_string()); + let path = path.display().to_string(); log::trace!( "Checking if file {} exists by executing query: \n\ {}\n\ with parameters: {:?}", - path.display(), - self.exists.sql(), - (path,) + path, + self.exists, + (&path,) ); - let result = query.fetch_optional(&app_state.db.connection).await; + let result = fetch_optional_i32_path(&app_state.db.connection, &self.exists, &path).await; log::debug!("DB File exists result: {result:?}"); result.map(|result| result.is_some()).with_context(|| { format!( "Unable to check if {} exists in the database", - path.display() + path ) }) } } +async fn execute_pool(pool: &DatabasePool, sql: &str) -> sqlx::Result<()> { + match pool { + DatabasePool::Sqlite(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Postgres(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::MySql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Mssql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Odbc(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + } +} + +async fn fetch_optional_i32_since_path( + pool: &DatabasePool, + sql: &str, + since: DateTime, + path: &str, +) -> sqlx::Result> { + let since = since.to_rfc3339(); + match pool { + DatabasePool::Sqlite(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(&since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Postgres(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(&since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::MySql(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(&since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Mssql(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(&since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Odbc(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(&since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + } +} + +async fn fetch_optional_i32_path( + pool: &DatabasePool, + sql: &str, + path: &str, +) -> sqlx::Result> { + match pool { + DatabasePool::Sqlite(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Postgres(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::MySql(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Mssql(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Odbc(pool) => { + sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + } +} + +async fn fetch_optional_bytes_path( + pool: &DatabasePool, + sql: &str, + path: &str, +) -> sqlx::Result>> { + match pool { + DatabasePool::Sqlite(pool) => { + sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Postgres(pool) => { + sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::MySql(pool) => { + sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Mssql(pool) => { + sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + DatabasePool::Odbc(pool) => { + sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)), + } +} + #[actix_web::test] async fn test_sql_file_read_utf8() -> anyhow::Result<()> { use crate::app_config; @@ -417,9 +560,9 @@ async fn test_sql_file_read_utf8() -> anyhow::Result<()> { let create_table_sql = DbFsQueries::get_create_table_sql(state.db.info.database_type); let db = &state.db; let conn = &db.connection; - conn.execute("DROP TABLE IF EXISTS sqlpage_files").await?; + execute_pool(conn, "DROP TABLE IF EXISTS sqlpage_files").await?; log::debug!("Creating table sqlpage_files: {create_table_sql}"); - conn.execute(create_table_sql).await?; + execute_pool(conn, create_table_sql).await?; let dbms = db.info.kind; let insert_sql = format!( @@ -427,10 +570,7 @@ async fn test_sql_file_read_utf8() -> anyhow::Result<()> { make_placeholder(dbms, 1), make_placeholder(dbms, 2) ); - sqlx::query(&insert_sql) - .bind("unit test file.txt") - .bind("Héllö world! 😀".as_bytes()) - .execute(conn) + insert_test_file(conn, &insert_sql, "unit test file.txt", "Héllö world! 😀".as_bytes()) .await?; let fs = FileSystem::init("/", db).await; @@ -463,3 +603,44 @@ async fn test_sql_file_read_utf8() -> anyhow::Result<()> { Ok(()) } + +#[cfg(test)] +async fn insert_test_file( + pool: &DatabasePool, + sql: &str, + path: &str, + contents: &[u8], +) -> sqlx::Result<()> { + match pool { + DatabasePool::Sqlite(pool) => sqlx::query::(sql) + .bind(path) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Postgres(pool) => sqlx::query::(sql) + .bind(path) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::MySql(pool) => sqlx::query::(sql) + .bind(path) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Mssql(pool) => sqlx::query::(sql) + .bind(path) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Odbc(pool) => sqlx::query::(sql) + .bind(path) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + } +} diff --git a/src/telemetry_metrics.rs b/src/telemetry_metrics.rs index 89974438..b693bec5 100644 --- a/src/telemetry_metrics.rs +++ b/src/telemetry_metrics.rs @@ -2,7 +2,8 @@ use opentelemetry::global; use opentelemetry::metrics::{Histogram, ObservableGauge}; use opentelemetry_semantic_conventions::attribute as otel; use opentelemetry_semantic_conventions::metric as otel_metric; -use sqlx::AnyPool; + +use crate::webserver::database::DatabasePool; pub struct TelemetryMetrics { pub http_request_duration: Histogram, @@ -41,7 +42,7 @@ impl Default for TelemetryMetrics { impl TelemetryMetrics { #[must_use] - pub fn new(pool: &AnyPool, db_system_name: &'static str) -> Self { + pub fn new(pool: &DatabasePool, db_system_name: &'static str) -> Self { let meter = global::meter("sqlpage"); let http_request_duration = meter .f64_histogram(otel_metric::HTTP_SERVER_REQUEST_DURATION) diff --git a/src/webserver/database/connect.rs b/src/webserver/database/connect.rs index 3433e81e..801d63c2 100644 --- a/src/webserver/database/connect.rs +++ b/src/webserver/database/connect.rs @@ -1,71 +1,118 @@ -use std::{mem::take, time::Duration}; +use std::time::Duration; -use super::Database; +use super::{Database, DatabasePool, DbInfo, DbKind, SupportedDatabase}; use crate::{ ON_CONNECT_FILE, ON_RESET_FILE, app_config::AppConfig, - webserver::database::{DbInfo, SupportedDatabase}, }; use anyhow::Context; use futures_util::future::BoxFuture; -use sqlx::odbc::OdbcConnectOptions; use sqlx::{ - ConnectOptions, Connection, Executor, - any::{Any, AnyConnectOptions, AnyConnection, AnyKind}, + ColumnIndex, ConnectOptions, Connection, Database as SqlxDatabase, Decode, Executor, Row, Type, + mysql::MySqlConnectOptions, pool::PoolOptions, - sqlite::{Function, SqliteConnectOptions, SqliteFunctionCtx}, + postgres::PgConnectOptions, + sqlite::SqliteConnectOptions, }; +use sqlx_odbc::{OdbcConnectOptions, OdbcConnection}; +use sqlx_sqlserver::MssqlConnectOptions; +use url::Url; impl Database { pub async fn init(config: &AppConfig) -> anyhow::Result { - let database_url = &config.database_url; - let mut connect_options: AnyConnectOptions = database_url - .parse() - .with_context(|| format!("\"{database_url}\" is not a valid database URL. Please change the \"database_url\" option in the configuration file."))?; - if let Some(password) = &config.database_password { - set_database_password(&mut connect_options, password); - } - connect_options.log_statements(log::LevelFilter::Trace); - connect_options.log_slow_statements( - log::LevelFilter::Warn, - std::time::Duration::from_millis(250), - ); - log::debug!( - "Connecting to a {:?} database on {}", - connect_options.kind(), - database_url - ); - set_custom_connect_options(&mut connect_options, config); - log::debug!("Connecting to database: {database_url}"); - let mut retries = config.database_connection_retries; + let database_url = database_url_with_password(config)?; + let db_kind = DbKind::from_database_url(&database_url); + log::debug!("Connecting to a {db_kind:?} database on {}", config.database_url); - let mut conn: AnyConnection = loop { - match AnyConnection::connect_with(&connect_options).await { - Ok(c) => break c, - Err(e) => { - if retries == 0 { - return Err(anyhow::Error::new(e) - .context(format!("Unable to open connection to {database_url}"))); - } - log::warn!("Failed to connect to the database: {e:#}. Retrying in 5 seconds."); - retries -= 1; - tokio::time::sleep(Duration::from_secs(5)).await; + let connection = match db_kind { + DbKind::Sqlite => { + let mut options = database_url.parse::()?; + options = set_common_connect_options(options); + options = set_custom_connect_options_sqlite(options, config); + let pool = Self::create_pool_options::(config, db_kind) + .connect_with(options) + .await + .with_context(|| { + format!("Unable to open connection pool to {}", config.database_url) + })?; + DatabasePool::Sqlite(pool) + } + DbKind::Postgres => { + let mut options = database_url.parse::()?; + if let Some(password) = &config.database_password { + options = options.password(password); + } + options = set_common_connect_options(options); + let pool = Self::create_pool_options::(config, db_kind) + .connect_with(options) + .await + .with_context(|| { + format!("Unable to open connection pool to {}", config.database_url) + })?; + DatabasePool::Postgres(pool) + } + DbKind::MySql => { + let mut options = database_url.parse::()?; + if let Some(password) = &config.database_password { + options = options.password(password); + } + options = set_common_connect_options(options); + let pool = Self::create_pool_options::(config, db_kind) + .connect_with(options) + .await + .with_context(|| { + format!("Unable to open connection pool to {}", config.database_url) + })?; + DatabasePool::MySql(pool) + } + DbKind::Mssql => { + let options = set_common_connect_options( + database_url + .parse::() + .with_context(|| format!("Unable to parse {}", config.database_url))?, + ); + let pool = Self::create_pool_options::(config, db_kind) + .connect_with(options) + .await + .with_context(|| { + format!("Unable to open connection pool to {}", config.database_url) + })?; + DatabasePool::Mssql(pool) + } + DbKind::Odbc => { + if config.database_password.is_some() { + log::warn!( + "Setting a password for an ODBC connection is not supported via separate config; include credentials in the DSN or connection string" + ); } + let mut options = database_url.parse::()?; + set_custom_connect_options_odbc(&mut options, config); + let dbms_name = detect_odbc_dbms_name(&options, config).await?; + let database_type = SupportedDatabase::from_dbms_name(&dbms_name); + let options = set_common_connect_options(options); + let pool = Self::create_pool_options::(config, db_kind) + .connect_with(options) + .await + .with_context(|| { + format!("Unable to open connection pool to {}", config.database_url) + })?; + log::debug!("Initialized {dbms_name:?} database pool: {pool:#?}"); + return Ok(Database { + connection: DatabasePool::Odbc(pool), + info: DbInfo { + dbms_name, + database_type, + kind: db_kind, + }, + }); } }; - let dbms_name: String = conn.dbms_name().await?; - let database_type = SupportedDatabase::from_dbms_name(&dbms_name); - drop(conn); - - let db_kind = connect_options.kind(); - let pool = Self::create_pool_options(config, db_kind) - .connect_with(connect_options) - .await - .with_context(|| format!("Unable to open connection pool to {database_url}"))?; - log::debug!("Initialized {dbms_name:?} database pool: {pool:#?}"); + let dbms_name = db_kind.display_name().to_owned(); + let database_type = SupportedDatabase::from(db_kind); + log::debug!("Initialized {dbms_name:?} database pool: {connection:#?}"); Ok(Database { - connection: pool, + connection, info: DbInfo { dbms_name, database_type, @@ -74,64 +121,63 @@ impl Database { }) } - fn create_pool_options(config: &AppConfig, kind: AnyKind) -> PoolOptions { - let mut pool_options = PoolOptions::new() - .max_connections(if let Some(max) = config.max_database_pool_connections { - max - } else { - // Different databases have a different number of max concurrent connections allowed by default - match kind { - AnyKind::Postgres | AnyKind::Odbc => 50, // Default to PostgreSQL-like limits for Generic - AnyKind::MySql => 75, - AnyKind::Sqlite => { - if config.database_url.contains(":memory:") { - 128 - } else { - 16 - } - } - AnyKind::Mssql => 100, - } - }) + fn create_pool_options(config: &AppConfig, kind: DbKind) -> PoolOptions + where + DB: SqlxDatabase, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'r> bool: Decode<'r, DB> + Type, + usize: ColumnIndex, + { + let max_connections = config + .max_database_pool_connections + .unwrap_or_else(|| default_max_connections(config, kind)); + let pool_options = PoolOptions::new() + .max_connections(max_connections) .idle_timeout(config.database_connection_idle_timeout) .max_lifetime(config.database_connection_max_lifetime) .acquire_timeout(Duration::from_secs_f64( config.database_connection_acquire_timeout_seconds, )); - pool_options = add_on_return_to_pool(config, pool_options); - pool_options = add_on_connection_handler(config, pool_options); - pool_options + let pool_options = add_on_return_to_pool(config, pool_options); + add_on_connection_handler(config, pool_options) } } -fn add_on_return_to_pool(config: &AppConfig, pool_options: PoolOptions) -> PoolOptions { - let on_disconnect_file = config.configuration_directory.join(ON_RESET_FILE); - let sql = if on_disconnect_file.exists() { - log::info!( - "Creating a custom SQL connection cleanup handler from {}", - on_disconnect_file.display() - ); - match std::fs::read_to_string(&on_disconnect_file) { - Ok(sql) => { - log::trace!("The custom SQL connection cleanup handler is:\n{sql}"); - Some(std::sync::Arc::new(sql)) - } - Err(e) => { - log::error!( - "Unable to read the file {}: {}", - on_disconnect_file.display(), - e - ); - None +fn default_max_connections(config: &AppConfig, kind: DbKind) -> u32 { + match kind { + DbKind::Postgres | DbKind::Odbc => 50, + DbKind::MySql => 75, + DbKind::Sqlite => { + if config.database_url.contains(":memory:") { + 128 + } else { + 16 } } - } else { - log::debug!( - "Not creating a custom SQL connection cleanup handler because {} does not exist", - on_disconnect_file.display() - ); - None - }; + DbKind::Mssql => 100, + } +} + +fn set_common_connect_options(options: T) -> T +where + T: ConnectOptions, +{ + options + .log_statements(log::LevelFilter::Trace) + .log_slow_statements(log::LevelFilter::Warn, Duration::from_millis(250)) +} + +fn add_on_return_to_pool( + config: &AppConfig, + pool_options: PoolOptions, +) -> PoolOptions +where + DB: SqlxDatabase, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'r> bool: Decode<'r, DB> + Type, + usize: ColumnIndex, +{ + let sql = read_optional_handler_sql(config, ON_RESET_FILE, "connection cleanup"); pool_options.after_release(move |conn, meta| { let sql = sql.clone(); @@ -145,15 +191,20 @@ fn add_on_return_to_pool(config: &AppConfig, pool_options: PoolOptions) -> }) } -fn on_return_to_pool( - conn: &mut sqlx::AnyConnection, +fn on_return_to_pool( + conn: &mut DB::Connection, meta: sqlx::pool::PoolConnectionMetadata, sql: std::sync::Arc, -) -> BoxFuture<'_, Result> { - use sqlx::Row; +) -> BoxFuture<'_, Result> +where + DB: SqlxDatabase, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'r> bool: Decode<'r, DB> + Type, + usize: ColumnIndex, +{ Box::pin(async move { log::trace!("Running the custom SQL connection cleanup handler. {meta:?}"); - let query_result = conn.fetch_optional(sql.as_str()).await?; + let query_result = conn.fetch_optional(sqlx::AssertSqlSafe(sql.as_str())).await?; if let Some(query_result) = query_result { let is_healthy = query_result.try_get::(0); log::debug!("Is the connection healthy? {is_healthy:?}"); @@ -165,38 +216,20 @@ fn on_return_to_pool( }) } -fn add_on_connection_handler( +fn add_on_connection_handler( config: &AppConfig, - pool_options: PoolOptions, -) -> PoolOptions { - let on_connect_file = config.configuration_directory.join(ON_CONNECT_FILE); - let on_connect_file_display = on_connect_file.display().to_string(); - let sql = if on_connect_file.exists() { - log::info!( - "Creating a custom SQL database connection handler from {}", - on_connect_file.display() - ); - match std::fs::read_to_string(&on_connect_file) { - Ok(sql) => { - log::trace!("The custom SQL database connection handler is:\n{sql}"); - Some(std::sync::Arc::new(sql)) - } - Err(e) => { - log::error!( - "Unable to read the file {}: {}", - on_connect_file.display(), - e - ); - None - } - } - } else { - log::debug!( - "Not creating a custom SQL database connection handler because {} does not exist", - on_connect_file.display() - ); - None - }; + pool_options: PoolOptions, +) -> PoolOptions +where + DB: SqlxDatabase, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + let sql = read_optional_handler_sql(config, ON_CONNECT_FILE, "database connection"); + let on_connect_file_display = config + .configuration_directory + .join(ON_CONNECT_FILE) + .display() + .to_string(); pool_options.after_connect(move |conn, _| { let sql = sql.clone(); @@ -204,72 +237,101 @@ fn add_on_connection_handler( Box::pin(async move { if let Some(sql) = sql { log::debug!("Running {on_connect_file_display} on new connection"); - let r = conn.execute(sql.as_str()).await?; - log::debug!("Finished running connection handler on new connection: {r:?}"); + conn.execute(sqlx::AssertSqlSafe(sql.as_str())).await?; + log::debug!("Finished running connection handler on new connection"); } Ok(()) }) }) } -fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfig) { - if let Some(sqlite_options) = options.as_sqlite_mut() { - set_custom_connect_options_sqlite(sqlite_options, config); - } - - if let Some(odbc_options) = options.as_odbc_mut() { - set_custom_connect_options_odbc(odbc_options, config); +fn read_optional_handler_sql( + config: &AppConfig, + file_name: &str, + handler_name: &str, +) -> Option> { + let handler_file = config.configuration_directory.join(file_name); + if handler_file.exists() { + log::info!( + "Creating a custom SQL {handler_name} handler from {}", + handler_file.display() + ); + match std::fs::read_to_string(&handler_file) { + Ok(sql) => { + log::trace!("The custom SQL {handler_name} handler is:\n{sql}"); + Some(std::sync::Arc::new(sql)) + } + Err(e) => { + log::error!("Unable to read the file {}: {}", handler_file.display(), e); + None + } + } + } else { + log::debug!( + "Not creating a custom SQL {handler_name} handler because {} does not exist", + handler_file.display() + ); + None } } fn set_custom_connect_options_sqlite( - sqlite_options: &mut SqliteConnectOptions, + mut sqlite_options: SqliteConnectOptions, config: &AppConfig, -) { +) -> SqliteConnectOptions { for extension_name in &config.sqlite_extensions { log::info!("Loading SQLite extension: {extension_name}"); - *sqlite_options = std::mem::take(sqlite_options).extension(extension_name.clone()); + // SAFETY: SQLPage has always treated `sqlite_extensions` as an explicit administrator + // opt-in to load trusted native extensions from the filesystem. + sqlite_options = unsafe { sqlite_options.extension(extension_name.clone()) }; } - *sqlite_options = std::mem::take(sqlite_options) - .collation("NOCASE", |a, b| a.to_lowercase().cmp(&b.to_lowercase())) - .function(make_sqlite_fun("upper", str::to_uppercase)) - .function(make_sqlite_fun("lower", str::to_lowercase)); -} - -fn make_sqlite_fun(name: &str, f: fn(&str) -> String) -> Function { - Function::new(name, move |ctx: &SqliteFunctionCtx| { - let arg = ctx.try_get_arg::>(0); - match arg { - Ok(Some(s)) => ctx.set_result(f(s)), - Ok(None) => ctx.set_result(None::), - Err(e) => ctx.set_error(&e.to_string()), - } - }) + sqlite_options.collation("NOCASE", |a, b| a.to_lowercase().cmp(&b.to_lowercase())) } fn set_custom_connect_options_odbc(odbc_options: &mut OdbcConnectOptions, config: &AppConfig) { - // Allow fetching very large text fields when using ODBC by removing the max column size limit let batch_size = config.max_pending_rows.clamp(1, 1024); odbc_options.batch_size(batch_size); log::trace!("ODBC batch size set to {batch_size}"); - // Disables ODBC batching, but avoids truncation of large text fields odbc_options.max_column_size(None); } -fn set_database_password(options: &mut AnyConnectOptions, password: &str) { - if let Some(opts) = options.as_postgres_mut() { - *opts = take(opts).password(password); - } else if let Some(opts) = options.as_mysql_mut() { - *opts = take(opts).password(password); - } else if let Some(opts) = options.as_mssql_mut() { - *opts = take(opts).password(password); - } else if let Some(_opts) = options.as_odbc_mut() { - log::warn!( - "Setting a password for an ODBC connection is not supported via separate config; include credentials in the DSN or connection string" - ); - } else if let Some(_opts) = options.as_sqlite_mut() { - log::warn!("Setting a password for a SQLite database is not supported"); - } else { - unreachable!("Unsupported database type"); +async fn detect_odbc_dbms_name( + options: &OdbcConnectOptions, + config: &AppConfig, +) -> anyhow::Result { + let mut retries = config.database_connection_retries; + let conn: OdbcConnection = loop { + match options.connect().await { + Ok(c) => break c, + Err(e) => { + if retries == 0 { + return Err(anyhow::Error::new(e).context(format!( + "Unable to open connection to {}", + config.database_url + ))); + } + log::warn!("Failed to connect to the database: {e:#}. Retrying in 5 seconds."); + retries -= 1; + tokio::time::sleep(Duration::from_secs(5)).await; + } + } + }; + let dbms_name = conn.dbms_name()?; + conn.close().await?; + Ok(dbms_name) +} + +fn database_url_with_password(config: &AppConfig) -> anyhow::Result { + let Some(password) = &config.database_password else { + return Ok(config.database_url.clone()); + }; + let kind = DbKind::from_database_url(&config.database_url); + if !matches!(kind, DbKind::Mssql) { + return Ok(config.database_url.clone()); } + let mut url = Url::parse(&config.database_url) + .with_context(|| format!("Unable to parse {}", config.database_url))?; + url.set_password(Some(password)) + .map_err(|_| anyhow::anyhow!("Unable to set password in database URL"))?; + Ok(url.to_string()) } diff --git a/src/webserver/database/csv_import.rs b/src/webserver/database/csv_import.rs index 94ac13df..71ecae25 100644 --- a/src/webserver/database/csv_import.rs +++ b/src/webserver/database/csv_import.rs @@ -5,15 +5,12 @@ use futures_util::StreamExt; use sqlparser::ast::{ CopyLegacyCsvOption, CopyLegacyOption, CopyOption, CopySource, CopyTarget, Statement, }; -use sqlx::{ - AnyConnection, Arguments, Executor, PgConnection, - any::{AnyArguments, AnyConnectionKind, AnyKind}, -}; +use sqlx::{Database as SqlxDatabase, Encode, Executor, IntoArguments, PgConnection, Type}; use tokio::io::AsyncRead; use crate::webserver::http_request_info::RequestInfo; -use super::make_placeholder; +use super::{DbKind, execute_queries::DbConnection, make_placeholder}; #[derive(Debug, PartialEq)] pub(super) struct CsvImport { @@ -142,7 +139,7 @@ pub(super) fn extract_csv_copy_statement(stmt: &mut Statement) -> Option anyhow::Result<()> { @@ -167,13 +164,28 @@ pub(super) async fn run_csv_import( ) })?; let buffered = tokio::io::BufReader::new(file); - // private_get_mut is not supposed to be used outside of sqlx, but it is the only way to - // access the underlying connection - match db.private_get_mut() { - AnyConnectionKind::Postgres(pg_connection) => { + match db { + DbConnection::Postgres(pg_connection) => { run_csv_import_postgres(pg_connection, csv_import, buffered).await } - _ => run_csv_import_insert(db, csv_import, buffered).await, + DbConnection::Sqlite(conn) => { + run_csv_import_insert::(conn, DbKind::Sqlite, csv_import, buffered).await + } + DbConnection::MySql(conn) => { + run_csv_import_insert::(conn, DbKind::MySql, csv_import, buffered).await + } + DbConnection::Mssql(conn) => { + run_csv_import_insert::( + conn, + DbKind::Mssql, + csv_import, + buffered, + ) + .await + } + DbConnection::Odbc(conn) => { + run_csv_import_insert::(conn, DbKind::Odbc, csv_import, buffered).await + } } .with_context(|| { let table_name = &csv_import.table_name; @@ -214,12 +226,19 @@ async fn run_csv_import_postgres( } } -async fn run_csv_import_insert( - db: &mut AnyConnection, +async fn run_csv_import_insert( + db: &mut sqlx::pool::PoolConnection, + db_kind: DbKind, csv_import: &CsvImport, file: impl AsyncRead + Unpin + Send, -) -> anyhow::Result<()> { - let insert_stmt = create_insert_stmt(db.kind(), csv_import); +) -> anyhow::Result<()> +where + DB: SqlxDatabase, + DB::Arguments: IntoArguments, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'q> Option: Encode<'q, DB> + Type, +{ + let insert_stmt = create_insert_stmt(db_kind, csv_import); log::debug!("CSV data insert statement: {insert_stmt}"); let mut reader = make_csv_reader(csv_import, file); let col_idxs = compute_column_indices(&mut reader, csv_import).await?; @@ -256,7 +275,7 @@ async fn compute_column_indices( Ok(col_idxs) } -fn create_insert_stmt(db_kind: AnyKind, csv_import: &CsvImport) -> String { +fn create_insert_stmt(db_kind: DbKind, csv_import: &CsvImport) -> String { let columns = csv_import.columns.join(", "); let placeholders = csv_import .columns @@ -274,22 +293,32 @@ fn create_insert_stmt(db_kind: AnyKind, csv_import: &CsvImport) -> String { format!("INSERT INTO {table_name} ({columns}) VALUES ({placeholders})") } -async fn process_csv_record( +async fn process_csv_record( record: csv_async::StringRecord, - db: &mut AnyConnection, + db: &mut sqlx::pool::PoolConnection, insert_stmt: &str, csv_import: &CsvImport, column_indices: &[usize], -) -> anyhow::Result<()> { - let mut arguments = AnyArguments::default(); +) -> anyhow::Result<()> +where + DB: SqlxDatabase, + DB::Arguments: IntoArguments, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'q> Option: Encode<'q, DB> + Type, +{ + let mut query = sqlx::query::(sqlx::AssertSqlSafe(insert_stmt)); let null_str = csv_import.null_str.as_deref().unwrap_or_default(); for (&i, column) in column_indices.iter().zip(csv_import.columns.iter()) { let value = record.get(i).unwrap_or_default(); - let value = if value == null_str { None } else { Some(value) }; + let value = if value == null_str { + None + } else { + Some(value.to_owned()) + }; log::trace!("CSV value: {column}={value:?}"); - arguments.add(value); + query = query.bind(value); } - db.execute((insert_stmt, Some(arguments))).await?; + query.execute(&mut **db).await?; Ok(()) } @@ -328,7 +357,7 @@ fn test_make_statement() { escape: None, uploaded_file: "my_file.csv".into(), }; - let insert_stmt = create_insert_stmt(AnyKind::Postgres, &csv_import); + let insert_stmt = create_insert_stmt(DbKind::Postgres, &csv_import); assert_eq!( insert_stmt, "INSERT INTO my_table (col1, col2) VALUES ($1, $2)" @@ -337,7 +366,7 @@ fn test_make_statement() { #[actix_web::test] async fn test_end_to_end() { - use sqlx::ConnectOptions; + use sqlx::Connection; let mut copy_stmt = sqlparser::parser::Parser::parse_sql( &sqlparser::dialect::GenericDialect {}, @@ -362,22 +391,18 @@ async fn test_end_to_end() { uploaded_file: "my_file.csv".into(), } ); - let mut conn = "sqlite::memory:" - .parse::() - .unwrap() - .connect() - .await - .unwrap(); - conn.execute("CREATE TABLE my_table (col1 TEXT, col2 TEXT)") - .await - .unwrap(); let csv = "col2;col1\na;b\nc;d"; // order is different from the table let file = csv.as_bytes(); - run_csv_import_insert(&mut conn, &csv_import, file) + let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); + pool.execute("CREATE TABLE my_table (col1 TEXT, col2 TEXT)") + .await + .unwrap(); + let mut pooled = pool.acquire().await.unwrap(); + run_csv_import_insert::(&mut pooled, DbKind::Sqlite, &csv_import, file) .await .unwrap(); let rows: Vec<(String, String)> = sqlx::query_as("SELECT * FROM my_table") - .fetch_all(&mut conn) + .fetch_all(&pool) .await .unwrap(); assert_eq!( diff --git a/src/webserver/database/error_highlighting.rs b/src/webserver/database/error_highlighting.rs index c2b7595f..15c8da51 100644 --- a/src/webserver/database/error_highlighting.rs +++ b/src/webserver/database/error_highlighting.rs @@ -43,27 +43,9 @@ impl std::fmt::Display for NiceDatabaseError { self.source_file.display(), self.db_err )?; - if let sqlx::error::Error::Database(db_err) = &self.db_err { - let Some(mut offset) = db_err.offset() else { - write!(f, "{}", self.query)?; - self.show_position_info(f)?; - return Ok(()); - }; - for line in self.query.lines() { - if offset > line.len() { - offset -= line.len() + 1; - } else { - highlight_line_offset(f, line, offset); - self.show_position_info(f)?; - break; - } - } - Ok(()) - } else { - write!(f, "{}", self.query)?; - self.show_position_info(f)?; - Ok(()) - } + write!(f, "{}", self.query)?; + self.show_position_info(f)?; + Ok(()) } } @@ -95,21 +77,6 @@ impl std::error::Error for NicePositionedError { } } -/// Display a database error without any position information -#[must_use] -pub fn display_db_error( - source_file: &Path, - query: &str, - db_err: sqlx::error::Error, -) -> anyhow::Error { - anyhow::Error::new(NiceDatabaseError { - source_file: source_file.to_path_buf(), - db_err, - query: query.to_string(), - query_position: None, - }) -} - /// Display a database error with a highlighted line and character offset. #[must_use] pub fn display_stmt_db_error( diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 32ff97ab..ebbd3be8 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -21,14 +21,22 @@ use crate::webserver::request_variables::SetVariablesMap; use crate::webserver::single_or_vec::SingleOrVec; use super::syntax_tree::{StmtParam, extract_req_param}; -use super::{Database, DbItem, error_highlighting::display_db_error}; -use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; +use super::{Database, DatabasePool, DbItem, DbKind}; use sqlx::pool::PoolConnection; use sqlx::{ - Any, AnyConnection, Arguments, Column, Either, Executor, Row as _, Statement, ValueRef, + Column, ColumnIndex, Database as SqlxDatabase, Either, Encode, Executor, IntoArguments, + Row as _, Type, ValueRef, }; -pub type DbConn = Option>; +pub type DbConn = Option; + +pub enum DbConnection { + Sqlite(PoolConnection), + Postgres(PoolConnection), + MySql(PoolConnection), + Mssql(PoolConnection), + Odbc(PoolConnection), +} fn source_line_number(line: usize) -> i64 { i64::try_from(line).unwrap_or(i64::MAX) @@ -133,20 +141,6 @@ fn create_db_query_span( (span, operation_name) } -impl Database { - pub(crate) async fn prepare_with( - &self, - query: &str, - param_types: &[AnyTypeInfo], - ) -> anyhow::Result> { - self.connection - .prepare_with(query, param_types) - .await - .map(|s| s.to_owned()) - .map_err(|e| display_db_error(Path::new("autogenerated sqlpage query"), query, e)) - } -} - pub fn stream_query_results_with_conn<'a>( sql_file: &'a ParsedSqlFile, request: &'a ExecutionContext, @@ -182,42 +176,18 @@ pub fn stream_query_results_with_conn<'a>( &request.app_state.telemetry_metrics, ); record_query_params(&query_metrics.span, &query.param_values); - let mut stream = connection.fetch_many(query); - let mut error = None; - let mut returned_rows: i64 = 0; - loop { - let start_next = std::time::Instant::now(); - let next_elem = stream.next().instrument(query_span.clone()).await; - query_metrics.add_duration(start_next.elapsed()); - let Some(elem) = next_elem else { break; }; - - let mut query_result = parse_single_sql_result(source_file, stmt, elem); - if let DbItem::Error(e) = query_result { - error = Some(e); - break; - } - if matches!(query_result, DbItem::Row(_)) { - returned_rows += 1; - } - apply_json_columns(&mut query_result, &stmt.json_columns); - if let Err(err) = apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) - .instrument(query_span.clone()) - .await - { - error = Some(err); - break; - } - for db_item in parse_dynamic_rows(query_result) { - yield db_item; - } - } - drop(stream); - if let Some(error) = error { - query_metrics.record_error(returned_rows, &error); - try_rollback_transaction(connection).await; - yield DbItem::Error(error); - } else { - query_metrics.record_success(returned_rows); + let items = execute_statement_collect( + connection, + &query, + source_file, + stmt, + request, + query_span, + &mut query_metrics, + ) + .await; + for db_item in items { + yield db_item; } }, ParsedStatement::SetVariable { variable, value} => { @@ -259,6 +229,136 @@ fn with_stmt_position( } } +async fn execute_statement_collect( + connection: &mut DbConnection, + query: &StatementWithParams<'_>, + source_file: &Path, + stmt: &StmtWithParams, + request: &ExecutionContext, + query_span: tracing::Span, + query_metrics: &mut DbQueryMetricsContext<'_>, +) -> Vec { + match connection { + DbConnection::Sqlite(conn) => { + collect_query_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + .await + } + DbConnection::Postgres(conn) => { + collect_query_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + .await + } + DbConnection::MySql(conn) => { + collect_query_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + .await + } + DbConnection::Mssql(conn) => { + collect_query_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + .await + } + DbConnection::Odbc(conn) => { + collect_query_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + .await + } + } +} + +async fn collect_query_results( + connection: &mut PoolConnection, + query: &StatementWithParams<'_>, + source_file: &Path, + stmt: &StmtWithParams, + request: &ExecutionContext, + query_span: tracing::Span, + query_metrics: &mut DbQueryMetricsContext<'_>, +) -> Vec +where + DB: SqlxDatabase, + DB::QueryResult: std::fmt::Debug, + DB::Row: super::sql_to_json::SqlPageRow, + usize: ColumnIndex, + DB::Arguments: IntoArguments, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'q> Option: Encode<'q, DB> + Type, +{ + let mut stream = (&mut **connection).fetch_many(bind_query::(query)); + let mut error = None; + let mut returned_rows: i64 = 0; + let mut items = Vec::new(); + loop { + let start_next = std::time::Instant::now(); + let next_elem = stream.next().instrument(query_span.clone()).await; + query_metrics.add_duration(start_next.elapsed()); + let Some(elem) = next_elem else { break }; + + let mut query_result = parse_single_sql_result::(source_file, stmt, elem); + if let DbItem::Error(e) = query_result { + error = Some(e); + break; + } + if matches!(query_result, DbItem::Row(_)) { + returned_rows += 1; + } + apply_json_columns(&mut query_result, &stmt.json_columns); + if let Err(err) = apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) + .instrument(query_span.clone()) + .await + { + error = Some(err); + break; + } + items.extend(parse_dynamic_rows(query_result)); + } + drop(stream); + if let Some(error) = error { + query_metrics.record_error(returned_rows, &error); + try_rollback_transaction(connection).await; + items.push(DbItem::Error(error)); + } else { + query_metrics.record_success(returned_rows); + } + items +} + /// Transforms a stream of database items to stop processing after encountering the first error. /// The error item itself is still emitted before stopping. pub fn stop_at_first_error( @@ -301,9 +401,13 @@ async fn exec_static_simple_select( Ok(serde_json::Value::Object(map)) } -async fn try_rollback_transaction(db_connection: &mut AnyConnection) { +async fn try_rollback_transaction(db_connection: &mut PoolConnection) +where + DB: SqlxDatabase, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ log::debug!("Attempting to rollback transaction"); - match db_connection.execute("ROLLBACK").await { + match (&mut **db_connection).execute("ROLLBACK").await { Ok(_) => log::debug!("Rolled back transaction"), Err(e) => { log::debug!("There was probably no transaction in progress when this happened: {e:?}"); @@ -311,6 +415,59 @@ async fn try_rollback_transaction(db_connection: &mut AnyConnection) { } } +async fn rollback_connection(connection: &mut DbConnection) { + match connection { + DbConnection::Sqlite(conn) => try_rollback_transaction(conn).await, + DbConnection::Postgres(conn) => try_rollback_transaction(conn).await, + DbConnection::MySql(conn) => try_rollback_transaction(conn).await, + DbConnection::Mssql(conn) => try_rollback_transaction(conn).await, + DbConnection::Odbc(conn) => try_rollback_transaction(conn).await, + } +} + +async fn fetch_optional_string( + connection: &mut DbConnection, + query: &StatementWithParams<'_>, + query_span: tracing::Span, +) -> sqlx::Result>> { + match connection { + DbConnection::Sqlite(conn) => { + fetch_optional_string_inner::(conn, query, query_span).await + } + DbConnection::Postgres(conn) => { + fetch_optional_string_inner::(conn, query, query_span).await + } + DbConnection::MySql(conn) => { + fetch_optional_string_inner::(conn, query, query_span).await + } + DbConnection::Mssql(conn) => { + fetch_optional_string_inner::(conn, query, query_span).await + } + DbConnection::Odbc(conn) => { + fetch_optional_string_inner::(conn, query, query_span).await + } + } +} + +async fn fetch_optional_string_inner( + connection: &mut PoolConnection, + query: &StatementWithParams<'_>, + query_span: tracing::Span, +) -> sqlx::Result>> +where + DB: SqlxDatabase, + DB::Row: super::sql_to_json::SqlPageRow, + DB::Arguments: IntoArguments, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, + for<'q> Option: Encode<'q, DB> + Type, +{ + (&mut **connection) + .fetch_optional(bind_query::(query)) + .instrument(query_span) + .await + .map(|row| row.map(|row| row_to_string(&row))) +} + /// Extracts the value of a parameter from the request. /// Returns `Ok(None)` when NULL should be used as the parameter value. async fn extract_req_param_as_json( @@ -368,15 +525,11 @@ async fn execute_set_variable_query<'a>( ); record_query_params(&query_metrics.span, &query.param_values); let start_time = std::time::Instant::now(); - let value = match connection - .fetch_optional(query) - .instrument(query_span.clone()) - .await - { - Ok(Some(row)) => { + let value = match fetch_optional_string(connection, &query, query_span.clone()).await { + Ok(Some(value)) => { query_metrics.add_duration(start_time.elapsed()); query_metrics.record_success(1_i64); - row_to_string(&row) + value } Ok(None) => { query_metrics.add_duration(start_time.elapsed()); @@ -385,7 +538,7 @@ async fn execute_set_variable_query<'a>( } Err(e) => { query_metrics.add_duration(start_time.elapsed()); - try_rollback_transaction(connection).await; + rollback_connection(connection).await; let err = display_stmt_db_error(source_file, statement, e); query_metrics.record_error(0_i64, &err); return Err(err); @@ -451,7 +604,7 @@ async fn take_connection<'a>( db: &'a Database, conn: &'a mut DbConn, request: &ExecutionContext, -) -> anyhow::Result<&'a mut PoolConnection> { +) -> anyhow::Result<&'a mut DbConnection> { if let Some(c) = conn { return Ok(c); } @@ -462,30 +615,38 @@ async fn take_connection<'a>( { otel::DB_CLIENT_CONNECTION_POOL_NAME } = "sqlpage", sqlpage.db.pool.size = pool_size, ); - match db.connection.acquire().instrument(acquire_span).await { - Ok(c) => { - log::debug!("Acquired a database connection"); - request.server_timing.record("db_conn"); - *conn = Some(c); - let connection = conn.as_mut().unwrap(); - set_trace_context(connection, db).await; - Ok(connection) - } - Err(e) => { - let db_name = db.connection.any_kind(); + let acquired = acquire_connection(&db.connection) + .instrument(acquire_span) + .await + .with_context(|| { + let db_name = db.info.kind.display_name(); let active_count = db.connection.size(); - let err_msg = format!( - "Unable to connect to {db_name:?}. The connection pool currently has {active_count} active connections." - ); - Err(anyhow::Error::new(e).context(err_msg)) - } + format!( + "Unable to connect to {db_name}. The connection pool currently has {active_count} active connections." + ) + })?; + log::debug!("Acquired a database connection"); + request.server_timing.record("db_conn"); + *conn = Some(acquired); + let connection = conn.as_mut().unwrap(); + set_trace_context(connection, db).await; + Ok(connection) +} + +async fn acquire_connection(pool: &DatabasePool) -> sqlx::Result { + match pool { + DatabasePool::Sqlite(pool) => pool.acquire().await.map(DbConnection::Sqlite), + DatabasePool::Postgres(pool) => pool.acquire().await.map(DbConnection::Postgres), + DatabasePool::MySql(pool) => pool.acquire().await.map(DbConnection::MySql), + DatabasePool::Mssql(pool) => pool.acquire().await.map(DbConnection::Mssql), + DatabasePool::Odbc(pool) => pool.acquire().await.map(DbConnection::Odbc), } } /// Sets the current `OTel` trace context on the database connection so it is visible /// in `pg_stat_activity.application_name` (`PostgreSQL`) or as a session variable (`MySQL`). /// This allows correlating `SQLPage` traces with database-side monitoring. -async fn set_trace_context(connection: &mut AnyConnection, db: &Database) { +async fn set_trace_context(connection: &mut DbConnection, db: &Database) { use opentelemetry::trace::TraceContextExt; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -503,30 +664,47 @@ async fn set_trace_context(connection: &mut AnyConnection, db: &Database) { span_context.trace_flags() ); let sql = match db.info.kind { - sqlx::any::AnyKind::Postgres => { + DbKind::Postgres => { // postgresqlreceiver expects application_name to be a raw W3C traceparent value. format!("SET application_name = '{traceparent}'") } - sqlx::any::AnyKind::MySql => { + DbKind::MySql => { format!("SET @traceparent = '{traceparent}'") } _ => return, }; - if let Err(e) = connection.execute(sql.as_str()).await { + let result = match connection { + DbConnection::Postgres(conn) => (&mut **conn) + .execute(sqlx::AssertSqlSafe(sql.clone())) + .await + .map(|_| ()), + DbConnection::MySql(conn) => (&mut **conn) + .execute(sqlx::AssertSqlSafe(sql.clone())) + .await + .map(|_| ()), + _ => return, + }; + if let Err(e) = result { log::debug!("Failed to set trace context on connection: {e}"); } } #[inline] -fn parse_single_sql_result( +fn parse_single_sql_result( source_file: &Path, stmt: &StmtWithParams, - res: sqlx::Result>, -) -> DbItem { + res: sqlx::Result>, +) -> DbItem +where + DB: SqlxDatabase, + DB::QueryResult: std::fmt::Debug, + DB::Row: super::sql_to_json::SqlPageRow, + usize: ColumnIndex, +{ match res { Ok(Either::Right(r)) => { if log::log_enabled!(log::Level::Trace) { - debug_row(&r); + debug_row::(&r); } DbItem::Row(super::sql_to_json::row_to_json(&r)) } @@ -541,7 +719,11 @@ fn parse_single_sql_result( } } -fn debug_row(r: &AnyRow) { +fn debug_row(r: &DB::Row) +where + DB: SqlxDatabase, + usize: ColumnIndex, +{ use std::fmt::Write; let columns = r.columns(); let mut row_str = String::new(); @@ -589,7 +771,6 @@ async fn bind_parameters<'a>( ) -> anyhow::Result> { let sql = stmt.query.as_str(); log::debug!("Preparing statement: {sql}"); - let mut arguments = AnyArguments::default(); let mut param_values = Vec::with_capacity(stmt.params.len()); for (param_idx, param) in stmt.params.iter().enumerate() { log::trace!("\tevaluating parameter {}: {}", param_idx + 1, param); @@ -600,19 +781,8 @@ async fn bind_parameters<'a>( argument.as_ref().unwrap_or(&Cow::Borrowed("NULL")) ); param_values.push(argument.as_deref().map(str::to_owned)); - match argument { - None => arguments.add(None::), - Some(Cow::Owned(s)) => arguments.add(s), - Some(Cow::Borrowed(v)) => arguments.add(v), - } } - let has_arguments = !stmt.params.is_empty(); - Ok(StatementWithParams { - sql, - arguments, - has_arguments, - param_values, - }) + Ok(StatementWithParams { sql, param_values }) } async fn apply_delayed_functions( @@ -700,32 +870,22 @@ fn apply_json_columns(item: &mut DbItem, json_columns: &[String]) { pub struct StatementWithParams<'a> { sql: &'a str, - arguments: AnyArguments<'a>, - has_arguments: bool, param_values: Vec>, } -impl<'q> sqlx::Execute<'q, Any> for StatementWithParams<'q> { - fn sql(&self) -> &'q str { - self.sql - } - - fn statement(&self) -> Option<&>::Statement> { - None - } - - fn take_arguments(&mut self) -> Option<>::Arguments> { - if self.has_arguments { - Some(std::mem::take(&mut self.arguments)) - } else { - None - } - } - - fn persistent(&self) -> bool { - // Let sqlx create a prepared statement the first time it is executed, and then reuse it. - true +fn bind_query<'q, DB>( + statement: &'q StatementWithParams<'q>, +) -> sqlx::query::Query<'q, DB, DB::Arguments> +where + DB: SqlxDatabase, + Option: Encode<'q, DB> + Type, + DB::Arguments: IntoArguments, +{ + let mut query = sqlx::query::(sqlx::AssertSqlSafe(statement.sql)); + for value in &statement.param_values { + query = query.bind(value.clone()); } + query } #[cfg(test)] diff --git a/src/webserver/database/migrations.rs b/src/webserver/database/migrations.rs index 426045b9..0d67ed66 100644 --- a/src/webserver/database/migrations.rs +++ b/src/webserver/database/migrations.rs @@ -1,5 +1,5 @@ use super::Database; -use super::error_highlighting::display_db_error; +use super::{DatabasePool, DbKind}; use crate::MIGRATIONS_DIR; use anyhow; use anyhow::Context; @@ -34,18 +34,14 @@ pub async fn apply(config: &crate::app_config::AppConfig, db: &Database) -> anyh for m in migrator.iter() { log::info!("\t{}", DisplayMigration(m)); } - migrator.run(&db.connection).await.map_err(|err| { + if db.info.kind == DbKind::Odbc { + anyhow::bail!( + "ODBC migrations are not supported by sqlx-odbc. Apply the migrations manually or use a native SQLPage backend for managed migrations." + ); + } + run_migrator(&migrator, &db.connection).await.map_err(|err| { match err { - MigrateError::Execute(n, source) => { - let migration = migrator.iter().find(|&m| m.version == n).unwrap(); - let source_file = - migrations_dir.join(format!("{:04}_{}.sql", n, migration.description)); - display_db_error(&source_file, &migration.sql, source).context(format!( - "Failed to apply {} migration {}", - db, - DisplayMigration(migration) - )) - } + MigrateError::Execute(source) => anyhow::Error::new(source), source => anyhow::Error::new(source), } .context(format!( @@ -55,6 +51,16 @@ pub async fn apply(config: &crate::app_config::AppConfig, db: &Database) -> anyh Ok(()) } +async fn run_migrator(migrator: &Migrator, pool: &DatabasePool) -> Result<(), MigrateError> { + match pool { + DatabasePool::Sqlite(pool) => migrator.run(pool).await, + DatabasePool::Postgres(pool) => migrator.run(pool).await, + DatabasePool::MySql(pool) => migrator.run(pool).await, + DatabasePool::Mssql(pool) => migrator.run(pool).await, + DatabasePool::Odbc(_) => unreachable!("ODBC migrations are checked before run_migrator"), + } +} + struct DisplayMigration<'a>(&'a Migration); impl std::fmt::Display for DisplayMigration<'_> { diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index b73e0cb1..922663f1 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -12,9 +12,112 @@ mod sql_to_json; pub use sql::ParsedSqlFile; use sql::{DB_PLACEHOLDERS, DbPlaceHolder}; -use sqlx::any::AnyKind; // SupportedDatabase is defined in this module +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum DbKind { + Sqlite, + Postgres, + MySql, + Mssql, + Odbc, +} + +impl DbKind { + #[must_use] + pub fn from_database_url(database_url: &str) -> Self { + let lower = database_url.to_ascii_lowercase(); + if lower.starts_with("postgres://") || lower.starts_with("postgresql://") { + Self::Postgres + } else if lower.starts_with("mysql://") || lower.starts_with("mariadb://") { + Self::MySql + } else if lower.starts_with("sqlite:") { + Self::Sqlite + } else if lower.starts_with("mssql://") || lower.starts_with("sqlserver://") { + Self::Mssql + } else { + Self::Odbc + } + } + + #[must_use] + pub fn display_name(self) -> &'static str { + match self { + Self::Sqlite => "SQLite", + Self::Postgres => "PostgreSQL", + Self::MySql => "MySQL", + Self::Mssql => "Microsoft SQL Server", + Self::Odbc => "ODBC", + } + } +} + +impl From for SupportedDatabase { + fn from(kind: DbKind) -> Self { + match kind { + DbKind::Sqlite => Self::Sqlite, + DbKind::Postgres => Self::Postgres, + DbKind::MySql => Self::MySql, + DbKind::Mssql => Self::Mssql, + DbKind::Odbc => Self::Generic, + } + } +} + +#[derive(Debug, Clone)] +pub enum DatabasePool { + Sqlite(sqlx::Pool), + Postgres(sqlx::Pool), + MySql(sqlx::Pool), + Mssql(sqlx::Pool), + Odbc(sqlx::Pool), +} + +impl DatabasePool { + #[must_use] + pub fn kind(&self) -> DbKind { + match self { + Self::Sqlite(_) => DbKind::Sqlite, + Self::Postgres(_) => DbKind::Postgres, + Self::MySql(_) => DbKind::MySql, + Self::Mssql(_) => DbKind::Mssql, + Self::Odbc(_) => DbKind::Odbc, + } + } + + #[must_use] + pub fn size(&self) -> u32 { + match self { + Self::Sqlite(pool) => pool.size(), + Self::Postgres(pool) => pool.size(), + Self::MySql(pool) => pool.size(), + Self::Mssql(pool) => pool.size(), + Self::Odbc(pool) => pool.size(), + } + } + + #[must_use] + pub fn num_idle(&self) -> usize { + match self { + Self::Sqlite(pool) => pool.num_idle(), + Self::Postgres(pool) => pool.num_idle(), + Self::MySql(pool) => pool.num_idle(), + Self::Mssql(pool) => pool.num_idle(), + Self::Odbc(pool) => pool.num_idle(), + } + } + + pub async fn close(&self) { + match self { + Self::Sqlite(pool) => pool.close().await, + Self::Postgres(pool) => pool.close().await, + Self::MySql(pool) => pool.close().await, + Self::Mssql(pool) => pool.close().await, + Self::Odbc(pool) => pool.close().await, + } + } +} + /// Supported database types in `SQLPage`. Represents an actual DBMS, not a sqlx backend kind (like "Odbc") #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum SupportedDatabase { @@ -81,20 +184,8 @@ impl SupportedDatabase { } } -impl From for SupportedDatabase { - fn from(kind: AnyKind) -> Self { - match kind { - AnyKind::Postgres => Self::Postgres, - AnyKind::MySql => Self::MySql, - AnyKind::Sqlite => Self::Sqlite, - AnyKind::Mssql => Self::Mssql, - AnyKind::Odbc => Self::Generic, - } - } -} - pub struct Database { - pub connection: sqlx::AnyPool, + pub connection: DatabasePool, pub info: DbInfo, } @@ -103,8 +194,8 @@ pub struct DbInfo { pub dbms_name: String, /// The actual database we are connected to. Can be "Generic" when using an unknown ODBC driver pub database_type: SupportedDatabase, - /// The sqlx database backend we are using. Can be "Odbc", in which case we need to use `database_type` to know what database we are actually using. - pub kind: AnyKind, + /// The SQLPage backend we are using. Can be "Odbc", in which case we need to use `database_type` to know what database we are actually using. + pub kind: DbKind, } impl Database { @@ -124,13 +215,13 @@ pub enum DbItem { impl std::fmt::Display for Database { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.connection.any_kind()) + write!(f, "{}", self.info.database_type.display_name()) } } #[inline] #[must_use] -pub fn make_placeholder(dbms: AnyKind, arg_number: usize) -> String { +pub fn make_placeholder(dbms: DbKind, arg_number: usize) -> String { if let Some((_, placeholder)) = DB_PLACEHOLDERS.iter().find(|(kind, _)| *kind == dbms) { match *placeholder { DbPlaceHolder::PrefixedNumber { prefix } => format!("{prefix}{arg_number}"), diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 0d4bdfef..c246e668 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -20,7 +20,6 @@ use sqlparser::dialect::{ use sqlparser::parser::{Parser, ParserError}; use sqlparser::tokenizer::Token::{self, EOF, SemiColon}; use sqlparser::tokenizer::{Location, Span, TokenWithSpan, Tokenizer}; -use sqlx::any::AnyKind; use std::fmt::Write; use std::path::{Path, PathBuf}; use std::str::FromStr; @@ -169,7 +168,7 @@ fn parse_sql<'a>( })) } -fn transform_to_positional_placeholders(stmt: &mut StmtWithParams, kind: AnyKind) { +fn transform_to_positional_placeholders(stmt: &mut StmtWithParams, kind: super::DbKind) { if let Some((_, DbPlaceHolder::Positional { placeholder })) = DB_PLACEHOLDERS .iter() .find(|(placeholder_kind, _)| *placeholder_kind == kind) @@ -784,11 +783,11 @@ mod test { fn create_test_db_info(database_type: SupportedDatabase) -> DbInfo { let kind = match database_type { - SupportedDatabase::Postgres => AnyKind::Postgres, - SupportedDatabase::Mssql => AnyKind::Mssql, - SupportedDatabase::MySql => AnyKind::MySql, - SupportedDatabase::Sqlite => AnyKind::Sqlite, - _ => AnyKind::Odbc, + SupportedDatabase::Postgres => DbKind::Postgres, + SupportedDatabase::Mssql => DbKind::Mssql, + SupportedDatabase::MySql => DbKind::MySql, + SupportedDatabase::Sqlite => DbKind::Sqlite, + _ => DbKind::Odbc, }; DbInfo { dbms_name: database_type.display_name().to_string(), @@ -1317,7 +1316,7 @@ mod test { delayed_functions: vec![], json_columns: vec![], }; - transform_to_positional_placeholders(&mut stmt, AnyKind::MySql); + transform_to_positional_placeholders(&mut stmt, DbKind::MySql); assert_eq!( stmt.query, "select \ diff --git a/src/webserver/database/sql/parameter_extraction.rs b/src/webserver/database/sql/parameter_extraction.rs index 57bb08c4..d1a8ce1b 100644 --- a/src/webserver/database/sql/parameter_extraction.rs +++ b/src/webserver/database/sql/parameter_extraction.rs @@ -1,4 +1,4 @@ -use super::super::{DbInfo, SupportedDatabase}; +use super::super::{DbInfo, DbKind, SupportedDatabase}; use super::{is_sqlpage_func, sqlpage_func_name}; use crate::webserver::database::sqlpage_functions::func_call_to_param; use crate::webserver::database::syntax_tree::StmtParam; @@ -7,7 +7,6 @@ use sqlparser::ast::{ FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, ObjectName, ObjectNamePart, Spanned, Statement, Value, ValueWithSpan, Visit, VisitMut, Visitor, VisitorMut, }; -use sqlx::any::AnyKind; use std::ops::ControlFlow; pub(super) struct ParameterExtractor { @@ -22,25 +21,25 @@ pub(crate) enum DbPlaceHolder { Positional { placeholder: &'static str }, } -pub(crate) const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ +pub(crate) const DB_PLACEHOLDERS: [(DbKind, DbPlaceHolder); 5] = [ ( - AnyKind::Sqlite, + DbKind::Sqlite, DbPlaceHolder::PrefixedNumber { prefix: "?" }, ), ( - AnyKind::Postgres, + DbKind::Postgres, DbPlaceHolder::PrefixedNumber { prefix: "$" }, ), ( - AnyKind::MySql, + DbKind::MySql, DbPlaceHolder::Positional { placeholder: "?" }, ), ( - AnyKind::Mssql, + DbKind::Mssql, DbPlaceHolder::PrefixedNumber { prefix: "@p" }, ), ( - AnyKind::Odbc, + DbKind::Odbc, DbPlaceHolder::Positional { placeholder: "?" }, ), ]; @@ -49,7 +48,7 @@ pub(crate) const DB_PLACEHOLDERS: [(AnyKind, DbPlaceHolder); 5] = [ /// And then replace it with the actual placeholder during statement rewriting. pub(crate) const TEMP_PLACEHOLDER_PREFIX: &str = "@SQLPAGE_TEMP"; -fn get_placeholder_prefix(kind: AnyKind) -> &'static str { +fn get_placeholder_prefix(kind: DbKind) -> &'static str { if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS .iter() .find(|(placeholder_kind, _prefix)| *placeholder_kind == kind) @@ -97,7 +96,9 @@ impl ParameterExtractor { let data_type = match self.db_info.database_type { SupportedDatabase::MySql => DataType::Char(None), SupportedDatabase::Mssql => DataType::Varchar(Some(CharacterLength::Max)), - SupportedDatabase::Postgres | SupportedDatabase::Sqlite => DataType::Text, + SupportedDatabase::Postgres | SupportedDatabase::Sqlite | SupportedDatabase::Duckdb => { + DataType::Text + } SupportedDatabase::Oracle => DataType::Varchar(Some(CharacterLength::IntegerLength { length: 4000, unit: None, @@ -463,7 +464,7 @@ fn function_arg_expr(arg: &mut FunctionArg) -> Option<&mut Expr> { #[inline] #[must_use] -pub(super) fn make_tmp_placeholder(kind: AnyKind, arg_number: usize) -> String { +pub(super) fn make_tmp_placeholder(kind: DbKind, arg_number: usize) -> String { let prefix = if let Some((_, DbPlaceHolder::PrefixedNumber { prefix })) = DB_PLACEHOLDERS.iter().find(|(db_typ, _)| *db_typ == kind) { diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 4c93dddd..f1e8aafd 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -1,472 +1,134 @@ use crate::utils::add_value_to_map; use crate::webserver::database::blob_to_data_url; -use bigdecimal::BigDecimal; -use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime}; -use serde_json::{self, Map, Value}; -use sqlx::any::{AnyColumn, AnyRow, AnyTypeInfo, AnyTypeInfoKind}; -use sqlx::postgres::PgValueRef; -use sqlx::postgres::types::PgRange; -use sqlx::{Column, Row, TypeInfo, ValueRef}; -use sqlx::{Decode, Type}; +use serde_json::{Map, Value}; +use sqlx::{Column, ColumnIndex, Row, TypeInfo, ValueRef}; -pub fn row_to_json(row: &AnyRow) -> Value { - use Value::Object; +pub trait SqlPageRow { + fn to_json(&self) -> Value; + fn first_value_to_string(&self) -> Option; +} - let columns = row.columns(); - let mut map = Map::new(); - for col in columns { - let key = canonical_col_name(col); - let value: Value = sql_to_json(row, col); - map = add_value_to_map(map, (key, value)); - } - Object(map) +pub fn row_to_json(row: &impl SqlPageRow) -> Value { + row.to_json() } -fn canonical_col_name(col: &AnyColumn) -> String { - // Some databases fold all unquoted identifiers to uppercase but SQLPage uses lowercase property names - if matches!(col.type_info().0, AnyTypeInfoKind::Odbc(_)) - && col - .name() - .chars() - .all(|c| c.is_ascii_uppercase() || c == '_') - { - col.name().to_ascii_lowercase() - } else { - col.name().to_owned() - } +pub fn row_to_string(row: &impl SqlPageRow) -> Option { + row.first_value_to_string() } -pub fn sql_to_json(row: &AnyRow, col: &sqlx::any::AnyColumn) -> Value { - let raw_value_result = row.try_get_raw(col.ordinal()); - match raw_value_result { - Ok(raw_value) if !raw_value.is_null() => { - let mut raw_value = Some(raw_value); - let decoded = sql_nonnull_to_json(|| { - raw_value - .take() - .unwrap_or_else(|| row.try_get_raw(col.ordinal()).unwrap()) - }); - log::trace!("Decoded value: {decoded:?}"); - decoded - } - Ok(_null) => Value::Null, - Err(e) => { - log::warn!("Unable to extract value from row: {e:?}"); - Value::Null +macro_rules! impl_sqlpage_row { + ($row:ty, $db:ty, $canonical_odbc_names:expr) => { + impl SqlPageRow for $row { + fn to_json(&self) -> Value { + let mut map = Map::new(); + for col in self.columns() { + let key = canonical_col_name(col.name(), $canonical_odbc_names); + let value = sql_to_json::<$db, _>(self, col.ordinal()); + map = add_value_to_map(map, (key, value)); + } + Value::Object(map) + } + + fn first_value_to_string(&self) -> Option { + let col = self.columns().first()?; + match sql_to_json::<$db, _>(self, col.ordinal()) { + Value::String(s) => Some(s), + Value::Null => None, + other => Some(other.to_string()), + } + } } - } + }; } -fn decode_raw<'a, T: Decode<'a, sqlx::any::Any> + Default>( - raw_value: sqlx::any::AnyValueRef<'a>, -) -> T { - match T::decode(raw_value) { - Ok(v) => v, - Err(e) => { - let type_name = std::any::type_name::(); - log::error!("Failed to decode {type_name} value: {e}"); - T::default() - } +impl_sqlpage_row!(sqlx::postgres::PgRow, sqlx::Postgres, false); +impl_sqlpage_row!(sqlx::mysql::MySqlRow, sqlx::MySql, false); +impl_sqlpage_row!(sqlx::sqlite::SqliteRow, sqlx::Sqlite, false); +impl_sqlpage_row!(sqlx_sqlserver::MssqlRow, sqlx_sqlserver::Mssql, false); +impl_sqlpage_row!(sqlx_odbc::OdbcRow, sqlx_odbc::Odbc, true); + +fn canonical_col_name(name: &str, canonicalize_uppercase: bool) -> String { + if canonicalize_uppercase + && name + .chars() + .all(|c| c.is_ascii_uppercase() || c == '_' || c.is_ascii_digit()) + { + name.to_ascii_lowercase() + } else { + name.to_owned() } } -fn decode_pg_range<'r, T>(raw_value: sqlx::any::AnyValueRef<'r>) -> Value +fn sql_to_json(row: &R, ordinal: usize) -> Value where - T: std::fmt::Display - + Type - + for<'a> sqlx::Decode<'a, sqlx::postgres::Postgres>, + DB: sqlx::Database, + R: Row, + for<'r> bool: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> i16: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> i32: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> i64: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> f32: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> f64: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> String: sqlx::Decode<'r, DB> + sqlx::Type, + for<'r> Vec: sqlx::Decode<'r, DB> + sqlx::Type, + usize: ColumnIndex, { - let Ok(pg_val): Result, _> = raw_value.try_into() else { - log::error!("Only postgres range values are supported"); - return Value::Null; - }; - match as sqlx::Decode<'r, sqlx::postgres::Postgres>>::decode(pg_val) { - Ok(pg_range) => pg_range.to_string().into(), + let raw_value = match row.try_get_raw(ordinal) { + Ok(raw_value) if raw_value.is_null() => return Value::Null, + Ok(raw_value) => raw_value, Err(e) => { - log::error!("Failed to decode postgres range value: {e}"); - Value::Null + log::warn!("Unable to extract value from row: {e:?}"); + return Value::Null; } - } -} - -fn decimal_to_json(decimal: &BigDecimal) -> Value { - // to_plain_string always returns a valid JSON string - Value::Number(serde_json::Number::from_string_unchecked( - decimal.normalized().to_plain_string(), - )) -} + }; -pub fn sql_nonnull_to_json<'r>(mut get_ref: impl FnMut() -> sqlx::any::AnyValueRef<'r>) -> Value { - use AnyTypeInfoKind::{Mssql, MySql}; - let raw_value = get_ref(); let type_info = raw_value.type_info(); - let type_name = type_info.name(); + let type_name = type_info.name().to_ascii_uppercase(); log::trace!("Decoding a value of type {type_name:?} (type info: {type_info:?})"); - let AnyTypeInfo(ref db_type) = *type_info; - match type_name { - "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" => decode_raw::(raw_value).into(), - "NUMERIC" | "DECIMAL" => decimal_to_json(&decode_raw(raw_value)), - "INT8" | "BIGINT" | "SERIAL8" | "BIGSERIAL" | "IDENTITY" | "INT64" | "INTEGER8" - | "BIGINT SIGNED" => decode_raw::(raw_value).into(), - "INT" | "INT4" | "INTEGER" | "MEDIUMINT" | "YEAR" => decode_raw::(raw_value).into(), - "INT2" | "SMALLINT" | "TINYINT" => decode_raw::(raw_value).into(), - "BIGINT UNSIGNED" => decode_raw::(raw_value).into(), - "INT UNSIGNED" | "MEDIUMINT UNSIGNED" | "SMALLINT UNSIGNED" | "TINYINT UNSIGNED" => { - decode_raw::(raw_value).into() - } - "BOOL" | "BOOLEAN" => decode_raw::(raw_value).into(), - "BIT" if matches!(db_type, Mssql(_)) => decode_raw::(raw_value).into(), - "BIT" if matches!(db_type, MySql(mysql_type) if mysql_type.max_size() == Some(1)) => { - decode_raw::(raw_value).into() - } - "BIT" if matches!(db_type, MySql(_)) => decode_raw::(raw_value).into(), - "DATE" => decode_raw::(raw_value) - .to_string() - .into(), - "TIME" | "TIMETZ" => decode_raw::(raw_value) - .to_string() - .into(), - "DATETIMEOFFSET" | "TIMESTAMP" | "TIMESTAMPTZ" => { - decode_raw::>(raw_value) - .to_rfc3339() - .into() + + match type_name.as_str() { + "BOOL" | "BOOLEAN" | "BIT" => decode::(row, ordinal).into(), + "INT2" | "SMALLINT" | "TINYINT" => decode::(row, ordinal).into(), + "INT" | "INT4" | "INTEGER" | "MEDIUMINT" | "YEAR" => { + decode::(row, ordinal).into() } - "DATETIME" | "DATETIME2" => decode_raw::(raw_value) - .format("%FT%T%.f") - .to_string() - .into(), - "MONEY" | "SMALLMONEY" if matches!(db_type, Mssql(_)) => { - decode_raw::(raw_value).into() + "INT8" | "BIGINT" | "SERIAL8" | "BIGSERIAL" | "IDENTITY" | "INT64" | "INTEGER8" => { + decode::(row, ordinal).into() } - "UUID" | "UNIQUEIDENTIFIER" => decode_raw::(raw_value) - .to_string() - .into(), - "JSON" | "JSON[]" | "JSONB" | "JSONB[]" => decode_raw::(raw_value), + "REAL" | "FLOAT4" => decode::(row, ordinal).into(), + "FLOAT" | "FLOAT8" | "DOUBLE" => decode::(row, ordinal).into(), "BLOB" | "BYTEA" | "FILESTREAM" | "VARBINARY" | "BIGVARBINARY" | "BINARY" | "IMAGE" => { - blob_to_data_url::vec_to_data_uri_value(&decode_raw::>(raw_value)) + blob_to_data_url::vec_to_data_uri_value(&decode::>(row, ordinal)) } - "INT4RANGE" => decode_pg_range::(raw_value), - "INT8RANGE" => decode_pg_range::(raw_value), - "NUMRANGE" => decode_pg_range::(raw_value), - "DATERANGE" => decode_pg_range::(raw_value), - "TSRANGE" => decode_pg_range::(raw_value), - "TSTZRANGE" => decode_pg_range::>(raw_value), - // Deserialize as a string by default - _ => decode_raw::(raw_value).into(), + _ => decode::(row, ordinal).into(), } } -/// Takes the first column of a row and converts it to a string. -pub fn row_to_string(row: &AnyRow) -> Option { - let col = row.columns().first()?; - match sql_to_json(row, col) { - serde_json::Value::String(s) => Some(s), - serde_json::Value::Null => None, - other => Some(other.to_string()), +fn decode(row: &R, ordinal: usize) -> T +where + DB: sqlx::Database, + R: Row, + for<'r> T: sqlx::Decode<'r, DB> + sqlx::Type + Default, + usize: ColumnIndex, +{ + match row.try_get::(ordinal) { + Ok(v) => v, + Err(e) => { + let type_name = std::any::type_name::(); + log::error!("Failed to decode {type_name} value: {e}"); + T::default() + } } } #[cfg(test)] mod tests { - use crate::app_config::tests::test_database_url; - use super::*; use sqlx::Connection; - fn setup_logging() { - crate::telemetry::init_test_logging(); - } - - fn db_specific_test(db_type: &str) -> Option { - setup_logging(); - let db_url = test_database_url(); - if db_url.starts_with(db_type) { - Some(db_url) - } else { - log::warn!("Skipping test because DATABASE_URL is not set to a {db_type} database"); - None - } - } - - #[actix_web::test] - async fn test_row_to_json() -> anyhow::Result<()> { - use sqlx::Connection; - let db_url = test_database_url(); - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let row = sqlx::query( - "SELECT - 123.456 as one_value, - 1 as two_values, - 2 as two_values, - 'x' as three_values, - 'y' as three_values, - 'z' as three_values - ", - ) - .fetch_one(&mut c) - .await?; - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "one_value": 123.456, - "two_values": [1,2], - "three_values": ["x","y","z"], - }), - ); - Ok(()) - } - - #[actix_web::test] - async fn test_postgres_types() -> anyhow::Result<()> { - let Some(db_url) = db_specific_test("postgres") else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let row = sqlx::query( - "SELECT - 42::INT2 as small_int, - 42::INT4 as integer, - 42::INT8 as big_int, - 42.25::FLOAT4 as float4, - 42.25::FLOAT8 as float8, - 123456789123456789123456789::NUMERIC as numeric, - TRUE as boolean, - '2024-03-14'::DATE as date, - '13:14:15'::TIME as time, - '2024-03-14 13:14:15'::TIMESTAMP as timestamp, - '2024-03-14 13:14:15+02:00'::TIMESTAMPTZ as timestamptz, - INTERVAL '1 year 2 months 3 days' as complex_interval, - INTERVAL '4 hours' as hour_interval, - INTERVAL '1.5 days' as fractional_interval, - '{\"key\": \"value\"}'::JSON as json, - '{\"key\": \"value\"}'::JSONB as jsonb, - age('2024-03-14'::timestamp, '2024-01-01'::timestamp) as age_interval, - justify_interval(interval '1 year 2 months 3 days') as justified_interval, - 1234.56::MONEY as money_val, - '\\x68656c6c6f20776f726c64'::BYTEA as blob_data, - '550e8400-e29b-41d4-a716-446655440000'::UUID as uuid, - '[1,5)'::INT4RANGE as int4range, - '[1,5]'::INT8RANGE as int8range, - '[1.5,4.5)'::NUMRANGE as numrange, - -- '[2024-11-12 01:02:03,2024-11-12 23:00:00)'::TSRANGE as tsrange, - -- '[2024-11-12 01:02:03+01:00,2024-11-12 23:00:00+00:00)'::TSTZRANGE as tstzrange, - '[2024-11-12,2024-11-13)'::DATERANGE as daterange - ", - ) - .fetch_one(&mut c) - .await?; - - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "small_int": 42, - "integer": 42, - "big_int": 42, - "float4": 42.25, - "float8": 42.25, - "numeric": 123_456_789_123_456_789_123_456_789_u128, - "boolean": true, - "date": "2024-03-14", - "time": "13:14:15", - "timestamp": "2024-03-14T13:14:15+00:00", - "timestamptz": "2024-03-14T11:14:15+00:00", - "complex_interval": "1 year 2 mons 3 days", - "hour_interval": "04:00:00", - "fractional_interval": "1 day 12:00:00", - "json": {"key": "value"}, - "jsonb": {"key": "value"}, - "age_interval": "2 mons 13 days", - "justified_interval": "1 year 2 mons 3 days", - "money_val": "$1,234.56", - "blob_data": "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=", - "uuid": "550e8400-e29b-41d4-a716-446655440000", - "int4range": "[1,5)", - "int8range": "[1,6)", - "numrange": "[1.5,4.5)", - //"tsrange": "[2024-11-12 01:02:03,2024-11-12 23:00:00)", // todo: bug in sqlx datetime range parsing - //"tstzrange": "[\"2024-11-12 02:00:00 +01:00\",\"2024-11-12 23:00:00 +00:00\")", // todo: tz info is lost in sqlx - "daterange": "[2024-11-12,2024-11-13)" - }), - ); - Ok(()) - } - - /// Postgres encodes values differently in prepared statements and in "simple" queries - /// - #[actix_web::test] - async fn test_postgres_prepared_types() -> anyhow::Result<()> { - let Some(db_url) = db_specific_test("postgres") else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let row = sqlx::query( - "SELECT - '2024-03-14'::DATE as date, - '13:14:15'::TIME as time, - '2024-03-14 13:14:15+02:00'::TIMESTAMPTZ as timestamptz, - INTERVAL '-01:02:03' as time_interval, - '{\"key\": \"value\"}'::JSON as json, - 1234.56::MONEY as money_val, - '\\x74657374'::BYTEA as blob_data, - '550e8400-e29b-41d4-a716-446655440000'::UUID as uuid - where $1", - ) - .bind(true) - .fetch_one(&mut c) - .await?; - - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "date": "2024-03-14", - "time": "13:14:15", - "timestamptz": "2024-03-14T11:14:15+00:00", - "time_interval": "-01:02:03", - "json": {"key": "value"}, - "money_val": "", // TODO: fix this bug: https://github.com/sqlpage/SQLPage/issues/983 - "blob_data": "data:application/octet-stream;base64,dGVzdA==", - "uuid": "550e8400-e29b-41d4-a716-446655440000", - }), - ); - Ok(()) - } - #[actix_web::test] - async fn test_postgres_prepared_range_types() -> anyhow::Result<()> { - let Some(db_url) = db_specific_test("postgres") else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let row = sqlx::query( - "SELECT - '[1,5)'::INT4RANGE as int4range, - '[2024-11-12 01:02:03,2024-11-12 23:00:00)'::TSRANGE as tsrange, - '[2024-11-12 01:02:03+01:00,2024-11-12 23:00:00+00:00)'::TSTZRANGE as tstzrange, - '[2024-11-12,2024-11-13)'::DATERANGE as daterange - where $1", - ) - .bind(true) - .fetch_one(&mut c) - .await?; - - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "int4range": "[1,5)", - "tsrange": "[2024-11-12 01:02:03,2024-11-12 23:00:00)", - "tstzrange": "[2024-11-12 00:02:03 +00:00,2024-11-12 23:00:00 +00:00)", // todo: tz info is lost in sqlx - "daterange": "[2024-11-12,2024-11-13)" - }), - ); - Ok(()) - } - - #[actix_web::test] - async fn test_mysql_types() -> anyhow::Result<()> { - let db_url = db_specific_test("mysql").or_else(|| db_specific_test("mariadb")); - let Some(db_url) = db_url else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - - sqlx::query( - "CREATE TEMPORARY TABLE _sqlp_t ( - tiny_int TINYINT, - small_int SMALLINT, - medium_int MEDIUMINT, - signed_int INTEGER, - big_int BIGINT, - unsigned_int INTEGER UNSIGNED, - tiny_int_unsigned TINYINT UNSIGNED, - small_int_unsigned SMALLINT UNSIGNED, - medium_int_unsigned MEDIUMINT UNSIGNED, - big_int_unsigned BIGINT UNSIGNED, - decimal_num DECIMAL(10,2), - float_num FLOAT, - double_num DOUBLE, - bit_val BIT(1), - date_val DATE, - time_val TIME, - datetime_val DATETIME, - timestamp_val TIMESTAMP, - year_val YEAR, - char_val CHAR(10), - varchar_val VARCHAR(50), - text_val TEXT, - blob_val BLOB - ) AS - SELECT - 127 as tiny_int, - 32767 as small_int, - 8388607 as medium_int, - -1000000 as signed_int, - 9223372036854775807 as big_int, - 1000000 as unsigned_int, - 255 as tiny_int_unsigned, - 65535 as small_int_unsigned, - 16777215 as medium_int_unsigned, - 18446744073709551615 as big_int_unsigned, - 123.45 as decimal_num, - 42.25 as float_num, - 42.25 as double_num, - 1 as bit_val, - '2024-03-14' as date_val, - '13:14:15' as time_val, - '2024-03-14 13:14:15' as datetime_val, - '2024-03-14 13:14:15' as timestamp_val, - 2024 as year_val, - 'CHAR' as char_val, - 'VARCHAR' as varchar_val, - 'TEXT' as text_val, - x'626c6f62' as blob_val", - ) - .execute(&mut c) - .await?; - - let row = sqlx::query("SELECT * FROM _sqlp_t") - .fetch_one(&mut c) - .await?; - - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "tiny_int": 127, - "small_int": 32767, - "medium_int": 8_388_607, - "signed_int": -1_000_000, - "big_int": 9_223_372_036_854_775_807_u64, - "unsigned_int": 1_000_000, - "tiny_int_unsigned": 255, - "small_int_unsigned": 65_535, - "medium_int_unsigned": 16_777_215, - "big_int_unsigned": 18_446_744_073_709_551_615_u64, - "decimal_num": 123.45, - "float_num": 42.25, - "double_num": 42.25, - "bit_val": true, - "date_val": "2024-03-14", - "time_val": "13:14:15", - "datetime_val": "2024-03-14T13:14:15", - "timestamp_val": "2024-03-14T13:14:15+00:00", - "year_val": 2024, - "char_val": "CHAR", - "varchar_val": "VARCHAR", - "text_val": "TEXT", - "blob_val": "data:application/octet-stream;base64,YmxvYg==" - }), - ); - - sqlx::query("DROP TABLE _sqlp_t").execute(&mut c).await?; - - Ok(()) - } - - #[actix_web::test] - async fn test_sqlite_types() -> anyhow::Result<()> { - let Some(db_url) = db_specific_test("sqlite") else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; + async fn test_sqlite_row_to_json() -> anyhow::Result<()> { + let mut c = sqlx::SqliteConnection::connect(":memory:").await?; let row = sqlx::query( "SELECT 42 as integer, @@ -477,9 +139,9 @@ mod tests { .fetch_one(&mut c) .await?; - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ + assert_eq!( + row_to_json(&row), + serde_json::json!({ "integer": 42, "real": 42.25, "string": "xxx", @@ -488,225 +150,4 @@ mod tests { ); Ok(()) } - - #[actix_web::test] - async fn test_mssql_types() -> anyhow::Result<()> { - let Some(db_url) = db_specific_test("mssql") else { - return Ok(()); - }; - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let row = sqlx::query( - "SELECT - CAST(1 AS BIT) as true_bit, - CAST(0 AS BIT) as false_bit, - CAST(NULL AS BIT) as null_bit, - CAST(255 AS TINYINT) as tiny_int, - CAST(42 AS SMALLINT) as small_int, - CAST(42 AS INT) as integer, - CAST(42 AS BIGINT) as big_int, - CAST(42.25 AS REAL) as real, - CAST(42.25 AS FLOAT) as float, - CAST(42.25 AS DECIMAL(10,2)) as decimal, - CAST('2024-03-14' AS DATE) as date, - CAST('13:14:15' AS TIME) as time, - CAST('2024-03-14 13:14:15' AS DATETIME) as datetime, - CAST('2024-03-14 13:14:15' AS DATETIME2) as datetime2, - CAST('2024-03-14 13:14:15 +02:00' AS DATETIMEOFFSET) as datetimeoffset, - N'Unicode String' as nvarchar, - 'ASCII String' as varchar, - CAST(1234.56 AS MONEY) as money_val, - CAST(12.34 AS SMALLMONEY) as small_money_val, - CAST(0x6D7373716C AS VARBINARY(10)) as blob_data, - CONVERT(UNIQUEIDENTIFIER, '6F9619FF-8B86-D011-B42D-00C04FC964FF') as unique_identifier - " - ) - .fetch_one(&mut c) - .await?; - - expect_json_object_equal( - &row_to_json(&row), - &serde_json::json!({ - "true_bit": true, - "false_bit": false, - "null_bit": null, - "tiny_int": 255, - "small_int": 42, - "integer": 42, - "big_int": 42, - "real": 42.25, - "float": 42.25, - "decimal": 42.25, - "date": "2024-03-14", - "time": "13:14:15", - "datetime": "2024-03-14T13:14:15", - "datetime2": "2024-03-14T13:14:15", - "datetimeoffset": "2024-03-14T13:14:15+02:00", - "nvarchar": "Unicode String", - "varchar": "ASCII String", - "money_val": 1234.56, - "small_money_val": 12.34, - "blob_data": "data:application/octet-stream;base64,bXNzcWw=", - "unique_identifier": "6f9619ff-8b86-d011-b42d-00c04fc964ff" - }), - ); - Ok(()) - } - - fn expect_json_object_equal(actual: &Value, expected: &Value) { - use std::fmt::Write; - - if json_values_equal(actual, expected) { - return; - } - let actual = actual.as_object().unwrap(); - let expected = expected.as_object().unwrap(); - - let all_keys: std::collections::BTreeSet<_> = - actual.keys().chain(expected.keys()).collect(); - let max_key_len = all_keys.iter().map(|k| k.len()).max().unwrap_or(0); - - let mut comparison_string = String::new(); - for key in all_keys { - let actual_value = actual.get(key).unwrap_or(&Value::Null); - let expected_value = expected.get(key).unwrap_or(&Value::Null); - if json_values_equal(actual_value, expected_value) { - continue; - } - writeln!( - &mut comparison_string, - "{key: anyhow::Result<()> { - let db_url = test_database_url(); - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - - // Test various column name formats to ensure canonical_col_name works correctly - let row = sqlx::query( - r#"SELECT - 42 as "UPPERCASE_COL", - 42 as "lowercase_col", - 42 as "Mixed_Case_Col", - 42 as "COL_WITH_123_NUMBERS", - 42 as "col-with-dashes", - 42 as "col with spaces", - 42 as "_UNDERSCORE_PREFIX", - 42 as "123_NUMBER_PREFIX" - "#, - ) - .fetch_one(&mut c) - .await?; - - let json_result = row_to_json(&row); - - // For ODBC databases, uppercase columns should be converted to lowercase - // For other databases, names should remain as-is - let expected_json = if c.kind() == sqlx::any::AnyKind::Odbc { - // ODBC database - uppercase should be converted to lowercase - serde_json::json!({ - "uppercase_col": 42, - "lowercase_col": 42, - "Mixed_Case_Col": 42, - "COL_WITH_123_NUMBERS": 42, - "col-with-dashes": 42, - "col with spaces": 42, - "_underscore_prefix": 42, - "123_NUMBER_PREFIX": 42 - }) - } else { - // Non-ODBC database - names remain as-is - serde_json::json!({ - "UPPERCASE_COL": 42, - "lowercase_col": 42, - "Mixed_Case_Col": 42, - "COL_WITH_123_NUMBERS": 42, - "col-with-dashes": 42, - "col with spaces": 42, - "_UNDERSCORE_PREFIX": 42, - "123_NUMBER_PREFIX": 42 - }) - }; - - expect_json_object_equal(&json_result, &expected_json); - - Ok(()) - } - - #[actix_web::test] - async fn test_row_to_json_edge_cases() -> anyhow::Result<()> { - let db_url = test_database_url(); - let mut c = sqlx::AnyConnection::connect(&db_url).await?; - let dbms_name = c.dbms_name().await.expect("retrieve db name"); - - // Test edge cases for row_to_json - let row = sqlx::query( - "SELECT - NULL as null_col, - '' as empty_string, - 0 as zero_value, - -42 as negative_int, - 1.23456 as my_float, - 'special_chars_!@#$%^&*()' as special_chars, - 'line1 -line2' as multiline_string - ", - ) - .fetch_one(&mut c) - .await?; - - let json_result = row_to_json(&row); - - // For Oracle databases, empty string is treated as NULL. - let empty_str_is_null = dbms_name.to_lowercase().contains("oracle"); - - let expected_json = serde_json::json!({ - "null_col": null, - "empty_string": if empty_str_is_null { serde_json::Value::Null } else { serde_json::Value::String(String::new()) }, - "zero_value": 0, - "negative_int": -42, - "my_float": 1.23456, - "special_chars": "special_chars_!@#$%^&*()", - "multiline_string": "line1\nline2" - }); - - expect_json_object_equal(&json_result, &expected_json); - - Ok(()) - } - - /// Compare JSON values, treating integers and floats that are numerically equal as equal - fn json_values_equal(a: &Value, b: &Value) -> bool { - use Value::*; - - match (a, b) { - (Null, Null) => true, - (Bool(a), Bool(b)) => a == b, - (Number(a), Number(b)) => { - // Treat integers and floats as equal if they represent the same numerical value - a.as_f64() == b.as_f64() - } - (String(a), String(b)) => a == b, - (Array(a), Array(b)) => { - a.len() == b.len() && a.iter().zip(b.iter()).all(|(a, b)| json_values_equal(a, b)) - } - (Object(a), Object(b)) => { - if a.len() != b.len() { - return false; - } - a.iter().all(|(key, value)| { - b.get(key) - .is_some_and(|expected_value| json_values_equal(value, expected_value)) - }) - } - _ => false, - } - } } diff --git a/src/webserver/database/sqlpage_functions/function_definition_macro.rs b/src/webserver/database/sqlpage_functions/function_definition_macro.rs index 235b7ffa..f9696597 100644 --- a/src/webserver/database/sqlpage_functions/function_definition_macro.rs +++ b/src/webserver/database/sqlpage_functions/function_definition_macro.rs @@ -56,7 +56,7 @@ macro_rules! sqlpage_functions { &self, #[allow(unused_variables)] request: &'a $crate::webserver::http_request_info::ExecutionContext, - db_connection: &mut Option>, + db_connection: &mut $crate::webserver::database::execute_queries::DbConn, params: Vec>> ) -> anyhow::Result>> { use $crate::webserver::database::sqlpage_functions::function_traits::*; From 8500e2192202521e5c6eb612f4ef86cb8fa640f9 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 18:03:19 +0200 Subject: [PATCH 3/9] Restore compatibility on native sqlx drivers --- Cargo.lock | 1 + Cargo.toml | 1 + src/filesystem.rs | 219 +++++++++--------- src/lib.rs | 5 + src/webserver/database/connect.rs | 173 ++++++++++++-- src/webserver/database/csv_import.rs | 2 - src/webserver/database/execute_queries.rs | 7 +- src/webserver/database/migrations.rs | 20 +- src/webserver/database/mod.rs | 40 +++- src/webserver/database/sql.rs | 2 + .../database/sql/parameter_extraction.rs | 5 +- src/webserver/database/sql_to_json.rs | 102 +++++++- tests/core/mod.rs | 94 ++++++-- 13 files changed, 509 insertions(+), 162 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 72aca6f7..99723e84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4533,6 +4533,7 @@ dependencies = [ "include_dir", "lambda-web", "libflate", + "libsqlite3-sys", "log", "markdown", "mime_guess", diff --git a/Cargo.toml b/Cargo.toml index 01246e5e..9cb29922 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ sqlx = { version = "0.9.0", default-features = false, features = [ ] } sqlx-sqlserver = { version = "0.0.2", features = ["migrate"] } sqlx-odbc = { version = "0.0.1", features = ["runtime-tokio"] } +libsqlite3-sys = "0.37" chrono = "0.4.23" actix-web = { version = "4", features = ["rustls-0_23", "cookies"] } percent-encoding = "2.2.0" diff --git a/src/filesystem.rs b/src/filesystem.rs index e16b4b63..87637aff 100644 --- a/src/filesystem.rs +++ b/src/filesystem.rs @@ -266,9 +266,9 @@ impl DbFsQueries { log::debug!("Initializing database filesystem queries"); Self::check_table_available(db).await?; Ok(Self { - was_modified: Self::make_was_modified_query(db).await?, - read_file: Self::make_read_file_query(db).await?, - exists: Self::make_exists_query(db).await?, + was_modified: Self::make_was_modified_query(db), + read_file: Self::make_read_file_query(db), + exists: Self::make_exists_query(db), }) } @@ -279,31 +279,31 @@ impl DbFsQueries { Ok(()) } - async fn make_was_modified_query(db: &Database) -> anyhow::Result { + fn make_was_modified_query(db: &Database) -> String { let was_modified_query = format!( "SELECT 1 from sqlpage_files WHERE last_modified >= {} AND path = {}", make_placeholder(db.info.kind, 1), make_placeholder(db.info.kind, 2) ); log::debug!("Preparing the database filesystem was_modified_query: {was_modified_query}"); - Ok(was_modified_query) + was_modified_query } - async fn make_read_file_query(db: &Database) -> anyhow::Result { + fn make_read_file_query(db: &Database) -> String { let read_file_query = format!( "SELECT contents from sqlpage_files WHERE path = {}", make_placeholder(db.info.kind, 1), ); log::debug!("Preparing the database filesystem read_file_query: {read_file_query}"); - Ok(read_file_query) + read_file_query } - async fn make_exists_query(db: &Database) -> anyhow::Result { + fn make_exists_query(db: &Database) -> String { let exists_query = format!( "SELECT 1 from sqlpage_files WHERE path = {}", make_placeholder(db.info.kind, 1), ); - Ok(exists_query) + exists_query } async fn file_modified_since_in_db( @@ -329,11 +329,10 @@ impl DbFsQueries { &path, ) .await - .with_context(|| format!("Unable to check when {path} was last modified in the database"))?; - log::trace!( - "DB File {} was modified result: {was_modified_i32:?}", - path - ); + .with_context(|| { + format!("Unable to check when {path} was last modified in the database") + })?; + log::trace!("DB File {path} was modified result: {was_modified_i32:?}"); Ok(was_modified_i32 == Some(1)) } @@ -368,12 +367,9 @@ impl DbFsQueries { ); let result = fetch_optional_i32_path(&app_state.db.connection, &self.exists, &path).await; log::debug!("DB File exists result: {result:?}"); - result.map(|result| result.is_some()).with_context(|| { - format!( - "Unable to check if {} exists in the database", - path - ) - }) + result + .map(|result| result.is_some()) + .with_context(|| format!("Unable to check if {path} exists in the database")) } } @@ -408,48 +404,49 @@ async fn fetch_optional_i32_since_path( since: DateTime, path: &str, ) -> sqlx::Result> { - let since = since.to_rfc3339(); + let since_text = since.to_rfc3339(); + let sqlite_since_text = since.format("%Y-%m-%d %H:%M:%S%.f").to_string(); match pool { DatabasePool::Sqlite(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(&sqlite_since_text) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(&since) - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Postgres(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(&since) - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::MySql(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(since) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(&since) - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Mssql(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(&since_text) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(&since) - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Odbc(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(&since_text) + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(&since) - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), } } @@ -460,40 +457,40 @@ async fn fetch_optional_i32_path( ) -> sqlx::Result> { match pool { DatabasePool::Sqlite(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Postgres(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::MySql(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Mssql(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Odbc(pool) => { - sqlx::query_as::(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), } } @@ -504,47 +501,46 @@ async fn fetch_optional_bytes_path( ) -> sqlx::Result>> { match pool { DatabasePool::Sqlite(pool) => { - sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Postgres(pool) => { - sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::MySql(pool) => { - sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Mssql(pool) => { - sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), DatabasePool::Odbc(pool) => { - sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) + { sqlx::query_as::,)>(sqlx::AssertSqlSafe(sql)) } + .bind(path) + .fetch_optional(pool) + .await + .map(|row| row.map(|(value,)| value)) } - .bind(path) - .fetch_optional(pool) - .await - .map(|row| row.map(|(value,)| value)), } } #[actix_web::test] async fn test_sql_file_read_utf8() -> anyhow::Result<()> { use crate::app_config; - use sqlx::Executor; let config = app_config::tests::test_config(); let state = AppState::init(&config).await?; @@ -570,8 +566,13 @@ async fn test_sql_file_read_utf8() -> anyhow::Result<()> { make_placeholder(dbms, 1), make_placeholder(dbms, 2) ); - insert_test_file(conn, &insert_sql, "unit test file.txt", "Héllö world! 😀".as_bytes()) - .await?; + insert_test_file( + conn, + &insert_sql, + "unit test file.txt", + "Héllö world! 😀".as_bytes(), + ) + .await?; let fs = FileSystem::init("/", db).await; let actual = fs @@ -612,31 +613,31 @@ async fn insert_test_file( contents: &[u8], ) -> sqlx::Result<()> { match pool { - DatabasePool::Sqlite(pool) => sqlx::query::(sql) + DatabasePool::Sqlite(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) .bind(path) .bind(contents) .execute(pool) .await .map(|_| ()), - DatabasePool::Postgres(pool) => sqlx::query::(sql) + DatabasePool::Postgres(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) .bind(path) .bind(contents) .execute(pool) .await .map(|_| ()), - DatabasePool::MySql(pool) => sqlx::query::(sql) + DatabasePool::MySql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) .bind(path) .bind(contents) .execute(pool) .await .map(|_| ()), - DatabasePool::Mssql(pool) => sqlx::query::(sql) + DatabasePool::Mssql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) .bind(path) .bind(contents) .execute(pool) .await .map(|_| ()), - DatabasePool::Odbc(pool) => sqlx::query::(sql) + DatabasePool::Odbc(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) .bind(path) .bind(contents) .execute(pool) diff --git a/src/lib.rs b/src/lib.rs index d0c13b3f..dd726e89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,7 @@ impl AppState { Self::init_with_db(config, db).await } pub async fn init_with_db(config: &AppConfig, db: Database) -> anyhow::Result { + install_default_rustls_provider(); let all_templates = AllTemplates::init(config)?; let mut sql_file_cache = FileCache::new(); let file_system = FileSystem::init(&config.web_root, &db).await; @@ -151,6 +152,10 @@ impl AppState { } } +fn install_default_rustls_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + impl std::fmt::Debug for AppState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AppState").finish() diff --git a/src/webserver/database/connect.rs b/src/webserver/database/connect.rs index 801d63c2..d0d070bf 100644 --- a/src/webserver/database/connect.rs +++ b/src/webserver/database/connect.rs @@ -1,17 +1,12 @@ -use std::time::Duration; +use std::{ffi::CString, time::Duration}; use super::{Database, DatabasePool, DbInfo, DbKind, SupportedDatabase}; -use crate::{ - ON_CONNECT_FILE, ON_RESET_FILE, - app_config::AppConfig, -}; +use crate::{ON_CONNECT_FILE, ON_RESET_FILE, app_config::AppConfig}; use anyhow::Context; use futures_util::future::BoxFuture; use sqlx::{ ColumnIndex, ConnectOptions, Connection, Database as SqlxDatabase, Decode, Executor, Row, Type, - mysql::MySqlConnectOptions, - pool::PoolOptions, - postgres::PgConnectOptions, + mysql::MySqlConnectOptions, pool::PoolOptions, postgres::PgConnectOptions, sqlite::SqliteConnectOptions, }; use sqlx_odbc::{OdbcConnectOptions, OdbcConnection}; @@ -19,17 +14,21 @@ use sqlx_sqlserver::MssqlConnectOptions; use url::Url; impl Database { + #[allow(clippy::too_many_lines)] pub async fn init(config: &AppConfig) -> anyhow::Result { let database_url = database_url_with_password(config)?; let db_kind = DbKind::from_database_url(&database_url); - log::debug!("Connecting to a {db_kind:?} database on {}", config.database_url); + log::debug!( + "Connecting to a {db_kind:?} database on {}", + config.database_url + ); let connection = match db_kind { DbKind::Sqlite => { let mut options = database_url.parse::()?; options = set_common_connect_options(options); options = set_custom_connect_options_sqlite(options, config); - let pool = Self::create_pool_options::(config, db_kind) + let pool = Self::create_sqlite_pool_options(config, db_kind) .connect_with(options) .await .with_context(|| { @@ -141,6 +140,21 @@ impl Database { let pool_options = add_on_return_to_pool(config, pool_options); add_on_connection_handler(config, pool_options) } + + fn create_sqlite_pool_options(config: &AppConfig, kind: DbKind) -> PoolOptions { + let max_connections = config + .max_database_pool_connections + .unwrap_or_else(|| default_max_connections(config, kind)); + let pool_options = PoolOptions::new() + .max_connections(max_connections) + .idle_timeout(config.database_connection_idle_timeout) + .max_lifetime(config.database_connection_max_lifetime) + .acquire_timeout(Duration::from_secs_f64( + config.database_connection_acquire_timeout_seconds, + )); + let pool_options = add_on_return_to_pool(config, pool_options); + add_sqlite_on_connection_handler(config, pool_options) + } } fn default_max_connections(config: &AppConfig, kind: DbKind) -> u32 { @@ -167,10 +181,7 @@ where .log_slow_statements(log::LevelFilter::Warn, Duration::from_millis(250)) } -fn add_on_return_to_pool( - config: &AppConfig, - pool_options: PoolOptions, -) -> PoolOptions +fn add_on_return_to_pool(config: &AppConfig, pool_options: PoolOptions) -> PoolOptions where DB: SqlxDatabase, for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, @@ -204,7 +215,9 @@ where { Box::pin(async move { log::trace!("Running the custom SQL connection cleanup handler. {meta:?}"); - let query_result = conn.fetch_optional(sqlx::AssertSqlSafe(sql.as_str())).await?; + let query_result = conn + .fetch_optional(sqlx::AssertSqlSafe(sql.as_str())) + .await?; if let Some(query_result) = query_result { let is_healthy = query_result.try_get::(0); log::debug!("Is the connection healthy? {is_healthy:?}"); @@ -245,6 +258,32 @@ where }) } +fn add_sqlite_on_connection_handler( + config: &AppConfig, + pool_options: PoolOptions, +) -> PoolOptions { + let sql = read_optional_handler_sql(config, ON_CONNECT_FILE, "database connection"); + let on_connect_file_display = config + .configuration_directory + .join(ON_CONNECT_FILE) + .display() + .to_string(); + + pool_options.after_connect(move |conn, _| { + let sql = sql.clone(); + let on_connect_file_display = on_connect_file_display.clone(); + Box::pin(async move { + install_sqlite_unicode_functions(conn).await?; + if let Some(sql) = sql { + log::debug!("Running {on_connect_file_display} on new connection"); + conn.execute(sqlx::AssertSqlSafe(sql.as_str())).await?; + log::debug!("Finished running connection handler on new connection"); + } + Ok(()) + }) + }) +} + fn read_optional_handler_sql( config: &AppConfig, file_name: &str, @@ -288,6 +327,108 @@ fn set_custom_connect_options_sqlite( sqlite_options.collation("NOCASE", |a, b| a.to_lowercase().cmp(&b.to_lowercase())) } +async fn install_sqlite_unicode_functions(conn: &mut sqlx::SqliteConnection) -> sqlx::Result<()> { + let mut handle = conn.lock_handle().await?; + let sqlite = handle.as_raw_handle().as_ptr(); + register_sqlite_function(sqlite, b"upper\0", sqlite_upper)?; + register_sqlite_function(sqlite, b"lower\0", sqlite_lower)?; + Ok(()) +} + +fn register_sqlite_function( + sqlite: *mut libsqlite3_sys::sqlite3, + name: &'static [u8], + function: unsafe extern "C" fn( + *mut libsqlite3_sys::sqlite3_context, + i32, + *mut *mut libsqlite3_sys::sqlite3_value, + ), +) -> sqlx::Result<()> { + let result = unsafe { + libsqlite3_sys::sqlite3_create_function_v2( + sqlite, + name.as_ptr().cast(), + 1, + libsqlite3_sys::SQLITE_UTF8 | libsqlite3_sys::SQLITE_DETERMINISTIC, + std::ptr::null_mut(), + Some(function), + None, + None, + None, + ) + }; + if result == libsqlite3_sys::SQLITE_OK { + Ok(()) + } else { + Err(sqlx::Error::Protocol(format!( + "sqlite3_create_function_v2 failed with code {result}" + ))) + } +} + +unsafe extern "C" fn sqlite_upper( + ctx: *mut libsqlite3_sys::sqlite3_context, + n_arg: i32, + args: *mut *mut libsqlite3_sys::sqlite3_value, +) { + sqlite_case_function(ctx, n_arg, args, str::to_uppercase); +} + +unsafe extern "C" fn sqlite_lower( + ctx: *mut libsqlite3_sys::sqlite3_context, + n_arg: i32, + args: *mut *mut libsqlite3_sys::sqlite3_value, +) { + sqlite_case_function(ctx, n_arg, args, str::to_lowercase); +} + +fn sqlite_case_function( + ctx: *mut libsqlite3_sys::sqlite3_context, + n_arg: i32, + args: *mut *mut libsqlite3_sys::sqlite3_value, + f: fn(&str) -> String, +) { + if n_arg != 1 { + unsafe { + libsqlite3_sys::sqlite3_result_error_code( + ctx, + libsqlite3_sys::SQLITE_CONSTRAINT_FUNCTION, + ); + } + return; + } + let arg = unsafe { *args }; + let Some(input) = sqlite_value_text(arg) else { + unsafe { libsqlite3_sys::sqlite3_result_null(ctx) }; + return; + }; + let output = f(&input); + match CString::new(output) { + Ok(output) => unsafe { + libsqlite3_sys::sqlite3_result_text( + ctx, + output.as_ptr(), + -1, + libsqlite3_sys::SQLITE_TRANSIENT(), + ); + }, + Err(_) => unsafe { + libsqlite3_sys::sqlite3_result_error_code(ctx, libsqlite3_sys::SQLITE_CONSTRAINT); + }, + } +} + +fn sqlite_value_text(value: *mut libsqlite3_sys::sqlite3_value) -> Option { + let text = unsafe { libsqlite3_sys::sqlite3_value_text(value) }; + if text.is_null() { + return None; + } + let len = unsafe { libsqlite3_sys::sqlite3_value_bytes(value) }; + let len = usize::try_from(len).ok()?; + let bytes = unsafe { std::slice::from_raw_parts(text.cast::(), len) }; + Some(String::from_utf8_lossy(bytes).into_owned()) +} + fn set_custom_connect_options_odbc(odbc_options: &mut OdbcConnectOptions, config: &AppConfig) { let batch_size = config.max_pending_rows.clamp(1, 1024); odbc_options.batch_size(batch_size); @@ -332,6 +473,6 @@ fn database_url_with_password(config: &AppConfig) -> anyhow::Result { let mut url = Url::parse(&config.database_url) .with_context(|| format!("Unable to parse {}", config.database_url))?; url.set_password(Some(password)) - .map_err(|_| anyhow::anyhow!("Unable to set password in database URL"))?; + .map_err(|()| anyhow::anyhow!("Unable to set password in database URL"))?; Ok(url.to_string()) } diff --git a/src/webserver/database/csv_import.rs b/src/webserver/database/csv_import.rs index 71ecae25..e552b4ea 100644 --- a/src/webserver/database/csv_import.rs +++ b/src/webserver/database/csv_import.rs @@ -366,8 +366,6 @@ fn test_make_statement() { #[actix_web::test] async fn test_end_to_end() { - use sqlx::Connection; - let mut copy_stmt = sqlparser::parser::Parser::parse_sql( &sqlparser::dialect::GenericDialect {}, "COPY my_table (col1, col2) FROM 'my_file.csv' (DELIMITER ';', HEADER)", diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index ebbd3be8..20d48801 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -339,9 +339,10 @@ where returned_rows += 1; } apply_json_columns(&mut query_result, &stmt.json_columns); - if let Err(err) = apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) - .instrument(query_span.clone()) - .await + if let Err(err) = + apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) + .instrument(query_span.clone()) + .await { error = Some(err); break; diff --git a/src/webserver/database/migrations.rs b/src/webserver/database/migrations.rs index 0d67ed66..186eab0e 100644 --- a/src/webserver/database/migrations.rs +++ b/src/webserver/database/migrations.rs @@ -39,15 +39,17 @@ pub async fn apply(config: &crate::app_config::AppConfig, db: &Database) -> anyh "ODBC migrations are not supported by sqlx-odbc. Apply the migrations manually or use a native SQLPage backend for managed migrations." ); } - run_migrator(&migrator, &db.connection).await.map_err(|err| { - match err { - MigrateError::Execute(source) => anyhow::Error::new(source), - source => anyhow::Error::new(source), - } - .context(format!( - "Failed to apply database migrations from {MIGRATIONS_DIR:?}" - )) - })?; + run_migrator(&migrator, &db.connection) + .await + .map_err(|err| { + match err { + MigrateError::Execute(source) => anyhow::Error::new(source), + source => anyhow::Error::new(source), + } + .context(format!( + "Failed to apply database migrations from {MIGRATIONS_DIR:?}" + )) + })?; Ok(()) } diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 922663f1..48261bc1 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -116,6 +116,44 @@ impl DatabasePool { Self::Odbc(pool) => pool.close().await, } } + + pub async fn execute(&self, sql: &str) -> sqlx::Result<()> { + match self { + Self::Sqlite(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + Self::Postgres(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + Self::MySql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + Self::Mssql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + Self::Odbc(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .execute(pool) + .await + .map(|_| ()), + } + } + + pub async fn acquire( + &self, + ) -> sqlx::Result { + use crate::webserver::database::execute_queries::DbConnection; + match self { + Self::Sqlite(pool) => pool.acquire().await.map(DbConnection::Sqlite), + Self::Postgres(pool) => pool.acquire().await.map(DbConnection::Postgres), + Self::MySql(pool) => pool.acquire().await.map(DbConnection::MySql), + Self::Mssql(pool) => pool.acquire().await.map(DbConnection::Mssql), + Self::Odbc(pool) => pool.acquire().await.map(DbConnection::Odbc), + } + } } /// Supported database types in `SQLPage`. Represents an actual DBMS, not a sqlx backend kind (like "Odbc") @@ -194,7 +232,7 @@ pub struct DbInfo { pub dbms_name: String, /// The actual database we are connected to. Can be "Generic" when using an unknown ODBC driver pub database_type: SupportedDatabase, - /// The SQLPage backend we are using. Can be "Odbc", in which case we need to use `database_type` to know what database we are actually using. + /// The `SQLPage` backend we are using. Can be "Odbc", in which case we need to use `database_type` to know what database we are actually using. pub kind: DbKind, } diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index c246e668..2b5759b4 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -715,6 +715,8 @@ fn expr_to_statement(expr: Expr) -> Statement { #[cfg(test)] mod test { + use crate::webserver::database::DbKind; + use super::super::sqlpage_functions::functions::SqlPageFunctionName; use super::super::syntax_tree::SqlPageFunctionCall; diff --git a/src/webserver/database/sql/parameter_extraction.rs b/src/webserver/database/sql/parameter_extraction.rs index d1a8ce1b..c5a48ab7 100644 --- a/src/webserver/database/sql/parameter_extraction.rs +++ b/src/webserver/database/sql/parameter_extraction.rs @@ -38,10 +38,7 @@ pub(crate) const DB_PLACEHOLDERS: [(DbKind, DbPlaceHolder); 5] = [ DbKind::Mssql, DbPlaceHolder::PrefixedNumber { prefix: "@p" }, ), - ( - DbKind::Odbc, - DbPlaceHolder::Positional { placeholder: "?" }, - ), + (DbKind::Odbc, DbPlaceHolder::Positional { placeholder: "?" }), ]; /// For positional parameters, we use a temporary placeholder during parameter extraction, diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index f1e8aafd..50bd95c6 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -1,6 +1,9 @@ use crate::utils::add_value_to_map; use crate::webserver::database::blob_to_data_url; +use bigdecimal::BigDecimal; +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; use serde_json::{Map, Value}; +use sqlx::postgres::types::PgRange; use sqlx::{Column, ColumnIndex, Row, TypeInfo, ValueRef}; pub trait SqlPageRow { @@ -41,12 +44,32 @@ macro_rules! impl_sqlpage_row { }; } -impl_sqlpage_row!(sqlx::postgres::PgRow, sqlx::Postgres, false); impl_sqlpage_row!(sqlx::mysql::MySqlRow, sqlx::MySql, false); impl_sqlpage_row!(sqlx::sqlite::SqliteRow, sqlx::Sqlite, false); impl_sqlpage_row!(sqlx_sqlserver::MssqlRow, sqlx_sqlserver::Mssql, false); impl_sqlpage_row!(sqlx_odbc::OdbcRow, sqlx_odbc::Odbc, true); +impl SqlPageRow for sqlx::postgres::PgRow { + fn to_json(&self) -> Value { + let mut map = Map::new(); + for col in self.columns() { + let key = canonical_col_name(col.name(), false); + let value = pg_to_json(self, col.ordinal()); + map = add_value_to_map(map, (key, value)); + } + Value::Object(map) + } + + fn first_value_to_string(&self) -> Option { + let col = self.columns().first()?; + match pg_to_json(self, col.ordinal()) { + Value::String(s) => Some(s), + Value::Null => None, + other => Some(other.to_string()), + } + } +} + fn canonical_col_name(name: &str, canonicalize_uppercase: bool) -> String { if canonicalize_uppercase && name @@ -104,6 +127,83 @@ where } } +fn pg_to_json(row: &sqlx::postgres::PgRow, ordinal: usize) -> Value { + let raw_value = match row.try_get_raw(ordinal) { + Ok(raw_value) if raw_value.is_null() => return Value::Null, + Ok(raw_value) => raw_value, + Err(e) => { + log::warn!("Unable to extract value from row: {e:?}"); + return Value::Null; + } + }; + let type_info = raw_value.type_info(); + let type_name = type_info.name().to_ascii_uppercase(); + log::trace!("Decoding a PostgreSQL value of type {type_name:?} (type info: {type_info:?})"); + + match type_name.as_str() { + "BOOL" | "BOOLEAN" => decode::(row, ordinal).into(), + "INT2" | "SMALLINT" => decode::(row, ordinal).into(), + "INT" | "INT4" | "INTEGER" => decode::(row, ordinal).into(), + "INT8" | "BIGINT" | "SERIAL8" | "BIGSERIAL" => { + decode::(row, ordinal).into() + } + "REAL" | "FLOAT4" => decode::(row, ordinal).into(), + "FLOAT" | "FLOAT8" | "DOUBLE" => decode::(row, ordinal).into(), + "NUMERIC" | "DECIMAL" => { + decimal_to_json(&decode::(row, ordinal)) + } + "DATE" => decode::(row, ordinal) + .to_string() + .into(), + "TIME" | "TIMETZ" => decode::(row, ordinal) + .to_string() + .into(), + "TIMESTAMP" => decode::(row, ordinal) + .format("%FT%T%.f") + .to_string() + .into(), + "TIMESTAMPTZ" => decode::>(row, ordinal) + .to_rfc3339() + .into(), + "JSON" | "JSON[]" | "JSONB" | "JSONB[]" => decode::(row, ordinal), + "BYTEA" => blob_to_data_url::vec_to_data_uri_value(&decode::>( + row, ordinal, + )), + "UUID" => decode::(row, ordinal) + .to_string() + .into(), + "INT4RANGE" => decode_pg_range::(row, ordinal), + "INT8RANGE" => decode_pg_range::(row, ordinal), + "NUMRANGE" => decode_pg_range::(row, ordinal), + "DATERANGE" => decode_pg_range::(row, ordinal), + "TSRANGE" => decode_pg_range::(row, ordinal), + "TSTZRANGE" => decode_pg_range::>(row, ordinal), + _ => decode::(row, ordinal).into(), + } +} + +fn decimal_to_json(decimal: &BigDecimal) -> Value { + Value::Number(serde_json::Number::from_string_unchecked( + decimal.normalized().to_plain_string(), + )) +} + +fn decode_pg_range(row: &sqlx::postgres::PgRow, ordinal: usize) -> Value +where + T: std::fmt::Display + sqlx::Type, + for<'r> T: sqlx::Decode<'r, sqlx::Postgres>, + PgRange: sqlx::Type, + for<'r> PgRange: sqlx::Decode<'r, sqlx::Postgres>, +{ + match row.try_get::, _>(ordinal) { + Ok(pg_range) => pg_range.to_string().into(), + Err(e) => { + log::error!("Failed to decode postgres range value: {e}"); + Value::Null + } + } +} + fn decode(row: &R, ordinal: usize) -> T where DB: sqlx::Database, diff --git a/tests/core/mod.rs b/tests/core/mod.rs index 229aeac2..8b0aa4b9 100644 --- a/tests/core/mod.rs +++ b/tests/core/mod.rs @@ -1,6 +1,7 @@ use actix_web::{http::StatusCode, test}; use sqlpage::{ AppState, + webserver::database::{DatabasePool, execute_queries::DbConnection}, webserver::{self, make_placeholder}, }; use sqlx::Executor as _; @@ -66,11 +67,13 @@ async fn test_routing_with_db_fs() { "INSERT INTO sqlpage_files(path, contents) VALUES ('on_db.sql', {})", make_placeholder(state.db.info.kind, 1) ); - sqlx::query(&insert_sql) - .bind("select ''text'' as component, ''Hi from db !'' AS contents;".as_bytes()) - .execute(&state.db.connection) - .await - .unwrap(); + insert_db_file_contents( + &state.db.connection, + &insert_sql, + "select ''text'' as component, ''Hi from db !'' AS contents;".as_bytes(), + ) + .await + .unwrap(); let state = AppState::init(&config).await.unwrap(); let app_data = actix_web::web::Data::new(state); @@ -101,23 +104,29 @@ async fn test_non_unicode_static_path_returns_bad_request_with_db_fs() { let expected_db_path = "\u{FFFD}.txt"; let mut conn = state.db.connection.acquire().await.unwrap(); - (&mut *conn) - .execute(sqlpage::filesystem::DbFsQueries::get_create_table_sql( - sqlpage::webserver::database::SupportedDatabase::Sqlite, - )) - .await - .unwrap(); + if let DbConnection::Sqlite(conn) = &mut conn { + (&mut **conn) + .execute(sqlpage::filesystem::DbFsQueries::get_create_table_sql( + sqlpage::webserver::database::SupportedDatabase::Sqlite, + )) + .await + .unwrap(); + } else { + unreachable!("test uses sqlite"); + } let insert_sql = format!( "INSERT INTO sqlpage_files(path, contents) VALUES ({}, {})", make_placeholder(state.db.info.kind, 1), make_placeholder(state.db.info.kind, 2) ); - sqlx::query(&insert_sql) - .bind(expected_db_path) - .bind("file from db fs".as_bytes()) - .execute(&mut *conn) - .await - .unwrap(); + insert_db_file_path_contents( + &mut conn, + &insert_sql, + expected_db_path, + "file from db fs".as_bytes(), + ) + .await + .unwrap(); drop(conn); let state = AppState::init(&config).await.unwrap(); @@ -136,6 +145,57 @@ async fn test_non_unicode_static_path_returns_bad_request_with_db_fs() { ); } +async fn insert_db_file_contents( + pool: &DatabasePool, + sql: &str, + contents: &[u8], +) -> sqlx::Result<()> { + match pool { + DatabasePool::Sqlite(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Postgres(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::MySql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Mssql(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + DatabasePool::Odbc(pool) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(contents) + .execute(pool) + .await + .map(|_| ()), + } +} + +async fn insert_db_file_path_contents( + conn: &mut DbConnection, + sql: &str, + path: &str, + contents: &[u8], +) -> sqlx::Result<()> { + match conn { + DbConnection::Sqlite(conn) => sqlx::query::(sqlx::AssertSqlSafe(sql)) + .bind(path) + .bind(contents) + .execute(&mut **conn) + .await + .map(|_| ()), + _ => unreachable!("test uses sqlite"), + } +} + #[actix_web::test] async fn test_routing_with_prefix() { let mut config = test_config(); From 7231527a540725ab03d128322e30ff3267ef92d2 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 22:10:10 +0200 Subject: [PATCH 4/9] Use sqlx-sqlserver advanced type decoders --- Cargo.lock | 8 +- Cargo.toml | 7 +- .../survey.sql | 17 +- src/webserver/database/sql_to_json.rs | 83 +++++- tests/examples.rs | 265 ++++++++++++++++++ 5 files changed, 363 insertions(+), 17 deletions(-) create mode 100644 tests/examples.rs diff --git a/Cargo.lock b/Cargo.lock index 99723e84..26b02a66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4789,20 +4789,24 @@ dependencies = [ [[package]] name = "sqlx-sqlserver" -version = "0.0.2" +version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "205de428fb54061bc058154abbff60e42876335e6d727e8aef77a70028e6386b" +checksum = "d07a49eceb2cfa3ab6ae42765cd54063201b4120667d6da9ff58ec5e2cbf3301" dependencies = [ + "bigdecimal", + "chrono", "futures-core", "futures-util", "log", "native-tls", + "num-bigint", "percent-encoding", "sqlx-core", "thiserror 2.0.18", "tokio", "tokio-native-tls", "url", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9cb29922..95fbde27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,12 @@ sqlx = { version = "0.9.0", default-features = false, features = [ "json", "uuid", ] } -sqlx-sqlserver = { version = "0.0.2", features = ["migrate"] } +sqlx-sqlserver = { version = "0.0.3", features = [ + "bigdecimal", + "chrono", + "migrate", + "uuid", +] } sqlx-odbc = { version = "0.0.1", features = ["runtime-tokio"] } libsqlite3-sys = "0.37" chrono = "0.4.23" diff --git a/examples/microsoft sql server advanced forms/survey.sql b/examples/microsoft sql server advanced forms/survey.sql index e1406e10..acc3d560 100644 --- a/examples/microsoft sql server advanced forms/survey.sql +++ b/examples/microsoft sql server advanced forms/survey.sql @@ -5,23 +5,16 @@ FROM questions; -- Save all the answers to the database, whatever the number and id of the questions INSERT INTO survey_answers (question_id, answer) SELECT - question_id, - json_unquote( - json_extract( - sqlpage.variables('post'), - concat('$."', question_id, '"') - ) - ) -FROM json_table( - json_keys(sqlpage.variables('post')), - '$[*]' columns (question_id int path '$') -) as question_ids; + TRY_CONVERT(int, answers.[key]) as question_id, + answers.value as answer +FROM OPENJSON(sqlpage.variables('post')) as answers +WHERE TRY_CONVERT(int, answers.[key]) IS NOT NULL; -- Show the answers select 'card' as component, 'Survey results' as title; select questions.question_text as title, survey_answers.answer as description, - 'On ' || survey_answers.timestamp as footer + 'On ' + CONVERT(varchar(33), survey_answers.timestamp, 126) as footer from survey_answers inner join questions on questions.id = survey_answers.question_id; diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 50bd95c6..6df5091c 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -1,7 +1,7 @@ use crate::utils::add_value_to_map; use crate::webserver::database::blob_to_data_url; use bigdecimal::BigDecimal; -use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, Utc}; use serde_json::{Map, Value}; use sqlx::postgres::types::PgRange; use sqlx::{Column, ColumnIndex, Row, TypeInfo, ValueRef}; @@ -46,9 +46,29 @@ macro_rules! impl_sqlpage_row { impl_sqlpage_row!(sqlx::mysql::MySqlRow, sqlx::MySql, false); impl_sqlpage_row!(sqlx::sqlite::SqliteRow, sqlx::Sqlite, false); -impl_sqlpage_row!(sqlx_sqlserver::MssqlRow, sqlx_sqlserver::Mssql, false); impl_sqlpage_row!(sqlx_odbc::OdbcRow, sqlx_odbc::Odbc, true); +impl SqlPageRow for sqlx_sqlserver::MssqlRow { + fn to_json(&self) -> Value { + let mut map = Map::new(); + for col in self.columns() { + let key = canonical_col_name(col.name(), false); + let value = mssql_to_json(self, col.ordinal()); + map = add_value_to_map(map, (key, value)); + } + Value::Object(map) + } + + fn first_value_to_string(&self) -> Option { + let col = self.columns().first()?; + match mssql_to_json(self, col.ordinal()) { + Value::String(s) => Some(s), + Value::Null => None, + other => Some(other.to_string()), + } + } +} + impl SqlPageRow for sqlx::postgres::PgRow { fn to_json(&self) -> Value { let mut map = Map::new(); @@ -182,6 +202,65 @@ fn pg_to_json(row: &sqlx::postgres::PgRow, ordinal: usize) -> Value { } } +fn mssql_to_json(row: &sqlx_sqlserver::MssqlRow, ordinal: usize) -> Value { + let raw_value = match row.try_get_raw(ordinal) { + Ok(raw_value) if raw_value.is_null() => return Value::Null, + Ok(raw_value) => raw_value, + Err(e) => { + log::warn!("Unable to extract value from row: {e:?}"); + return Value::Null; + } + }; + let type_info = raw_value.type_info(); + let type_name = type_info.name().to_ascii_uppercase(); + log::trace!("Decoding a SQL Server value of type {type_name:?} (type info: {type_info:?})"); + + match type_name.as_str() { + "BIT" => decode::(row, ordinal).into(), + "SMALLINT" | "TINYINT" => decode::(row, ordinal).into(), + "INT" => decode::(row, ordinal).into(), + "BIGINT" => decode::(row, ordinal).into(), + "REAL" => decode::(row, ordinal).into(), + "FLOAT" => decode::(row, ordinal).into(), + "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY" => { + decimal_to_json(&decode::( + row, ordinal, + )) + } + "DATE" => decode::(row, ordinal) + .to_string() + .into(), + "TIME" => decode::(row, ordinal) + .to_string() + .into(), + "DATETIME2" => decode::(row, ordinal) + .format("%FT%T%.f") + .to_string() + .into(), + "DATETIME" | "SMALLDATETIME" => { + decode::>(row, ordinal) + .naive_local() + .format("%FT%T%.f") + .to_string() + .into() + } + "DATETIMEOFFSET" => decode::>(row, ordinal) + .to_rfc3339() + .into(), + "UNIQUEIDENTIFIER" => { + decode::(row, ordinal) + .to_string() + .into() + } + "FILESTREAM" | "VARBINARY" | "BIGVARBINARY" | "BINARY" | "IMAGE" => { + blob_to_data_url::vec_to_data_uri_value(&decode::>( + row, ordinal, + )) + } + _ => decode::(row, ordinal).into(), + } +} + fn decimal_to_json(decimal: &BigDecimal) -> Value { Value::Number(serde_json::Number::from_string_unchecked( decimal.normalized().to_plain_string(), diff --git a/tests/examples.rs b/tests/examples.rs new file mode 100644 index 00000000..fd2c8fad --- /dev/null +++ b/tests/examples.rs @@ -0,0 +1,265 @@ +use actix_web::{http::header, test, web::Data}; +use sqlpage::{ + AppState, + app_config::{self, AppConfig}, + webserver::{ + database::migrations, + http::{form_config, main_handler, payload_config}, + }, +}; +use std::collections::BTreeSet; +use std::path::Path; +use std::time::Duration; + +struct ExampleSmoke { + name: &'static str, + web_root: Option<&'static str>, + request_path: &'static str, +} + +const SQLITE_SMOKE_EXAMPLES: &[ExampleSmoke] = &[ + ExampleSmoke { + name: "CRUD - Authentication", + web_root: Some("www"), + request_path: "/", + }, + ExampleSmoke { + name: "cards-with-remote-content", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "charts, computations and custom components", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "corporate-conundrum", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "forms with a variable number of fields", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "forms-with-multiple-steps", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "handle-404", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "image gallery with user uploads", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "light-dark-toggle", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "master-detail-forms", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "modeling a many to many relationship with a form", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "multiple-choice-question", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "official-site", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "plots tables and forms", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "read-and-set-http-cookies", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "rich-text-editor", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "roundest_pokemon_rating", + web_root: Some("src"), + request_path: "/", + }, + ExampleSmoke { + name: "sending emails", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "simple-website-example", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "splitwise", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "todo application", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "using react and other custom scripts and styles", + web_root: None, + request_path: "/", + }, + ExampleSmoke { + name: "user-authentication", + web_root: None, + request_path: "/", + }, +]; + +const EXTERNAL_SERVICE_EXAMPLES: &[&str] = &[ + "PostGIS - using sqlpage with geographic data", + "SQLPage developer user interface", + "custom form component", + "make a geographic data application using sqlite extensions", + "microsoft sql server advanced forms", + "mysql json handling", + "nginx", + "single sign on", + "telemetry", + "tiny_twitter", + "todo application (PostgreSQL)", + "web servers - apache", +]; + +#[actix_web::test] +async fn examples_folder_is_fully_accounted_for() { + let actual = top_level_example_directories(); + let accounted_for = SQLITE_SMOKE_EXAMPLES + .iter() + .map(|example| example.name.to_owned()) + .chain( + EXTERNAL_SERVICE_EXAMPLES + .iter() + .map(|name| (*name).to_owned()), + ) + .collect::>(); + assert_eq!(actual, accounted_for); +} + +#[actix_web::test] +async fn sqlite_compatible_examples_render_their_entry_page() { + for example in SQLITE_SMOKE_EXAMPLES { + let response = request_example(example).await.unwrap_or_else(|err| { + panic!( + "failed to render entry page for example {:?}: {err:#}", + example.name + ) + }); + + let status = response.status(); + let body = test::read_body(response).await; + let body = String::from_utf8_lossy(&body); + assert!( + status.is_success() + || status.is_redirection() + || status == actix_web::http::StatusCode::UNAUTHORIZED, + "example {:?} returned status {status}; body:\n{body}", + example.name + ); + assert!( + !body.contains("SQLPage Error"), + "example {:?} rendered an SQLPage error with status {status}; body:\n{body}", + example.name + ); + } +} + +async fn request_example( + example: &ExampleSmoke, +) -> anyhow::Result { + let app_data = make_example_app_data(example).await?; + let request = test::TestRequest::get() + .uri(example.request_path) + .insert_header(header::Accept::html()) + .app_data(payload_config(&app_data)) + .app_data(form_config(&app_data)) + .app_data(app_data) + .to_srv_request(); + tokio::time::timeout(Duration::from_secs(8), main_handler(request)) + .await + .map_err(|err| anyhow::anyhow!("request timed out: {err}"))? + .map_err(|err| anyhow::anyhow!("request failed: {err:#}")) +} + +async fn make_example_app_data(example: &ExampleSmoke) -> anyhow::Result> { + sqlpage::telemetry::init_test_logging(); + + let root = Path::new("examples").join(example.name); + let mut config = load_example_config(&root)?; + config.database_url = "sqlite://:memory:?cache=shared".to_owned(); + config.max_database_pool_connections = Some(1); + config.database_connection_retries = 0; + config.database_connection_acquire_timeout_seconds = 8.0; + config.web_root = example + .web_root + .map_or_else(|| root.clone(), |web_root| root.join(web_root)); + + let state = Data::new(AppState::init(&config).await?); + migrations::apply(&config, &state.db).await?; + Ok(state) +} + +fn load_example_config(root: &Path) -> anyhow::Result { + let config_dir = if root.join("sqlpage").exists() { + root.join("sqlpage") + } else if root.join("sqlpage_config").exists() { + root.join("sqlpage_config") + } else { + root.join("sqlpage") + }; + + if config_dir.join("sqlpage.json").exists() || config_dir.join("sqlpage.yaml").exists() { + let mut config = app_config::load_from_directory(&config_dir)?; + config.configuration_directory = config_dir; + Ok(config) + } else { + let mut config = serde_json::from_str::( + r#"{ + "database_url": "sqlite://:memory:", + "max_database_pool_connections": 1, + "database_connection_retries": 0, + "database_connection_acquire_timeout_seconds": 8 + }"#, + )?; + config.configuration_directory = config_dir; + Ok(config) + } +} + +fn top_level_example_directories() -> BTreeSet { + std::fs::read_dir("examples") + .unwrap() + .map(Result::unwrap) + .filter(|entry| entry.file_type().unwrap().is_dir()) + .map(|entry| entry.file_name().into_string().unwrap()) + .collect() +} From 9872234815fd9e7926a7e341bcfe5699164ab67e Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 22:36:46 +0200 Subject: [PATCH 5/9] Preserve streaming and cover database examples --- Cargo.lock | 11 + Cargo.toml | 3 + examples/mysql json handling/index.sql | 12 +- .../migrations/0001_users_and_groups.sql | 6 +- examples/mysql json handling/survey.sql | 2 +- examples/nginx/website/add_comment.sql | 2 +- examples/nginx/website/index.sql | 4 +- examples/nginx/website/post.sql | 2 +- .../web servers - apache/website/index.sql | 6 +- src/webserver/database/execute_queries.rs | 303 +++++++++++------- src/webserver/database/sql_to_json.rs | 54 +++- tests/data_formats/csv_data_mssql.sql | 11 + tests/data_formats/mod.rs | 10 +- tests/examples.rs | 218 +++++++++++++ 14 files changed, 511 insertions(+), 133 deletions(-) create mode 100644 tests/data_formats/csv_data_mssql.sql diff --git a/Cargo.lock b/Cargo.lock index 26b02a66..ed7970ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3284,6 +3284,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-src" +version = "300.6.0+3.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8e8cbfd3a4a8c8f089147fd7aaa33cf8c7450c4d09f8f80698a0cf093abeff4" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.116" @@ -3292,6 +3301,7 @@ checksum = "f28a22dc7140cda5f096e5e7724a6962ca81a7f8bfd2979f9b18c11af56318c4" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -4538,6 +4548,7 @@ dependencies = [ "markdown", "mime_guess", "openidconnect", + "openssl", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", diff --git a/Cargo.toml b/Cargo.toml index 95fbde27..fac4a41c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,9 @@ sqlx-sqlserver = { version = "0.0.3", features = [ "uuid", ] } sqlx-odbc = { version = "0.0.1", features = ["runtime-tokio"] } +# sqlx-sqlserver currently uses native-tls. Enable vendored OpenSSL so Docker cross-builds +# do not depend on target-architecture OpenSSL development packages. +openssl = { version = "0.10", features = ["vendored"] } libsqlite3-sys = "0.37" chrono = "0.4.23" actix-web = { version = "4", features = ["rustls-0_23", "cookies"] } diff --git a/examples/mysql json handling/index.sql b/examples/mysql json handling/index.sql index 4d155872..3b14bf21 100644 --- a/examples/mysql json handling/index.sql +++ b/examples/mysql json handling/index.sql @@ -1,10 +1,10 @@ select 'form' as component, 'Create a new Group' as title, 'Create' as validate; select 'Name' as name; -insert into groups(name) select :Name where :Name is not null; +insert into `groups`(name) select :Name where :Name is not null; select 'list' as component, 'Groups' as title, 'No group yet' as empty_title; -select name as title from groups; +select name as title from `groups`; select 'form' as component, 'Add a user' as title, 'Add' as validate; select 'UserName' as name, 'Name' as label; @@ -15,7 +15,7 @@ select TRUE as multiple, 'press ctrl to select multiple values' as description, json_arrayagg(json_object("label", name, "value", id)) as options -from groups; +from `groups`; insert into users(name) select :UserName where :UserName is not null; insert into group_members(group_id, user_id) @@ -28,8 +28,8 @@ where :Memberships is not null; select 'list' as component, 'Users' as title, 'No user yet' as empty_title; select users.name as title, - group_concat(groups.name) as description + group_concat(`groups`.name) as description from users left join group_members on users.id = group_members.user_id -left join groups on groups.id = group_members.group_id -group by users.id, users.name; \ No newline at end of file +left join `groups` on `groups`.id = group_members.group_id +group by users.id, users.name; diff --git a/examples/mysql json handling/sqlpage/migrations/0001_users_and_groups.sql b/examples/mysql json handling/sqlpage/migrations/0001_users_and_groups.sql index 954872c0..5c0def64 100644 --- a/examples/mysql json handling/sqlpage/migrations/0001_users_and_groups.sql +++ b/examples/mysql json handling/sqlpage/migrations/0001_users_and_groups.sql @@ -3,7 +3,7 @@ create table users ( name varchar(255) not null ); -create table groups ( +create table `groups` ( id int primary key auto_increment, name varchar(255) not null ); @@ -12,6 +12,6 @@ create table group_members ( group_id int not null, user_id int not null, primary key (group_id, user_id), - foreign key (group_id) references groups (id), + foreign key (group_id) references `groups` (id), foreign key (user_id) references users (id) -); \ No newline at end of file +); diff --git a/examples/mysql json handling/survey.sql b/examples/mysql json handling/survey.sql index e1406e10..a5b730e7 100644 --- a/examples/mysql json handling/survey.sql +++ b/examples/mysql json handling/survey.sql @@ -22,6 +22,6 @@ select 'card' as component, 'Survey results' as title; select questions.question_text as title, survey_answers.answer as description, - 'On ' || survey_answers.timestamp as footer + CONCAT('On ', survey_answers.timestamp) as footer from survey_answers inner join questions on questions.id = survey_answers.question_id; diff --git a/examples/nginx/website/add_comment.sql b/examples/nginx/website/add_comment.sql index 2d7db565..ea461ea3 100644 --- a/examples/nginx/website/add_comment.sql +++ b/examples/nginx/website/add_comment.sql @@ -1,2 +1,2 @@ INSERT INTO comments (post_id, user_id, content) VALUES ($id, 1, :content); -SELECT 'redirect' as component, '/post/' || $id AS link; \ No newline at end of file +SELECT 'redirect' as component, CONCAT('/post/', $id) AS link; diff --git a/examples/nginx/website/index.sql b/examples/nginx/website/index.sql index 42711a57..af00fad1 100644 --- a/examples/nginx/website/index.sql +++ b/examples/nginx/website/index.sql @@ -4,7 +4,7 @@ SELECT p.title, u.username AS description, 'user' AS icon, - '/post/' || p.id AS link + CONCAT('/post/', p.id) AS link FROM posts p JOIN users u ON p.user_id = u.id -ORDER BY p.created_at DESC; \ No newline at end of file +ORDER BY p.created_at DESC; diff --git a/examples/nginx/website/post.sql b/examples/nginx/website/post.sql index 97e7dba2..a546a3d2 100644 --- a/examples/nginx/website/post.sql +++ b/examples/nginx/website/post.sql @@ -37,7 +37,7 @@ SELECT 'divider' as component; SELECT 'form' as component, 'Add a comment' as title, 'Post comment' as validate, - '/add_comment.sql?id=' || $id as action; + CONCAT('/add_comment.sql?id=', $id) as action; SELECT 'textarea' as type, 'content' as name, diff --git a/examples/web servers - apache/website/index.sql b/examples/web servers - apache/website/index.sql index d3d9438c..25aba978 100644 --- a/examples/web servers - apache/website/index.sql +++ b/examples/web servers - apache/website/index.sql @@ -1,9 +1,9 @@ select 'text' as component, true as article, - ' + CONCAT(' # Welcome to my website -Using SQLPage v' || sqlpage.version() || ' +Using SQLPage v', sqlpage.version(), ' -Connected to **MySQL** v' || version () as contents_md; +Connected to **MySQL** v', version ()) as contents_md; diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 20d48801..1566a266 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -176,7 +176,7 @@ pub fn stream_query_results_with_conn<'a>( &request.app_state.telemetry_metrics, ); record_query_params(&query_metrics.span, &query.param_values); - let items = execute_statement_collect( + let mut items = stream_statement_for_connection( connection, &query, source_file, @@ -184,9 +184,8 @@ pub fn stream_query_results_with_conn<'a>( request, query_span, &mut query_metrics, - ) - .await; - for db_item in items { + ); + while let Some(db_item) = items.next().await { yield db_item; } }, @@ -229,90 +228,89 @@ fn with_stmt_position( } } -async fn execute_statement_collect( - connection: &mut DbConnection, - query: &StatementWithParams<'_>, - source_file: &Path, - stmt: &StmtWithParams, - request: &ExecutionContext, +fn stream_statement_for_connection<'a>( + connection: &'a mut DbConnection, + query: &'a StatementWithParams<'_>, + source_file: &'a Path, + stmt: &'a StmtWithParams, + request: &'a ExecutionContext, query_span: tracing::Span, - query_metrics: &mut DbQueryMetricsContext<'_>, -) -> Vec { + query_metrics: &'a mut DbQueryMetricsContext<'_>, +) -> Pin + 'a>> { match connection { - DbConnection::Sqlite(conn) => { - collect_query_results::( - conn, - query, - source_file, - stmt, - request, - query_span, - query_metrics, - ) - .await - } - DbConnection::Postgres(conn) => { - collect_query_results::( - conn, - query, - source_file, - stmt, - request, - query_span, - query_metrics, - ) - .await - } + DbConnection::Sqlite(conn) => stream_prepared_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ), + DbConnection::Postgres(conn) => stream_prepared_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ), DbConnection::MySql(conn) => { - collect_query_results::( - conn, - query, - source_file, - stmt, - request, - query_span, - query_metrics, - ) - .await - } - DbConnection::Mssql(conn) => { - collect_query_results::( - conn, - query, - source_file, - stmt, - request, - query_span, - query_metrics, - ) - .await - } - DbConnection::Odbc(conn) => { - collect_query_results::( - conn, - query, - source_file, - stmt, - request, - query_span, - query_metrics, - ) - .await + if should_run_mysql_transaction_control_as_raw_sql(query) { + stream_raw_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + } else { + stream_prepared_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ) + } } + DbConnection::Mssql(conn) => stream_prepared_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ), + DbConnection::Odbc(conn) => stream_prepared_statement_results::( + conn, + query, + source_file, + stmt, + request, + query_span, + query_metrics, + ), } } -async fn collect_query_results( - connection: &mut PoolConnection, - query: &StatementWithParams<'_>, - source_file: &Path, - stmt: &StmtWithParams, - request: &ExecutionContext, +fn stream_prepared_statement_results<'a, DB>( + connection: &'a mut PoolConnection, + query: &'a StatementWithParams<'_>, + source_file: &'a Path, + stmt: &'a StmtWithParams, + request: &'a ExecutionContext, query_span: tracing::Span, - query_metrics: &mut DbQueryMetricsContext<'_>, -) -> Vec + query_metrics: &'a mut DbQueryMetricsContext<'_>, +) -> Pin + 'a>> where - DB: SqlxDatabase, + DB: SqlxDatabase + 'a, DB::QueryResult: std::fmt::Debug, DB::Row: super::sql_to_json::SqlPageRow, usize: ColumnIndex, @@ -320,44 +318,105 @@ where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, for<'q> Option: Encode<'q, DB> + Type, { - let mut stream = (&mut **connection).fetch_many(bind_query::(query)); - let mut error = None; - let mut returned_rows: i64 = 0; - let mut items = Vec::new(); - loop { - let start_next = std::time::Instant::now(); - let next_elem = stream.next().instrument(query_span.clone()).await; - query_metrics.add_duration(start_next.elapsed()); - let Some(elem) = next_elem else { break }; - - let mut query_result = parse_single_sql_result::(source_file, stmt, elem); - if let DbItem::Error(e) = query_result { - error = Some(e); - break; + Box::pin(async_stream::stream! { + let mut stream = (&mut **connection).fetch_many(bind_query::(query)); + let mut error = None; + let mut returned_rows: i64 = 0; + loop { + let start_next = std::time::Instant::now(); + let next_elem = stream.next().instrument(query_span.clone()).await; + query_metrics.add_duration(start_next.elapsed()); + let Some(elem) = next_elem else { break }; + + let mut query_result = parse_single_sql_result::(source_file, stmt, elem); + if let DbItem::Error(e) = query_result { + error = Some(e); + break; + } + if matches!(query_result, DbItem::Row(_)) { + returned_rows += 1; + } + apply_json_columns(&mut query_result, &stmt.json_columns); + if let Err(err) = + apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) + .instrument(query_span.clone()) + .await + { + error = Some(err); + break; + } + for db_item in parse_dynamic_rows(query_result) { + yield db_item; + } } - if matches!(query_result, DbItem::Row(_)) { - returned_rows += 1; + drop(stream); + if let Some(error) = error { + query_metrics.record_error(returned_rows, &error); + try_rollback_transaction(connection).await; + yield DbItem::Error(error); + } else { + query_metrics.record_success(returned_rows); + } + }) +} + +fn stream_raw_statement_results<'a, DB>( + connection: &'a mut PoolConnection, + query: &'a StatementWithParams<'_>, + source_file: &'a Path, + stmt: &'a StmtWithParams, + request: &'a ExecutionContext, + query_span: tracing::Span, + query_metrics: &'a mut DbQueryMetricsContext<'_>, +) -> Pin + 'a>> +where + DB: SqlxDatabase + 'a, + DB::QueryResult: std::fmt::Debug, + DB::Row: super::sql_to_json::SqlPageRow, + usize: ColumnIndex, + for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, +{ + Box::pin(async_stream::stream! { + let mut stream = + sqlx::raw_sql(sqlx::AssertSqlSafe(query.sql)).fetch_many(&mut **connection); + let mut error = None; + let mut returned_rows: i64 = 0; + loop { + let start_next = std::time::Instant::now(); + let next_elem = stream.next().instrument(query_span.clone()).await; + query_metrics.add_duration(start_next.elapsed()); + let Some(elem) = next_elem else { break }; + + let mut query_result = parse_single_sql_result::(source_file, stmt, elem); + if let DbItem::Error(e) = query_result { + error = Some(e); + break; + } + if matches!(query_result, DbItem::Row(_)) { + returned_rows += 1; + } + apply_json_columns(&mut query_result, &stmt.json_columns); + if let Err(err) = + apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) + .instrument(query_span.clone()) + .await + { + error = Some(err); + break; + } + for db_item in parse_dynamic_rows(query_result) { + yield db_item; + } } - apply_json_columns(&mut query_result, &stmt.json_columns); - if let Err(err) = - apply_delayed_functions(request, &stmt.delayed_functions, &mut query_result) - .instrument(query_span.clone()) - .await - { - error = Some(err); - break; + drop(stream); + if let Some(error) = error { + query_metrics.record_error(returned_rows, &error); + try_rollback_transaction(connection).await; + yield DbItem::Error(error); + } else { + query_metrics.record_success(returned_rows); } - items.extend(parse_dynamic_rows(query_result)); - } - drop(stream); - if let Some(error) = error { - query_metrics.record_error(returned_rows, &error); - try_rollback_transaction(connection).await; - items.push(DbItem::Error(error)); - } else { - query_metrics.record_success(returned_rows); - } - items + }) } /// Transforms a stream of database items to stop processing after encountering the first error. @@ -416,6 +475,22 @@ where } } +fn should_run_mysql_transaction_control_as_raw_sql(query: &StatementWithParams<'_>) -> bool { + if !query.param_values.is_empty() { + return false; + } + let sql = query.sql.trim().trim_end_matches(';').trim(); + let normalized = sql + .split_whitespace() + .collect::>() + .join(" ") + .to_ascii_uppercase(); + matches!( + normalized.as_str(), + "START TRANSACTION" | "BEGIN" | "COMMIT" | "ROLLBACK" + ) +} + async fn rollback_connection(connection: &mut DbConnection) { match connection { DbConnection::Sqlite(conn) => try_rollback_transaction(conn).await, diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 6df5091c..7bc13473 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -44,10 +44,30 @@ macro_rules! impl_sqlpage_row { }; } -impl_sqlpage_row!(sqlx::mysql::MySqlRow, sqlx::MySql, false); impl_sqlpage_row!(sqlx::sqlite::SqliteRow, sqlx::Sqlite, false); impl_sqlpage_row!(sqlx_odbc::OdbcRow, sqlx_odbc::Odbc, true); +impl SqlPageRow for sqlx::mysql::MySqlRow { + fn to_json(&self) -> Value { + let mut map = Map::new(); + for col in self.columns() { + let key = canonical_col_name(col.name(), false); + let value = mysql_to_json(self, col.ordinal()); + map = add_value_to_map(map, (key, value)); + } + Value::Object(map) + } + + fn first_value_to_string(&self) -> Option { + let col = self.columns().first()?; + match mysql_to_json(self, col.ordinal()) { + Value::String(s) => Some(s), + Value::Null => None, + other => Some(other.to_string()), + } + } +} + impl SqlPageRow for sqlx_sqlserver::MssqlRow { fn to_json(&self) -> Value { let mut map = Map::new(); @@ -202,6 +222,38 @@ fn pg_to_json(row: &sqlx::postgres::PgRow, ordinal: usize) -> Value { } } +fn mysql_to_json(row: &sqlx::mysql::MySqlRow, ordinal: usize) -> Value { + let raw_value = match row.try_get_raw(ordinal) { + Ok(raw_value) if raw_value.is_null() => return Value::Null, + Ok(raw_value) => raw_value, + Err(e) => { + log::warn!("Unable to extract value from row: {e:?}"); + return Value::Null; + } + }; + let type_info = raw_value.type_info(); + let type_name = type_info.name().to_ascii_uppercase(); + log::trace!("Decoding a MySQL value of type {type_name:?} (type info: {type_info:?})"); + + match type_name.as_str() { + "DECIMAL" | "NEWDECIMAL" => { + decimal_to_json(&decode::(row, ordinal)) + } + "DATE" => decode::(row, ordinal) + .to_string() + .into(), + "TIME" => decode::(row, ordinal) + .to_string() + .into(), + "DATETIME" | "TIMESTAMP" => decode::(row, ordinal) + .format("%FT%T%.f") + .to_string() + .into(), + "JSON" => decode::>(row, ordinal).0, + _ => sql_to_json::(row, ordinal), + } +} + fn mssql_to_json(row: &sqlx_sqlserver::MssqlRow, ordinal: usize) -> Value { let raw_value = match row.try_get_raw(ordinal) { Ok(raw_value) if raw_value.is_null() => return Value::Null, diff --git a/tests/data_formats/csv_data_mssql.sql b/tests/data_formats/csv_data_mssql.sql new file mode 100644 index 00000000..d1cf3575 --- /dev/null +++ b/tests/data_formats/csv_data_mssql.sql @@ -0,0 +1,11 @@ +select + N'csv' as component, + N';' as separator; + +select + 0 as id, + N'Hello World !' as msg +union all +select + 1 as id, + N'Tu gères '';'' et ''"'' ?' as msg; diff --git a/tests/data_formats/mod.rs b/tests/data_formats/mod.rs index c7b3f178..d8c6a16b 100644 --- a/tests/data_formats/mod.rs +++ b/tests/data_formats/mod.rs @@ -49,7 +49,15 @@ async fn test_csv_body() -> actix_web::Result<()> { return Ok(()); } - let req = crate::common::get_request_to_with_data("/tests/data_formats/csv_data.sql", app_data) + let csv_test_file = if matches!( + app_data.db.info.database_type, + sqlpage::webserver::database::SupportedDatabase::Mssql + ) { + "/tests/data_formats/csv_data_mssql.sql" + } else { + "/tests/data_formats/csv_data.sql" + }; + let req = crate::common::get_request_to_with_data(csv_test_file, app_data) .await? .to_srv_request(); let resp = main_handler(req).await?; diff --git a/tests/examples.rs b/tests/examples.rs index fd2c8fad..decaa14a 100644 --- a/tests/examples.rs +++ b/tests/examples.rs @@ -8,6 +8,7 @@ use sqlpage::{ }, }; use std::collections::BTreeSet; +use std::env; use std::path::Path; use std::time::Duration; @@ -17,6 +18,13 @@ struct ExampleSmoke { request_path: &'static str, } +struct DatabaseExampleSmoke { + name: &'static str, + web_root: Option<&'static str>, + request_path: &'static str, + db_name: &'static str, +} + const SQLITE_SMOKE_EXAMPLES: &[ExampleSmoke] = &[ ExampleSmoke { name: "CRUD - Authentication", @@ -150,6 +158,66 @@ const EXTERNAL_SERVICE_EXAMPLES: &[&str] = &[ "web servers - apache", ]; +const POSTGRES_SMOKE_EXAMPLES: &[DatabaseExampleSmoke] = &[ + DatabaseExampleSmoke { + name: "SQLPage developer user interface", + web_root: Some("website"), + request_path: "/", + db_name: "sqlpage_example_developer_ui", + }, + DatabaseExampleSmoke { + name: "telemetry", + web_root: Some("website"), + request_path: "/", + db_name: "sqlpage_example_telemetry", + }, + DatabaseExampleSmoke { + name: "tiny_twitter", + web_root: None, + request_path: "/", + db_name: "sqlpage_example_tiny_twitter", + }, + DatabaseExampleSmoke { + name: "todo application (PostgreSQL)", + web_root: None, + request_path: "/", + db_name: "sqlpage_example_todo_postgres", + }, + DatabaseExampleSmoke { + name: "user-authentication", + web_root: None, + request_path: "/", + db_name: "sqlpage_example_user_authentication", + }, +]; + +const MYSQL_SMOKE_EXAMPLES: &[DatabaseExampleSmoke] = &[ + DatabaseExampleSmoke { + name: "custom form component", + web_root: None, + request_path: "/", + db_name: "sqlpage_example_custom_form", + }, + DatabaseExampleSmoke { + name: "mysql json handling", + web_root: None, + request_path: "/", + db_name: "sqlpage_example_mysql_json", + }, + DatabaseExampleSmoke { + name: "nginx", + web_root: Some("website"), + request_path: "/", + db_name: "sqlpage_example_nginx", + }, + DatabaseExampleSmoke { + name: "web servers - apache", + web_root: Some("website"), + request_path: "/my_website/", + db_name: "sqlpage_example_apache", + }, +]; + #[actix_web::test] async fn examples_folder_is_fully_accounted_for() { let actual = top_level_example_directories(); @@ -165,6 +233,34 @@ async fn examples_folder_is_fully_accounted_for() { assert_eq!(actual, accounted_for); } +#[actix_web::test] +async fn postgres_examples_render_their_entry_page_when_database_url_is_provided() { + let Ok(admin_url) = env::var("SQLPAGE_TEST_EXAMPLES_POSTGRES_ADMIN_URL") else { + return; + }; + + for example in POSTGRES_SMOKE_EXAMPLES { + let database_url = recreate_postgres_database(&admin_url, example.db_name) + .await + .unwrap_or_else(|err| panic!("failed to prepare {:?}: {err:#}", example.name)); + assert_database_example_renders(example, &database_url).await; + } +} + +#[actix_web::test] +async fn mysql_examples_render_their_entry_page_when_database_url_is_provided() { + let Ok(admin_url) = env::var("SQLPAGE_TEST_EXAMPLES_MYSQL_ADMIN_URL") else { + return; + }; + + for example in MYSQL_SMOKE_EXAMPLES { + let database_url = recreate_mysql_database(&admin_url, example.db_name) + .await + .unwrap_or_else(|err| panic!("failed to prepare {:?}: {err:#}", example.name)); + assert_database_example_renders(example, &database_url).await; + } +} + #[actix_web::test] async fn sqlite_compatible_examples_render_their_entry_page() { for example in SQLITE_SMOKE_EXAMPLES { @@ -193,6 +289,38 @@ async fn sqlite_compatible_examples_render_their_entry_page() { } } +async fn assert_database_example_renders(example: &DatabaseExampleSmoke, database_url: &str) { + let response = request_database_example(example, database_url) + .await + .unwrap_or_else(|err| { + panic!( + "failed to render entry page for example {:?}: {err:#}", + example.name + ) + }); + + assert_successful_example_response(example.name, response).await; +} + +async fn assert_successful_example_response( + example_name: &str, + response: actix_web::dev::ServiceResponse, +) { + let status = response.status(); + let body = test::read_body(response).await; + let body = String::from_utf8_lossy(&body); + assert!( + status.is_success() + || status.is_redirection() + || status == actix_web::http::StatusCode::UNAUTHORIZED, + "example {example_name:?} returned status {status}; body:\n{body}", + ); + assert!( + !body.contains("SQLPage Error"), + "example {example_name:?} rendered an SQLPage error with status {status}; body:\n{body}", + ); +} + async fn request_example( example: &ExampleSmoke, ) -> anyhow::Result { @@ -210,6 +338,24 @@ async fn request_example( .map_err(|err| anyhow::anyhow!("request failed: {err:#}")) } +async fn request_database_example( + example: &DatabaseExampleSmoke, + database_url: &str, +) -> anyhow::Result { + let app_data = make_database_example_app_data(example, database_url).await?; + let request = test::TestRequest::get() + .uri(example.request_path) + .insert_header(header::Accept::html()) + .app_data(payload_config(&app_data)) + .app_data(form_config(&app_data)) + .app_data(app_data) + .to_srv_request(); + tokio::time::timeout(Duration::from_secs(8), main_handler(request)) + .await + .map_err(|err| anyhow::anyhow!("request timed out: {err}"))? + .map_err(|err| anyhow::anyhow!("request failed: {err:#}")) +} + async fn make_example_app_data(example: &ExampleSmoke) -> anyhow::Result> { sqlpage::telemetry::init_test_logging(); @@ -228,6 +374,27 @@ async fn make_example_app_data(example: &ExampleSmoke) -> anyhow::Result anyhow::Result> { + sqlpage::telemetry::init_test_logging(); + + let root = Path::new("examples").join(example.name); + let mut config = load_example_config(&root)?; + config.database_url = database_url.to_owned(); + config.max_database_pool_connections = Some(1); + config.database_connection_retries = 0; + config.database_connection_acquire_timeout_seconds = 8.0; + config.web_root = example + .web_root + .map_or_else(|| root.clone(), |web_root| root.join(web_root)); + + let state = Data::new(AppState::init(&config).await?); + migrations::apply(&config, &state.db).await?; + Ok(state) +} + fn load_example_config(root: &Path) -> anyhow::Result { let config_dir = if root.join("sqlpage").exists() { root.join("sqlpage") @@ -255,6 +422,57 @@ fn load_example_config(root: &Path) -> anyhow::Result { } } +async fn recreate_postgres_database(admin_url: &str, db_name: &str) -> anyhow::Result { + use sqlx::Connection; + + validate_database_name(db_name)?; + let mut conn = sqlx::PgConnection::connect(admin_url).await?; + let drop_database = format!(r#"DROP DATABASE IF EXISTS "{db_name}" WITH (FORCE)"#); + let create_database = format!(r#"CREATE DATABASE "{db_name}""#); + sqlx::query(sqlx::AssertSqlSafe(drop_database.as_str())) + .execute(&mut conn) + .await?; + sqlx::query(sqlx::AssertSqlSafe(create_database.as_str())) + .execute(&mut conn) + .await?; + + database_url_with_path(admin_url, db_name) +} + +async fn recreate_mysql_database(admin_url: &str, db_name: &str) -> anyhow::Result { + use sqlx::Connection; + + validate_database_name(db_name)?; + let mut conn = sqlx::MySqlConnection::connect(admin_url).await?; + let drop_database = format!("DROP DATABASE IF EXISTS `{db_name}`"); + let create_database = format!("CREATE DATABASE `{db_name}`"); + sqlx::query(sqlx::AssertSqlSafe(drop_database.as_str())) + .execute(&mut conn) + .await?; + sqlx::query(sqlx::AssertSqlSafe(create_database.as_str())) + .execute(&mut conn) + .await?; + + database_url_with_path(admin_url, db_name) +} + +fn validate_database_name(db_name: &str) -> anyhow::Result<()> { + if db_name + .chars() + .all(|c| c.is_ascii_lowercase() || c == '_' || c.is_ascii_digit()) + { + Ok(()) + } else { + anyhow::bail!("invalid test database name: {db_name}"); + } +} + +fn database_url_with_path(admin_url: &str, db_name: &str) -> anyhow::Result { + let mut url = url::Url::parse(admin_url)?; + url.set_path(db_name); + Ok(url.to_string()) +} + fn top_level_example_directories() -> BTreeSet { std::fs::read_dir("examples") .unwrap() From 8ad2892e619851d024cd2dffa69245b93fe99983 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 22:45:04 +0200 Subject: [PATCH 6/9] Install Perl for vendored OpenSSL builds --- scripts/setup-cross-compilation.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/setup-cross-compilation.sh b/scripts/setup-cross-compilation.sh index 83531abb..b4a2799e 100755 --- a/scripts/setup-cross-compilation.sh +++ b/scripts/setup-cross-compilation.sh @@ -6,6 +6,7 @@ BUILDARCH="$2" BINDGEN_EXTRA_CLANG_ARGS="" apt-get update +apt-get install -y perl if [ "$TARGETARCH" = "$BUILDARCH" ]; then TARGET="$(rustup target list --installed | head -n1)" From 103a4299c3e95780a1121397217e234ca3ef74cf Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 22:56:40 +0200 Subject: [PATCH 7/9] Fix ODBC numeric decoding and MSSQL CSV fixture --- src/webserver/database/sql_to_json.rs | 63 ++++++++++++++++++++++++++- tests/data_formats/csv_data_mssql.sql | 2 +- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 7bc13473..9539f60b 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -45,7 +45,27 @@ macro_rules! impl_sqlpage_row { } impl_sqlpage_row!(sqlx::sqlite::SqliteRow, sqlx::Sqlite, false); -impl_sqlpage_row!(sqlx_odbc::OdbcRow, sqlx_odbc::Odbc, true); + +impl SqlPageRow for sqlx_odbc::OdbcRow { + fn to_json(&self) -> Value { + let mut map = Map::new(); + for col in self.columns() { + let key = canonical_col_name(col.name(), true); + let value = odbc_to_json(self, col.ordinal()); + map = add_value_to_map(map, (key, value)); + } + Value::Object(map) + } + + fn first_value_to_string(&self) -> Option { + let col = self.columns().first()?; + match odbc_to_json(self, col.ordinal()) { + Value::String(s) => Some(s), + Value::Null => None, + other => Some(other.to_string()), + } + } +} impl SqlPageRow for sqlx::mysql::MySqlRow { fn to_json(&self) -> Value { @@ -254,6 +274,27 @@ fn mysql_to_json(row: &sqlx::mysql::MySqlRow, ordinal: usize) -> Value { } } +fn odbc_to_json(row: &sqlx_odbc::OdbcRow, ordinal: usize) -> Value { + let raw_value = match row.try_get_raw(ordinal) { + Ok(raw_value) if raw_value.is_null() => return Value::Null, + Ok(raw_value) => raw_value, + Err(e) => { + log::warn!("Unable to extract value from row: {e:?}"); + return Value::Null; + } + }; + let type_info = raw_value.type_info(); + let type_name = type_info.name().to_ascii_uppercase(); + log::trace!("Decoding an ODBC value of type {type_name:?} (type info: {type_info:?})"); + + match type_name.as_str() { + "DECIMAL" | "NUMERIC" => { + string_decimal_to_json(&decode::(row, ordinal)) + } + _ => sql_to_json::(row, ordinal), + } +} + fn mssql_to_json(row: &sqlx_sqlserver::MssqlRow, ordinal: usize) -> Value { let raw_value = match row.try_get_raw(ordinal) { Ok(raw_value) if raw_value.is_null() => return Value::Null, @@ -319,6 +360,16 @@ fn decimal_to_json(decimal: &BigDecimal) -> Value { )) } +fn string_decimal_to_json(decimal: &str) -> Value { + match decimal.trim().parse::() { + Ok(number) => Value::Number(number), + Err(e) => { + log::warn!("Failed to parse decimal value {decimal:?}: {e}"); + Value::String(decimal.to_owned()) + } + } +} + fn decode_pg_range(row: &sqlx::postgres::PgRow, ordinal: usize) -> Value where T: std::fmt::Display + sqlx::Type, @@ -381,4 +432,14 @@ mod tests { ); Ok(()) } + + #[test] + fn test_string_decimal_to_json() { + assert_eq!(string_decimal_to_json("2"), serde_json::json!(2)); + assert_eq!(string_decimal_to_json(" 42.5 "), serde_json::json!(42.5)); + assert_eq!( + string_decimal_to_json("not a decimal"), + serde_json::json!("not a decimal") + ); + } } diff --git a/tests/data_formats/csv_data_mssql.sql b/tests/data_formats/csv_data_mssql.sql index d1cf3575..3c5a8ba0 100644 --- a/tests/data_formats/csv_data_mssql.sql +++ b/tests/data_formats/csv_data_mssql.sql @@ -8,4 +8,4 @@ select union all select 1 as id, - N'Tu gères '';'' et ''"'' ?' as msg; + CONCAT(N'Tu gères ', NCHAR(39), NCHAR(59), NCHAR(39), N' et ', NCHAR(39), NCHAR(34), NCHAR(39), N' ?') as msg; From d1d059f563f89f52eaeafae8cea6e06d58b99943 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Sun, 31 May 2026 23:04:54 +0200 Subject: [PATCH 8/9] Use ODBC column metadata for numeric JSON --- src/webserver/database/sql_to_json.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 9539f60b..36fb8670 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -283,7 +283,10 @@ fn odbc_to_json(row: &sqlx_odbc::OdbcRow, ordinal: usize) -> Value { return Value::Null; } }; - let type_info = raw_value.type_info(); + let type_info = row.columns().get(ordinal).map_or_else( + || raw_value.type_info(), + |col| std::borrow::Cow::Owned(col.type_info().clone()), + ); let type_name = type_info.name().to_ascii_uppercase(); log::trace!("Decoding an ODBC value of type {type_name:?} (type info: {type_info:?})"); @@ -442,4 +445,20 @@ mod tests { serde_json::json!("not a decimal") ); } + + #[test] + fn test_odbc_decimal_column_to_json() { + let row = sqlx_odbc::OdbcRow::new( + vec![sqlx_odbc::OdbcColumn::new( + 0, + "actual", + sqlx_odbc::OdbcTypeInfo::numeric(10, 2), + )], + vec![sqlx_odbc::OdbcValue::new(sqlx_odbc::OdbcValueKind::Text( + "2".to_owned(), + ))], + ); + + assert_eq!(row_to_json(&row), serde_json::json!({ "actual": 2 })); + } } From 8e49780831bb3d3050926ceb5846888ba0a7ffee Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Mon, 1 Jun 2026 11:23:44 +0200 Subject: [PATCH 9/9] Handle leading-dot ODBC decimals --- src/webserver/database/sql_to_json.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/webserver/database/sql_to_json.rs b/src/webserver/database/sql_to_json.rs index 36fb8670..bddbbf5e 100644 --- a/src/webserver/database/sql_to_json.rs +++ b/src/webserver/database/sql_to_json.rs @@ -364,7 +364,18 @@ fn decimal_to_json(decimal: &BigDecimal) -> Value { } fn string_decimal_to_json(decimal: &str) -> Value { - match decimal.trim().parse::() { + let trimmed = decimal.trim(); + let normalized; + let json_number = if trimmed.starts_with('.') { + normalized = format!("0{trimmed}"); + normalized.as_str() + } else if trimmed.starts_with("-.") { + normalized = format!("-0{}", &trimmed[1..]); + normalized.as_str() + } else { + trimmed + }; + match json_number.parse::() { Ok(number) => Value::Number(number), Err(e) => { log::warn!("Failed to parse decimal value {decimal:?}: {e}"); @@ -440,6 +451,8 @@ mod tests { fn test_string_decimal_to_json() { assert_eq!(string_decimal_to_json("2"), serde_json::json!(2)); assert_eq!(string_decimal_to_json(" 42.5 "), serde_json::json!(42.5)); + assert_eq!(string_decimal_to_json(".47"), serde_json::json!(0.47)); + assert_eq!(string_decimal_to_json("-.47"), serde_json::json!(-0.47)); assert_eq!( string_decimal_to_json("not a decimal"), serde_json::json!("not a decimal")