diff --git a/docs/architecture.md b/docs/architecture.md index 8b207bd..c57d807 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -165,7 +165,7 @@ vulnerabilities ( On PostgreSQL, `INTEGER PRIMARY KEY` becomes `SERIAL`, `DATETIME` becomes `TIMESTAMP`, `INTEGER DEFAULT 0` booleans become `BOOLEAN DEFAULT FALSE`, and size/count columns use `BIGINT`. -The `MigrateSchema()` function handles backward compatibility with older git-pkgs databases by adding missing columns via `ALTER TABLE` as needed. +The `MigrateSchema()` function handles backward compatibility with older git-pkgs databases by running named migrations that add missing columns and tables. See [migrations.md](migrations.md) for how to add new schema changes. **Key operations:** - `GetPackageByPURL()` - Look up package by PURL diff --git a/docs/migrations.md b/docs/migrations.md new file mode 100644 index 0000000..21ecb2e --- /dev/null +++ b/docs/migrations.md @@ -0,0 +1,51 @@ +# Database Migrations + +Schema changes are tracked in a `migrations` table. Each migration has a name and a function. On startup, `MigrateSchema()` loads the set of already-applied names in one query and runs anything new. + +Fresh databases created via `Create()` get the full schema and all migrations are recorded as already applied. + +## Adding a migration + +In `internal/database/schema.go`: + +1. Write a migration function: + +```go +func migrateAddWidgetColumn(db *DB) error { + hasCol, err := db.HasColumn("packages", "widget") + if err != nil { + return fmt.Errorf("checking column widget: %w", err) + } + if !hasCol { + colType := "TEXT" + if db.dialect == DialectPostgres { + colType = "TEXT" // adjust if types differ + } + if _, err := db.Exec(fmt.Sprintf("ALTER TABLE packages ADD COLUMN widget %s", colType)); err != nil { + return fmt.Errorf("adding column widget: %w", err) + } + } + return nil +} +``` + +2. Append it to the `migrations` slice with the next sequential prefix: + +```go +var migrations = []migration{ + {"001_add_packages_enrichment_columns", migrateAddPackagesEnrichmentColumns}, + {"002_add_versions_enrichment_columns", migrateAddVersionsEnrichmentColumns}, + {"003_ensure_artifacts_table", migrateEnsureArtifactsTable}, + {"004_ensure_vulnerabilities_table", migrateEnsureVulnerabilitiesTable}, + {"005_add_widget_column", migrateAddWidgetColumn}, // new +} +``` + +3. Add the same column to both `schemaSQLite` and `schemaPostgres` at the top of the file so fresh databases start with the full schema. + +## Rules + +- Migration functions must be idempotent. Use `HasColumn`/`HasTable` checks or `IF NOT EXISTS` clauses so they're safe to run against a database that already has the change. +- Handle both SQLite and Postgres dialects. Common differences: `DATETIME` vs `TIMESTAMP`, `INTEGER DEFAULT 0` vs `BOOLEAN DEFAULT FALSE`, `INTEGER PRIMARY KEY` vs `SERIAL PRIMARY KEY`. +- Never reorder or rename existing entries. The name string is the migration's identity in the database. +- Never remove old migrations from the list. They won't run on already-migrated databases, but they need to exist for older databases upgrading for the first time. diff --git a/internal/database/database_test.go b/internal/database/database_test.go index e85e937..6fca4ea 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -651,58 +651,159 @@ func TestMigrationFromOldSchema(t *testing.T) { } defer func() { _ = db.Close() }() - // Try to run queries that require new columns - these should fail without migration - t.Run("queries should fail without migration", func(t *testing.T) { - _, err := db.GetEnrichmentStats() - if err == nil { - t.Error("GetEnrichmentStats: expected error querying enriched_at column, got nil") - } - - _, err = db.GetPackageByEcosystemName("npm", "test-package") - if err == nil { - t.Error("GetPackageByEcosystemName: expected error querying registry_url column, got nil") - } - - // SearchPackages should work even with old schema because it uses sql.NullString - // for nullable columns, which can handle NULL values properly - _, err = db.SearchPackages("test", "", 10, 0) - if err != nil { - t.Errorf("SearchPackages: unexpected error with old schema: %v", err) - } - }) + // Queries that require new columns should fail without migration + if _, err := db.GetEnrichmentStats(); err == nil { + t.Error("GetEnrichmentStats: expected error querying enriched_at column, got nil") + } + if _, err := db.GetPackageByEcosystemName("npm", "test-package"); err == nil { + t.Error("GetPackageByEcosystemName: expected error querying registry_url column, got nil") + } + // SearchPackages should work even with old schema because it uses sql.NullString + if _, err := db.SearchPackages("test", "", 10, 0); err != nil { + t.Errorf("SearchPackages: unexpected error with old schema: %v", err) + } // Run migration - t.Run("migrate schema", func(t *testing.T) { - if err := db.MigrateSchema(); err != nil { - t.Fatalf("MigrateSchema failed: %v", err) - } - }) + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema failed: %v", err) + } // Verify queries work after migration - t.Run("queries should work after migration", func(t *testing.T) { - stats, err := db.GetEnrichmentStats() - if err != nil { - t.Errorf("GetEnrichmentStats failed after migration: %v", err) - } - if stats == nil { - t.Error("GetEnrichmentStats returned nil after migration") - } + stats, err := db.GetEnrichmentStats() + if err != nil { + t.Errorf("GetEnrichmentStats failed after migration: %v", err) + } + if stats == nil { + t.Error("GetEnrichmentStats returned nil after migration") + } - pkg, err := db.GetPackageByEcosystemName("npm", "test-package") - if err != nil { - t.Errorf("GetPackageByEcosystemName failed after migration: %v", err) + pkg, err := db.GetPackageByEcosystemName("npm", "test-package") + if err != nil { + t.Errorf("GetPackageByEcosystemName failed after migration: %v", err) + } + if pkg == nil { + t.Fatal("GetPackageByEcosystemName returned nil after migration") + } + if pkg.Name != "test-package" { + t.Errorf("expected package name test-package, got %s", pkg.Name) + } + + // Verify migrations were recorded + applied, err := db.appliedMigrations() + if err != nil { + t.Fatalf("appliedMigrations failed: %v", err) + } + for _, m := range migrations { + if !applied[m.name] { + t.Errorf("migration %s not recorded as applied", m.name) } - if pkg == nil { - t.Fatal("GetPackageByEcosystemName returned nil after migration") + } + + // Running again should be a no-op + if err := db.MigrateSchema(); err != nil { + t.Fatalf("second MigrateSchema failed: %v", err) + } +} + +func TestFreshDatabaseRecordsMigrations(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "fresh.db") + + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + applied, err := db.appliedMigrations() + if err != nil { + t.Fatalf("appliedMigrations failed: %v", err) + } + + for _, m := range migrations { + if !applied[m.name] { + t.Errorf("migration %s not recorded in fresh database", m.name) } - if pkg.Name != "test-package" { - t.Errorf("expected package name test-package, got %s", pkg.Name) + } +} + +func TestMigrateSchemaSkipsApplied(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + // All migrations are already recorded from Create. Running MigrateSchema + // should return without running any migration functions. + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema failed: %v", err) + } + + // Verify count hasn't changed (no duplicate inserts) + var count int + if err := db.Get(&count, "SELECT COUNT(*) FROM migrations"); err != nil { + t.Fatalf("counting migrations failed: %v", err) + } + if count != len(migrations) { + t.Errorf("expected %d migrations, got %d", len(migrations), count) + } +} + +func TestMigrateSchemaUpgradeFromFullyMigrated(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "existing.db") + + // Simulate an existing proxy database that has the full current schema + // but no migrations table (i.e. it was running the previous version). + sqlDB, err := sql.Open("sqlite", dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + if _, err := sqlDB.Exec(schemaSQLite); err != nil { + t.Fatalf("failed to create schema: %v", err) + } + // Drop the migrations table that schemaSQLite now includes + if _, err := sqlDB.Exec("DROP TABLE migrations"); err != nil { + t.Fatalf("failed to drop migrations table: %v", err) + } + if _, err := sqlDB.Exec("INSERT INTO schema_info (version) VALUES (1)"); err != nil { + t.Fatalf("failed to set schema version: %v", err) + } + if err := sqlDB.Close(); err != nil { + t.Fatalf("failed to close database: %v", err) + } + + db, err := Open(dbPath) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + defer func() { _ = db.Close() }() + + // This should create the migrations table and record all migrations + // without altering any tables (everything already exists). + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema failed: %v", err) + } + + applied, err := db.appliedMigrations() + if err != nil { + t.Fatalf("appliedMigrations failed: %v", err) + } + for _, m := range migrations { + if !applied[m.name] { + t.Errorf("migration %s not recorded after upgrade", m.name) } + } - // Note: SearchPackages not tested here because old timestamp data - // stored as strings can't be scanned into time.Time. This is a data - // migration issue, not a schema migration issue. - }) + // Second run should be the fast path (single SELECT) + if err := db.MigrateSchema(); err != nil { + t.Fatalf("second MigrateSchema failed: %v", err) + } } func TestConcurrentWrites(t *testing.T) { @@ -890,3 +991,26 @@ func TestSearchPackagesWithValues(t *testing.T) { t.Errorf("expected 10 hits, got %d", result.Hits) } } + +func BenchmarkMigrateSchemaFullyMigrated(b *testing.B) { + dir := b.TempDir() + dbPath := filepath.Join(dir, "bench.db") + + db, err := Create(dbPath) + if err != nil { + b.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + // First call to ensure everything is migrated + if err := db.MigrateSchema(); err != nil { + b.Fatalf("initial MigrateSchema failed: %v", err) + } + + b.ResetTimer() + for b.Loop() { + if err := db.MigrateSchema(); err != nil { + b.Fatalf("MigrateSchema failed: %v", err) + } + } +} diff --git a/internal/database/schema.go b/internal/database/schema.go index 496a129..233357f 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -1,6 +1,10 @@ package database -import "fmt" +import ( + "fmt" + "strings" + "time" +) const postgresTimestamp = "TIMESTAMP" @@ -86,6 +90,11 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( ); CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); + +CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL PRIMARY KEY, + applied_at DATETIME NOT NULL +); ` var schemaPostgres = ` @@ -166,6 +175,11 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( ); CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); + +CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL PRIMARY KEY, + applied_at TIMESTAMP NOT NULL +); ` // schemaArtifactsOnly contains just the artifacts table for adding to existing git-pkgs databases. @@ -232,6 +246,11 @@ func (db *DB) CreateSchema() error { return fmt.Errorf("setting schema version: %w", err) } + // Record all migrations as applied since the full schema is already current. + if err := db.recordAllMigrations(); err != nil { + return fmt.Errorf("recording migrations: %w", err) + } + return db.OptimizeForReads() } @@ -292,24 +311,135 @@ func (db *DB) HasColumn(table, column string) (bool, error) { return exists, err } -// MigrateSchema adds missing columns to existing tables for backward compatibility. +// migration represents a named schema migration. +type migration struct { + name string + fn func(db *DB) error +} + +// migrations is the ordered list of all schema migrations. See +// docs/migrations.md for how to add new ones. +var migrations = []migration{ + {"001_add_packages_enrichment_columns", migrateAddPackagesEnrichmentColumns}, + {"002_add_versions_enrichment_columns", migrateAddVersionsEnrichmentColumns}, + {"003_ensure_artifacts_table", migrateEnsureArtifactsTable}, + {"004_ensure_vulnerabilities_table", migrateEnsureVulnerabilitiesTable}, +} + +// isTableNotFound returns true if the error indicates a missing table. +// SQLite returns "no such table: X", Postgres returns "relation \"X\" does not exist". +func isTableNotFound(err error) bool { + msg := err.Error() + return strings.Contains(msg, "no such table") || + strings.Contains(msg, "does not exist") +} + +// createMigrationsTable creates the migrations table. +func (db *DB) createMigrationsTable() error { + var ts string + if db.dialect == DialectPostgres { + ts = "TIMESTAMP" + } else { + ts = "DATETIME" + } + + query := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS migrations ( + name TEXT NOT NULL PRIMARY KEY, + applied_at %s NOT NULL + )`, ts) + + if _, err := db.Exec(query); err != nil { + return fmt.Errorf("creating migrations table: %w", err) + } + return nil +} + +// appliedMigrations returns the set of migration names that have been recorded. +// Returns nil if the migrations table does not exist yet. +func (db *DB) appliedMigrations() (map[string]bool, error) { + var names []string + err := db.Select(&names, "SELECT name FROM migrations") + if err != nil { + // Table doesn't exist yet — this is a pre-migration database. + if isTableNotFound(err) { + return nil, nil + } + return nil, fmt.Errorf("loading applied migrations: %w", err) + } + + applied := make(map[string]bool, len(names)) + for _, name := range names { + applied[name] = true + } + return applied, nil +} + +// recordMigration inserts a migration name into the migrations table. +func (db *DB) recordMigration(name string) error { + query := db.Rebind("INSERT INTO migrations (name, applied_at) VALUES (?, ?)") + if _, err := db.Exec(query, name, time.Now().UTC()); err != nil { + return fmt.Errorf("recording migration %s: %w", name, err) + } + return nil +} + +// recordAllMigrations marks every known migration as applied. +func (db *DB) recordAllMigrations() error { + for _, m := range migrations { + if err := db.recordMigration(m.name); err != nil { + return err + } + } + return nil +} + +// MigrateSchema applies any unapplied migrations in order. +// For a fully migrated database this executes a single SELECT query. func (db *DB) MigrateSchema() error { - // Check and add missing columns to packages table - packagesColumns := map[string]string{ - "registry_url": "TEXT", - "supplier_name": "TEXT", - "supplier_type": "TEXT", - "source": "TEXT", - "enriched_at": "DATETIME", - "vulns_synced_at": "DATETIME", + applied, err := db.appliedMigrations() + if err != nil { + return err + } + + // If the migrations table didn't exist, create it now. + if applied == nil { + if err := db.createMigrationsTable(); err != nil { + return err + } + applied = make(map[string]bool) + } + + for _, m := range migrations { + if applied[m.name] { + continue + } + if err := m.fn(db); err != nil { + return fmt.Errorf("migration %s: %w", m.name, err) + } + if err := db.recordMigration(m.name); err != nil { + return err + } + } + + return nil +} + +func migrateAddPackagesEnrichmentColumns(db *DB) error { + columns := map[string]string{ + "registry_url": "TEXT", + "supplier_name": "TEXT", + "supplier_type": "TEXT", + "source": "TEXT", + "enriched_at": "DATETIME", + "vulns_synced_at": "DATETIME", } if db.dialect == DialectPostgres { - packagesColumns["enriched_at"] = postgresTimestamp - packagesColumns["vulns_synced_at"] = postgresTimestamp + columns["enriched_at"] = postgresTimestamp + columns["vulns_synced_at"] = postgresTimestamp } - for column, colType := range packagesColumns { + for column, colType := range columns { hasCol, err := db.HasColumn("packages", column) if err != nil { return fmt.Errorf("checking column %s: %w", column, err) @@ -321,9 +451,11 @@ func (db *DB) MigrateSchema() error { } } } + return nil +} - // Check and add missing columns to versions table - versionsColumns := map[string]string{ +func migrateAddVersionsEnrichmentColumns(db *DB) error { + columns := map[string]string{ "integrity": "TEXT", "yanked": "INTEGER DEFAULT 0", "source": "TEXT", @@ -331,11 +463,11 @@ func (db *DB) MigrateSchema() error { } if db.dialect == DialectPostgres { - versionsColumns["yanked"] = "BOOLEAN DEFAULT FALSE" - versionsColumns["enriched_at"] = postgresTimestamp + columns["yanked"] = "BOOLEAN DEFAULT FALSE" + columns["enriched_at"] = postgresTimestamp } - for column, colType := range versionsColumns { + for column, colType := range columns { hasCol, err := db.HasColumn("versions", column) if err != nil { return fmt.Errorf("checking column %s: %w", column, err) @@ -347,62 +479,64 @@ func (db *DB) MigrateSchema() error { } } } + return nil +} - // Ensure artifacts table exists - if err := db.EnsureArtifactsTable(); err != nil { - return fmt.Errorf("ensuring artifacts table: %w", err) - } +func migrateEnsureArtifactsTable(db *DB) error { + return db.EnsureArtifactsTable() +} - // Ensure vulnerabilities table exists +func migrateEnsureVulnerabilitiesTable(db *DB) error { hasVulns, err := db.HasTable("vulnerabilities") if err != nil { return fmt.Errorf("checking vulnerabilities table: %w", err) } - if !hasVulns { - var vulnSchema string - if db.dialect == DialectPostgres { - vulnSchema = ` - CREATE TABLE vulnerabilities ( - id SERIAL PRIMARY KEY, - vuln_id TEXT NOT NULL, - ecosystem TEXT NOT NULL, - package_name TEXT NOT NULL, - severity TEXT, - summary TEXT, - fixed_version TEXT, - cvss_score REAL, - "references" TEXT, - fetched_at TIMESTAMP, - created_at TIMESTAMP, - updated_at TIMESTAMP - ); - CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); - CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); - ` - } else { - vulnSchema = ` - CREATE TABLE vulnerabilities ( - id INTEGER PRIMARY KEY, - vuln_id TEXT NOT NULL, - ecosystem TEXT NOT NULL, - package_name TEXT NOT NULL, - severity TEXT, - summary TEXT, - fixed_version TEXT, - cvss_score REAL, - "references" TEXT, - fetched_at DATETIME, - created_at DATETIME, - updated_at DATETIME - ); - CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); - CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); - ` - } - if _, err := db.Exec(vulnSchema); err != nil { - return fmt.Errorf("creating vulnerabilities table: %w", err) - } + if hasVulns { + return nil } + var vulnSchema string + if db.dialect == DialectPostgres { + vulnSchema = ` + CREATE TABLE vulnerabilities ( + id SERIAL PRIMARY KEY, + vuln_id TEXT NOT NULL, + ecosystem TEXT NOT NULL, + package_name TEXT NOT NULL, + severity TEXT, + summary TEXT, + fixed_version TEXT, + cvss_score REAL, + "references" TEXT, + fetched_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); + CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); + ` + } else { + vulnSchema = ` + CREATE TABLE vulnerabilities ( + id INTEGER PRIMARY KEY, + vuln_id TEXT NOT NULL, + ecosystem TEXT NOT NULL, + package_name TEXT NOT NULL, + severity TEXT, + summary TEXT, + fixed_version TEXT, + cvss_score REAL, + "references" TEXT, + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); + CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); + ` + } + if _, err := db.Exec(vulnSchema); err != nil { + return fmt.Errorf("creating vulnerabilities table: %w", err) + } return nil }