From 5b30f1870aa188465e7d7091203a507e6c9c2049 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Thu, 19 Mar 2026 21:06:02 +0000 Subject: [PATCH 1/5] Add mirror command and API for selective package mirroring Add a `proxy mirror` CLI command and `/api/mirror` API endpoints that pre-populate the cache from various input sources: individual PURLs, SBOM files (CycloneDX and SPDX), or full registry enumeration. The mirror reuses the existing handler.Proxy.GetOrFetchArtifact() pipeline so cached artifacts are identical to those fetched on demand. A bounded worker pool controls download parallelism. Metadata caching is opt-in via `cache_metadata: true` in config (or PROXY_CACHE_METADATA=true). The mirror command always enables it. When enabled, upstream metadata responses are stored for offline fallback with ETag-based conditional revalidation. New internal/mirror package with Source interface, PURLSource, SBOMSource, RegistrySource, and async JobStore. New metadata_cache database table for offline metadata serving. --- README.md | 49 +++++ cmd/proxy/main.go | 162 +++++++++++++++ docs/architecture.md | 23 ++- docs/configuration.md | 34 ++++ go.mod | 7 +- go.sum | 20 +- internal/config/config.go | 8 + internal/database/metadata_cache_test.go | 180 +++++++++++++++++ internal/database/queries.go | 61 ++++++ internal/database/schema.go | 84 ++++++++ internal/database/types.go | 14 ++ internal/handler/cargo.go | 62 ++---- internal/handler/cargo_test.go | 7 +- internal/handler/composer.go | 22 +- internal/handler/conan.go | 18 +- internal/handler/conda.go | 11 +- internal/handler/cran.go | 19 +- internal/handler/debian.go | 3 +- internal/handler/download_test.go | 5 +- internal/handler/gem.go | 17 +- internal/handler/go.go | 15 +- internal/handler/handler.go | 222 +++++++++++++++++---- internal/handler/hex.go | 16 +- internal/handler/maven.go | 4 +- internal/handler/npm.go | 34 +--- internal/handler/nuget.go | 21 +- internal/handler/nuget_test.go | 5 +- internal/handler/pub.go | 23 +-- internal/handler/pypi.go | 51 +---- internal/handler/rpm.go | 3 +- internal/mirror/job.go | 187 +++++++++++++++++ internal/mirror/job_test.go | 160 +++++++++++++++ internal/mirror/mirror.go | 181 +++++++++++++++++ internal/mirror/mirror_test.go | 195 ++++++++++++++++++ internal/mirror/registry.go | 47 +++++ internal/mirror/registry_test.go | 46 +++++ internal/mirror/source.go | 190 ++++++++++++++++++ internal/mirror/source_test.go | 243 +++++++++++++++++++++++ internal/server/mirror_api.go | 70 +++++++ internal/server/mirror_api_test.go | 163 +++++++++++++++ internal/server/server.go | 20 +- internal/server/server_test.go | 12 +- 42 files changed, 2449 insertions(+), 265 deletions(-) create mode 100644 internal/database/metadata_cache_test.go create mode 100644 internal/mirror/job.go create mode 100644 internal/mirror/job_test.go create mode 100644 internal/mirror/mirror.go create mode 100644 internal/mirror/mirror_test.go create mode 100644 internal/mirror/registry.go create mode 100644 internal/mirror/registry_test.go create mode 100644 internal/mirror/source.go create mode 100644 internal/mirror/source_test.go create mode 100644 internal/server/mirror_api.go create mode 100644 internal/server/mirror_api_test.go diff --git a/README.md b/README.md index 8da1165..e4aa31f 100644 --- a/README.md +++ b/README.md @@ -460,6 +460,47 @@ proxy serve [flags] proxy [flags] # same as 'proxy serve' ``` +### mirror + +Pre-populate the cache from PURLs, SBOM files, or entire registries. Useful for ensuring offline availability or warming the cache before deployments. + +```bash +# Mirror specific package versions +proxy mirror pkg:npm/lodash@4.17.21 pkg:cargo/serde@1.0.0 + +# Mirror all versions of a package +proxy mirror pkg:npm/lodash + +# Mirror from a CycloneDX or SPDX SBOM +proxy mirror --sbom sbom.cdx.json + +# Full registry mirror (npm, pypi, cargo supported) +proxy mirror --registry npm + +# Preview what would be mirrored +proxy mirror --dry-run pkg:npm/lodash + +# Control parallelism +proxy mirror --concurrency 8 pkg:npm/lodash@4.17.21 +``` + +The mirror command accepts the same storage and database flags as `serve`. Already-cached artifacts are skipped. + +A mirror API is also available when the server is running: + +```bash +# Start a mirror job +curl -X POST http://localhost:8080/api/mirror \ + -H "Content-Type: application/json" \ + -d '{"purls": ["pkg:npm/lodash@4.17.21"]}' + +# Check job status +curl http://localhost:8080/api/mirror/mirror-1 + +# Cancel a running job +curl -X DELETE http://localhost:8080/api/mirror/mirror-1 +``` + ### stats Show cache statistics without running the server. @@ -534,6 +575,14 @@ Recently cached: | `GET /debian/*` | Debian/APT repository protocol | | `GET /rpm/*` | RPM/Yum repository protocol | +### Mirror API + +| Endpoint | Description | +|----------|-------------| +| `POST /api/mirror` | Start a mirror job (JSON body with `purls` or `registry`) | +| `GET /api/mirror/{id}` | Get job status and progress | +| `DELETE /api/mirror/{id}` | Cancel a running job | + ### Enrichment API The proxy provides REST endpoints for package metadata enrichment, vulnerability scanning, and outdated detection. diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 4f13ea2..229f2d5 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -16,6 +16,7 @@ // // serve Start the proxy server (default if no command given) // stats Show cache statistics +// mirror Pre-populate cache from PURLs, SBOMs, or registries // // Serve Flags: // @@ -100,7 +101,11 @@ import ( "github.com/git-pkgs/proxy/internal/config" "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/mirror" "github.com/git-pkgs/proxy/internal/server" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" ) const defaultTopN = 10 @@ -124,6 +129,10 @@ func main() { os.Args = append(os.Args[:1], os.Args[2:]...) runStats() return + case "mirror": + os.Args = append(os.Args[:1], os.Args[2:]...) + runMirror() + return case "-version", "--version": fmt.Printf("proxy %s (%s)\n", Version, Commit) os.Exit(0) @@ -145,6 +154,7 @@ Usage: proxy [command] [flags] Commands: serve Start the proxy server (default) stats Show cache statistics + mirror Pre-populate cache from PURLs, SBOMs, or registries Run 'proxy -help' for more information on a command. @@ -340,6 +350,158 @@ func runStats() { } } +func runMirror() { + fs := flag.NewFlagSet("mirror", flag.ExitOnError) + configPath := fs.String("config", "", "Path to configuration file") + storageURL := fs.String("storage-url", "", "Storage URL (file:// or s3://)") + databaseDriver := fs.String("database-driver", "", "Database driver: sqlite or postgres") + databasePath := fs.String("database-path", "", "Path to SQLite database file") + databaseURL := fs.String("database-url", "", "PostgreSQL connection URL") + sbomPath := fs.String("sbom", "", "Path to CycloneDX or SPDX SBOM file") + registry := fs.String("registry", "", "Ecosystem name for full registry mirror") + concurrency := fs.Int("concurrency", 4, "Number of parallel downloads") //nolint:mnd // default concurrency + dryRun := fs.Bool("dry-run", false, "Show what would be mirrored without downloading") + + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "git-pkgs proxy - Pre-populate cache\n\n") + fmt.Fprintf(os.Stderr, "Usage: proxy mirror [flags] [purl...]\n\n") + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash@4.17.21\n") + fmt.Fprintf(os.Stderr, " proxy mirror --sbom sbom.cdx.json\n") + fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash # all versions\n") + fmt.Fprintf(os.Stderr, " proxy mirror --registry npm\n\n") + fmt.Fprintf(os.Stderr, "Flags:\n") + fs.PrintDefaults() + } + + _ = fs.Parse(os.Args[1:]) + purls := fs.Args() + + // Determine source + var source mirror.Source + switch { + case *sbomPath != "": + source = &mirror.SBOMSource{Path: *sbomPath} + case *registry != "": + source = &mirror.RegistrySource{Ecosystem: *registry} + case len(purls) > 0: + source = &mirror.PURLSource{PURLs: purls} + default: + fmt.Fprintf(os.Stderr, "error: provide PURLs, --sbom, or --registry\n") + fs.Usage() + os.Exit(1) + } + + // Load config + cfg, err := loadConfig(*configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading config: %v\n", err) + os.Exit(1) + } + cfg.LoadFromEnv() + + if *storageURL != "" { + cfg.Storage.URL = *storageURL + } + if *databaseDriver != "" { + cfg.Database.Driver = *databaseDriver + } + if *databasePath != "" { + cfg.Database.Path = *databasePath + } + if *databaseURL != "" { + cfg.Database.URL = *databaseURL + } + + if err := cfg.Validate(); err != nil { + fmt.Fprintf(os.Stderr, "invalid configuration: %v\n", err) + os.Exit(1) + } + + logger := setupLogger("info", "text") + + // Open database + var db *database.DB + switch cfg.Database.Driver { + case "postgres": + db, err = database.OpenPostgresOrCreate(cfg.Database.URL) + default: + db, err = database.OpenOrCreate(cfg.Database.Path) + } + if err != nil { + fmt.Fprintf(os.Stderr, "error opening database: %v\n", err) + os.Exit(1) + } + + if err := db.MigrateSchema(); err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error migrating schema: %v\n", err) + os.Exit(1) + } + + // Open storage + sURL := cfg.Storage.URL + if sURL == "" { + sURL = "file://" + cfg.Storage.Path //nolint:staticcheck // backwards compat + } + store, err := storage.OpenBucket(context.Background(), sURL) + if err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error opening storage: %v\n", err) + os.Exit(1) + } + + // Build proxy (reuses same pipeline as serve) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + proxy.CacheMetadata = true // mirror always caches metadata + + m := mirror.New(proxy, db, store, logger, *concurrency) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + cancel() + }() + + if *dryRun { + items, err := m.RunDryRun(ctx, source) + if err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Printf("Would mirror %d package versions:\n", len(items)) + for _, item := range items { + fmt.Printf(" %s\n", item) + } + _ = db.Close() + return + } + + progress, err := m.Run(ctx, source) + if err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + _ = db.Close() + + fmt.Printf("Mirror complete: %d downloaded, %d skipped (cached), %d failed, %s total\n", + progress.Completed, progress.Skipped, progress.Failed, formatSize(progress.Bytes)) + + if len(progress.Errors) > 0 { + fmt.Fprintf(os.Stderr, "\nErrors:\n") + for _, e := range progress.Errors { + fmt.Fprintf(os.Stderr, " %s/%s@%s: %s\n", e.Ecosystem, e.Name, e.Version, e.Error) + } + } +} + func printStats(db *database.DB, popular, recent int, asJSON bool) error { defer func() { _ = db.Close() }() diff --git a/docs/architecture.md b/docs/architecture.md index c57d807..81c41cf 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -161,6 +161,20 @@ vulnerabilities ( updated_at DATETIME ) -- indexes: (vuln_id, ecosystem, package_name) unique, (ecosystem, package_name) + +metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, -- BIGINT on Postgres + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME +) +-- indexes: (ecosystem, name) unique ``` On PostgreSQL, `INTEGER PRIMARY KEY` becomes `SERIAL`, `DATETIME` becomes `TIMESTAMP`, `INTEGER DEFAULT 0` booleans become `BOOLEAN DEFAULT FALSE`, and size/count columns use `BIGINT`. @@ -277,6 +291,12 @@ Version age filtering for supply chain attack mitigation. Configurable at global Package metadata enrichment. Fetches license, description, homepage, repository URL, and vulnerability data from upstream registries. Powers the `/api/` endpoints and the web UI's package detail pages. +### `internal/mirror` + +Selective package mirroring for pre-populating the proxy cache. Supports multiple input sources: individual PURLs (versioned or unversioned), CycloneDX/SPDX SBOM files, and full registry enumeration. Uses a bounded worker pool backed by `errgroup` to download artifacts in parallel, reusing `handler.Proxy.GetOrFetchArtifact()` for the actual fetch-and-cache work. + +The package also provides a `MetadataCache` for storing raw upstream metadata blobs so the proxy can serve metadata responses offline. The `JobStore` manages async mirror jobs exposed via the `/api/mirror` endpoints. + ### `internal/config` Configuration loading. @@ -326,10 +346,11 @@ Eviction can be implemented as: - Ensures clients fetch artifacts through proxy - Alternative: Let clients fetch directly, miss cache opportunity -**Why not cache metadata?** +**Why not cache metadata (by default)?** - Simplicity - no invalidation logic needed - Fresh data - new versions visible immediately - Metadata is small, upstream fetch is fast +- Set `cache_metadata: true` or use the mirror command to enable metadata caching for offline use via the `metadata_cache` table **Why stream artifacts?** - Memory efficient - don't load large files into RAM diff --git a/docs/configuration.md b/docs/configuration.md index 68ace5f..2ffb10f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,40 @@ Currently supported for npm, PyPI, pub.dev, Composer, Cargo, NuGet, Conda, RubyG Note: Hex cooldown requires disabling registry signature verification since the proxy re-encodes the protobuf payload without the original signature. Set `HEX_NO_VERIFY_REPO_ORIGIN=1` or configure your repo with `no_verify: true`. +## Metadata Caching + +By default the proxy fetches metadata fresh from upstream on every request. Enable `cache_metadata` to store metadata responses in the database and storage backend for offline fallback. When upstream is unreachable, the proxy serves the last cached copy. ETag-based revalidation avoids re-downloading unchanged metadata. + +```yaml +cache_metadata: true +``` + +Or via environment variable: `PROXY_CACHE_METADATA=true`. + +The `proxy mirror` command always enables metadata caching regardless of this setting. + +## Mirror Command + +The `proxy mirror` command pre-populates the cache from various sources. It accepts the same storage and database flags as `serve`. + +| Flag | Default | Description | +|------|---------|-------------| +| `--sbom` | | Path to CycloneDX or SPDX SBOM file | +| `--registry` | | Ecosystem name for full registry mirror | +| `--concurrency` | `4` | Number of parallel downloads | +| `--dry-run` | `false` | Show what would be mirrored without downloading | +| `--config` | | Path to configuration file | +| `--storage-url` | | Storage URL | +| `--database-driver` | | Database driver | +| `--database-path` | | SQLite database file | +| `--database-url` | | PostgreSQL connection URL | + +Positional arguments are treated as PURLs: + +```bash +proxy mirror pkg:npm/lodash@4.17.21 pkg:cargo/serde@1.0.0 +``` + ## Docker ### SQLite with Local Storage diff --git a/go.mod b/go.mod index 805edf0..07f6c55 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/git-pkgs/proxy go 1.25.6 require ( + github.com/CycloneDX/cyclonedx-go v0.10.0 github.com/git-pkgs/archives v0.2.2 github.com/git-pkgs/enrichment v0.2.1 github.com/git-pkgs/purl v0.1.10 @@ -15,8 +16,10 @@ require ( github.com/lib/pq v1.12.0 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 + github.com/spdx/tools-golang v0.5.7 github.com/swaggo/swag v1.16.6 gocloud.dev v0.45.0 + golang.org/x/sync v0.20.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.47.0 @@ -52,6 +55,7 @@ require ( github.com/alfatraining/structtag v1.0.0 // indirect github.com/alingse/asasalint v0.0.11 // indirect github.com/alingse/nilnesserr v0.2.0 // indirect + github.com/anchore/go-struct-converter v0.1.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/ashanbrown/forbidigo/v2 v2.3.0 // indirect github.com/ashanbrown/makezero/v2 v2.1.0 // indirect @@ -277,7 +281,6 @@ require ( golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.42.0 // indirect @@ -293,7 +296,7 @@ require ( modernc.org/memory v1.11.0 // indirect mvdan.cc/gofumpt v0.9.2 // indirect mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15 // indirect - sigs.k8s.io/yaml v1.3.0 // indirect + sigs.k8s.io/yaml v1.6.0 // indirect ) tool github.com/golangci/golangci-lint/v2/cmd/golangci-lint diff --git a/go.sum b/go.sum index 6e68fb7..41d5e06 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/Antonboom/testifylint v1.6.4/go.mod h1:YO33FROXX2OoUfwjz8g+gUxQXio5i9 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/CycloneDX/cyclonedx-go v0.10.0 h1:7xyklU7YD+CUyGzSFIARG18NYLsKVn4QFg04qSsu+7Y= +github.com/CycloneDX/cyclonedx-go v0.10.0/go.mod h1:vUvbCXQsEm48OI6oOlanxstwNByXjCZ2wuleUlwGEO8= github.com/Djarvur/go-err113 v0.1.1 h1:eHfopDqXRwAi+YmCUas75ZE0+hoBHJ2GQNLYRSxao4g= github.com/Djarvur/go-err113 v0.1.1/go.mod h1:IaWJdYFLg76t2ihfflPZnM1LIQszWOsFDh2hhhAVF6k= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= @@ -85,6 +87,8 @@ github.com/alingse/asasalint v0.0.11 h1:SFwnQXJ49Kx/1GghOFz1XGqHYKp21Kq1nHad/0WQ github.com/alingse/asasalint v0.0.11/go.mod h1:nCaoMhw7a9kSJObvQyVzNTPBDbNpdocqrSP7t/cW5+I= github.com/alingse/nilnesserr v0.2.0 h1:raLem5KG7EFVb4UIDAXgrv3N2JIaffeKNtcEXkEWd/w= github.com/alingse/nilnesserr v0.2.0/go.mod h1:1xJPrXonEtX7wyTq8Dytns5P2hNzoWymVUIaKm4HNFg= +github.com/anchore/go-struct-converter v0.1.0 h1:2rDRssAl6mgKBSLNiVCMADgZRhoqtw9dedlWa0OhD30= +github.com/anchore/go-struct-converter v0.1.0/go.mod h1:rYqSE9HbjzpHTI74vwPvae4ZVYZd1lue2ta6xHPdblA= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/ashanbrown/forbidigo/v2 v2.3.0 h1:OZZDOchCgsX5gvToVtEBoV2UWbFfI6RKQTir2UZzSxo= @@ -144,6 +148,8 @@ github.com/bombsimon/wsl/v4 v4.7.0 h1:1Ilm9JBPRczjyUs6hvOPKvd7VL1Q++PL8M0SXBDf+j github.com/bombsimon/wsl/v4 v4.7.0/go.mod h1:uV/+6BkffuzSAVYD+yGyld1AChO7/EuLrCF/8xTiapg= github.com/bombsimon/wsl/v5 v5.6.0 h1:4z+/sBqC5vUmSp1O0mS+czxwH9+LKXtCWtHH9rZGQL8= github.com/bombsimon/wsl/v5 v5.6.0/go.mod h1:Uqt2EfrMj2NV8UGoN1f1Y3m0NpUVCsUdrNCdet+8LvU= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/breml/bidichk v0.3.3 h1:WSM67ztRusf1sMoqH6/c4OBCUlRVTKq+CbSeo0R17sE= github.com/breml/bidichk v0.3.3/go.mod h1:ISbsut8OnjB367j5NseXEGGgO/th206dVa427kR8YTE= github.com/breml/errchkjson v0.4.1 h1:keFSS8D7A2T0haP9kzZTi7o26r7kE3vymjZNeNDRDwg= @@ -568,6 +574,8 @@ github.com/sonatard/noctx v0.4.0 h1:7MC/5Gg4SQ4lhLYR6mvOP6mQVSxCrdyiExo7atBs27o= github.com/sonatard/noctx v0.4.0/go.mod h1:64XdbzFb18XL4LporKXp8poqZtPKbCrqQ402CV+kJas= github.com/sourcegraph/go-diff v0.7.0 h1:9uLlrd5T46OXs5qpp8L/MTltk0zikUGi0sNNyCpA8G0= github.com/sourcegraph/go-diff v0.7.0/go.mod h1:iBszgVvyxdc8SFZ7gm69go2KDdt3ag071iBaWPF6cjs= +github.com/spdx/tools-golang v0.5.7 h1:+sWcKGnhwp3vLdMqPcLdA6QK679vd86cK9hQWH3AwCg= +github.com/spdx/tools-golang v0.5.7/go.mod h1:jg7w0LOpoNAw6OxKEzCoqPC2GCTj45LyTlVmXubDsYw= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= @@ -606,6 +614,8 @@ github.com/tenntenn/modver v1.0.1 h1:2klLppGhDgzJrScMpkj9Ujy3rXPUspSjAcev9tSEBgA github.com/tenntenn/modver v1.0.1/go.mod h1:bePIyQPb7UeioSRkw3Q0XeMhYZSMx9B8ePqg6SAMGH0= github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpRQGxTSkNYKJ51yaw6ChIqO+Je8UqsTKN/cDag= github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= +github.com/terminalstatic/go-xsd-validate v0.1.6 h1:TenYeQ3eY631qNi1/cTmLH/s2slHPRKTTHT+XSHkepo= +github.com/terminalstatic/go-xsd-validate v0.1.6/go.mod h1:18lsvYFofBflqCrvo1umpABZ99+GneNTw2kEEc8UPJw= github.com/tetafro/godot v1.5.4 h1:u1ww+gqpRLiIA16yF2PV1CV1n/X3zhyezbNXC3E14Sg= github.com/tetafro/godot v1.5.4/go.mod h1:eOkMrVQurDui411nBY2FA05EYH01r14LuWY/NrVDVcU= github.com/timakin/bodyclose v0.0.0-20241222091800-1db5c5ca4d67 h1:9LPGD+jzxMlnk5r6+hJnar67cgpDIz/iyD+rfl5r2Vk= @@ -628,6 +638,12 @@ github.com/uudashr/gocognit v1.2.0 h1:3BU9aMr1xbhPlvJLSydKwdLN3tEUUrzPSSM8S4hDYR github.com/uudashr/gocognit v1.2.0/go.mod h1:k/DdKPI6XBZO1q7HgoV2juESI2/Ofj9AcHPZhBBdrTU= github.com/uudashr/iface v1.4.1 h1:J16Xl1wyNX9ofhpHmQ9h9gk5rnv2A6lX/2+APLTo0zU= github.com/uudashr/iface v1.4.1/go.mod h1:pbeBPlbuU2qkNDn0mmfrxP2X+wjPMIQAy+r1MBXSXtg= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xen0n/gosmopolitan v1.3.0 h1:zAZI1zefvo7gcpbCOrPSHJZJYA9ZgLfJqtKzZ5pHqQM= github.com/xen0n/gosmopolitan v1.3.0/go.mod h1:rckfr5T6o4lBtM1ga7mLGKZmLxswUoH1zxHgNXOsEt4= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -867,5 +883,5 @@ mvdan.cc/gofumpt v0.9.2 h1:zsEMWL8SVKGHNztrx6uZrXdp7AX8r421Vvp23sz7ik4= mvdan.cc/gofumpt v0.9.2/go.mod h1:iB7Hn+ai8lPvofHd9ZFGVg2GOr8sBUw1QUWjNbmIL/s= mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15 h1:ssMzja7PDPJV8FStj7hq9IKiuiKhgz9ErWw+m68e7DI= mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15/go.mod h1:4M5MMXl2kW6fivUT6yRGpLLPNfuGtU2Z0cPvFquGDYU= -sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= -sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= +sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/internal/config/config.go b/internal/config/config.go index 3bc45af..8021783 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -83,6 +83,11 @@ type Config struct { // Cooldown configures version age filtering to mitigate supply chain attacks. Cooldown CooldownConfig `json:"cooldown" yaml:"cooldown"` + + // CacheMetadata enables caching of upstream metadata responses for offline fallback. + // When enabled, metadata is stored in the database and storage backend. + // The mirror command always enables this regardless of this setting. + CacheMetadata bool `json:"cache_metadata" yaml:"cache_metadata"` } // CooldownConfig configures version cooldown periods. @@ -306,6 +311,9 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_COOLDOWN_DEFAULT"); v != "" { c.Cooldown.Default = v } + if v := os.Getenv("PROXY_CACHE_METADATA"); v != "" { + c.CacheMetadata = v == "true" || v == "1" + } } // Validate checks the configuration for errors. diff --git a/internal/database/metadata_cache_test.go b/internal/database/metadata_cache_test.go new file mode 100644 index 0000000..5701816 --- /dev/null +++ b/internal/database/metadata_cache_test.go @@ -0,0 +1,180 @@ +package database + +import ( + "database/sql" + "path/filepath" + "testing" + "time" +) + +func setupMetadataCacheDB(t *testing.T) *DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema failed: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func TestUpsertAndGetMetadataCache(t *testing.T) { + db := setupMetadataCacheDB(t) + + entry := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"abc123"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 1024, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + + err := db.UpsertMetadataCache(entry) + if err != nil { + t.Fatalf("UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache(testEcosystemNPM, "lodash") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("GetMetadataCache() returned nil") + } + + if got.Ecosystem != testEcosystemNPM { + t.Errorf("ecosystem = %q, want %q", got.Ecosystem, testEcosystemNPM) + } + if got.Name != "lodash" { + t.Errorf("name = %q, want %q", got.Name, "lodash") + } + if got.StoragePath != "_metadata/npm/lodash/metadata" { + t.Errorf("storage_path = %q, want %q", got.StoragePath, "_metadata/npm/lodash/metadata") + } + if !got.ETag.Valid || got.ETag.String != `"abc123"` { + t.Errorf("etag = %v, want %q", got.ETag, `"abc123"`) + } + if !got.ContentType.Valid || got.ContentType.String != "application/json" { + t.Errorf("content_type = %v, want %q", got.ContentType, "application/json") + } + if !got.Size.Valid || got.Size.Int64 != 1024 { + t.Errorf("size = %v, want 1024", got.Size) + } +} + +func TestGetMetadataCacheMiss(t *testing.T) { + db := setupMetadataCacheDB(t) + + got, err := db.GetMetadataCache(testEcosystemNPM, "nonexistent") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got != nil { + t.Errorf("expected nil for cache miss, got %v", got) + } +} + +func TestUpsertMetadataCacheOverwrite(t *testing.T) { + db := setupMetadataCacheDB(t) + + // First insert + entry1 := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"v1"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 100, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + if err := db.UpsertMetadataCache(entry1); err != nil { + t.Fatalf("first UpsertMetadataCache() error = %v", err) + } + + // Second insert (same ecosystem+name, different etag and size) + entry2 := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"v2"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 200, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + if err := db.UpsertMetadataCache(entry2); err != nil { + t.Fatalf("second UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache(testEcosystemNPM, "lodash") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("expected entry after overwrite") + } + if got.ETag.String != `"v2"` { + t.Errorf("etag = %q, want %q", got.ETag.String, `"v2"`) + } + if got.Size.Int64 != 200 { + t.Errorf("size = %d, want 200", got.Size.Int64) + } +} + +func TestUpsertMetadataCacheNullableFields(t *testing.T) { + db := setupMetadataCacheDB(t) + + entry := &MetadataCacheEntry{ + Ecosystem: "pypi", + Name: "requests", + StoragePath: "_metadata/pypi/requests/metadata", + } + + if err := db.UpsertMetadataCache(entry); err != nil { + t.Fatalf("UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache("pypi", "requests") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("expected entry") + } + if got.ETag.Valid { + t.Error("expected null etag") + } + if got.ContentType.Valid { + t.Error("expected null content_type") + } + if got.Size.Valid { + t.Error("expected null size") + } +} + +func TestMetadataCacheTableCreatedByMigration(t *testing.T) { + // Create a DB without the metadata_cache table, then migrate + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + // MigrateSchema should create the metadata_cache table + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema() error = %v", err) + } + + has, err := db.HasTable("metadata_cache") + if err != nil { + t.Fatalf("HasTable() error = %v", err) + } + if !has { + t.Error("metadata_cache table should exist after migration") + } +} diff --git a/internal/database/queries.go b/internal/database/queries.go index fc6a3b3..8f48876 100644 --- a/internal/database/queries.go +++ b/internal/database/queries.go @@ -887,3 +887,64 @@ func (db *DB) CountCachedPackages(ecosystem string) (int64, error) { err = db.Get(&count, query, args...) return count, err } + +// Metadata cache queries + +func (db *DB) GetMetadataCache(ecosystem, name string) (*MetadataCacheEntry, error) { + var entry MetadataCacheEntry + query := db.Rebind(` + SELECT id, ecosystem, name, storage_path, etag, content_type, + size, fetched_at, created_at, updated_at + FROM metadata_cache WHERE ecosystem = ? AND name = ? + `) + err := db.Get(&entry, query, ecosystem, name) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &entry, nil +} + +func (db *DB) UpsertMetadataCache(entry *MetadataCacheEntry) error { + now := time.Now() + var query string + + if db.dialect == DialectPostgres { + query = ` + INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, + size, fetched_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT(ecosystem, name) DO UPDATE SET + storage_path = EXCLUDED.storage_path, + etag = EXCLUDED.etag, + content_type = EXCLUDED.content_type, + size = EXCLUDED.size, + fetched_at = EXCLUDED.fetched_at, + updated_at = EXCLUDED.updated_at + ` + } else { + query = ` + INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, + size, fetched_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(ecosystem, name) DO UPDATE SET + storage_path = excluded.storage_path, + etag = excluded.etag, + content_type = excluded.content_type, + size = excluded.size, + fetched_at = excluded.fetched_at, + updated_at = excluded.updated_at + ` + } + + _, err := db.Exec(query, + entry.Ecosystem, entry.Name, entry.StoragePath, entry.ETag, + entry.ContentType, entry.Size, entry.FetchedAt, now, now, + ) + if err != nil { + return fmt.Errorf("upserting metadata cache: %w", err) + } + return nil +} diff --git a/internal/database/schema.go b/internal/database/schema.go index 233357f..91827a2 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -91,6 +91,20 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); +CREATE TABLE IF NOT EXISTS metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + CREATE TABLE IF NOT EXISTS migrations ( name TEXT NOT NULL PRIMARY KEY, applied_at DATETIME NOT NULL @@ -176,6 +190,20 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); +CREATE TABLE IF NOT EXISTS metadata_cache ( + id SERIAL PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size BIGINT, + fetched_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + CREATE TABLE IF NOT EXISTS migrations ( name TEXT NOT NULL PRIMARY KEY, applied_at TIMESTAMP NOT NULL @@ -324,6 +352,7 @@ var migrations = []migration{ {"002_add_versions_enrichment_columns", migrateAddVersionsEnrichmentColumns}, {"003_ensure_artifacts_table", migrateEnsureArtifactsTable}, {"004_ensure_vulnerabilities_table", migrateEnsureVulnerabilitiesTable}, + {"005_ensure_metadata_cache_table", migrateEnsureMetadataCacheTable}, } // isTableNotFound returns true if the error indicates a missing table. @@ -538,5 +567,60 @@ func migrateEnsureVulnerabilitiesTable(db *DB) error { if _, err := db.Exec(vulnSchema); err != nil { return fmt.Errorf("creating vulnerabilities table: %w", err) } + + return nil +} + +func migrateEnsureMetadataCacheTable(db *DB) error { + return db.EnsureMetadataCacheTable() +} + +// EnsureMetadataCacheTable creates the metadata_cache table if it doesn't exist. +func (db *DB) EnsureMetadataCacheTable() error { + has, err := db.HasTable("metadata_cache") + if err != nil { + return fmt.Errorf("checking metadata_cache table: %w", err) + } + if has { + return nil + } + + var schema string + if db.dialect == DialectPostgres { + schema = ` + CREATE TABLE metadata_cache ( + id SERIAL PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size BIGINT, + fetched_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + ` + } else { + schema = ` + CREATE TABLE metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + ` + } + if _, err := db.Exec(schema); err != nil { + return fmt.Errorf("creating metadata_cache table: %w", err) + } return nil } diff --git a/internal/database/types.go b/internal/database/types.go index f73bfb4..9d0898b 100644 --- a/internal/database/types.go +++ b/internal/database/types.go @@ -76,6 +76,20 @@ func (a *Artifact) IsCached() bool { return a.StoragePath.Valid && a.FetchedAt.Valid } +// MetadataCacheEntry represents a cached metadata blob for offline serving. +type MetadataCacheEntry struct { + ID int64 `db:"id" json:"id"` + Ecosystem string `db:"ecosystem" json:"ecosystem"` + Name string `db:"name" json:"name"` + StoragePath string `db:"storage_path" json:"storage_path"` + ETag sql.NullString `db:"etag" json:"etag,omitempty"` + ContentType sql.NullString `db:"content_type" json:"content_type,omitempty"` + Size sql.NullInt64 `db:"size" json:"size,omitempty"` + FetchedAt sql.NullTime `db:"fetched_at" json:"fetched_at,omitempty"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + // Vulnerability represents a cached vulnerability record. type Vulnerability struct { ID int64 `db:"id" json:"id"` diff --git a/internal/handler/cargo.go b/internal/handler/cargo.go index 9602fe6..5d7810c 100644 --- a/internal/handler/cargo.go +++ b/internal/handler/cargo.go @@ -3,8 +3,8 @@ package handler import ( "bufio" "encoding/json" + "errors" "fmt" - "io" "net/http" "strings" "time" @@ -88,44 +88,27 @@ func (h *CargoHandler) handleIndex(w http.ResponseWriter, r *http.Request) { h.proxy.Logger.Info("cargo index request", "crate", name) - // Build the index path indexPath := h.buildIndexPath(name) upstreamURL := fmt.Sprintf("%s/%s", h.indexURL, indexPath) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "internal error", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, contentType, err := h.proxy.FetchOrCacheMetadata(r.Context(), "cargo", name, upstreamURL, "text/plain") if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("failed to fetch upstream index", "error", err) http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode == http.StatusNotFound { - http.Error(w, "not found", http.StatusNotFound) - return + if contentType == "" { + contentType = "text/plain; charset=utf-8" } - if resp.StatusCode != http.StatusOK { - http.Error(w, fmt.Sprintf("upstream returned %d", resp.StatusCode), http.StatusBadGateway) - return - } - - // Copy headers and body - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - if etag := resp.Header.Get("ETag"); etag != "" { - w.Header().Set("ETag", etag) - } - if lastMod := resp.Header.Get("Last-Modified"); lastMod != "" { - w.Header().Set("Last-Modified", lastMod) - } - - h.applyCooldownFiltering(w, resp.Body) + w.Header().Set("Content-Type", contentType) + w.WriteHeader(http.StatusOK) + h.applyCooldownFiltering(w, body) } type crateIndexEntry struct { @@ -134,56 +117,45 @@ type crateIndexEntry struct { PublishTime string `json:"pubtime,omitempty"` } -func (h *CargoHandler) applyCooldownFiltering(downstreamResponse io.Writer, upstreamBody io.Reader) { +func (h *CargoHandler) applyCooldownFiltering(downstreamResponse http.ResponseWriter, body []byte) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - // not using cooldowns, just copy the upstream to the downstream - _, _ = io.Copy(downstreamResponse, upstreamBody) + _, _ = downstreamResponse.Write(body) return } - // create a scanner on the body of the http response - requestScanner := bufio.NewScanner(upstreamBody) + scanner := bufio.NewScanner(strings.NewReader(string(body))) - // the response is newline-delimited JSON, loop through each line - for requestScanner.Scan() { - line := requestScanner.Text() + for scanner.Scan() { + line := scanner.Text() - // decode the line var crate crateIndexEntry err := json.Unmarshal([]byte(line), &crate) if err != nil { - // if there is an error parsing this line then exclude it and move to the next entry h.proxy.Logger.Error("failed to parse json entry in index", "error", err) continue } - // parse publish time publishedAt, err := time.Parse(time.RFC3339, crate.PublishTime) if crate.PublishTime == "" || err != nil { - // publish time is empty/missing/invalid, presumably was published before pubtime was added as a field - // write line to response _, _ = downstreamResponse.Write([]byte(line + "\n")) continue } - // make PURL cratePURL := purl.MakePURLString("cargo", crate.Name, "") if !h.proxy.Cooldown.IsAllowed("cargo", cratePURL, publishedAt) { - // crate is not allowed, move to next crate h.proxy.Logger.Info("cooldown: filtering cargo version", "crate", crate.Name, "version", crate.Version, "published", crate.PublishTime) continue } - // crate passes, write to response _, _ = downstreamResponse.Write([]byte(line + "\n")) } - if err := requestScanner.Err(); err != nil { + if err := scanner.Err(); err != nil { h.proxy.Logger.Error("error reading index response", "error", err) } } diff --git a/internal/handler/cargo_test.go b/internal/handler/cargo_test.go index 5e7f2e4..5ce81b6 100644 --- a/internal/handler/cargo_test.go +++ b/internal/handler/cargo_test.go @@ -1,7 +1,6 @@ package handler import ( - "bytes" "encoding/json" "log/slog" "net/http" @@ -196,9 +195,9 @@ func TestCargoCooldown(t *testing.T) { proxyURL: "http://localhost:8080", } - var outputBuffer bytes.Buffer - h.applyCooldownFiltering(&outputBuffer, strings.NewReader(testInput.String())) - output := outputBuffer.String() + recorder := httptest.NewRecorder() + h.applyCooldownFiltering(recorder, []byte(testInput.String())) + output := recorder.Body.String() if output != expectedOutput.String() { t.Errorf("output = %q, want %q", output, expectedOutput.String()) diff --git a/internal/handler/composer.go b/internal/handler/composer.go index d47a0f2..2a47e81 100644 --- a/internal/handler/composer.go +++ b/internal/handler/composer.go @@ -87,34 +87,14 @@ func (h *ComposerHandler) handlePackageMetadata(w http.ResponseWriter, r *http.R h.proxy.Logger.Info("composer metadata request", "package", packageName) - // Fetch from repo.packagist.org (Composer v2 metadata) upstreamURL := fmt.Sprintf("%s/p2/%s/%s.json", h.repoURL, vendor, pkg) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "composer", packageName, upstreamURL) if err != nil { h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteMetadata(body) if err != nil { diff --git a/internal/handler/conan.go b/internal/handler/conan.go index b91f7c3..53f6428 100644 --- a/internal/handler/conan.go +++ b/internal/handler/conan.go @@ -43,8 +43,8 @@ func (h *ConanHandler) Routes() http.Handler { mux.HandleFunc("GET /v1/files/{name}/{version}/{user}/{channel}/{revision}/package/{pkgref}/{pkgrev}/{filename}", h.handlePackageFile) mux.HandleFunc("GET /v2/files/{name}/{version}/{user}/{channel}/{revision}/package/{pkgref}/{pkgrev}/{filename}", h.handlePackageFile) - // Proxy all other endpoints (metadata, search, etc.) - mux.HandleFunc("GET /", h.proxyUpstream) + // Proxy all other endpoints (metadata, search, etc.) with caching + mux.HandleFunc("GET /", h.proxyCached) return mux } @@ -147,6 +147,20 @@ func (h *ConanHandler) shouldCacheFile(filename string) bool { return false } +// proxyCached forwards a request with metadata caching. +func (h *ConanHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + if r.URL.RawQuery != "" { + cacheKey += "_" + r.URL.RawQuery + } + upstreamURL := h.upstreamURL + r.URL.Path + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + h.proxy.ProxyCached(w, r, upstreamURL, "conan", cacheKey, "*/*") +} + // proxyUpstream forwards a request to conan center without caching. func (h *ConanHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { upstreamURL := h.upstreamURL + r.URL.Path diff --git a/internal/handler/conda.go b/internal/handler/conda.go index 46d8704..a986f01 100644 --- a/internal/handler/conda.go +++ b/internal/handler/conda.go @@ -37,7 +37,7 @@ func (h *CondaHandler) Routes() http.Handler { // Channel index (repodata) mux.HandleFunc("GET /{channel}/{arch}/repodata.json", h.handleRepodata) - mux.HandleFunc("GET /{channel}/{arch}/repodata.json.bz2", h.proxyUpstream) + mux.HandleFunc("GET /{channel}/{arch}/repodata.json.bz2", h.proxyCached) mux.HandleFunc("GET /{channel}/{arch}/current_repodata.json", h.handleRepodata) // Package downloads (cache these) @@ -127,7 +127,7 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) { // handleRepodata proxies repodata.json, applying cooldown filtering when enabled. func (h *CondaHandler) handleRepodata(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -232,6 +232,13 @@ func (h *CondaHandler) applyCooldownFiltering(body []byte) ([]byte, error) { return json.Marshal(repodata) } +// proxyCached forwards a metadata request with caching. +func (h *CondaHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "conda", cacheKey, "*/*") +} + // proxyUpstream forwards a request to Anaconda without caching. func (h *CondaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) diff --git a/internal/handler/cran.go b/internal/handler/cran.go index b8c3c3a..246fcaa 100644 --- a/internal/handler/cran.go +++ b/internal/handler/cran.go @@ -30,14 +30,14 @@ func (h *CRANHandler) Routes() http.Handler { mux := http.NewServeMux() // Package indexes - mux.HandleFunc("GET /src/contrib/PACKAGES", h.proxyUpstream) - mux.HandleFunc("GET /src/contrib/PACKAGES.gz", h.proxyUpstream) - mux.HandleFunc("GET /src/contrib/PACKAGES.rds", h.proxyUpstream) + mux.HandleFunc("GET /src/contrib/PACKAGES", h.proxyCached) + mux.HandleFunc("GET /src/contrib/PACKAGES.gz", h.proxyCached) + mux.HandleFunc("GET /src/contrib/PACKAGES.rds", h.proxyCached) // Binary package indexes - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES", h.proxyUpstream) - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.gz", h.proxyUpstream) - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.rds", h.proxyUpstream) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES", h.proxyCached) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.gz", h.proxyCached) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.rds", h.proxyCached) // Source package downloads mux.HandleFunc("GET /src/contrib/{filename}", h.handleSourceDownload) @@ -150,6 +150,13 @@ func (h *CRANHandler) isBinaryPackage(filename string) bool { return strings.HasSuffix(filename, ".zip") || strings.HasSuffix(filename, ".tgz") } +// proxyCached forwards a metadata request with caching. +func (h *CRANHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "cran", cacheKey, "*/*") +} + // proxyUpstream forwards a request to CRAN without caching. func (h *CRANHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) diff --git a/internal/handler/debian.go b/internal/handler/debian.go index 11a1979..db57e3b 100644 --- a/internal/handler/debian.go +++ b/internal/handler/debian.go @@ -93,7 +93,8 @@ func (h *DebianHandler) handlePackageDownload(w http.ResponseWriter, r *http.Req // handleMetadata proxies repository metadata files. // These change frequently so we don't cache them. func (h *DebianHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian") + cacheKey := strings.ReplaceAll(path, "/", "_") + h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "deb", cacheKey, "*/*") } // proxyFile proxies any file directly without caching. diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go index d560b82..639e976 100644 --- a/internal/handler/download_test.go +++ b/internal/handler/download_test.go @@ -197,9 +197,8 @@ func TestGemHandler_UpstreamProxy(t *testing.T) { if string(body) != "upstream specs data" { t.Errorf("body = %q, want %q", body, "upstream specs data") } - if resp.Header.Get("X-Test") != "upstream" { - t.Errorf("missing upstream header") - } + // Metadata caching reads the response body into storage and serves it back, + // so arbitrary upstream headers are not forwarded. Content-Type is preserved. } func TestGemHandler_CacheMiss(t *testing.T) { diff --git a/internal/handler/gem.go b/internal/handler/gem.go index 8faae54..bdb4bb9 100644 --- a/internal/handler/gem.go +++ b/internal/handler/gem.go @@ -40,12 +40,12 @@ func (h *GemHandler) Routes() http.Handler { mux.HandleFunc("GET /gems/{filename}", h.handleDownload) // Specs indexes (compressed Ruby Marshal format) - mux.HandleFunc("GET /specs.4.8.gz", h.proxyUpstream) - mux.HandleFunc("GET /latest_specs.4.8.gz", h.proxyUpstream) - mux.HandleFunc("GET /prerelease_specs.4.8.gz", h.proxyUpstream) + mux.HandleFunc("GET /specs.4.8.gz", h.proxyCached) + mux.HandleFunc("GET /latest_specs.4.8.gz", h.proxyCached) + mux.HandleFunc("GET /prerelease_specs.4.8.gz", h.proxyCached) // Compact index (bundler 2.x+) - mux.HandleFunc("GET /versions", h.proxyUpstream) + mux.HandleFunc("GET /versions", h.proxyCached) mux.HandleFunc("GET /info/{name}", h.handleCompactIndex) // Quick index @@ -107,7 +107,7 @@ func (h *GemHandler) parseGemFilename(filename string) (name, version string) { // based on cooldown when enabled. func (h *GemHandler) handleCompactIndex(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -288,6 +288,13 @@ func (h *GemHandler) fetchFilteredVersions(r *http.Request, name string) (map[st return filtered, nil } +// proxyCached forwards a metadata request with caching. +func (h *GemHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + upstreamURL := h.upstreamURL + r.URL.Path + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + h.proxy.ProxyCached(w, r, upstreamURL, "gem", cacheKey, "*/*") +} + // proxyUpstream forwards a request to rubygems.org without caching. func (h *GemHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { upstreamURL := h.upstreamURL + r.URL.Path diff --git a/internal/handler/go.go b/internal/handler/go.go index dd4e17d..955a89c 100644 --- a/internal/handler/go.go +++ b/internal/handler/go.go @@ -54,18 +54,19 @@ func (h *GoHandler) handleRequest(w http.ResponseWriter, r *http.Request) { module := path[:idx] rest := path[idx+4:] // after "/@v/" + decodedMod := decodeGoModule(module) switch { case rest == "list": // GET /{module}/@v/list - list versions - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/list") case strings.HasSuffix(rest, ".info"): // GET /{module}/@v/{version}.info - version metadata - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/"+rest) case strings.HasSuffix(rest, ".mod"): // GET /{module}/@v/{version}.mod - go.mod file - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/"+rest) case strings.HasSuffix(rest, ".zip"): // GET /{module}/@v/{version}.zip - source archive (cache this) @@ -80,7 +81,8 @@ func (h *GoHandler) handleRequest(w http.ResponseWriter, r *http.Request) { // Check for @latest if strings.HasSuffix(path, "/@latest") { - h.proxyUpstream(w, r) + module := strings.TrimSuffix(path, "/@latest") + h.proxyCached(w, r, decodeGoModule(module)+"/@latest") return } @@ -111,6 +113,11 @@ func (h *GoHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil) } +// proxyCached forwards a request with metadata caching. +func (h *GoHandler) proxyCached(w http.ResponseWriter, r *http.Request, cacheKey string) { + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "golang", cacheKey, "*/*") +} + // decodeGoModule decodes an encoded module path. // In the encoding, uppercase letters are represented as "!" followed by lowercase. func decodeGoModule(encoded string) string { diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 109eacd..799fbd3 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -2,8 +2,10 @@ package handler import ( + "bytes" "context" "database/sql" + "errors" "fmt" "io" "log/slog" @@ -32,6 +34,8 @@ func containsPathTraversal(path string) bool { const defaultHTTPTimeout = 30 * time.Second +const contentTypeJSON = "application/json" + // maxMetadataSize is the maximum size of upstream metadata responses (50 MB). // Package metadata (e.g. npm with many versions) can be large, but unbounded // reads risk OOM if an upstream misbehaves. @@ -45,13 +49,14 @@ func ReadMetadata(r io.Reader) ([]byte, error) { // Proxy provides shared functionality for protocol handlers. type Proxy struct { - DB *database.DB - Storage storage.Storage - Fetcher fetch.FetcherInterface - Resolver *fetch.Resolver - Logger *slog.Logger - Cooldown *cooldown.Config - HTTPClient *http.Client + DB *database.DB + Storage storage.Storage + Fetcher fetch.FetcherInterface + Resolver *fetch.Resolver + Logger *slog.Logger + Cooldown *cooldown.Config + CacheMetadata bool + HTTPClient *http.Client } // NewProxy creates a new Proxy with the given dependencies. @@ -311,33 +316,24 @@ func (p *Proxy) ProxyUpstream(w http.ResponseWriter, r *http.Request, upstreamUR _, _ = io.Copy(w, resp.Body) } -// ProxyMetadata forwards a metadata request to upstream, copying only specific response headers. -func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamURL string, logLabel string) { - p.Logger.Debug(logLabel+" metadata request", "url", upstreamURL) - +// ProxyFile forwards a file request to upstream, copying all response headers. +func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) { req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) if err != nil { http.Error(w, "failed to create request", http.StatusInternalServerError) return } - for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} { - if v := r.Header.Get(header); v != "" { - req.Header.Set(header, v) - } - } - resp, err := p.HTTPClient.Do(req) if err != nil { - p.Logger.Error("failed to fetch upstream metadata", "error", err) http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) return } defer func() { _ = resp.Body.Close() }() - for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { - if v := resp.Header.Get(header); v != "" { - w.Header().Set(header, v) + for key, values := range resp.Header { + for _, v := range values { + w.Header().Add(key, v) } } @@ -345,36 +341,186 @@ func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamUR _, _ = io.Copy(w, resp.Body) } -// ProxyFile forwards a file request to upstream, copying all response headers. -func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) { - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) +// JSONError writes a JSON error response. +func JSONError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", contentTypeJSON) + w.WriteHeader(status) + _, _ = fmt.Fprintf(w, `{"error":%q}`, message) +} + +// ErrUpstreamNotFound indicates the upstream returned 404. +var ErrUpstreamNotFound = fmt.Errorf("upstream: not found") + +// errStale304 is returned when upstream sends 304 but the cached file is missing. +var errStale304 = fmt.Errorf("upstream returned 304 but cached file is missing") + +// metadataStoragePath builds a storage path for cached metadata. +func metadataStoragePath(ecosystem, cacheKey string) string { + return "_metadata/" + ecosystem + "/" + cacheKey + "/metadata" +} + +// FetchOrCacheMetadata fetches metadata from upstream with caching. +// On success it returns the raw response bytes and content type. +// If upstream fails and a cached copy exists, the cached version is returned. +// cacheKey is typically the package name but can include subpath components. +// Optional acceptHeaders specify the Accept header(s) to send; defaults to application/json. +func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, upstreamURL string, acceptHeaders ...string) ([]byte, string, error) { + if containsPathTraversal(cacheKey) { + return nil, "", fmt.Errorf("invalid cache key: %q", cacheKey) + } + + storagePath := metadataStoragePath(ecosystem, cacheKey) + + // Check for existing cache entry (for ETag revalidation) + var entry *database.MetadataCacheEntry + if p.CacheMetadata && p.DB != nil { + entry, _ = p.DB.GetMetadataCache(ecosystem, cacheKey) + } + + accept := contentTypeJSON + if len(acceptHeaders) > 0 && acceptHeaders[0] != "" { + accept = acceptHeaders[0] + } + + // Try upstream + body, contentType, etag, err := p.fetchUpstreamMetadata(ctx, upstreamURL, entry, accept) + if errors.Is(err, errStale304) { + // 304 but cached file is gone; retry without ETag + body, contentType, etag, err = p.fetchUpstreamMetadata(ctx, upstreamURL, nil, accept) + } + if err == nil { + if p.CacheMetadata { + p.cacheMetadataBlob(ctx, ecosystem, cacheKey, storagePath, body, contentType, etag) + } + return body, contentType, nil + } + + // Upstream failed -- fall back to cache if available + if !p.CacheMetadata || entry == nil { + return nil, "", fmt.Errorf("upstream failed and no cached metadata: %w", err) + } + + p.Logger.Warn("upstream metadata fetch failed, checking cache", + "ecosystem", ecosystem, "key", cacheKey, "error", err) + + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr != nil { + return nil, "", fmt.Errorf("upstream failed and cached file missing: %w", err) + } + defer func() { _ = cached.Close() }() + + data, readErr := ReadMetadata(cached) + if readErr != nil { + return nil, "", fmt.Errorf("upstream failed and cached read error: %w", err) + } + + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + p.Logger.Info("serving metadata from cache", + "ecosystem", ecosystem, "key", cacheKey) + return data, ct, nil +} + +// fetchUpstreamMetadata fetches metadata from upstream, using ETag for conditional revalidation. +// Returns the body, content type, ETag, and any error. +func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, entry *database.MetadataCacheEntry, accept string) ([]byte, string, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL, nil) if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return + return nil, "", "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", accept) + + if entry != nil && entry.ETag.Valid { + req.Header.Set("If-None-Match", entry.ETag.String) } resp, err := p.HTTPClient.Do(req) if err != nil { - http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) - return + return nil, "", "", fmt.Errorf("fetching metadata: %w", err) } defer func() { _ = resp.Body.Close() }() - for key, values := range resp.Header { - for _, v := range values { - w.Header().Add(key, v) + // 304 Not Modified -- our cached copy is still good + if resp.StatusCode == http.StatusNotModified && entry != nil { + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr != nil { + return nil, "", "", errStale304 } + defer func() { _ = cached.Close() }() + data, readErr := ReadMetadata(cached) + if readErr != nil { + return nil, "", "", errStale304 + } + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + return data, ct, entry.ETag.String, nil } - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) + if resp.StatusCode == http.StatusNotFound { + return nil, "", "", ErrUpstreamNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, "", "", fmt.Errorf("upstream returned %d", resp.StatusCode) + } + + body, err := ReadMetadata(resp.Body) + if err != nil { + return nil, "", "", fmt.Errorf("reading response: %w", err) + } + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = contentTypeJSON + } + + etag := resp.Header.Get("ETag") + return body, contentType, etag, nil } -// JSONError writes a JSON error response. -func JSONError(w http.ResponseWriter, status int, message string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _, _ = fmt.Fprintf(w, `{"error":%q}`, message) +// cacheMetadataBlob stores metadata bytes in storage and updates the database. +func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, storagePath string, data []byte, contentType, etag string) { + if p.DB == nil || p.Storage == nil { + return + } + + size, _, err := p.Storage.Store(ctx, storagePath, bytes.NewReader(data)) + if err != nil { + p.Logger.Warn("failed to cache metadata", "ecosystem", ecosystem, "key", cacheKey, "error", err) + return + } + + _ = p.DB.UpsertMetadataCache(&database.MetadataCacheEntry{ + Ecosystem: ecosystem, + Name: cacheKey, + StoragePath: storagePath, + ETag: sql.NullString{String: etag, Valid: etag != ""}, + ContentType: sql.NullString{String: contentType, Valid: contentType != ""}, + Size: sql.NullInt64{Int64: size, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) +} + +// ProxyCached fetches metadata from upstream (with optional caching for offline fallback) +// and writes it to the response. Optional acceptHeaders specify the Accept header to send. +func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, ecosystem, cacheKey string, acceptHeaders ...string) { + body, contentType, err := p.FetchOrCacheMetadata(r.Context(), ecosystem, cacheKey, upstreamURL, acceptHeaders...) + if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + p.Logger.Error("metadata fetch failed", "error", err) + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + + w.Header().Set("Content-Type", contentType) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) } // GetOrFetchArtifactFromURL retrieves an artifact from cache or fetches from a specific URL. diff --git a/internal/handler/hex.go b/internal/handler/hex.go index 990fb55..0f0c72e 100644 --- a/internal/handler/hex.go +++ b/internal/handler/hex.go @@ -41,9 +41,9 @@ func (h *HexHandler) Routes() http.Handler { // Package tarballs (cache these) mux.HandleFunc("GET /tarballs/{filename}", h.handleDownload) - // Registry resources (proxy without caching) - mux.HandleFunc("GET /names", h.proxyUpstream) - mux.HandleFunc("GET /versions", h.proxyUpstream) + // Registry resources (cached for offline) + mux.HandleFunc("GET /names", h.proxyCached) + mux.HandleFunc("GET /versions", h.proxyCached) mux.HandleFunc("GET /packages/{name}", h.handlePackages) // Public keys @@ -102,13 +102,13 @@ const hexAPIURL = "https://hex.pm" // Hex HTTP API concurrently. func (h *HexHandler) handlePackages(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } name := r.PathValue("name") if name == "" { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -417,6 +417,12 @@ func extractProtobufBytes(data []byte, fieldNum protowire.Number) ([]byte, error return nil, fmt.Errorf("field %d not found", fieldNum) } +// proxyCached forwards a request with metadata caching. +func (h *HexHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "hex", cacheKey, "*/*") +} + // proxyUpstream forwards a request to hex.pm without caching. func (h *HexHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept"}) diff --git a/internal/handler/maven.go b/internal/handler/maven.go index 79da0c0..86664a2 100644 --- a/internal/handler/maven.go +++ b/internal/handler/maven.go @@ -51,8 +51,8 @@ func (h *MavenHandler) handleRequest(w http.ResponseWriter, r *http.Request) { filename := path.Base(urlPath) if h.isMetadataFile(filename) { - // Proxy metadata without caching - h.proxyUpstream(w, r) + cacheKey := strings.ReplaceAll(urlPath, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "maven", cacheKey, "*/*") return } diff --git a/internal/handler/npm.go b/internal/handler/npm.go index e0b0566..f15d626 100644 --- a/internal/handler/npm.go +++ b/internal/handler/npm.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -65,39 +66,18 @@ func (h *NPMHandler) handlePackageMetadata(w http.ResponseWriter, r *http.Reques h.proxy.Logger.Info("npm metadata request", "package", packageName) - // Fetch metadata from upstream upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, url.PathEscape(packageName)) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "npm", packageName, upstreamURL) if err != nil { - JSONError(w, http.StatusInternalServerError, "failed to create request") - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err) + if errors.Is(err, ErrUpstreamNotFound) { + JSONError(w, http.StatusNotFound, "package not found") + return + } + h.proxy.Logger.Error("failed to fetch npm metadata", "error", err) JSONError(w, http.StatusBadGateway, "failed to fetch from upstream") return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode == http.StatusNotFound { - JSONError(w, http.StatusNotFound, "package not found") - return - } - if resp.StatusCode != http.StatusOK { - JSONError(w, http.StatusBadGateway, fmt.Sprintf("upstream returned %d", resp.StatusCode)) - return - } - - // Parse and rewrite tarball URLs - body, err := ReadMetadata(resp.Body) - if err != nil { - JSONError(w, http.StatusInternalServerError, "failed to read response") - return - } rewritten, err := h.rewriteMetadata(packageName, body) if err != nil { diff --git a/internal/handler/nuget.go b/internal/handler/nuget.go index 21bcc46..8bced9f 100644 --- a/internal/handler/nuget.go +++ b/internal/handler/nuget.go @@ -60,31 +60,12 @@ func (h *NuGetHandler) handleServiceIndex(w http.ResponseWriter, r *http.Request upstreamURL := h.upstreamURL + "/v3/index.json" - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "nuget", "_service_index", upstreamURL) if err != nil { h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteServiceIndex(body) if err != nil { diff --git a/internal/handler/nuget_test.go b/internal/handler/nuget_test.go index 43d6f60..5dbb242 100644 --- a/internal/handler/nuget_test.go +++ b/internal/handler/nuget_test.go @@ -230,8 +230,9 @@ func TestNuGetHandleServiceIndexUpstreamError(t *testing.T) { w := httptest.NewRecorder() h.handleServiceIndex(w, req) - if w.Code != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) + // With metadata caching, upstream 500 is reported as 502 (bad gateway) + if w.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) } } diff --git a/internal/handler/pub.go b/internal/handler/pub.go index b8f6207..a0b5b4c 100644 --- a/internal/handler/pub.go +++ b/internal/handler/pub.go @@ -3,7 +3,6 @@ package handler import ( "encoding/json" "fmt" - "io" "net/http" "strings" "time" @@ -89,32 +88,12 @@ func (h *PubHandler) handlePackageMetadata(w http.ResponseWriter, r *http.Reques upstreamURL := fmt.Sprintf("%s/api/packages/%s", h.upstreamURL, name) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pub", name, upstreamURL) if err != nil { h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteMetadata(name, body) if err != nil { diff --git a/internal/handler/pypi.go b/internal/handler/pypi.go index aac33a7..4fc9cd5 100644 --- a/internal/handler/pypi.go +++ b/internal/handler/pypi.go @@ -74,33 +74,14 @@ func (h *PyPIHandler) handleSimplePackage(w http.ResponseWriter, r *http.Request h.proxy.Logger.Info("pypi simple request", "package", name) upstreamURL := fmt.Sprintf("%s/simple/%s/", h.upstreamURL, name) + cacheKey := name + "/simple" - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "text/html") - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL, "text/html") if err != nil { h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } // When cooldown is enabled, fetch JSON metadata to get version timestamps var filteredVersions map[string]bool @@ -221,7 +202,7 @@ func (h *PyPIHandler) handleJSON(w http.ResponseWriter, r *http.Request) { h.proxy.Logger.Info("pypi json request", "package", name) upstreamURL := fmt.Sprintf("%s/pypi/%s/json", h.upstreamURL, name) - h.proxyAndRewriteJSON(w, r, upstreamURL) + h.proxyAndRewriteJSON(w, r, upstreamURL, name+"/json") } // handleVersionJSON serves the JSON API version metadata. @@ -237,37 +218,17 @@ func (h *PyPIHandler) handleVersionJSON(w http.ResponseWriter, r *http.Request) h.proxy.Logger.Info("pypi version json request", "package", name, "version", version) upstreamURL := fmt.Sprintf("%s/pypi/%s/%s/json", h.upstreamURL, name, version) - h.proxyAndRewriteJSON(w, r, upstreamURL) + h.proxyAndRewriteJSON(w, r, upstreamURL, name+"/"+version) } // proxyAndRewriteJSON fetches JSON metadata and rewrites download URLs. -func (h *PyPIHandler) proxyAndRewriteJSON(w http.ResponseWriter, r *http.Request, upstreamURL string) { - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) +func (h *PyPIHandler) proxyAndRewriteJSON(w http.ResponseWriter, r *http.Request, upstreamURL, cacheKey string) { + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL) if err != nil { h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteJSONMetadata(body) if err != nil { diff --git a/internal/handler/rpm.go b/internal/handler/rpm.go index 92da8b6..6440d0f 100644 --- a/internal/handler/rpm.go +++ b/internal/handler/rpm.go @@ -95,7 +95,8 @@ func (h *RPMHandler) handlePackageDownload(w http.ResponseWriter, r *http.Reques // handleMetadata proxies repository metadata files (repomd.xml, primary.xml.gz, etc.). // These change frequently so we don't cache them. func (h *RPMHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm") + cacheKey := strings.ReplaceAll(path, "/", "_") + h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm", cacheKey, "*/*") } // proxyFile proxies any file directly without caching. diff --git a/internal/mirror/job.go b/internal/mirror/job.go new file mode 100644 index 0000000..a0da7c5 --- /dev/null +++ b/internal/mirror/job.go @@ -0,0 +1,187 @@ +package mirror + +import ( + "context" + "crypto/rand" + "fmt" + "sync" + "time" +) + +// JobState represents the current state of a mirror job. +type JobState string + +const ( + JobStatePending JobState = "pending" + JobStateRunning JobState = "running" + JobStateComplete JobState = "complete" + JobStateFailed JobState = "failed" + JobStateCanceled JobState = "canceled" +) + +const jobTTL = 1 * time.Hour +const cleanupInterval = 5 * time.Minute //nolint:mnd // cleanup ticker + +// Job represents an async mirror operation. +type Job struct { + ID string `json:"id"` + State JobState `json:"state"` + Progress Progress `json:"progress"` + CreatedAt time.Time `json:"created_at"` + Error string `json:"error,omitempty"` + + cancel context.CancelFunc +} + +// JobRequest is the JSON body for starting a mirror job via the API. +type JobRequest struct { + PURLs []string `json:"purls,omitempty"` + Registry string `json:"registry,omitempty"` +} + +// JobStore manages in-memory mirror jobs. +type JobStore struct { + mu sync.RWMutex + jobs map[string]*Job + mirror *Mirror +} + +// NewJobStore creates a new job store. +func NewJobStore(m *Mirror) *JobStore { + return &JobStore{ + jobs: make(map[string]*Job), + mirror: m, + } +} + +// Create starts a new mirror job and returns its ID. +func (js *JobStore) Create(req JobRequest) (string, error) { + source, err := js.sourceFromRequest(req) + if err != nil { + return "", err + } + + id := newJobID() + ctx, cancel := context.WithCancel(context.Background()) + + job := &Job{ + ID: id, + State: JobStatePending, + CreatedAt: time.Now(), + cancel: cancel, + } + + js.mu.Lock() + js.jobs[id] = job + js.mu.Unlock() + + go js.runJob(ctx, job, source) + + return id, nil +} + +// Get returns a snapshot of a job by ID. The returned copy is safe to +// serialize without holding the lock. +func (js *JobStore) Get(id string) *Job { + js.mu.RLock() + defer js.mu.RUnlock() + job := js.jobs[id] + if job == nil { + return nil + } + snapshot := *job + snapshot.cancel = nil // don't leak cancel func + if len(job.Progress.Errors) > 0 { + snapshot.Progress.Errors = make([]MirrorError, len(job.Progress.Errors)) + copy(snapshot.Progress.Errors, job.Progress.Errors) + } + return &snapshot +} + +// Cancel cancels a running job. +func (js *JobStore) Cancel(id string) bool { + js.mu.Lock() + defer js.mu.Unlock() + + job := js.jobs[id] + if job == nil || job.cancel == nil { + return false + } + + if job.State != JobStatePending && job.State != JobStateRunning { + return false + } + + job.cancel() + job.State = JobStateCanceled + return true +} + +// Cleanup removes completed/failed/canceled jobs older than jobTTL. +func (js *JobStore) Cleanup() { + js.mu.Lock() + defer js.mu.Unlock() + for id, job := range js.jobs { + if job.State == JobStateComplete || job.State == JobStateFailed || job.State == JobStateCanceled { + if time.Since(job.CreatedAt) > jobTTL { + delete(js.jobs, id) + } + } + } +} + +// StartCleanup runs periodic cleanup of old jobs until the context is canceled. +func (js *JobStore) StartCleanup(ctx context.Context) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + js.Cleanup() + } + } +} + +func (js *JobStore) runJob(ctx context.Context, job *Job, source Source) { + js.mu.Lock() + job.State = JobStateRunning + js.mu.Unlock() + + progress, err := js.mirror.Run(ctx, source) + + js.mu.Lock() + defer js.mu.Unlock() + + if err != nil { + job.State = JobStateFailed + job.Error = err.Error() + return + } + + job.Progress = *progress + if progress.Failed > 0 && progress.Completed == 0 { + job.State = JobStateFailed + } else { + job.State = JobStateComplete + } +} + +func (js *JobStore) sourceFromRequest(req JobRequest) (Source, error) { //nolint:ireturn // interface return is the design + switch { + case len(req.PURLs) > 0: + return &PURLSource{PURLs: req.PURLs}, nil + case req.Registry != "": + return &RegistrySource{Ecosystem: req.Registry}, nil + default: + return nil, fmt.Errorf("request must include purls or registry") + } +} + +// newJobID generates a random hex job ID. +func newJobID() string { + b := make([]byte, 16) //nolint:mnd // 128-bit ID + _, _ = rand.Read(b) + return fmt.Sprintf("%x", b) +} diff --git a/internal/mirror/job_test.go b/internal/mirror/job_test.go new file mode 100644 index 0000000..1159b45 --- /dev/null +++ b/internal/mirror/job_test.go @@ -0,0 +1,160 @@ +package mirror + +import ( + "testing" + "time" +) + +func TestJobStoreCreateAndGet(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + id, err := js.Create(JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if id == "" { + t.Fatal("expected non-empty job ID") + } + + // Wait for the job to start (it runs async) + time.Sleep(100 * time.Millisecond) + + job := js.Get(id) + if job == nil { + t.Fatal("Get() returned nil") + } + if job.ID != id { + t.Errorf("job ID = %q, want %q", job.ID, id) + } +} + +func TestJobStoreGetNotFound(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + job := js.Get("nonexistent") + if job != nil { + t.Errorf("expected nil for nonexistent job, got %v", job) + } +} + +func TestJobStoreCancelNotFound(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + if js.Cancel("nonexistent") { + t.Error("expected Cancel to return false for nonexistent job") + } +} + +func TestJobStoreCreateInvalidRequest(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + _, err := js.Create(JobRequest{}) + if err == nil { + t.Fatal("expected error for empty request") + } +} + +func TestJobStoreMultipleJobs(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + id1, err := js.Create(JobRequest{PURLs: []string{"pkg:npm/lodash@4.17.21"}}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + id2, err := js.Create(JobRequest{PURLs: []string{"pkg:cargo/serde@1.0.0"}}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if id1 == id2 { + t.Error("expected different job IDs") + } + + job1 := js.Get(id1) + job2 := js.Get(id2) + if job1 == nil || job2 == nil { + t.Fatal("expected both jobs to exist") + } +} + +func TestSourceFromRequestPURLs(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + source, err := js.sourceFromRequest(JobRequest{PURLs: []string{"pkg:npm/lodash@1.0.0"}}) + if err != nil { + t.Fatalf("sourceFromRequest() error = %v", err) + } + if _, ok := source.(*PURLSource); !ok { + t.Errorf("expected *PURLSource, got %T", source) + } +} + +func TestSourceFromRequestRegistry(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + source, err := js.sourceFromRequest(JobRequest{Registry: "npm"}) + if err != nil { + t.Fatalf("sourceFromRequest() error = %v", err) + } + if _, ok := source.(*RegistrySource); !ok { + t.Errorf("expected *RegistrySource, got %T", source) + } +} + +func TestJobStoreCleanup(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(m) + + // Add a completed job with old CreatedAt + js.mu.Lock() + js.jobs["old-job"] = &Job{ + ID: "old-job", + State: JobStateComplete, + CreatedAt: time.Now().Add(-2 * time.Hour), + } + js.jobs["recent-job"] = &Job{ + ID: "recent-job", + State: JobStateComplete, + CreatedAt: time.Now(), + } + js.jobs["running-job"] = &Job{ + ID: "running-job", + State: JobStateRunning, + CreatedAt: time.Now().Add(-2 * time.Hour), + } + js.mu.Unlock() + + js.Cleanup() + + if js.Get("old-job") != nil { + t.Error("expected old completed job to be cleaned up") + } + if js.Get("recent-job") == nil { + t.Error("expected recent completed job to be kept") + } + if js.Get("running-job") == nil { + t.Error("expected running job to be kept regardless of age") + } +} + +func TestNewJobIDUnique(t *testing.T) { + ids := make(map[string]bool) + for range 100 { + id := newJobID() + if ids[id] { + t.Fatalf("duplicate job ID: %s", id) + } + ids[id] = true + } +} diff --git a/internal/mirror/mirror.go b/internal/mirror/mirror.go new file mode 100644 index 0000000..4377cf1 --- /dev/null +++ b/internal/mirror/mirror.go @@ -0,0 +1,181 @@ +// Package mirror provides selective package mirroring for pre-populating the proxy cache. +package mirror + +import ( + "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/storage" + "golang.org/x/sync/errgroup" +) + +// Mirror pre-populates the proxy cache from various input sources. +type Mirror struct { + proxy *handler.Proxy + db *database.DB + storage storage.Storage + logger *slog.Logger + workers int +} + +// New creates a new Mirror with the given dependencies. +func New(proxy *handler.Proxy, db *database.DB, store storage.Storage, logger *slog.Logger, workers int) *Mirror { + if workers < 1 { + workers = 1 + } + return &Mirror{ + proxy: proxy, + db: db, + storage: store, + logger: logger, + workers: workers, + } +} + +// Progress tracks the state of a mirror operation. +type Progress struct { + Total int64 `json:"total"` + Completed int64 `json:"completed"` + Skipped int64 `json:"skipped"` + Failed int64 `json:"failed"` + Bytes int64 `json:"bytes"` + Errors []MirrorError `json:"errors,omitempty"` + StartedAt time.Time `json:"started_at"` + Phase string `json:"phase"` +} + +// MirrorError records a single failed mirror attempt. +type MirrorError struct { + Ecosystem string `json:"ecosystem"` + Name string `json:"name"` + Version string `json:"version"` + Error string `json:"error"` +} + +type progressTracker struct { + total atomic.Int64 + completed atomic.Int64 + skipped atomic.Int64 + failed atomic.Int64 + bytes atomic.Int64 + mu sync.Mutex + errors []MirrorError + startedAt time.Time + phase atomic.Value // string +} + +func newProgressTracker() *progressTracker { + pt := &progressTracker{ + startedAt: time.Now(), + } + pt.phase.Store("resolving") + return pt +} + +func (pt *progressTracker) addError(eco, name, version, err string) { + pt.mu.Lock() + pt.errors = append(pt.errors, MirrorError{ + Ecosystem: eco, + Name: name, + Version: version, + Error: err, + }) + pt.mu.Unlock() +} + +func (pt *progressTracker) snapshot() Progress { + pt.mu.Lock() + errs := make([]MirrorError, len(pt.errors)) + copy(errs, pt.errors) + pt.mu.Unlock() + + phase, _ := pt.phase.Load().(string) + return Progress{ + Total: pt.total.Load(), + Completed: pt.completed.Load(), + Skipped: pt.skipped.Load(), + Failed: pt.failed.Load(), + Bytes: pt.bytes.Load(), + Errors: errs, + StartedAt: pt.startedAt, + Phase: phase, + } +} + +// Run mirrors all packages from the source using a bounded worker pool. +// It returns the final progress when complete. +func (m *Mirror) Run(ctx context.Context, source Source) (*Progress, error) { + tracker := newProgressTracker() + + // Collect items from source + var items []PackageVersion + tracker.phase.Store("resolving") + err := source.Enumerate(ctx, func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + return nil, fmt.Errorf("enumerating packages: %w", err) + } + + tracker.total.Store(int64(len(items))) + tracker.phase.Store("downloading") + + // Process items with bounded concurrency + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(m.workers) + + for _, item := range items { + g.Go(func() error { + m.mirrorOne(gctx, item, tracker) + return nil // never fail the group; errors are tracked + }) + } + + _ = g.Wait() + + tracker.phase.Store("complete") + p := tracker.snapshot() + return &p, nil +} + +// RunDryRun enumerates what would be mirrored without downloading. +func (m *Mirror) RunDryRun(ctx context.Context, source Source) ([]PackageVersion, error) { + var items []PackageVersion + err := source.Enumerate(ctx, func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + return items, err +} + +func (m *Mirror) mirrorOne(ctx context.Context, pv PackageVersion, tracker *progressTracker) { + result, err := m.proxy.GetOrFetchArtifact(ctx, pv.Ecosystem, pv.Name, pv.Version, "") + if err != nil { + tracker.failed.Add(1) + tracker.addError(pv.Ecosystem, pv.Name, pv.Version, err.Error()) + m.logger.Warn("mirror failed", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version, "error", err) + return + } + + _ = result.Reader.Close() + + if result.Cached { + tracker.skipped.Add(1) + m.logger.Debug("already cached", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version) + } else { + tracker.completed.Add(1) + tracker.bytes.Add(result.Size) + m.logger.Info("mirrored", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version, + "size", result.Size) + } +} diff --git a/internal/mirror/mirror_test.go b/internal/mirror/mirror_test.go new file mode 100644 index 0000000..1d7d30d --- /dev/null +++ b/internal/mirror/mirror_test.go @@ -0,0 +1,195 @@ +package mirror + +import ( + "context" + "log/slog" + "os" + "testing" + "time" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" +) + +// setupTestMirror creates a Mirror with real DB and filesystem storage for integration tests. +func setupTestMirror(t *testing.T, workers int) *Mirror { + t.Helper() + + dbPath := t.TempDir() + "/test.db" + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("creating database: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("migrating schema: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + storeDir := t.TempDir() + store, err := storage.OpenBucket(context.Background(), "file://"+storeDir) + if err != nil { + t.Fatalf("opening storage: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + + return New(proxy, db, store, logger, workers) +} + +const testPackageLodash = "lodash" + +func TestMirrorRunEmptySource(t *testing.T) { + m := setupTestMirror(t, 2) + + source := &PURLSource{PURLs: []string{}} + progress, err := m.Run(context.Background(), source) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + + if progress.Total != 0 { + t.Errorf("total = %d, want 0", progress.Total) + } + if progress.Phase != "complete" { + t.Errorf("phase = %q, want %q", progress.Phase, "complete") + } +} + +func TestMirrorRunDryRun(t *testing.T) { + m := setupTestMirror(t, 1) + + source := &PURLSource{ + PURLs: []string{ + "pkg:npm/lodash@4.17.21", + "pkg:cargo/serde@1.0.0", + }, + } + + items, err := m.RunDryRun(context.Background(), source) + if err != nil { + t.Fatalf("RunDryRun() error = %v", err) + } + + if len(items) != 2 { + t.Fatalf("got %d items, want 2", len(items)) + } + + // Dry run should not modify the database + stats, err := m.db.GetCacheStats() + if err != nil { + t.Fatalf("GetCacheStats() error = %v", err) + } + if stats.TotalArtifacts != 0 { + t.Errorf("artifacts = %d, want 0 (dry run should not cache)", stats.TotalArtifacts) + } +} + +func TestMirrorRunCanceled(t *testing.T) { + m := setupTestMirror(t, 1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // Use a source that produces items but they'll all fail due to canceled context + source := &PURLSource{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + } + + progress, err := m.Run(ctx, source) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + + // With a canceled context, the fetch should fail + if progress.Failed != 1 { + t.Errorf("failed = %d, want 1", progress.Failed) + } +} + +func TestProgressTrackerSnapshot(t *testing.T) { + pt := newProgressTracker() + pt.total.Store(10) + pt.completed.Store(5) + pt.skipped.Store(3) + pt.failed.Store(2) + pt.bytes.Store(1024) + pt.phase.Store("downloading") + pt.addError("npm", testPackageLodash, "4.17.21", "fetch failed") + + snap := pt.snapshot() + if snap.Total != 10 { + t.Errorf("total = %d, want 10", snap.Total) + } + if snap.Completed != 5 { + t.Errorf("completed = %d, want 5", snap.Completed) + } + if snap.Skipped != 3 { + t.Errorf("skipped = %d, want 3", snap.Skipped) + } + if snap.Failed != 2 { + t.Errorf("failed = %d, want 2", snap.Failed) + } + if snap.Bytes != 1024 { + t.Errorf("bytes = %d, want 1024", snap.Bytes) + } + if snap.Phase != "downloading" { + t.Errorf("phase = %q, want %q", snap.Phase, "downloading") + } + if len(snap.Errors) != 1 { + t.Fatalf("errors = %d, want 1", len(snap.Errors)) + } + if snap.Errors[0].Name != testPackageLodash { + t.Errorf("error name = %q, want %q", snap.Errors[0].Name, testPackageLodash) + } + if snap.StartedAt.IsZero() { + t.Error("started_at should not be zero") + } +} + +func TestProgressTrackerConcurrentAccess(t *testing.T) { + pt := newProgressTracker() + done := make(chan struct{}) + + for range 10 { + go func() { + pt.completed.Add(1) + pt.addError("npm", "test", "1.0.0", "error") + _ = pt.snapshot() + done <- struct{}{} + }() + } + + timeout := time.After(5 * time.Second) + for range 10 { + select { + case <-done: + case <-timeout: + t.Fatal("timed out waiting for goroutines") + } + } + + snap := pt.snapshot() + if snap.Completed != 10 { + t.Errorf("completed = %d, want 10", snap.Completed) + } + if len(snap.Errors) != 10 { + t.Errorf("errors = %d, want 10", len(snap.Errors)) + } +} + +func TestNewMirrorDefaultWorkers(t *testing.T) { + m := New(nil, nil, nil, slog.Default(), 0) + if m.workers != 1 { + t.Errorf("workers = %d, want 1 (minimum)", m.workers) + } + + m = New(nil, nil, nil, slog.Default(), -5) + if m.workers != 1 { + t.Errorf("workers = %d, want 1 (minimum)", m.workers) + } +} diff --git a/internal/mirror/registry.go b/internal/mirror/registry.go new file mode 100644 index 0000000..795e190 --- /dev/null +++ b/internal/mirror/registry.go @@ -0,0 +1,47 @@ +package mirror + +import ( + "context" + "fmt" +) + +// RegistrySource enumerates all packages in a registry for full mirroring. +type RegistrySource struct { + Ecosystem string +} + +// supportedRegistries lists ecosystems that support enumeration. +var supportedRegistries = map[string]bool{ + "npm": true, + "pypi": true, + "cargo": true, +} + +func (s *RegistrySource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { + if !supportedRegistries[s.Ecosystem] { + return fmt.Errorf("registry enumeration not supported for ecosystem %q; supported: npm, pypi, cargo", s.Ecosystem) + } + + switch s.Ecosystem { + case "npm": + return s.enumerateNPM(ctx, fn) + case "pypi": + return s.enumeratePyPI(ctx, fn) + case "cargo": + return s.enumerateCargo(ctx, fn) + default: + return fmt.Errorf("unsupported ecosystem: %s", s.Ecosystem) + } +} + +func (s *RegistrySource) enumerateNPM(_ context.Context, _ func(PackageVersion) error) error { + return fmt.Errorf("npm registry enumeration not yet implemented") +} + +func (s *RegistrySource) enumeratePyPI(_ context.Context, _ func(PackageVersion) error) error { + return fmt.Errorf("pypi registry enumeration not yet implemented") +} + +func (s *RegistrySource) enumerateCargo(_ context.Context, _ func(PackageVersion) error) error { + return fmt.Errorf("cargo registry enumeration not yet implemented") +} diff --git a/internal/mirror/registry_test.go b/internal/mirror/registry_test.go new file mode 100644 index 0000000..363bfea --- /dev/null +++ b/internal/mirror/registry_test.go @@ -0,0 +1,46 @@ +package mirror + +import ( + "context" + "testing" +) + +func TestRegistrySourceUnsupported(t *testing.T) { + source := &RegistrySource{Ecosystem: "golang"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for unsupported ecosystem") + } +} + +func TestRegistrySourceNPMNotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "npm"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} + +func TestRegistrySourcePyPINotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "pypi"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} + +func TestRegistrySourceCargoNotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "cargo"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} diff --git a/internal/mirror/source.go b/internal/mirror/source.go new file mode 100644 index 0000000..a6fa364 --- /dev/null +++ b/internal/mirror/source.go @@ -0,0 +1,190 @@ +package mirror + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + + cdx "github.com/CycloneDX/cyclonedx-go" + "github.com/git-pkgs/purl" + "github.com/git-pkgs/registries" + _ "github.com/git-pkgs/registries/all" + "github.com/spdx/tools-golang/spdx" + spdxjson "github.com/spdx/tools-golang/json" + spdxtv "github.com/spdx/tools-golang/tagvalue" +) + +// PackageVersion identifies a specific package version to mirror. +type PackageVersion struct { + Ecosystem string + Name string + Version string +} + +func (pv PackageVersion) String() string { + return fmt.Sprintf("pkg:%s/%s@%s", pv.Ecosystem, pv.Name, pv.Version) +} + +// Source produces PackageVersion items for mirroring. +type Source interface { + Enumerate(ctx context.Context, fn func(PackageVersion) error) error +} + +// PURLSource yields packages from PURL strings. +// Versioned PURLs produce a single item. Unversioned PURLs look up all versions from the registry. +type PURLSource struct { + PURLs []string + RegClient *registries.Client +} + +func (s *PURLSource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { + client := s.RegClient + if client == nil { + client = registries.DefaultClient() + } + + for _, purlStr := range s.PURLs { + p, err := purl.Parse(purlStr) + if err != nil { + return fmt.Errorf("parsing PURL %q: %w", purlStr, err) + } + + ecosystem := purl.PURLTypeToEcosystem(p.Type) + name := p.Name + if p.Namespace != "" { + name = p.Namespace + "/" + p.Name + } + + if p.Version != "" { + if err := fn(PackageVersion{Ecosystem: ecosystem, Name: name, Version: p.Version}); err != nil { + return err + } + continue + } + + // Unversioned: enumerate all versions + versions, err := s.fetchVersions(ctx, client, ecosystem, name) + if err != nil { + return fmt.Errorf("fetching versions for %s/%s: %w", ecosystem, name, err) + } + for _, v := range versions { + if err := fn(PackageVersion{Ecosystem: ecosystem, Name: name, Version: v}); err != nil { + return err + } + } + } + return nil +} + +func (s *PURLSource) fetchVersions(ctx context.Context, client *registries.Client, ecosystem, name string) ([]string, error) { + reg, err := registries.New(purl.EcosystemToPURLType(ecosystem), "", client) + if err != nil { + return nil, err + } + versions, err := reg.FetchVersions(ctx, name) + if err != nil { + return nil, err + } + result := make([]string, len(versions)) + for i, v := range versions { + result[i] = v.Number + } + return result, nil +} + +// SBOMSource extracts package versions from a CycloneDX or SPDX SBOM file. +type SBOMSource struct { + Path string + RegClient *registries.Client +} + +func (s *SBOMSource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { + purls, err := s.extractPURLs() + if err != nil { + return fmt.Errorf("reading SBOM %s: %w", s.Path, err) + } + + inner := &PURLSource{PURLs: purls, RegClient: s.RegClient} + return inner.Enumerate(ctx, fn) +} + +func (s *SBOMSource) extractPURLs() ([]string, error) { + data, err := os.ReadFile(s.Path) + if err != nil { + return nil, err + } + + // Try CycloneDX first + if purls, err := extractCycloneDXPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + // Try SPDX JSON + if purls, err := extractSPDXJSONPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + // Try SPDX tag-value + if purls, err := extractSPDXTVPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + return nil, fmt.Errorf("could not parse SBOM as CycloneDX or SPDX") +} + +func extractCycloneDXPURLs(data []byte) ([]string, error) { + bom := new(cdx.BOM) + if err := json.Unmarshal(data, bom); err != nil { + // Try XML + decoder := cdx.NewBOMDecoder(bytes.NewReader(data), cdx.BOMFileFormatXML) + bom = new(cdx.BOM) + if err := decoder.Decode(bom); err != nil { + return nil, err + } + } + + if bom.Components == nil { + return nil, nil + } + + var purls []string + for _, c := range *bom.Components { + if c.PackageURL != "" { + purls = append(purls, c.PackageURL) + } + } + return purls, nil +} + +func extractSPDXJSONPURLs(data []byte) ([]string, error) { + doc, err := spdxjson.Read(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return extractSPDXDocPURLs(doc), nil +} + +func extractSPDXTVPURLs(data []byte) ([]string, error) { + doc, err := spdxtv.Read(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return extractSPDXDocPURLs(doc), nil +} + +func extractSPDXDocPURLs(doc *spdx.Document) []string { + if doc == nil { + return nil + } + var purls []string + for _, pkg := range doc.Packages { + for _, ref := range pkg.PackageExternalReferences { + if ref.RefType == "purl" { + purls = append(purls, ref.Locator) + } + } + } + return purls +} diff --git a/internal/mirror/source_test.go b/internal/mirror/source_test.go new file mode 100644 index 0000000..ce53acf --- /dev/null +++ b/internal/mirror/source_test.go @@ -0,0 +1,243 @@ +package mirror + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestPURLSourceVersioned(t *testing.T) { + source := &PURLSource{ + PURLs: []string{ + "pkg:npm/lodash@4.17.21", + "pkg:cargo/serde@1.0.0", + "pkg:pypi/requests@2.31.0", + }, + } + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 3 { + t.Fatalf("got %d items, want 3", len(items)) + } + + expected := []PackageVersion{ + {Ecosystem: "npm", Name: "lodash", Version: "4.17.21"}, + {Ecosystem: "cargo", Name: "serde", Version: "1.0.0"}, + {Ecosystem: "pypi", Name: "requests", Version: "2.31.0"}, + } + + for i, want := range expected { + got := items[i] + if got.Ecosystem != want.Ecosystem || got.Name != want.Name || got.Version != want.Version { + t.Errorf("items[%d] = %v, want %v", i, got, want) + } + } +} + +func TestPURLSourceScopedPackage(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"pkg:npm/%40babel/core@7.23.0"}, + } + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 1 { + t.Fatalf("got %d items, want 1", len(items)) + } + + if items[0].Name != "@babel/core" { + t.Errorf("name = %q, want %q", items[0].Name, "@babel/core") + } + if items[0].Version != "7.23.0" { + t.Errorf("version = %q, want %q", items[0].Version, "7.23.0") + } +} + +func TestPURLSourceInvalid(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"not-a-purl"}, + } + + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for invalid PURL") + } +} + +func TestPURLSourceCallbackError(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + } + + wantErr := context.Canceled + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return wantErr + }) + if err != wantErr { + t.Fatalf("got error %v, want %v", err, wantErr) + } +} + +func TestPackageVersionString(t *testing.T) { + pv := PackageVersion{Ecosystem: "npm", Name: "lodash", Version: "4.17.21"} + got := pv.String() + want := "pkg:npm/lodash@4.17.21" + if got != want { + t.Errorf("String() = %q, want %q", got, want) + } +} + +func TestSBOMSourceCycloneDXJSON(t *testing.T) { + bom := map[string]any{ + "bomFormat": "CycloneDX", + "specVersion": "1.4", + "components": []map[string]any{ + {"type": "library", "name": "lodash", "version": "4.17.21", "purl": "pkg:npm/lodash@4.17.21"}, + {"type": "library", "name": "serde", "version": "1.0.0", "purl": "pkg:cargo/serde@1.0.0"}, + }, + } + + path := writeTempJSON(t, bom) + source := &SBOMSource{Path: path} + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 2 { + t.Fatalf("got %d items, want 2", len(items)) + } + + if items[0].Ecosystem != "npm" || items[0].Name != "lodash" || items[0].Version != "4.17.21" { + t.Errorf("items[0] = %v", items[0]) + } + if items[1].Ecosystem != "cargo" || items[1].Name != "serde" || items[1].Version != "1.0.0" { + t.Errorf("items[1] = %v", items[1]) + } +} + +func TestSBOMSourceSPDXJSON(t *testing.T) { + doc := map[string]any{ + "spdxVersion": "SPDX-2.3", + "dataLicense": "CC0-1.0", + "SPDXID": "SPDXRef-DOCUMENT", + "name": "test", + "documentNamespace": "https://example.com/test", + "packages": []map[string]any{ + { + "SPDXID": "SPDXRef-Package", + "name": "lodash", + "version": "4.17.21", + "downloadLocation": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "externalRefs": []map[string]any{ + { + "referenceCategory": "PACKAGE-MANAGER", + "referenceType": "purl", + "referenceLocator": "pkg:npm/lodash@4.17.21", + }, + }, + }, + }, + } + + path := writeTempJSON(t, doc) + source := &SBOMSource{Path: path} + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 1 { + t.Fatalf("got %d items, want 1", len(items)) + } + + if items[0].Name != "lodash" || items[0].Version != "4.17.21" { + t.Errorf("items[0] = %v", items[0]) + } +} + +func TestSBOMSourceNonexistentFile(t *testing.T) { + source := &SBOMSource{Path: "/nonexistent/sbom.json"} + + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} + +func TestSBOMSourceInvalidFormat(t *testing.T) { + path := filepath.Join(t.TempDir(), "invalid.txt") + if err := os.WriteFile(path, []byte("this is not an SBOM"), 0644); err != nil { + t.Fatal(err) + } + + source := &SBOMSource{Path: path} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for invalid SBOM") + } +} + +func TestSBOMSourceEmptyCycloneDX(t *testing.T) { + bom := map[string]any{ + "bomFormat": "CycloneDX", + "specVersion": "1.4", + } + path := writeTempJSON(t, bom) + + // This should fall through to SPDX parsing, which will also fail, + // resulting in an error about not being able to parse + source := &SBOMSource{Path: path} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for empty SBOM") + } +} + +func writeTempJSON(t *testing.T, v any) string { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + path := filepath.Join(t.TempDir(), "sbom.json") + if err := os.WriteFile(path, data, 0644); err != nil { + t.Fatal(err) + } + return path +} diff --git a/internal/server/mirror_api.go b/internal/server/mirror_api.go new file mode 100644 index 0000000..6a6a6ca --- /dev/null +++ b/internal/server/mirror_api.go @@ -0,0 +1,70 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/git-pkgs/proxy/internal/mirror" + "github.com/go-chi/chi/v5" +) + +// MirrorAPIHandler handles mirror API requests. +type MirrorAPIHandler struct { + jobs *mirror.JobStore +} + +// NewMirrorAPIHandler creates a new mirror API handler. +func NewMirrorAPIHandler(jobs *mirror.JobStore) *MirrorAPIHandler { + return &MirrorAPIHandler{jobs: jobs} +} + +// HandleCreate starts a new mirror job. +func (h *MirrorAPIHandler) HandleCreate(w http.ResponseWriter, r *http.Request) { + var req mirror.JobRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + writeJSON(w, map[string]string{"error": "invalid request body"}) + return + } + + id, err := h.jobs.Create(req) + if err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + writeJSON(w, map[string]string{"error": err.Error()}) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + writeJSON(w, map[string]string{"id": id}) +} + +// HandleGet returns the status of a mirror job. +func (h *MirrorAPIHandler) HandleGet(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + job := h.jobs.Get(id) + if job == nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "job not found"}) + return + } + + w.Header().Set("Content-Type", "application/json") + writeJSON(w, job) +} + +// HandleCancel cancels a running mirror job. +func (h *MirrorAPIHandler) HandleCancel(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + if h.jobs.Cancel(id) { + w.Header().Set("Content-Type", "application/json") + writeJSON(w, map[string]string{"status": "canceled"}) + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "job not found or not running"}) + } +} diff --git a/internal/server/mirror_api_test.go b/internal/server/mirror_api_test.go new file mode 100644 index 0000000..56b2c58 --- /dev/null +++ b/internal/server/mirror_api_test.go @@ -0,0 +1,163 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/mirror" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" + "github.com/go-chi/chi/v5" +) + +func setupMirrorAPI(t *testing.T) *MirrorAPIHandler { + t.Helper() + + dbPath := t.TempDir() + "/test.db" + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("creating database: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("migrating schema: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + storeDir := t.TempDir() + store, err := storage.OpenBucket(context.Background(), "file://"+storeDir) + if err != nil { + t.Fatalf("opening storage: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + + m := mirror.New(proxy, db, store, logger, 1) + js := mirror.NewJobStore(m) + return NewMirrorAPIHandler(js) +} + +func TestMirrorAPICreateJob(t *testing.T) { + h := setupMirrorAPI(t) + + body, _ := json.Marshal(mirror.JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status = %d, want %d", w.Code, http.StatusAccepted) + } + + var resp map[string]string + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + if resp["id"] == "" { + t.Error("expected non-empty job ID") + } +} + +func TestMirrorAPICreateInvalidBody(t *testing.T) { + h := setupMirrorAPI(t) + + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader([]byte("not json"))) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestMirrorAPICreateEmptyRequest(t *testing.T) { + h := setupMirrorAPI(t) + + body, _ := json.Marshal(mirror.JobRequest{}) + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestMirrorAPIGetNotFound(t *testing.T) { + h := setupMirrorAPI(t) + + r := chi.NewRouter() + r.Get("/api/mirror/{id}", h.HandleGet) + + req := httptest.NewRequest("GET", "/api/mirror/nonexistent", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestMirrorAPICancelNotFound(t *testing.T) { + h := setupMirrorAPI(t) + + r := chi.NewRouter() + r.Delete("/api/mirror/{id}", h.HandleCancel) + + req := httptest.NewRequest("DELETE", "/api/mirror/nonexistent", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestMirrorAPICreateAndGetJob(t *testing.T) { + h := setupMirrorAPI(t) + + // Create a job + body, _ := json.Marshal(mirror.JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + createReq := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + createW := httptest.NewRecorder() + h.HandleCreate(createW, createReq) + + var createResp map[string]string + _ = json.NewDecoder(createW.Body).Decode(&createResp) + jobID := createResp["id"] + + // Get the job + r := chi.NewRouter() + r.Get("/api/mirror/{id}", h.HandleGet) + + getReq := httptest.NewRequest("GET", "/api/mirror/"+jobID, nil) + getW := httptest.NewRecorder() + r.ServeHTTP(getW, getReq) + + if getW.Code != http.StatusOK { + t.Errorf("status = %d, want %d", getW.Code, http.StatusOK) + } + + var job mirror.Job + if err := json.NewDecoder(getW.Body).Decode(&job); err != nil { + t.Fatalf("decoding job: %v", err) + } + if job.ID != jobID { + t.Errorf("job ID = %q, want %q", job.ID, jobID) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 429c988..60ed835 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,6 +53,7 @@ import ( "github.com/git-pkgs/proxy/internal/enrichment" "github.com/git-pkgs/proxy/internal/handler" "github.com/git-pkgs/proxy/internal/metrics" + "github.com/git-pkgs/proxy/internal/mirror" "github.com/git-pkgs/proxy/internal/storage" "github.com/git-pkgs/purl" "github.com/git-pkgs/registries/fetch" @@ -77,6 +78,7 @@ type Server struct { logger *slog.Logger http *http.Server templates *Templates + cancel context.CancelFunc } // New creates a new Server with the given configuration. @@ -144,6 +146,7 @@ func (s *Server) Start() error { } proxy := handler.NewProxy(s.db, s.storage, fetcher, resolver, s.logger) proxy.Cooldown = cd + proxy.CacheMetadata = s.cfg.CacheMetadata // Create router with Chi r := chi.NewRouter() @@ -228,6 +231,14 @@ func (s *Server) Start() error { r.Get("/api/browse/{ecosystem}/*", s.handleBrowsePath) r.Get("/api/compare/{ecosystem}/*", s.handleComparePath) + // Mirror API endpoints + mirrorSvc := mirror.New(proxy, s.db, s.storage, s.logger, 4) //nolint:mnd // default concurrency + jobStore := mirror.NewJobStore(mirrorSvc) + mirrorAPI := NewMirrorAPIHandler(jobStore) + r.Post("/api/mirror", mirrorAPI.HandleCreate) + r.Get("/api/mirror/{id}", mirrorAPI.HandleGet) + r.Delete("/api/mirror/{id}", mirrorAPI.HandleCancel) + s.http = &http.Server{ Addr: s.cfg.Listen, Handler: r, @@ -242,8 +253,11 @@ func (s *Server) Start() error { "storage", s.storage.URL(), "database", s.cfg.Database.Path) - // Start background goroutine to update cache stats metrics + // Start background goroutines + bgCtx, bgCancel := context.WithCancel(context.Background()) + s.cancel = bgCancel go s.updateCacheStatsMetrics() + go jobStore.StartCleanup(bgCtx) return s.http.ListenAndServe() } @@ -274,6 +288,10 @@ func (s *Server) updateCacheStats() { func (s *Server) Shutdown(ctx context.Context) error { s.logger.Info("shutting down server") + if s.cancel != nil { + s.cancel() + } + var errs []error if s.http != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 4c49035..be88bf6 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -320,10 +320,14 @@ func TestGoList(t *testing.T) { w := httptest.NewRecorder() ts.handler.ServeHTTP(w, req) - // The handler is mounted if we get a Go proxy error (not a generic 404) - body := w.Body.String() - if w.Code == http.StatusNotFound && !strings.Contains(body, "example.com") { - t.Errorf("go handler should be mounted, got status %d, body: %s", w.Code, body) + // The handler is mounted if we get a response from the proxy (404 from upstream + // or 502 from connection failure), not a chi router 404. + // With metadata caching, upstream 404 is cleanly returned as our own 404. + if w.Code == http.StatusNotFound { + body := w.Body.String() + if !strings.Contains(body, "not found") { + t.Errorf("go handler should be mounted, got status %d, body: %s", w.Code, body) + } } } From eb2e6d6e8fd2232120156e5896e8c2d5b3397a24 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Wed, 1 Apr 2026 15:40:18 +0100 Subject: [PATCH 2/5] Fix concurrency, resource, and reliability issues in mirror - Wire job contexts to server shutdown context so jobs are canceled on server stop instead of running indefinitely - Defer context cancel in runJob so completed jobs don't leak contexts - Cap error accumulation in progressTracker to 1000 entries to prevent OOM on large mirror operations with many failures - Add panic recovery in errgroup workers to prevent process crashes - Use defer for db.Close() in runMirror CLI to ensure cleanup on all error paths --- cmd/proxy/main.go | 10 +++------- internal/mirror/job.go | 25 +++++++++++++++---------- internal/mirror/job_test.go | 17 +++++++++-------- internal/mirror/mirror.go | 26 +++++++++++++++++++------- internal/server/mirror_api_test.go | 2 +- internal/server/server.go | 10 +++++----- 6 files changed, 52 insertions(+), 38 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 229f2d5..ba0e9af 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -432,11 +432,12 @@ func runMirror() { fmt.Fprintf(os.Stderr, "error opening database: %v\n", err) os.Exit(1) } + defer func() { _ = db.Close() }() if err := db.MigrateSchema(); err != nil { _ = db.Close() fmt.Fprintf(os.Stderr, "error migrating schema: %v\n", err) - os.Exit(1) + os.Exit(1) //nolint:gocritic // db closed above } // Open storage @@ -448,7 +449,7 @@ func runMirror() { if err != nil { _ = db.Close() fmt.Fprintf(os.Stderr, "error opening storage: %v\n", err) - os.Exit(1) + os.Exit(1) //nolint:gocritic // db closed above } // Build proxy (reuses same pipeline as serve) @@ -470,7 +471,6 @@ func runMirror() { if *dryRun { items, err := m.RunDryRun(ctx, source) if err != nil { - _ = db.Close() fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -478,19 +478,15 @@ func runMirror() { for _, item := range items { fmt.Printf(" %s\n", item) } - _ = db.Close() return } progress, err := m.Run(ctx, source) if err != nil { - _ = db.Close() fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - _ = db.Close() - fmt.Printf("Mirror complete: %d downloaded, %d skipped (cached), %d failed, %s total\n", progress.Completed, progress.Skipped, progress.Failed, formatSize(progress.Bytes)) diff --git a/internal/mirror/job.go b/internal/mirror/job.go index a0da7c5..c612780 100644 --- a/internal/mirror/job.go +++ b/internal/mirror/job.go @@ -41,16 +41,19 @@ type JobRequest struct { // JobStore manages in-memory mirror jobs. type JobStore struct { - mu sync.RWMutex - jobs map[string]*Job - mirror *Mirror + mu sync.RWMutex + jobs map[string]*Job + mirror *Mirror + parentCtx context.Context } -// NewJobStore creates a new job store. -func NewJobStore(m *Mirror) *JobStore { +// NewJobStore creates a new job store. The parent context is used as the base +// for all job contexts so that jobs are canceled when the server shuts down. +func NewJobStore(ctx context.Context, m *Mirror) *JobStore { return &JobStore{ - jobs: make(map[string]*Job), - mirror: m, + jobs: make(map[string]*Job), + mirror: m, + parentCtx: ctx, } } @@ -62,7 +65,7 @@ func (js *JobStore) Create(req JobRequest) (string, error) { } id := newJobID() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(js.parentCtx) job := &Job{ ID: id, @@ -75,7 +78,7 @@ func (js *JobStore) Create(req JobRequest) (string, error) { js.jobs[id] = job js.mu.Unlock() - go js.runJob(ctx, job, source) + go js.runJob(ctx, cancel, job, source) return id, nil } @@ -144,7 +147,9 @@ func (js *JobStore) StartCleanup(ctx context.Context) { } } -func (js *JobStore) runJob(ctx context.Context, job *Job, source Source) { +func (js *JobStore) runJob(ctx context.Context, cancel context.CancelFunc, job *Job, source Source) { + defer cancel() + js.mu.Lock() job.State = JobStateRunning js.mu.Unlock() diff --git a/internal/mirror/job_test.go b/internal/mirror/job_test.go index 1159b45..4698ec7 100644 --- a/internal/mirror/job_test.go +++ b/internal/mirror/job_test.go @@ -1,13 +1,14 @@ package mirror import ( + "context" "testing" "time" ) func TestJobStoreCreateAndGet(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) id, err := js.Create(JobRequest{ PURLs: []string{"pkg:npm/lodash@4.17.21"}, @@ -34,7 +35,7 @@ func TestJobStoreCreateAndGet(t *testing.T) { func TestJobStoreGetNotFound(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) job := js.Get("nonexistent") if job != nil { @@ -44,7 +45,7 @@ func TestJobStoreGetNotFound(t *testing.T) { func TestJobStoreCancelNotFound(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) if js.Cancel("nonexistent") { t.Error("expected Cancel to return false for nonexistent job") @@ -53,7 +54,7 @@ func TestJobStoreCancelNotFound(t *testing.T) { func TestJobStoreCreateInvalidRequest(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) _, err := js.Create(JobRequest{}) if err == nil { @@ -63,7 +64,7 @@ func TestJobStoreCreateInvalidRequest(t *testing.T) { func TestJobStoreMultipleJobs(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) id1, err := js.Create(JobRequest{PURLs: []string{"pkg:npm/lodash@4.17.21"}}) if err != nil { @@ -88,7 +89,7 @@ func TestJobStoreMultipleJobs(t *testing.T) { func TestSourceFromRequestPURLs(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) source, err := js.sourceFromRequest(JobRequest{PURLs: []string{"pkg:npm/lodash@1.0.0"}}) if err != nil { @@ -101,7 +102,7 @@ func TestSourceFromRequestPURLs(t *testing.T) { func TestSourceFromRequestRegistry(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) source, err := js.sourceFromRequest(JobRequest{Registry: "npm"}) if err != nil { @@ -114,7 +115,7 @@ func TestSourceFromRequestRegistry(t *testing.T) { func TestJobStoreCleanup(t *testing.T) { m := setupTestMirror(t, 1) - js := NewJobStore(m) + js := NewJobStore(context.Background(), m) // Add a completed job with old CreatedAt js.mu.Lock() diff --git a/internal/mirror/mirror.go b/internal/mirror/mirror.go index 4377cf1..8553f2b 100644 --- a/internal/mirror/mirror.go +++ b/internal/mirror/mirror.go @@ -78,14 +78,18 @@ func newProgressTracker() *progressTracker { return pt } +const maxTrackedErrors = 1000 + func (pt *progressTracker) addError(eco, name, version, err string) { pt.mu.Lock() - pt.errors = append(pt.errors, MirrorError{ - Ecosystem: eco, - Name: name, - Version: version, - Error: err, - }) + if len(pt.errors) < maxTrackedErrors { + pt.errors = append(pt.errors, MirrorError{ + Ecosystem: eco, + Name: name, + Version: version, + Error: err, + }) + } pt.mu.Unlock() } @@ -132,7 +136,15 @@ func (m *Mirror) Run(ctx context.Context, source Source) (*Progress, error) { g.SetLimit(m.workers) for _, item := range items { - g.Go(func() error { + g.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + m.logger.Error("panic in mirror worker", "recover", r, + "ecosystem", item.Ecosystem, "name", item.Name, "version", item.Version) + tracker.failed.Add(1) + tracker.addError(item.Ecosystem, item.Name, item.Version, fmt.Sprintf("panic: %v", r)) + } + }() m.mirrorOne(gctx, item, tracker) return nil // never fail the group; errors are tracked }) diff --git a/internal/server/mirror_api_test.go b/internal/server/mirror_api_test.go index 56b2c58..0e84da1 100644 --- a/internal/server/mirror_api_test.go +++ b/internal/server/mirror_api_test.go @@ -43,7 +43,7 @@ func setupMirrorAPI(t *testing.T) *MirrorAPIHandler { proxy := handler.NewProxy(db, store, fetcher, resolver, logger) m := mirror.New(proxy, db, store, logger, 1) - js := mirror.NewJobStore(m) + js := mirror.NewJobStore(context.Background(), m) return NewMirrorAPIHandler(js) } diff --git a/internal/server/server.go b/internal/server/server.go index 60ed835..dc683df 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -231,9 +231,13 @@ func (s *Server) Start() error { r.Get("/api/browse/{ecosystem}/*", s.handleBrowsePath) r.Get("/api/compare/{ecosystem}/*", s.handleComparePath) + // Start background context (used by mirror jobs and cleanup) + bgCtx, bgCancel := context.WithCancel(context.Background()) + s.cancel = bgCancel + // Mirror API endpoints mirrorSvc := mirror.New(proxy, s.db, s.storage, s.logger, 4) //nolint:mnd // default concurrency - jobStore := mirror.NewJobStore(mirrorSvc) + jobStore := mirror.NewJobStore(bgCtx, mirrorSvc) mirrorAPI := NewMirrorAPIHandler(jobStore) r.Post("/api/mirror", mirrorAPI.HandleCreate) r.Get("/api/mirror/{id}", mirrorAPI.HandleGet) @@ -252,10 +256,6 @@ func (s *Server) Start() error { "base_url", s.cfg.BaseURL, "storage", s.storage.URL(), "database", s.cfg.Database.Path) - - // Start background goroutines - bgCtx, bgCancel := context.WithCancel(context.Background()) - s.cancel = bgCancel go s.updateCacheStatsMetrics() go jobStore.StartCleanup(bgCtx) From 6feec3c455d0a52a58cb9d4be682de2a83094138 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Wed, 1 Apr 2026 16:14:07 +0100 Subject: [PATCH 3/5] Fix review issues in mirror feature - Fix race where runJob could overwrite canceled state set by Cancel() - Fix Debian ecosystem name inconsistency ("deb" -> "debian") - Stream metadata responses when caching is disabled to avoid buffering - Add metadata_cache table to initial schema strings for consistency - Gate mirror API behind mirror_api config flag (disabled by default) - Fix goconst lint in metadata_cache_test.go --- docs/configuration.md | 12 ++++++++++ internal/config/config.go | 7 ++++++ internal/handler/debian.go | 2 +- internal/handler/handler.go | 46 +++++++++++++++++++++++++++++++++++++ internal/mirror/job.go | 9 ++++++++ internal/mirror/job_test.go | 25 ++++++++++++++++++++ internal/server/server.go | 18 ++++++++------- 7 files changed, 110 insertions(+), 9 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 2ffb10f..0623965 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,6 +225,18 @@ Or via environment variable: `PROXY_CACHE_METADATA=true`. The `proxy mirror` command always enables metadata caching regardless of this setting. +## Mirror API + +The `/api/mirror` endpoints are disabled by default. Enable them to allow starting mirror jobs via HTTP: + +```yaml +mirror_api: true +``` + +Or via environment variable: `PROXY_MIRROR_API=true`. + +When disabled, the endpoints are not registered and return 404. + ## Mirror Command The `proxy mirror` command pre-populates the cache from various sources. It accepts the same storage and database flags as `serve`. diff --git a/internal/config/config.go b/internal/config/config.go index 8021783..6b82861 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -88,6 +88,10 @@ type Config struct { // When enabled, metadata is stored in the database and storage backend. // The mirror command always enables this regardless of this setting. CacheMetadata bool `json:"cache_metadata" yaml:"cache_metadata"` + + // MirrorAPI enables the /api/mirror endpoints for starting mirror jobs via HTTP. + // Disabled by default to prevent unauthenticated users from triggering downloads. + MirrorAPI bool `json:"mirror_api" yaml:"mirror_api"` } // CooldownConfig configures version cooldown periods. @@ -314,6 +318,9 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_CACHE_METADATA"); v != "" { c.CacheMetadata = v == "true" || v == "1" } + if v := os.Getenv("PROXY_MIRROR_API"); v != "" { + c.MirrorAPI = v == "true" || v == "1" + } } // Validate checks the configuration for errors. diff --git a/internal/handler/debian.go b/internal/handler/debian.go index db57e3b..b767f6d 100644 --- a/internal/handler/debian.go +++ b/internal/handler/debian.go @@ -94,7 +94,7 @@ func (h *DebianHandler) handlePackageDownload(w http.ResponseWriter, r *http.Req // These change frequently so we don't cache them. func (h *DebianHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { cacheKey := strings.ReplaceAll(path, "/", "_") - h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "deb", cacheKey, "*/*") + h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian", cacheKey, "*/*") } // proxyFile proxies any file directly without caching. diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 799fbd3..2092b96 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -506,7 +506,15 @@ func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, stor // ProxyCached fetches metadata from upstream (with optional caching for offline fallback) // and writes it to the response. Optional acceptHeaders specify the Accept header to send. +// When metadata caching is disabled, the response is streamed directly to avoid buffering +// large metadata responses (e.g. npm packages with many versions) in memory. func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, ecosystem, cacheKey string, acceptHeaders ...string) { + if !p.CacheMetadata { + // Stream directly without buffering when caching is off. + p.proxyMetadataStream(w, r, upstreamURL, acceptHeaders...) + return + } + body, contentType, err := p.FetchOrCacheMetadata(r.Context(), ecosystem, cacheKey, upstreamURL, acceptHeaders...) if err != nil { if errors.Is(err, ErrUpstreamNotFound) { @@ -523,6 +531,44 @@ func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, _, _ = w.Write(body) } +// proxyMetadataStream forwards an upstream metadata response by streaming it to the client +// without buffering the full body in memory. +func (p *Proxy) proxyMetadataStream(w http.ResponseWriter, r *http.Request, upstreamURL string, acceptHeaders ...string) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) + if err != nil { + http.Error(w, "failed to create request", http.StatusInternalServerError) + return + } + + accept := contentTypeJSON + if len(acceptHeaders) > 0 && acceptHeaders[0] != "" { + accept = acceptHeaders[0] + } + req.Header.Set("Accept", accept) + + for _, header := range []string{"Accept-Encoding", "If-Modified-Since", "If-None-Match"} { + if v := r.Header.Get(header); v != "" { + req.Header.Set(header, v) + } + } + + resp, err := p.HTTPClient.Do(req) + if err != nil { + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + defer func() { _ = resp.Body.Close() }() + + for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { + if v := resp.Header.Get(header); v != "" { + w.Header().Set(header, v) + } + } + + w.WriteHeader(resp.StatusCode) + _, _ = io.Copy(w, resp.Body) +} + // GetOrFetchArtifactFromURL retrieves an artifact from cache or fetches from a specific URL. // This is useful for registries where download URLs are determined from metadata. func (p *Proxy) GetOrFetchArtifactFromURL(ctx context.Context, ecosystem, name, version, filename, downloadURL string) (*CacheResult, error) { diff --git a/internal/mirror/job.go b/internal/mirror/job.go index c612780..b91a7f8 100644 --- a/internal/mirror/job.go +++ b/internal/mirror/job.go @@ -151,6 +151,10 @@ func (js *JobStore) runJob(ctx context.Context, cancel context.CancelFunc, job * defer cancel() js.mu.Lock() + if job.State == JobStateCanceled { + js.mu.Unlock() + return + } job.State = JobStateRunning js.mu.Unlock() @@ -159,6 +163,11 @@ func (js *JobStore) runJob(ctx context.Context, cancel context.CancelFunc, job * js.mu.Lock() defer js.mu.Unlock() + // Cancel() may have already set the state; don't overwrite it. + if job.State == JobStateCanceled { + return + } + if err != nil { job.State = JobStateFailed job.Error = err.Error() diff --git a/internal/mirror/job_test.go b/internal/mirror/job_test.go index 4698ec7..90f09e4 100644 --- a/internal/mirror/job_test.go +++ b/internal/mirror/job_test.go @@ -149,6 +149,31 @@ func TestJobStoreCleanup(t *testing.T) { } } +func TestJobStoreCancelPreservesStateAfterRunJob(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + // Create a job that will fail (registry enumeration is not implemented) + id, err := js.Create(JobRequest{Registry: "npm"}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + // Cancel immediately -- the job may already be running + js.Cancel(id) + + // Wait for runJob goroutine to finish + time.Sleep(200 * time.Millisecond) + + job := js.Get(id) + if job == nil { + t.Fatal("Get() returned nil") + } + if job.State != JobStateCanceled { + t.Errorf("state = %q, want %q (cancel should not be overwritten by runJob)", job.State, JobStateCanceled) + } +} + func TestNewJobIDUnique(t *testing.T) { ids := make(map[string]bool) for range 100 { diff --git a/internal/server/server.go b/internal/server/server.go index dc683df..1c2156b 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -235,13 +235,16 @@ func (s *Server) Start() error { bgCtx, bgCancel := context.WithCancel(context.Background()) s.cancel = bgCancel - // Mirror API endpoints - mirrorSvc := mirror.New(proxy, s.db, s.storage, s.logger, 4) //nolint:mnd // default concurrency - jobStore := mirror.NewJobStore(bgCtx, mirrorSvc) - mirrorAPI := NewMirrorAPIHandler(jobStore) - r.Post("/api/mirror", mirrorAPI.HandleCreate) - r.Get("/api/mirror/{id}", mirrorAPI.HandleGet) - r.Delete("/api/mirror/{id}", mirrorAPI.HandleCancel) + // Mirror API endpoints (opt-in via mirror_api config or PROXY_MIRROR_API env) + if s.cfg.MirrorAPI { + mirrorSvc := mirror.New(proxy, s.db, s.storage, s.logger, 4) //nolint:mnd // default concurrency + jobStore := mirror.NewJobStore(bgCtx, mirrorSvc) + mirrorAPI := NewMirrorAPIHandler(jobStore) + r.Post("/api/mirror", mirrorAPI.HandleCreate) + r.Get("/api/mirror/{id}", mirrorAPI.HandleGet) + r.Delete("/api/mirror/{id}", mirrorAPI.HandleCancel) + go jobStore.StartCleanup(bgCtx) + } s.http = &http.Server{ Addr: s.cfg.Listen, @@ -257,7 +260,6 @@ func (s *Server) Start() error { "storage", s.storage.URL(), "database", s.cfg.Database.Path) go s.updateCacheStatsMetrics() - go jobStore.StartCleanup(bgCtx) return s.http.ListenAndServe() } From f7328c5c279f05bf52804f8a286d8d4e0137b4a0 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Wed, 1 Apr 2026 20:14:11 +0100 Subject: [PATCH 4/5] Fix metadata caching, 404 propagation, mirror progress, and registry stubs - ProxyCached now stores upstream Last-Modified in the cache and uses it (along with ETag) for conditional request handling, returning 304 when client validators match. Adds Content-Length to cached responses. - Handlers calling FetchOrCacheMetadata (pypi, composer, pub, nuget) now check for ErrUpstreamNotFound and return 404 instead of 502, matching the existing npm and cargo behavior. - Mirror jobs report live progress via a periodic callback while running, so API polls return real counts instead of zeroed progress. - Registry mirroring removed from CLI flags, API acceptance, README, and docs since every enumerator was a stub returning "not yet implemented". - Added tests for the conditional metadata path (ETag/If-None-Match, Last-Modified/If-Modified-Since, 304 responses, header omission). --- README.md | 5 +- cmd/proxy/main.go | 8 +- docs/configuration.md | 1 - internal/database/queries.go | 14 +-- internal/database/schema.go | 4 + internal/database/types.go | 21 ++-- internal/handler/composer.go | 5 + internal/handler/handler.go | 95 +++++++++++++----- internal/handler/handler_test.go | 164 +++++++++++++++++++++++++++++++ internal/handler/nuget.go | 5 + internal/handler/pub.go | 5 + internal/handler/pypi.go | 9 ++ internal/mirror/job.go | 12 ++- internal/mirror/job_test.go | 15 ++- internal/mirror/mirror.go | 39 +++++++- internal/mirror/registry.go | 37 +------ 16 files changed, 342 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index e4aa31f..40ee5b5 100644 --- a/README.md +++ b/README.md @@ -474,9 +474,6 @@ proxy mirror pkg:npm/lodash # Mirror from a CycloneDX or SPDX SBOM proxy mirror --sbom sbom.cdx.json -# Full registry mirror (npm, pypi, cargo supported) -proxy mirror --registry npm - # Preview what would be mirrored proxy mirror --dry-run pkg:npm/lodash @@ -579,7 +576,7 @@ Recently cached: | Endpoint | Description | |----------|-------------| -| `POST /api/mirror` | Start a mirror job (JSON body with `purls` or `registry`) | +| `POST /api/mirror` | Start a mirror job (JSON body with `purls`) | | `GET /api/mirror/{id}` | Get job status and progress | | `DELETE /api/mirror/{id}` | Cancel a running job | diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index ba0e9af..40d5d34 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -358,7 +358,6 @@ func runMirror() { databasePath := fs.String("database-path", "", "Path to SQLite database file") databaseURL := fs.String("database-url", "", "PostgreSQL connection URL") sbomPath := fs.String("sbom", "", "Path to CycloneDX or SPDX SBOM file") - registry := fs.String("registry", "", "Ecosystem name for full registry mirror") concurrency := fs.Int("concurrency", 4, "Number of parallel downloads") //nolint:mnd // default concurrency dryRun := fs.Bool("dry-run", false, "Show what would be mirrored without downloading") @@ -368,8 +367,7 @@ func runMirror() { fmt.Fprintf(os.Stderr, "Examples:\n") fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash@4.17.21\n") fmt.Fprintf(os.Stderr, " proxy mirror --sbom sbom.cdx.json\n") - fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash # all versions\n") - fmt.Fprintf(os.Stderr, " proxy mirror --registry npm\n\n") + fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash # all versions\n\n") fmt.Fprintf(os.Stderr, "Flags:\n") fs.PrintDefaults() } @@ -382,12 +380,10 @@ func runMirror() { switch { case *sbomPath != "": source = &mirror.SBOMSource{Path: *sbomPath} - case *registry != "": - source = &mirror.RegistrySource{Ecosystem: *registry} case len(purls) > 0: source = &mirror.PURLSource{PURLs: purls} default: - fmt.Fprintf(os.Stderr, "error: provide PURLs, --sbom, or --registry\n") + fmt.Fprintf(os.Stderr, "error: provide PURLs or --sbom\n") fs.Usage() os.Exit(1) } diff --git a/docs/configuration.md b/docs/configuration.md index 0623965..16e71bb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -244,7 +244,6 @@ The `proxy mirror` command pre-populates the cache from various sources. It acce | Flag | Default | Description | |------|---------|-------------| | `--sbom` | | Path to CycloneDX or SPDX SBOM file | -| `--registry` | | Ecosystem name for full registry mirror | | `--concurrency` | `4` | Number of parallel downloads | | `--dry-run` | `false` | Show what would be mirrored without downloading | | `--config` | | Path to configuration file | diff --git a/internal/database/queries.go b/internal/database/queries.go index 8f48876..5d95596 100644 --- a/internal/database/queries.go +++ b/internal/database/queries.go @@ -894,7 +894,7 @@ func (db *DB) GetMetadataCache(ecosystem, name string) (*MetadataCacheEntry, err var entry MetadataCacheEntry query := db.Rebind(` SELECT id, ecosystem, name, storage_path, etag, content_type, - size, fetched_at, created_at, updated_at + size, last_modified, fetched_at, created_at, updated_at FROM metadata_cache WHERE ecosystem = ? AND name = ? `) err := db.Get(&entry, query, ecosystem, name) @@ -914,26 +914,28 @@ func (db *DB) UpsertMetadataCache(entry *MetadataCacheEntry) error { if db.dialect == DialectPostgres { query = ` INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, - size, fetched_at, created_at, updated_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + size, last_modified, fetched_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON CONFLICT(ecosystem, name) DO UPDATE SET storage_path = EXCLUDED.storage_path, etag = EXCLUDED.etag, content_type = EXCLUDED.content_type, size = EXCLUDED.size, + last_modified = EXCLUDED.last_modified, fetched_at = EXCLUDED.fetched_at, updated_at = EXCLUDED.updated_at ` } else { query = ` INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, - size, fetched_at, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + size, last_modified, fetched_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(ecosystem, name) DO UPDATE SET storage_path = excluded.storage_path, etag = excluded.etag, content_type = excluded.content_type, size = excluded.size, + last_modified = excluded.last_modified, fetched_at = excluded.fetched_at, updated_at = excluded.updated_at ` @@ -941,7 +943,7 @@ func (db *DB) UpsertMetadataCache(entry *MetadataCacheEntry) error { _, err := db.Exec(query, entry.Ecosystem, entry.Name, entry.StoragePath, entry.ETag, - entry.ContentType, entry.Size, entry.FetchedAt, now, now, + entry.ContentType, entry.Size, entry.LastModified, entry.FetchedAt, now, now, ) if err != nil { return fmt.Errorf("upserting metadata cache: %w", err) diff --git a/internal/database/schema.go b/internal/database/schema.go index 91827a2..e6f284f 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -99,6 +99,7 @@ CREATE TABLE IF NOT EXISTS metadata_cache ( etag TEXT, content_type TEXT, size INTEGER, + last_modified DATETIME, fetched_at DATETIME, created_at DATETIME, updated_at DATETIME @@ -198,6 +199,7 @@ CREATE TABLE IF NOT EXISTS metadata_cache ( etag TEXT, content_type TEXT, size BIGINT, + last_modified TIMESTAMP, fetched_at TIMESTAMP, created_at TIMESTAMP, updated_at TIMESTAMP @@ -596,6 +598,7 @@ func (db *DB) EnsureMetadataCacheTable() error { etag TEXT, content_type TEXT, size BIGINT, + last_modified TIMESTAMP, fetched_at TIMESTAMP, created_at TIMESTAMP, updated_at TIMESTAMP @@ -612,6 +615,7 @@ func (db *DB) EnsureMetadataCacheTable() error { etag TEXT, content_type TEXT, size INTEGER, + last_modified DATETIME, fetched_at DATETIME, created_at DATETIME, updated_at DATETIME diff --git a/internal/database/types.go b/internal/database/types.go index 9d0898b..f5b718e 100644 --- a/internal/database/types.go +++ b/internal/database/types.go @@ -78,16 +78,17 @@ func (a *Artifact) IsCached() bool { // MetadataCacheEntry represents a cached metadata blob for offline serving. type MetadataCacheEntry struct { - ID int64 `db:"id" json:"id"` - Ecosystem string `db:"ecosystem" json:"ecosystem"` - Name string `db:"name" json:"name"` - StoragePath string `db:"storage_path" json:"storage_path"` - ETag sql.NullString `db:"etag" json:"etag,omitempty"` - ContentType sql.NullString `db:"content_type" json:"content_type,omitempty"` - Size sql.NullInt64 `db:"size" json:"size,omitempty"` - FetchedAt sql.NullTime `db:"fetched_at" json:"fetched_at,omitempty"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ID int64 `db:"id" json:"id"` + Ecosystem string `db:"ecosystem" json:"ecosystem"` + Name string `db:"name" json:"name"` + StoragePath string `db:"storage_path" json:"storage_path"` + ETag sql.NullString `db:"etag" json:"etag,omitempty"` + ContentType sql.NullString `db:"content_type" json:"content_type,omitempty"` + Size sql.NullInt64 `db:"size" json:"size,omitempty"` + LastModified sql.NullTime `db:"last_modified" json:"last_modified,omitempty"` + FetchedAt sql.NullTime `db:"fetched_at" json:"fetched_at,omitempty"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } // Vulnerability represents a cached vulnerability record. diff --git a/internal/handler/composer.go b/internal/handler/composer.go index 2a47e81..7936401 100644 --- a/internal/handler/composer.go +++ b/internal/handler/composer.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -91,6 +92,10 @@ func (h *ComposerHandler) handlePackageMetadata(w http.ResponseWriter, r *http.R body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "composer", packageName, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 2092b96..2edd648 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -10,6 +10,7 @@ import ( "io" "log/slog" "net/http" + "strconv" "strings" "time" @@ -383,14 +384,14 @@ func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, u } // Try upstream - body, contentType, etag, err := p.fetchUpstreamMetadata(ctx, upstreamURL, entry, accept) + body, contentType, etag, lastModified, err := p.fetchUpstreamMetadata(ctx, upstreamURL, entry, accept) if errors.Is(err, errStale304) { // 304 but cached file is gone; retry without ETag - body, contentType, etag, err = p.fetchUpstreamMetadata(ctx, upstreamURL, nil, accept) + body, contentType, etag, lastModified, err = p.fetchUpstreamMetadata(ctx, upstreamURL, nil, accept) } if err == nil { if p.CacheMetadata { - p.cacheMetadataBlob(ctx, ecosystem, cacheKey, storagePath, body, contentType, etag) + p.cacheMetadataBlob(ctx, ecosystem, cacheKey, storagePath, body, contentType, etag, lastModified) } return body, contentType, nil } @@ -424,11 +425,13 @@ func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, u } // fetchUpstreamMetadata fetches metadata from upstream, using ETag for conditional revalidation. -// Returns the body, content type, ETag, and any error. -func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, entry *database.MetadataCacheEntry, accept string) ([]byte, string, string, error) { +// Returns the body, content type, ETag, upstream Last-Modified time, and any error. +func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, entry *database.MetadataCacheEntry, accept string) ([]byte, string, string, time.Time, error) { + var zeroTime time.Time + req, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL, nil) if err != nil { - return nil, "", "", fmt.Errorf("creating request: %w", err) + return nil, "", "", zeroTime, fmt.Errorf("creating request: %w", err) } req.Header.Set("Accept", accept) @@ -438,7 +441,7 @@ func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, e resp, err := p.HTTPClient.Do(req) if err != nil { - return nil, "", "", fmt.Errorf("fetching metadata: %w", err) + return nil, "", "", zeroTime, fmt.Errorf("fetching metadata: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -446,30 +449,34 @@ func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, e if resp.StatusCode == http.StatusNotModified && entry != nil { cached, readErr := p.Storage.Open(ctx, entry.StoragePath) if readErr != nil { - return nil, "", "", errStale304 + return nil, "", "", zeroTime, errStale304 } defer func() { _ = cached.Close() }() data, readErr := ReadMetadata(cached) if readErr != nil { - return nil, "", "", errStale304 + return nil, "", "", zeroTime, errStale304 } ct := contentTypeJSON if entry.ContentType.Valid { ct = entry.ContentType.String } - return data, ct, entry.ETag.String, nil + lm := zeroTime + if entry.LastModified.Valid { + lm = entry.LastModified.Time + } + return data, ct, entry.ETag.String, lm, nil } if resp.StatusCode == http.StatusNotFound { - return nil, "", "", ErrUpstreamNotFound + return nil, "", "", zeroTime, ErrUpstreamNotFound } if resp.StatusCode != http.StatusOK { - return nil, "", "", fmt.Errorf("upstream returned %d", resp.StatusCode) + return nil, "", "", zeroTime, fmt.Errorf("upstream returned %d", resp.StatusCode) } body, err := ReadMetadata(resp.Body) if err != nil { - return nil, "", "", fmt.Errorf("reading response: %w", err) + return nil, "", "", zeroTime, fmt.Errorf("reading response: %w", err) } contentType := resp.Header.Get("Content-Type") @@ -478,11 +485,17 @@ func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, e } etag := resp.Header.Get("ETag") - return body, contentType, etag, nil + + var lastModified time.Time + if lm := resp.Header.Get("Last-Modified"); lm != "" { + lastModified, _ = http.ParseTime(lm) + } + + return body, contentType, etag, lastModified, nil } // cacheMetadataBlob stores metadata bytes in storage and updates the database. -func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, storagePath string, data []byte, contentType, etag string) { +func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, storagePath string, data []byte, contentType, etag string, lastModified time.Time) { if p.DB == nil || p.Storage == nil { return } @@ -494,13 +507,14 @@ func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, stor } _ = p.DB.UpsertMetadataCache(&database.MetadataCacheEntry{ - Ecosystem: ecosystem, - Name: cacheKey, - StoragePath: storagePath, - ETag: sql.NullString{String: etag, Valid: etag != ""}, - ContentType: sql.NullString{String: contentType, Valid: contentType != ""}, - Size: sql.NullInt64{Int64: size, Valid: true}, - FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + Ecosystem: ecosystem, + Name: cacheKey, + StoragePath: storagePath, + ETag: sql.NullString{String: etag, Valid: etag != ""}, + ContentType: sql.NullString{String: contentType, Valid: contentType != ""}, + Size: sql.NullInt64{Int64: size, Valid: true}, + LastModified: sql.NullTime{Time: lastModified, Valid: !lastModified.IsZero()}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, }) } @@ -526,7 +540,44 @@ func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, return } + // Look up cache entry to get ETag and upstream Last-Modified for conditional response headers + var etag string + var lastModified time.Time + if p.DB != nil { + if entry, err := p.DB.GetMetadataCache(ecosystem, cacheKey); err == nil && entry != nil { + if entry.ETag.Valid { + etag = entry.ETag.String + } + if entry.LastModified.Valid { + lastModified = entry.LastModified.Time + } + } + } + + // Honor client conditional request headers + if etag != "" { + if match := r.Header.Get("If-None-Match"); match != "" && match == etag { + w.WriteHeader(http.StatusNotModified) + return + } + } + if !lastModified.IsZero() { + if ims := r.Header.Get("If-Modified-Since"); ims != "" { + if t, err := http.ParseTime(ims); err == nil && !lastModified.After(t) { + w.WriteHeader(http.StatusNotModified) + return + } + } + } + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + if etag != "" { + w.Header().Set("ETag", etag) + } + if !lastModified.IsZero() { + w.Header().Set("Last-Modified", lastModified.UTC().Format(http.TimeFormat)) + } w.WriteHeader(http.StatusOK) _, _ = w.Write(body) } diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 4c71319..256a107 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -486,3 +486,167 @@ func TestNewProxy_NilLogger(t *testing.T) { t.Error("Logger should be set to default when nil is passed") } } + +const testLastModified = "Wed, 01 Jan 2025 12:00:00 GMT" + +// setupCachedProxy creates a Proxy with CacheMetadata enabled and an upstream +// test server that returns JSON with ETag and Last-Modified headers. +func setupCachedProxy(t *testing.T, upstreamETag, upstreamLastModified string) (*Proxy, *httptest.Server) { + t.Helper() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upstreamETag != "" { + w.Header().Set("ETag", upstreamETag) + } + if upstreamLastModified != "" { + w.Header().Set("Last-Modified", upstreamLastModified) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.HTTPClient = upstream.Client() + + return proxy, upstream +} + +func TestProxyCached_SetsETagAndLastModified(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, `"abc123"`, lm) + + // First request populates the cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "test-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("ETag"); got != `"abc123"` { + t.Errorf("ETag = %q, want %q", got, `"abc123"`) + } + if got := w.Header().Get("Last-Modified"); got != lm { + t.Errorf("Last-Modified = %q, want %q", got, lm) + } + if got := w.Header().Get("Content-Length"); got != "11" { + t.Errorf("Content-Length = %q, want %q", got, "11") + } + if w.Body.String() != `{"ok":true}` { + t.Errorf("body = %q, want %q", w.Body.String(), `{"ok":true}`) + } +} + +func TestProxyCached_IfNoneMatch_Returns304(t *testing.T) { + proxy, upstream := setupCachedProxy(t, `"abc123"`, "") + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with matching ETag + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-None-Match", `"abc123"`) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-key") + + if w.Code != http.StatusNotModified { + t.Errorf("conditional request: status = %d, want 304", w.Code) + } + if w.Body.Len() != 0 { + t.Errorf("304 response should have empty body, got %d bytes", w.Body.Len()) + } +} + +func TestProxyCached_IfNoneMatch_NonMatching_Returns200(t *testing.T) { + proxy, upstream := setupCachedProxy(t, `"abc123"`, "") + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-nm-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with non-matching ETag + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-None-Match", `"different"`) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-nm-key") + + if w.Code != http.StatusOK { + t.Errorf("non-matching ETag: status = %d, want 200", w.Code) + } +} + +func TestProxyCached_IfModifiedSince_Returns304(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, "", lm) + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with If-Modified-Since equal to Last-Modified + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-Modified-Since", lm) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-key") + + if w.Code != http.StatusNotModified { + t.Errorf("conditional request: status = %d, want 304", w.Code) + } +} + +func TestProxyCached_IfModifiedSince_OlderDate_Returns200(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, "", lm) + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-old-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with If-Modified-Since older than Last-Modified + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-Modified-Since", "Mon, 01 Dec 2024 12:00:00 GMT") + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-old-key") + + if w.Code != http.StatusOK { + t.Errorf("older If-Modified-Since: status = %d, want 200", w.Code) + } +} + +func TestProxyCached_NoValidators_OmitsHeaders(t *testing.T) { + proxy, upstream := setupCachedProxy(t, "", "") + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "no-val-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("ETag"); got != "" { + t.Errorf("ETag should be empty when upstream has none, got %q", got) + } + if got := w.Header().Get("Last-Modified"); got != "" { + t.Errorf("Last-Modified should be empty when upstream has none, got %q", got) + } +} diff --git a/internal/handler/nuget.go b/internal/handler/nuget.go index 8bced9f..615b0d2 100644 --- a/internal/handler/nuget.go +++ b/internal/handler/nuget.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -62,6 +63,10 @@ func (h *NuGetHandler) handleServiceIndex(w http.ResponseWriter, r *http.Request body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "nuget", "_service_index", upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return diff --git a/internal/handler/pub.go b/internal/handler/pub.go index a0b5b4c..60bbbad 100644 --- a/internal/handler/pub.go +++ b/internal/handler/pub.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -90,6 +91,10 @@ func (h *PubHandler) handlePackageMetadata(w http.ResponseWriter, r *http.Reques body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pub", name, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return diff --git a/internal/handler/pypi.go b/internal/handler/pypi.go index 4fc9cd5..954adbf 100644 --- a/internal/handler/pypi.go +++ b/internal/handler/pypi.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -78,6 +79,10 @@ func (h *PyPIHandler) handleSimplePackage(w http.ResponseWriter, r *http.Request body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL, "text/html") if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return @@ -225,6 +230,10 @@ func (h *PyPIHandler) handleVersionJSON(w http.ResponseWriter, r *http.Request) func (h *PyPIHandler) proxyAndRewriteJSON(w http.ResponseWriter, r *http.Request, upstreamURL, cacheKey string) { body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return diff --git a/internal/mirror/job.go b/internal/mirror/job.go index b91a7f8..8915d4a 100644 --- a/internal/mirror/job.go +++ b/internal/mirror/job.go @@ -158,7 +158,13 @@ func (js *JobStore) runJob(ctx context.Context, cancel context.CancelFunc, job * job.State = JobStateRunning js.mu.Unlock() - progress, err := js.mirror.Run(ctx, source) + progress, err := js.mirror.Run(ctx, source, func(p Progress) { + js.mu.Lock() + defer js.mu.Unlock() + if job.State == JobStateRunning { + job.Progress = p + } + }) js.mu.Lock() defer js.mu.Unlock() @@ -187,9 +193,9 @@ func (js *JobStore) sourceFromRequest(req JobRequest) (Source, error) { //nolint case len(req.PURLs) > 0: return &PURLSource{PURLs: req.PURLs}, nil case req.Registry != "": - return &RegistrySource{Ecosystem: req.Registry}, nil + return nil, fmt.Errorf("registry mirroring is not yet implemented; use purls instead") default: - return nil, fmt.Errorf("request must include purls or registry") + return nil, fmt.Errorf("request must include purls") } } diff --git a/internal/mirror/job_test.go b/internal/mirror/job_test.go index 90f09e4..f7f2f1c 100644 --- a/internal/mirror/job_test.go +++ b/internal/mirror/job_test.go @@ -100,16 +100,13 @@ func TestSourceFromRequestPURLs(t *testing.T) { } } -func TestSourceFromRequestRegistry(t *testing.T) { +func TestSourceFromRequestRegistryRejected(t *testing.T) { m := setupTestMirror(t, 1) js := NewJobStore(context.Background(), m) - source, err := js.sourceFromRequest(JobRequest{Registry: "npm"}) - if err != nil { - t.Fatalf("sourceFromRequest() error = %v", err) - } - if _, ok := source.(*RegistrySource); !ok { - t.Errorf("expected *RegistrySource, got %T", source) + _, err := js.sourceFromRequest(JobRequest{Registry: "npm"}) + if err == nil { + t.Fatal("expected error for registry request") } } @@ -153,8 +150,8 @@ func TestJobStoreCancelPreservesStateAfterRunJob(t *testing.T) { m := setupTestMirror(t, 1) js := NewJobStore(context.Background(), m) - // Create a job that will fail (registry enumeration is not implemented) - id, err := js.Create(JobRequest{Registry: "npm"}) + // Create a job with a PURL that will fail (no real upstream in test) + id, err := js.Create(JobRequest{PURLs: []string{"pkg:npm/nonexistent-pkg@0.0.0"}}) if err != nil { t.Fatalf("Create() error = %v", err) } diff --git a/internal/mirror/mirror.go b/internal/mirror/mirror.go index 8553f2b..06b496f 100644 --- a/internal/mirror/mirror.go +++ b/internal/mirror/mirror.go @@ -79,6 +79,7 @@ func newProgressTracker() *progressTracker { } const maxTrackedErrors = 1000 +const progressReportInterval = 500 * time.Millisecond //nolint:mnd // progress update frequency func (pt *progressTracker) addError(eco, name, version, err string) { pt.mu.Lock() @@ -112,9 +113,13 @@ func (pt *progressTracker) snapshot() Progress { } } +// ProgressFunc is called periodically with a snapshot of the current progress. +type ProgressFunc func(Progress) + // Run mirrors all packages from the source using a bounded worker pool. -// It returns the final progress when complete. -func (m *Mirror) Run(ctx context.Context, source Source) (*Progress, error) { +// It returns the final progress when complete. If onProgress is non-nil, +// it is called with progress snapshots as work proceeds. +func (m *Mirror) Run(ctx context.Context, source Source, onProgress ...ProgressFunc) (*Progress, error) { tracker := newProgressTracker() // Collect items from source @@ -131,6 +136,28 @@ func (m *Mirror) Run(ctx context.Context, source Source) (*Progress, error) { tracker.total.Store(int64(len(items))) tracker.phase.Store("downloading") + // Start periodic progress reporting if a callback was provided + var progressFn ProgressFunc + if len(onProgress) > 0 && onProgress[0] != nil { + progressFn = onProgress[0] + } + progressDone := make(chan struct{}) + if progressFn != nil { + progressFn(tracker.snapshot()) // initial snapshot with total set + go func() { + ticker := time.NewTicker(progressReportInterval) + defer ticker.Stop() + for { + select { + case <-progressDone: + return + case <-ticker.C: + progressFn(tracker.snapshot()) + } + } + }() + } + // Process items with bounded concurrency g, gctx := errgroup.WithContext(ctx) g.SetLimit(m.workers) @@ -152,8 +179,16 @@ func (m *Mirror) Run(ctx context.Context, source Source) (*Progress, error) { _ = g.Wait() + close(progressDone) // stop the progress reporter goroutine + tracker.phase.Store("complete") p := tracker.snapshot() + + // Send final snapshot + if progressFn != nil { + progressFn(p) + } + return &p, nil } diff --git a/internal/mirror/registry.go b/internal/mirror/registry.go index 795e190..6b2c449 100644 --- a/internal/mirror/registry.go +++ b/internal/mirror/registry.go @@ -6,42 +6,11 @@ import ( ) // RegistrySource enumerates all packages in a registry for full mirroring. +// Registry enumeration is not yet implemented for any ecosystem. type RegistrySource struct { Ecosystem string } -// supportedRegistries lists ecosystems that support enumeration. -var supportedRegistries = map[string]bool{ - "npm": true, - "pypi": true, - "cargo": true, -} - -func (s *RegistrySource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { - if !supportedRegistries[s.Ecosystem] { - return fmt.Errorf("registry enumeration not supported for ecosystem %q; supported: npm, pypi, cargo", s.Ecosystem) - } - - switch s.Ecosystem { - case "npm": - return s.enumerateNPM(ctx, fn) - case "pypi": - return s.enumeratePyPI(ctx, fn) - case "cargo": - return s.enumerateCargo(ctx, fn) - default: - return fmt.Errorf("unsupported ecosystem: %s", s.Ecosystem) - } -} - -func (s *RegistrySource) enumerateNPM(_ context.Context, _ func(PackageVersion) error) error { - return fmt.Errorf("npm registry enumeration not yet implemented") -} - -func (s *RegistrySource) enumeratePyPI(_ context.Context, _ func(PackageVersion) error) error { - return fmt.Errorf("pypi registry enumeration not yet implemented") -} - -func (s *RegistrySource) enumerateCargo(_ context.Context, _ func(PackageVersion) error) error { - return fmt.Errorf("cargo registry enumeration not yet implemented") +func (s *RegistrySource) Enumerate(_ context.Context, _ func(PackageVersion) error) error { + return fmt.Errorf("registry enumeration is not yet implemented for ecosystem %q", s.Ecosystem) } From 52bc6e8850a9721b6e65f9b9570b1592ae859266 Mon Sep 17 00:00:00 2001 From: Andrew Nesbitt Date: Mon, 6 Apr 2026 19:30:59 +0100 Subject: [PATCH 5/5] Add metadata TTL and stale-while-revalidate support Cached metadata is now served directly within a configurable TTL window (default 5m) without contacting upstream, reducing latency and upstream load. When upstream is unreachable and the cache is past its TTL, stale content is served with a Warning: 110 header per RFC 7234. New config: `metadata_ttl` (YAML) / `PROXY_METADATA_TTL` (env). Set to "0" to always revalidate with upstream. --- cmd/proxy/main.go | 1 + docs/configuration.md | 14 ++++ internal/config/config.go | 34 ++++++++ internal/config/config_test.go | 54 +++++++++++++ internal/handler/handler.go | 86 +++++++++++++++----- internal/handler/handler_test.go | 135 +++++++++++++++++++++++++++++++ internal/server/server.go | 1 + 7 files changed, 303 insertions(+), 22 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 40d5d34..0268e9e 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -453,6 +453,7 @@ func runMirror() { resolver := fetch.NewResolver() proxy := handler.NewProxy(db, store, fetcher, resolver, logger) proxy.CacheMetadata = true // mirror always caches metadata + proxy.MetadataTTL = cfg.ParseMetadataTTL() m := mirror.New(proxy, db, store, logger, *concurrency) diff --git a/docs/configuration.md b/docs/configuration.md index 16e71bb..be196de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,6 +225,20 @@ Or via environment variable: `PROXY_CACHE_METADATA=true`. The `proxy mirror` command always enables metadata caching regardless of this setting. +### Metadata TTL + +When metadata caching is enabled, `metadata_ttl` controls how long a cached response is considered fresh before revalidating with upstream. During the TTL window, cached metadata is served directly without contacting upstream, reducing latency and upstream load. + +```yaml +metadata_ttl: "5m" # default +``` + +Or via environment variable: `PROXY_METADATA_TTL=10m`. + +Set to `"0"` to always revalidate with upstream (ETag-based conditional requests still avoid re-downloading unchanged content). + +When upstream is unreachable and the cached entry is past its TTL, the proxy serves the stale cached copy with a `Warning: 110 - "Response is Stale"` header so clients can tell the data may be outdated. + ## Mirror API The `/api/mirror` endpoints are disabled by default. Enable them to allow starting mirror jobs via HTTP: diff --git a/internal/config/config.go b/internal/config/config.go index 6b82861..ad0acc0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,6 +55,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "gopkg.in/yaml.v3" ) @@ -89,6 +90,11 @@ type Config struct { // The mirror command always enables this regardless of this setting. CacheMetadata bool `json:"cache_metadata" yaml:"cache_metadata"` + // MetadataTTL is how long cached metadata is considered fresh before + // revalidating with upstream. Uses Go duration syntax (e.g. "5m", "1h"). + // Default: "5m". Set to "0" to always revalidate. + MetadataTTL string `json:"metadata_ttl" yaml:"metadata_ttl"` + // MirrorAPI enables the /api/mirror endpoints for starting mirror jobs via HTTP. // Disabled by default to prevent unauthenticated users from triggering downloads. MirrorAPI bool `json:"mirror_api" yaml:"mirror_api"` @@ -321,6 +327,9 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_MIRROR_API"); v != "" { c.MirrorAPI = v == "true" || v == "1" } + if v := os.Getenv("PROXY_METADATA_TTL"); v != "" { + c.MetadataTTL = v + } } // Validate checks the configuration for errors. @@ -370,9 +379,34 @@ func (c *Config) Validate() error { } } + // Validate metadata TTL if specified + if c.MetadataTTL != "" && c.MetadataTTL != "0" { + if _, err := time.ParseDuration(c.MetadataTTL); err != nil { + return fmt.Errorf("invalid metadata_ttl %q: %w", c.MetadataTTL, err) + } + } + return nil } +const defaultMetadataTTL = 5 * time.Minute //nolint:mnd // sensible default + +// ParseMetadataTTL returns the metadata TTL duration. +// Returns 5 minutes if unset, 0 if explicitly disabled. +func (c *Config) ParseMetadataTTL() time.Duration { + if c.MetadataTTL == "" { + return defaultMetadataTTL + } + if c.MetadataTTL == "0" { + return 0 + } + d, err := time.ParseDuration(c.MetadataTTL) + if err != nil { + return defaultMetadataTTL + } + return d +} + // ParseSize parses a human-readable size string (e.g., "10GB", "500MB"). // Returns the size in bytes. func ParseSize(s string) (int64, error) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4dd1c17..6e8c3a0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "testing" + "time" ) const ( @@ -301,3 +302,56 @@ func TestLoadFileNotFound(t *testing.T) { t.Error("expected error for nonexistent file") } } + +func TestParseMetadataTTL(t *testing.T) { + tests := []struct { + name string + ttl string + want time.Duration + }{ + {"empty defaults to 5m", "", 5 * time.Minute}, + {"explicit zero", "0", 0}, + {"10 minutes", "10m", 10 * time.Minute}, + {"1 hour", "1h", 1 * time.Hour}, + {"invalid defaults to 5m", "not-a-duration", 5 * time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Default() + cfg.MetadataTTL = tt.ttl + got := cfg.ParseMetadataTTL() + if got != tt.want { + t.Errorf("ParseMetadataTTL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidateMetadataTTL(t *testing.T) { + cfg := Default() + cfg.MetadataTTL = "invalid" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for invalid metadata_ttl") + } + + cfg.MetadataTTL = "5m" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid metadata_ttl: %v", err) + } + + cfg.MetadataTTL = "0" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for zero metadata_ttl: %v", err) + } +} + +func TestLoadMetadataTTLFromEnv(t *testing.T) { + cfg := Default() + t.Setenv("PROXY_METADATA_TTL", "10m") + cfg.LoadFromEnv() + + if cfg.MetadataTTL != "10m" { + t.Errorf("MetadataTTL = %q, want %q", cfg.MetadataTTL, "10m") + } +} diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 2edd648..3f3291e 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -57,6 +57,7 @@ type Proxy struct { Logger *slog.Logger Cooldown *cooldown.Config CacheMetadata bool + MetadataTTL time.Duration HTTPClient *http.Client } @@ -372,12 +373,31 @@ func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, u storagePath := metadataStoragePath(ecosystem, cacheKey) - // Check for existing cache entry (for ETag revalidation) + // Check for existing cache entry (for ETag revalidation and TTL) var entry *database.MetadataCacheEntry if p.CacheMetadata && p.DB != nil { entry, _ = p.DB.GetMetadataCache(ecosystem, cacheKey) } + // Serve from cache if within TTL (skip upstream entirely) + if entry != nil && p.MetadataTTL > 0 && entry.FetchedAt.Valid { + if time.Since(entry.FetchedAt.Time) < p.MetadataTTL { + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr == nil { + defer func() { _ = cached.Close() }() + data, readErr := ReadMetadata(cached) + if readErr == nil { + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + return data, ct, nil + } + } + // Cache file missing/unreadable, fall through to upstream + } + } + accept := contentTypeJSON if len(acceptHeaders) > 0 && acceptHeaders[0] != "" { accept = acceptHeaders[0] @@ -518,6 +538,37 @@ func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, stor }) } +// cachedMeta holds cache validators and freshness state from a metadata cache entry. +type cachedMeta struct { + etag string + lastModified time.Time + stale bool +} + +// lookupCachedMeta retrieves cache validators for a metadata entry. +func (p *Proxy) lookupCachedMeta(ecosystem, cacheKey string) cachedMeta { + if p.DB == nil { + return cachedMeta{} + } + entry, err := p.DB.GetMetadataCache(ecosystem, cacheKey) + if err != nil || entry == nil { + return cachedMeta{} + } + var cm cachedMeta + if entry.ETag.Valid { + cm.etag = entry.ETag.String + } + if entry.LastModified.Valid { + cm.lastModified = entry.LastModified.Time + } + // If FetchedAt is older than TTL, upstream must have failed and + // we served from stale cache (successful fetches update FetchedAt). + if p.MetadataTTL > 0 && entry.FetchedAt.Valid && time.Since(entry.FetchedAt.Time) > p.MetadataTTL { + cm.stale = true + } + return cm +} + // ProxyCached fetches metadata from upstream (with optional caching for offline fallback) // and writes it to the response. Optional acceptHeaders specify the Accept header to send. // When metadata caching is disabled, the response is streamed directly to avoid buffering @@ -540,30 +591,18 @@ func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, return } - // Look up cache entry to get ETag and upstream Last-Modified for conditional response headers - var etag string - var lastModified time.Time - if p.DB != nil { - if entry, err := p.DB.GetMetadataCache(ecosystem, cacheKey); err == nil && entry != nil { - if entry.ETag.Valid { - etag = entry.ETag.String - } - if entry.LastModified.Valid { - lastModified = entry.LastModified.Time - } - } - } + cm := p.lookupCachedMeta(ecosystem, cacheKey) // Honor client conditional request headers - if etag != "" { - if match := r.Header.Get("If-None-Match"); match != "" && match == etag { + if cm.etag != "" { + if match := r.Header.Get("If-None-Match"); match != "" && match == cm.etag { w.WriteHeader(http.StatusNotModified) return } } - if !lastModified.IsZero() { + if !cm.lastModified.IsZero() { if ims := r.Header.Get("If-Modified-Since"); ims != "" { - if t, err := http.ParseTime(ims); err == nil && !lastModified.After(t) { + if t, err := http.ParseTime(ims); err == nil && !cm.lastModified.After(t) { w.WriteHeader(http.StatusNotModified) return } @@ -572,11 +611,14 @@ func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Length", strconv.Itoa(len(body))) - if etag != "" { - w.Header().Set("ETag", etag) + if cm.etag != "" { + w.Header().Set("ETag", cm.etag) + } + if !cm.lastModified.IsZero() { + w.Header().Set("Last-Modified", cm.lastModified.UTC().Format(http.TimeFormat)) } - if !lastModified.IsZero() { - w.Header().Set("Last-Modified", lastModified.UTC().Format(http.TimeFormat)) + if cm.stale { + w.Header().Set("Warning", `110 - "Response is Stale"`) } w.WriteHeader(http.StatusOK) _, _ = w.Write(body) diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 256a107..78ed415 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -650,3 +650,138 @@ func TestProxyCached_NoValidators_OmitsHeaders(t *testing.T) { t.Errorf("Last-Modified should be empty when upstream has none, got %q", got) } } + +func TestFetchOrCacheMetadata_TTL_ServesFreshFromCache(t *testing.T) { + upstreamHits := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"v":1}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 1 * time.Hour + proxy.HTTPClient = upstream.Client() + + ctx := context.Background() + + // First request populates cache + body, _, err := proxy.FetchOrCacheMetadata(ctx, "test", "ttl-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("first fetch: %v", err) + } + if string(body) != `{"v":1}` { + t.Errorf("body = %q, want %q", body, `{"v":1}`) + } + if upstreamHits != 1 { + t.Fatalf("expected 1 upstream hit, got %d", upstreamHits) + } + + // Second request within TTL should serve from cache without hitting upstream + body, _, err = proxy.FetchOrCacheMetadata(ctx, "test", "ttl-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("second fetch: %v", err) + } + if string(body) != `{"v":1}` { + t.Errorf("body = %q, want %q", body, `{"v":1}`) + } + if upstreamHits != 1 { + t.Errorf("expected upstream to still be hit only once, got %d", upstreamHits) + } +} + +func TestFetchOrCacheMetadata_TTL_Zero_AlwaysRevalidates(t *testing.T) { + upstreamHits := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"v":1}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 0 // always revalidate + proxy.HTTPClient = upstream.Client() + + ctx := context.Background() + + _, _, err := proxy.FetchOrCacheMetadata(ctx, "test", "ttl0-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("first fetch: %v", err) + } + + _, _, err = proxy.FetchOrCacheMetadata(ctx, "test", "ttl0-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("second fetch: %v", err) + } + + if upstreamHits != 2 { + t.Errorf("expected 2 upstream hits with TTL=0, got %d", upstreamHits) + } +} + +func TestProxyCached_StaleWarningHeader(t *testing.T) { + requestCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + if requestCount == 1 { + // First request succeeds to populate cache + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cached":true}`)) + return + } + // Subsequent requests fail to simulate upstream outage + w.WriteHeader(http.StatusBadGateway) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 1 * time.Millisecond // very short TTL so it expires immediately + proxy.HTTPClient = upstream.Client() + + // First request populates cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "stale-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Wait for TTL to expire + time.Sleep(5 * time.Millisecond) + + // Second request: upstream fails, should serve stale cache with Warning header + req = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "stale-key") + + if w.Code != http.StatusOK { + t.Fatalf("stale request: status = %d, want 200", w.Code) + } + if w.Body.String() != `{"cached":true}` { + t.Errorf("body = %q, want %q", w.Body.String(), `{"cached":true}`) + } + if got := w.Header().Get("Warning"); got != `110 - "Response is Stale"` { + t.Errorf("Warning = %q, want %q", got, `110 - "Response is Stale"`) + } +} + +func TestProxyCached_FreshResponse_NoWarningHeader(t *testing.T) { + proxy, upstream := setupCachedProxy(t, "", "") + proxy.MetadataTTL = 1 * time.Hour + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "fresh-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("Warning"); got != "" { + t.Errorf("Warning should be empty for fresh response, got %q", got) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 1c2156b..5d544a2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -147,6 +147,7 @@ func (s *Server) Start() error { proxy := handler.NewProxy(s.db, s.storage, fetcher, resolver, s.logger) proxy.Cooldown = cd proxy.CacheMetadata = s.cfg.CacheMetadata + proxy.MetadataTTL = s.cfg.ParseMetadataTTL() // Create router with Chi r := chi.NewRouter()