diff --git a/db/common.h b/db/common.h index 2533d7be1238..8b4e55b4fee0 100644 --- a/db/common.h +++ b/db/common.h @@ -66,6 +66,9 @@ struct db { /* Fatal if we try to write to db */ bool readonly; + + /* Set during migrations to prevent STRICT mode on table creation */ + bool in_migration; }; struct db_query { diff --git a/db/db_sqlite3.c b/db/db_sqlite3.c index ed63989d4f66..387002349b58 100644 --- a/db/db_sqlite3.c +++ b/db/db_sqlite3.c @@ -203,7 +203,23 @@ static bool db_sqlite3_setup(struct db *db, bool create) "PRAGMA foreign_keys = ON;", -1, &stmt, NULL); err = sqlite3_step(stmt); sqlite3_finalize(stmt); - return err == SQLITE_DONE; + + if (err != SQLITE_DONE) + return false; + + if (db->developer) { + sqlite3_prepare_v2(conn2sql(db->conn), + "PRAGMA trusted_schema = OFF;", -1, &stmt, NULL); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + + sqlite3_prepare_v2(conn2sql(db->conn), + "PRAGMA cell_size_check = ON;", -1, &stmt, NULL); + sqlite3_step(stmt); + sqlite3_finalize(stmt); + } + + return true; } static bool db_sqlite3_query(struct db_stmt *stmt) @@ -211,8 +227,22 @@ static bool db_sqlite3_query(struct db_stmt *stmt) sqlite3_stmt *s; sqlite3 *conn = conn2sql(stmt->db->conn); int err; + const char *query = stmt->query->query; + char *modified_query = NULL; + + /* STRICT tables for developer mode, and not during upgrades. */ + if (stmt->db->developer && + !stmt->db->in_migration && + strncasecmp(query, "CREATE TABLE", 12) == 0 && + !strstr(query, "STRICT")) { + modified_query = tal_fmt(stmt, "%s STRICT", query); + query = modified_query; + } + + err = sqlite3_prepare_v2(conn, query, -1, &s, NULL); - err = sqlite3_prepare_v2(conn, stmt->query->query, -1, &s, NULL); + if (modified_query) + tal_free(modified_query); for (size_t i=0; iquery->placeholders; i++) { struct db_binding *b = &stmt->bindings[i]; diff --git a/db/utils.c b/db/utils.c index d6234179df5a..2091111089ce 100644 --- a/db/utils.c +++ b/db/utils.c @@ -364,6 +364,7 @@ struct db *db_open_(const tal_t *ctx, const char *filename, db->in_transaction = NULL; db->transaction_started = false; db->changes = NULL; + db->in_migration = false; /* This must be outside a transaction, so catch it */ assert(!db->in_transaction); diff --git a/devtools/sql-rewrite.py b/devtools/sql-rewrite.py index 03c358a643c7..4bee77a8714f 100755 --- a/devtools/sql-rewrite.py +++ b/devtools/sql-rewrite.py @@ -45,6 +45,8 @@ def rewrite_single(self, query): r'BIGINT': 'INTEGER', r'BIGINTEGER': 'INTEGER', r'BIGSERIAL': 'INTEGER', + r'VARCHAR(?:\(\d+\))?': 'TEXT', + r'\bINT\b': 'INTEGER', r'CURRENT_TIMESTAMP\(\)': "strftime('%s', 'now')", r'INSERT INTO[ \t]+(.*)[ \t]+ON CONFLICT.*DO NOTHING;': 'INSERT OR IGNORE INTO \\1;', # Rewrite "decode('abcd', 'hex')" to become "x'abcd'" diff --git a/tests/test_db.py b/tests/test_db.py index 47720dc581b2..237108df2b49 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -163,6 +163,14 @@ def test_scid_upgrade(node_factory, bitcoind): assert l1.db_query('SELECT scid FROM channels;') == [{'scid': scid_to_int('103x1x1')}] assert l1.db_query('SELECT failscid FROM payments;') == [{'failscid': scid_to_int('103x1x1')}] + faildetail_types = l1.db_query( + "SELECT id, typeof(faildetail) as type " + "FROM payments WHERE faildetail IS NOT NULL" + ) + for row in faildetail_types: + assert row['type'] == 'text', \ + f"Payment {row['id']}: faildetail has type {row['type']}, expected 'text'" + @unittest.skipIf(not COMPAT, "needs COMPAT to convert obsolete db") @unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "This test is based on a sqlite3 snapshot") @@ -642,3 +650,48 @@ def test_channel_htlcs_id_change(bitcoind, node_factory): # Make some HTLCS for amt in (100, 500, 1000, 5000, 10000, 50000, 100000): l1.pay(l3, amt) + + +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "STRICT tables are SQLite3 specific") +def test_sqlite_strict_mode(node_factory): + """Test that STRICT is appended to CREATE TABLE in developer mode.""" + l1 = node_factory.get_node(options={'developer': None}) + + tables = l1.db_query("SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + + strict_tables = [t for t in tables if t['sql'] and 'STRICT' in t['sql']] + assert len(strict_tables) > 0, f"Expected at least one STRICT table in developer mode, found none out of {len(tables)}" + + known_strict_tables = ['version', 'forwards', 'payments', 'local_anchors', 'addresses'] + for table_name in known_strict_tables: + table_sql = next((t['sql'] for t in tables if t['name'] == table_name), None) + if table_sql: + assert 'STRICT' in table_sql, f"Expected table '{table_name}' to be STRICT in developer mode" + + +@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "SQLite3-specific test") +@unittest.skipIf(not COMPAT, "needs COMPAT to test old database upgrade") +@unittest.skipIf(TEST_NETWORK != 'regtest', "The network must match the DB snapshot") +def test_strict_mode_with_old_database(node_factory, bitcoind): + """Test old database upgrades work (STRICT not applied during migrations).""" + bitcoind.generate_block(1) + + l1 = node_factory.get_node(dbfile='oldstyle-scids.sqlite3.xz', + options={'database-upgrade': True, + 'developer': None}) + + assert l1.rpc.getinfo()['id'] is not None + + # Upgraded tables won't be STRICT (only fresh databases get STRICT). + strict_tables = l1.db_query( + "SELECT name FROM sqlite_master " + "WHERE type='table' AND sql LIKE '%STRICT%'" + ) + assert len(strict_tables) == 0, "Upgraded database should not have STRICT tables" + + # Verify BLOB->TEXT migration ran for faildetail cleanup. + result = l1.db_query( + "SELECT COUNT(*) as count FROM payments " + "WHERE typeof(faildetail) = 'blob'" + ) + assert result[0]['count'] == 0, "Found BLOB-typed faildetail after migration" diff --git a/wallet/db.c b/wallet/db.c index a0387220b119..5146aab295be 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -86,6 +86,8 @@ static void migrate_initialize_channel_htlcs_wait_indexes_and_fixup_forwards(str struct db *db); static void migrate_fail_pending_payments_without_htlcs(struct lightningd *ld, struct db *db); +static void migrate_fix_payments_faildetail_type(struct lightningd *ld, + struct db *db); /* Do not reorder or remove elements from this array, it is used to * migrate existing databases from a previous state, based on the @@ -1102,6 +1104,8 @@ static struct migration dbmigrations[] = { ")"), NULL}, {NULL, migrate_fail_pending_payments_without_htlcs}, {SQL("ALTER TABLE channels ADD withheld INTEGER DEFAULT 0;"), NULL}, + /* Fix BLOB→TEXT in payments.faildetail for old databases. */ + {NULL, migrate_fix_payments_faildetail_type}, }; /** @@ -1118,6 +1122,9 @@ static bool db_migrate(struct lightningd *ld, struct db *db, orig = current = db_get_version(db); available = ARRAY_SIZE(dbmigrations) - 1; + /* Disable STRICT for upgrades: legacy data may have wrong type affinity. */ + db->in_migration = (current != -1); + if (current == -1) log_info(ld->log, "Creating database"); else if (available < current) { @@ -1195,6 +1202,8 @@ struct db *db_setup(const tal_t *ctx, struct lightningd *ld, db_commit_transaction(db); + db->in_migration = false; + /* This needs to be done outside a transaction, apparently. * It's a good idea to do this every so often, and on db * upgrade is a reasonable time. */ @@ -2153,3 +2162,55 @@ static void migrate_fail_pending_payments_without_htlcs(struct lightningd *ld, db_bind_int(stmt, payment_status_in_db(PAYMENT_PENDING)); db_exec_prepared_v2(take(stmt)); } + +static void migrate_fix_payments_faildetail_type(struct lightningd *ld, + struct db *db) +{ + /* Historical databases may have BLOB-typed faildetail data. + * STRICT mode rejects this, so convert or NULL out invalid UTF-8. */ + struct db_stmt *stmt; + size_t fixed = 0, invalid = 0; + + stmt = db_prepare_v2(db, SQL("SELECT id, faildetail " + "FROM payments " + "WHERE typeof(faildetail) = 'blob'")); + db_query_prepared(stmt); + + while (db_step(stmt)) { + u64 id = db_col_u64(stmt, "id"); + const u8 *blob = db_col_blob(stmt, "faildetail"); + size_t len = db_col_bytes(stmt, "faildetail"); + struct db_stmt *upd; + + if (!utf8_check(blob, len)) { + log_unusual(ld->log, "Payment %"PRIu64": " + "Invalid UTF-8 in faildetail, setting to NULL", + id); + upd = db_prepare_v2(db, + SQL("UPDATE payments " + "SET faildetail = NULL " + "WHERE id = ?")); + db_bind_u64(upd, id); + db_exec_prepared_v2(take(upd)); + invalid++; + continue; + } + + char *text = tal_strndup(tmpctx, (char *)blob, len); + upd = db_prepare_v2(db, + SQL("UPDATE payments " + "SET faildetail = ? " + "WHERE id = ?")); + db_bind_text(upd, text); + db_bind_u64(upd, id); + db_exec_prepared_v2(take(upd)); + fixed++; + } + + tal_free(stmt); + + if (fixed > 0 || invalid > 0) + log_info(ld->log, "payments.faildetail migration: " + "%zu converted, %zu invalid UTF-8 nulled", + fixed, invalid); +}