diff --git a/README.md b/README.md index 8da1165..40ee5b5 100644 --- a/README.md +++ b/README.md @@ -460,6 +460,44 @@ proxy serve [flags] proxy [flags] # same as 'proxy serve' ``` +### mirror + +Pre-populate the cache from PURLs, SBOM files, or entire registries. Useful for ensuring offline availability or warming the cache before deployments. + +```bash +# Mirror specific package versions +proxy mirror pkg:npm/lodash@4.17.21 pkg:cargo/serde@1.0.0 + +# Mirror all versions of a package +proxy mirror pkg:npm/lodash + +# Mirror from a CycloneDX or SPDX SBOM +proxy mirror --sbom sbom.cdx.json + +# Preview what would be mirrored +proxy mirror --dry-run pkg:npm/lodash + +# Control parallelism +proxy mirror --concurrency 8 pkg:npm/lodash@4.17.21 +``` + +The mirror command accepts the same storage and database flags as `serve`. Already-cached artifacts are skipped. + +A mirror API is also available when the server is running: + +```bash +# Start a mirror job +curl -X POST http://localhost:8080/api/mirror \ + -H "Content-Type: application/json" \ + -d '{"purls": ["pkg:npm/lodash@4.17.21"]}' + +# Check job status +curl http://localhost:8080/api/mirror/mirror-1 + +# Cancel a running job +curl -X DELETE http://localhost:8080/api/mirror/mirror-1 +``` + ### stats Show cache statistics without running the server. @@ -534,6 +572,14 @@ Recently cached: | `GET /debian/*` | Debian/APT repository protocol | | `GET /rpm/*` | RPM/Yum repository protocol | +### Mirror API + +| Endpoint | Description | +|----------|-------------| +| `POST /api/mirror` | Start a mirror job (JSON body with `purls`) | +| `GET /api/mirror/{id}` | Get job status and progress | +| `DELETE /api/mirror/{id}` | Cancel a running job | + ### Enrichment API The proxy provides REST endpoints for package metadata enrichment, vulnerability scanning, and outdated detection. diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 4f13ea2..0268e9e 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -16,6 +16,7 @@ // // serve Start the proxy server (default if no command given) // stats Show cache statistics +// mirror Pre-populate cache from PURLs, SBOMs, or registries // // Serve Flags: // @@ -100,7 +101,11 @@ import ( "github.com/git-pkgs/proxy/internal/config" "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/mirror" "github.com/git-pkgs/proxy/internal/server" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" ) const defaultTopN = 10 @@ -124,6 +129,10 @@ func main() { os.Args = append(os.Args[:1], os.Args[2:]...) runStats() return + case "mirror": + os.Args = append(os.Args[:1], os.Args[2:]...) + runMirror() + return case "-version", "--version": fmt.Printf("proxy %s (%s)\n", Version, Commit) os.Exit(0) @@ -145,6 +154,7 @@ Usage: proxy [command] [flags] Commands: serve Start the proxy server (default) stats Show cache statistics + mirror Pre-populate cache from PURLs, SBOMs, or registries Run 'proxy -help' for more information on a command. @@ -340,6 +350,151 @@ func runStats() { } } +func runMirror() { + fs := flag.NewFlagSet("mirror", flag.ExitOnError) + configPath := fs.String("config", "", "Path to configuration file") + storageURL := fs.String("storage-url", "", "Storage URL (file:// or s3://)") + databaseDriver := fs.String("database-driver", "", "Database driver: sqlite or postgres") + databasePath := fs.String("database-path", "", "Path to SQLite database file") + databaseURL := fs.String("database-url", "", "PostgreSQL connection URL") + sbomPath := fs.String("sbom", "", "Path to CycloneDX or SPDX SBOM file") + concurrency := fs.Int("concurrency", 4, "Number of parallel downloads") //nolint:mnd // default concurrency + dryRun := fs.Bool("dry-run", false, "Show what would be mirrored without downloading") + + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "git-pkgs proxy - Pre-populate cache\n\n") + fmt.Fprintf(os.Stderr, "Usage: proxy mirror [flags] [purl...]\n\n") + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash@4.17.21\n") + fmt.Fprintf(os.Stderr, " proxy mirror --sbom sbom.cdx.json\n") + fmt.Fprintf(os.Stderr, " proxy mirror pkg:npm/lodash # all versions\n\n") + fmt.Fprintf(os.Stderr, "Flags:\n") + fs.PrintDefaults() + } + + _ = fs.Parse(os.Args[1:]) + purls := fs.Args() + + // Determine source + var source mirror.Source + switch { + case *sbomPath != "": + source = &mirror.SBOMSource{Path: *sbomPath} + case len(purls) > 0: + source = &mirror.PURLSource{PURLs: purls} + default: + fmt.Fprintf(os.Stderr, "error: provide PURLs or --sbom\n") + fs.Usage() + os.Exit(1) + } + + // Load config + cfg, err := loadConfig(*configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error loading config: %v\n", err) + os.Exit(1) + } + cfg.LoadFromEnv() + + if *storageURL != "" { + cfg.Storage.URL = *storageURL + } + if *databaseDriver != "" { + cfg.Database.Driver = *databaseDriver + } + if *databasePath != "" { + cfg.Database.Path = *databasePath + } + if *databaseURL != "" { + cfg.Database.URL = *databaseURL + } + + if err := cfg.Validate(); err != nil { + fmt.Fprintf(os.Stderr, "invalid configuration: %v\n", err) + os.Exit(1) + } + + logger := setupLogger("info", "text") + + // Open database + var db *database.DB + switch cfg.Database.Driver { + case "postgres": + db, err = database.OpenPostgresOrCreate(cfg.Database.URL) + default: + db, err = database.OpenOrCreate(cfg.Database.Path) + } + if err != nil { + fmt.Fprintf(os.Stderr, "error opening database: %v\n", err) + os.Exit(1) + } + defer func() { _ = db.Close() }() + + if err := db.MigrateSchema(); err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error migrating schema: %v\n", err) + os.Exit(1) //nolint:gocritic // db closed above + } + + // Open storage + sURL := cfg.Storage.URL + if sURL == "" { + sURL = "file://" + cfg.Storage.Path //nolint:staticcheck // backwards compat + } + store, err := storage.OpenBucket(context.Background(), sURL) + if err != nil { + _ = db.Close() + fmt.Fprintf(os.Stderr, "error opening storage: %v\n", err) + os.Exit(1) //nolint:gocritic // db closed above + } + + // Build proxy (reuses same pipeline as serve) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + proxy.CacheMetadata = true // mirror always caches metadata + proxy.MetadataTTL = cfg.ParseMetadataTTL() + + m := mirror.New(proxy, db, store, logger, *concurrency) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + cancel() + }() + + if *dryRun { + items, err := m.RunDryRun(ctx, source) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Printf("Would mirror %d package versions:\n", len(items)) + for _, item := range items { + fmt.Printf(" %s\n", item) + } + return + } + + progress, err := m.Run(ctx, source) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Mirror complete: %d downloaded, %d skipped (cached), %d failed, %s total\n", + progress.Completed, progress.Skipped, progress.Failed, formatSize(progress.Bytes)) + + if len(progress.Errors) > 0 { + fmt.Fprintf(os.Stderr, "\nErrors:\n") + for _, e := range progress.Errors { + fmt.Fprintf(os.Stderr, " %s/%s@%s: %s\n", e.Ecosystem, e.Name, e.Version, e.Error) + } + } +} + func printStats(db *database.DB, popular, recent int, asJSON bool) error { defer func() { _ = db.Close() }() diff --git a/docs/architecture.md b/docs/architecture.md index c57d807..81c41cf 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -161,6 +161,20 @@ vulnerabilities ( updated_at DATETIME ) -- indexes: (vuln_id, ecosystem, package_name) unique, (ecosystem, package_name) + +metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, -- BIGINT on Postgres + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME +) +-- indexes: (ecosystem, name) unique ``` On PostgreSQL, `INTEGER PRIMARY KEY` becomes `SERIAL`, `DATETIME` becomes `TIMESTAMP`, `INTEGER DEFAULT 0` booleans become `BOOLEAN DEFAULT FALSE`, and size/count columns use `BIGINT`. @@ -277,6 +291,12 @@ Version age filtering for supply chain attack mitigation. Configurable at global Package metadata enrichment. Fetches license, description, homepage, repository URL, and vulnerability data from upstream registries. Powers the `/api/` endpoints and the web UI's package detail pages. +### `internal/mirror` + +Selective package mirroring for pre-populating the proxy cache. Supports multiple input sources: individual PURLs (versioned or unversioned), CycloneDX/SPDX SBOM files, and full registry enumeration. Uses a bounded worker pool backed by `errgroup` to download artifacts in parallel, reusing `handler.Proxy.GetOrFetchArtifact()` for the actual fetch-and-cache work. + +The package also provides a `MetadataCache` for storing raw upstream metadata blobs so the proxy can serve metadata responses offline. The `JobStore` manages async mirror jobs exposed via the `/api/mirror` endpoints. + ### `internal/config` Configuration loading. @@ -326,10 +346,11 @@ Eviction can be implemented as: - Ensures clients fetch artifacts through proxy - Alternative: Let clients fetch directly, miss cache opportunity -**Why not cache metadata?** +**Why not cache metadata (by default)?** - Simplicity - no invalidation logic needed - Fresh data - new versions visible immediately - Metadata is small, upstream fetch is fast +- Set `cache_metadata: true` or use the mirror command to enable metadata caching for offline use via the `metadata_cache` table **Why stream artifacts?** - Memory efficient - don't load large files into RAM diff --git a/docs/configuration.md b/docs/configuration.md index 68ace5f..be196de 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,65 @@ Currently supported for npm, PyPI, pub.dev, Composer, Cargo, NuGet, Conda, RubyG Note: Hex cooldown requires disabling registry signature verification since the proxy re-encodes the protobuf payload without the original signature. Set `HEX_NO_VERIFY_REPO_ORIGIN=1` or configure your repo with `no_verify: true`. +## Metadata Caching + +By default the proxy fetches metadata fresh from upstream on every request. Enable `cache_metadata` to store metadata responses in the database and storage backend for offline fallback. When upstream is unreachable, the proxy serves the last cached copy. ETag-based revalidation avoids re-downloading unchanged metadata. + +```yaml +cache_metadata: true +``` + +Or via environment variable: `PROXY_CACHE_METADATA=true`. + +The `proxy mirror` command always enables metadata caching regardless of this setting. + +### Metadata TTL + +When metadata caching is enabled, `metadata_ttl` controls how long a cached response is considered fresh before revalidating with upstream. During the TTL window, cached metadata is served directly without contacting upstream, reducing latency and upstream load. + +```yaml +metadata_ttl: "5m" # default +``` + +Or via environment variable: `PROXY_METADATA_TTL=10m`. + +Set to `"0"` to always revalidate with upstream (ETag-based conditional requests still avoid re-downloading unchanged content). + +When upstream is unreachable and the cached entry is past its TTL, the proxy serves the stale cached copy with a `Warning: 110 - "Response is Stale"` header so clients can tell the data may be outdated. + +## Mirror API + +The `/api/mirror` endpoints are disabled by default. Enable them to allow starting mirror jobs via HTTP: + +```yaml +mirror_api: true +``` + +Or via environment variable: `PROXY_MIRROR_API=true`. + +When disabled, the endpoints are not registered and return 404. + +## Mirror Command + +The `proxy mirror` command pre-populates the cache from various sources. It accepts the same storage and database flags as `serve`. + +| Flag | Default | Description | +|------|---------|-------------| +| `--sbom` | | Path to CycloneDX or SPDX SBOM file | +| `--concurrency` | `4` | Number of parallel downloads | +| `--dry-run` | `false` | Show what would be mirrored without downloading | +| `--config` | | Path to configuration file | +| `--storage-url` | | Storage URL | +| `--database-driver` | | Database driver | +| `--database-path` | | SQLite database file | +| `--database-url` | | PostgreSQL connection URL | + +Positional arguments are treated as PURLs: + +```bash +proxy mirror pkg:npm/lodash@4.17.21 pkg:cargo/serde@1.0.0 +``` + ## Docker ### SQLite with Local Storage diff --git a/go.mod b/go.mod index 805edf0..07f6c55 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/git-pkgs/proxy go 1.25.6 require ( + github.com/CycloneDX/cyclonedx-go v0.10.0 github.com/git-pkgs/archives v0.2.2 github.com/git-pkgs/enrichment v0.2.1 github.com/git-pkgs/purl v0.1.10 @@ -15,8 +16,10 @@ require ( github.com/lib/pq v1.12.0 github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_model v0.6.2 + github.com/spdx/tools-golang v0.5.7 github.com/swaggo/swag v1.16.6 gocloud.dev v0.45.0 + golang.org/x/sync v0.20.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.47.0 @@ -52,6 +55,7 @@ require ( github.com/alfatraining/structtag v1.0.0 // indirect github.com/alingse/asasalint v0.0.11 // indirect github.com/alingse/nilnesserr v0.2.0 // indirect + github.com/anchore/go-struct-converter v0.1.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/ashanbrown/forbidigo/v2 v2.3.0 // indirect github.com/ashanbrown/makezero/v2 v2.1.0 // indirect @@ -277,7 +281,6 @@ require ( golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/tools v0.42.0 // indirect @@ -293,7 +296,7 @@ require ( modernc.org/memory v1.11.0 // indirect mvdan.cc/gofumpt v0.9.2 // indirect mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15 // indirect - sigs.k8s.io/yaml v1.3.0 // indirect + sigs.k8s.io/yaml v1.6.0 // indirect ) tool github.com/golangci/golangci-lint/v2/cmd/golangci-lint diff --git a/go.sum b/go.sum index 6e68fb7..41d5e06 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/Antonboom/testifylint v1.6.4/go.mod h1:YO33FROXX2OoUfwjz8g+gUxQXio5i9 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/CycloneDX/cyclonedx-go v0.10.0 h1:7xyklU7YD+CUyGzSFIARG18NYLsKVn4QFg04qSsu+7Y= +github.com/CycloneDX/cyclonedx-go v0.10.0/go.mod h1:vUvbCXQsEm48OI6oOlanxstwNByXjCZ2wuleUlwGEO8= github.com/Djarvur/go-err113 v0.1.1 h1:eHfopDqXRwAi+YmCUas75ZE0+hoBHJ2GQNLYRSxao4g= github.com/Djarvur/go-err113 v0.1.1/go.mod h1:IaWJdYFLg76t2ihfflPZnM1LIQszWOsFDh2hhhAVF6k= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= @@ -85,6 +87,8 @@ github.com/alingse/asasalint v0.0.11 h1:SFwnQXJ49Kx/1GghOFz1XGqHYKp21Kq1nHad/0WQ github.com/alingse/asasalint v0.0.11/go.mod h1:nCaoMhw7a9kSJObvQyVzNTPBDbNpdocqrSP7t/cW5+I= github.com/alingse/nilnesserr v0.2.0 h1:raLem5KG7EFVb4UIDAXgrv3N2JIaffeKNtcEXkEWd/w= github.com/alingse/nilnesserr v0.2.0/go.mod h1:1xJPrXonEtX7wyTq8Dytns5P2hNzoWymVUIaKm4HNFg= +github.com/anchore/go-struct-converter v0.1.0 h1:2rDRssAl6mgKBSLNiVCMADgZRhoqtw9dedlWa0OhD30= +github.com/anchore/go-struct-converter v0.1.0/go.mod h1:rYqSE9HbjzpHTI74vwPvae4ZVYZd1lue2ta6xHPdblA= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/ashanbrown/forbidigo/v2 v2.3.0 h1:OZZDOchCgsX5gvToVtEBoV2UWbFfI6RKQTir2UZzSxo= @@ -144,6 +148,8 @@ github.com/bombsimon/wsl/v4 v4.7.0 h1:1Ilm9JBPRczjyUs6hvOPKvd7VL1Q++PL8M0SXBDf+j github.com/bombsimon/wsl/v4 v4.7.0/go.mod h1:uV/+6BkffuzSAVYD+yGyld1AChO7/EuLrCF/8xTiapg= github.com/bombsimon/wsl/v5 v5.6.0 h1:4z+/sBqC5vUmSp1O0mS+czxwH9+LKXtCWtHH9rZGQL8= github.com/bombsimon/wsl/v5 v5.6.0/go.mod h1:Uqt2EfrMj2NV8UGoN1f1Y3m0NpUVCsUdrNCdet+8LvU= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oMMlVBbn9M= +github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/breml/bidichk v0.3.3 h1:WSM67ztRusf1sMoqH6/c4OBCUlRVTKq+CbSeo0R17sE= github.com/breml/bidichk v0.3.3/go.mod h1:ISbsut8OnjB367j5NseXEGGgO/th206dVa427kR8YTE= github.com/breml/errchkjson v0.4.1 h1:keFSS8D7A2T0haP9kzZTi7o26r7kE3vymjZNeNDRDwg= @@ -568,6 +574,8 @@ github.com/sonatard/noctx v0.4.0 h1:7MC/5Gg4SQ4lhLYR6mvOP6mQVSxCrdyiExo7atBs27o= github.com/sonatard/noctx v0.4.0/go.mod h1:64XdbzFb18XL4LporKXp8poqZtPKbCrqQ402CV+kJas= github.com/sourcegraph/go-diff v0.7.0 h1:9uLlrd5T46OXs5qpp8L/MTltk0zikUGi0sNNyCpA8G0= github.com/sourcegraph/go-diff v0.7.0/go.mod h1:iBszgVvyxdc8SFZ7gm69go2KDdt3ag071iBaWPF6cjs= +github.com/spdx/tools-golang v0.5.7 h1:+sWcKGnhwp3vLdMqPcLdA6QK679vd86cK9hQWH3AwCg= +github.com/spdx/tools-golang v0.5.7/go.mod h1:jg7w0LOpoNAw6OxKEzCoqPC2GCTj45LyTlVmXubDsYw= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= @@ -606,6 +614,8 @@ github.com/tenntenn/modver v1.0.1 h1:2klLppGhDgzJrScMpkj9Ujy3rXPUspSjAcev9tSEBgA github.com/tenntenn/modver v1.0.1/go.mod h1:bePIyQPb7UeioSRkw3Q0XeMhYZSMx9B8ePqg6SAMGH0= github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3 h1:f+jULpRQGxTSkNYKJ51yaw6ChIqO+Je8UqsTKN/cDag= github.com/tenntenn/text/transform v0.0.0-20200319021203-7eef512accb3/go.mod h1:ON8b8w4BN/kE1EOhwT0o+d62W65a6aPw1nouo9LMgyY= +github.com/terminalstatic/go-xsd-validate v0.1.6 h1:TenYeQ3eY631qNi1/cTmLH/s2slHPRKTTHT+XSHkepo= +github.com/terminalstatic/go-xsd-validate v0.1.6/go.mod h1:18lsvYFofBflqCrvo1umpABZ99+GneNTw2kEEc8UPJw= github.com/tetafro/godot v1.5.4 h1:u1ww+gqpRLiIA16yF2PV1CV1n/X3zhyezbNXC3E14Sg= github.com/tetafro/godot v1.5.4/go.mod h1:eOkMrVQurDui411nBY2FA05EYH01r14LuWY/NrVDVcU= github.com/timakin/bodyclose v0.0.0-20241222091800-1db5c5ca4d67 h1:9LPGD+jzxMlnk5r6+hJnar67cgpDIz/iyD+rfl5r2Vk= @@ -628,6 +638,12 @@ github.com/uudashr/gocognit v1.2.0 h1:3BU9aMr1xbhPlvJLSydKwdLN3tEUUrzPSSM8S4hDYR github.com/uudashr/gocognit v1.2.0/go.mod h1:k/DdKPI6XBZO1q7HgoV2juESI2/Ofj9AcHPZhBBdrTU= github.com/uudashr/iface v1.4.1 h1:J16Xl1wyNX9ofhpHmQ9h9gk5rnv2A6lX/2+APLTo0zU= github.com/uudashr/iface v1.4.1/go.mod h1:pbeBPlbuU2qkNDn0mmfrxP2X+wjPMIQAy+r1MBXSXtg= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xen0n/gosmopolitan v1.3.0 h1:zAZI1zefvo7gcpbCOrPSHJZJYA9ZgLfJqtKzZ5pHqQM= github.com/xen0n/gosmopolitan v1.3.0/go.mod h1:rckfr5T6o4lBtM1ga7mLGKZmLxswUoH1zxHgNXOsEt4= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -867,5 +883,5 @@ mvdan.cc/gofumpt v0.9.2 h1:zsEMWL8SVKGHNztrx6uZrXdp7AX8r421Vvp23sz7ik4= mvdan.cc/gofumpt v0.9.2/go.mod h1:iB7Hn+ai8lPvofHd9ZFGVg2GOr8sBUw1QUWjNbmIL/s= mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15 h1:ssMzja7PDPJV8FStj7hq9IKiuiKhgz9ErWw+m68e7DI= mvdan.cc/unparam v0.0.0-20251027182757-5beb8c8f8f15/go.mod h1:4M5MMXl2kW6fivUT6yRGpLLPNfuGtU2Z0cPvFquGDYU= -sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= -sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= +sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/internal/config/config.go b/internal/config/config.go index 3bc45af..ad0acc0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -55,6 +55,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "gopkg.in/yaml.v3" ) @@ -83,6 +84,20 @@ type Config struct { // Cooldown configures version age filtering to mitigate supply chain attacks. Cooldown CooldownConfig `json:"cooldown" yaml:"cooldown"` + + // CacheMetadata enables caching of upstream metadata responses for offline fallback. + // When enabled, metadata is stored in the database and storage backend. + // The mirror command always enables this regardless of this setting. + CacheMetadata bool `json:"cache_metadata" yaml:"cache_metadata"` + + // MetadataTTL is how long cached metadata is considered fresh before + // revalidating with upstream. Uses Go duration syntax (e.g. "5m", "1h"). + // Default: "5m". Set to "0" to always revalidate. + MetadataTTL string `json:"metadata_ttl" yaml:"metadata_ttl"` + + // MirrorAPI enables the /api/mirror endpoints for starting mirror jobs via HTTP. + // Disabled by default to prevent unauthenticated users from triggering downloads. + MirrorAPI bool `json:"mirror_api" yaml:"mirror_api"` } // CooldownConfig configures version cooldown periods. @@ -306,6 +321,15 @@ func (c *Config) LoadFromEnv() { if v := os.Getenv("PROXY_COOLDOWN_DEFAULT"); v != "" { c.Cooldown.Default = v } + if v := os.Getenv("PROXY_CACHE_METADATA"); v != "" { + c.CacheMetadata = v == "true" || v == "1" + } + if v := os.Getenv("PROXY_MIRROR_API"); v != "" { + c.MirrorAPI = v == "true" || v == "1" + } + if v := os.Getenv("PROXY_METADATA_TTL"); v != "" { + c.MetadataTTL = v + } } // Validate checks the configuration for errors. @@ -355,9 +379,34 @@ func (c *Config) Validate() error { } } + // Validate metadata TTL if specified + if c.MetadataTTL != "" && c.MetadataTTL != "0" { + if _, err := time.ParseDuration(c.MetadataTTL); err != nil { + return fmt.Errorf("invalid metadata_ttl %q: %w", c.MetadataTTL, err) + } + } + return nil } +const defaultMetadataTTL = 5 * time.Minute //nolint:mnd // sensible default + +// ParseMetadataTTL returns the metadata TTL duration. +// Returns 5 minutes if unset, 0 if explicitly disabled. +func (c *Config) ParseMetadataTTL() time.Duration { + if c.MetadataTTL == "" { + return defaultMetadataTTL + } + if c.MetadataTTL == "0" { + return 0 + } + d, err := time.ParseDuration(c.MetadataTTL) + if err != nil { + return defaultMetadataTTL + } + return d +} + // ParseSize parses a human-readable size string (e.g., "10GB", "500MB"). // Returns the size in bytes. func ParseSize(s string) (int64, error) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4dd1c17..6e8c3a0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "testing" + "time" ) const ( @@ -301,3 +302,56 @@ func TestLoadFileNotFound(t *testing.T) { t.Error("expected error for nonexistent file") } } + +func TestParseMetadataTTL(t *testing.T) { + tests := []struct { + name string + ttl string + want time.Duration + }{ + {"empty defaults to 5m", "", 5 * time.Minute}, + {"explicit zero", "0", 0}, + {"10 minutes", "10m", 10 * time.Minute}, + {"1 hour", "1h", 1 * time.Hour}, + {"invalid defaults to 5m", "not-a-duration", 5 * time.Minute}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := Default() + cfg.MetadataTTL = tt.ttl + got := cfg.ParseMetadataTTL() + if got != tt.want { + t.Errorf("ParseMetadataTTL() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidateMetadataTTL(t *testing.T) { + cfg := Default() + cfg.MetadataTTL = "invalid" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for invalid metadata_ttl") + } + + cfg.MetadataTTL = "5m" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for valid metadata_ttl: %v", err) + } + + cfg.MetadataTTL = "0" + if err := cfg.Validate(); err != nil { + t.Errorf("unexpected error for zero metadata_ttl: %v", err) + } +} + +func TestLoadMetadataTTLFromEnv(t *testing.T) { + cfg := Default() + t.Setenv("PROXY_METADATA_TTL", "10m") + cfg.LoadFromEnv() + + if cfg.MetadataTTL != "10m" { + t.Errorf("MetadataTTL = %q, want %q", cfg.MetadataTTL, "10m") + } +} diff --git a/internal/database/metadata_cache_test.go b/internal/database/metadata_cache_test.go new file mode 100644 index 0000000..5701816 --- /dev/null +++ b/internal/database/metadata_cache_test.go @@ -0,0 +1,180 @@ +package database + +import ( + "database/sql" + "path/filepath" + "testing" + "time" +) + +func setupMetadataCacheDB(t *testing.T) *DB { + t.Helper() + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema failed: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db +} + +func TestUpsertAndGetMetadataCache(t *testing.T) { + db := setupMetadataCacheDB(t) + + entry := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"abc123"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 1024, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + + err := db.UpsertMetadataCache(entry) + if err != nil { + t.Fatalf("UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache(testEcosystemNPM, "lodash") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("GetMetadataCache() returned nil") + } + + if got.Ecosystem != testEcosystemNPM { + t.Errorf("ecosystem = %q, want %q", got.Ecosystem, testEcosystemNPM) + } + if got.Name != "lodash" { + t.Errorf("name = %q, want %q", got.Name, "lodash") + } + if got.StoragePath != "_metadata/npm/lodash/metadata" { + t.Errorf("storage_path = %q, want %q", got.StoragePath, "_metadata/npm/lodash/metadata") + } + if !got.ETag.Valid || got.ETag.String != `"abc123"` { + t.Errorf("etag = %v, want %q", got.ETag, `"abc123"`) + } + if !got.ContentType.Valid || got.ContentType.String != "application/json" { + t.Errorf("content_type = %v, want %q", got.ContentType, "application/json") + } + if !got.Size.Valid || got.Size.Int64 != 1024 { + t.Errorf("size = %v, want 1024", got.Size) + } +} + +func TestGetMetadataCacheMiss(t *testing.T) { + db := setupMetadataCacheDB(t) + + got, err := db.GetMetadataCache(testEcosystemNPM, "nonexistent") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got != nil { + t.Errorf("expected nil for cache miss, got %v", got) + } +} + +func TestUpsertMetadataCacheOverwrite(t *testing.T) { + db := setupMetadataCacheDB(t) + + // First insert + entry1 := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"v1"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 100, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + if err := db.UpsertMetadataCache(entry1); err != nil { + t.Fatalf("first UpsertMetadataCache() error = %v", err) + } + + // Second insert (same ecosystem+name, different etag and size) + entry2 := &MetadataCacheEntry{ + Ecosystem: testEcosystemNPM, + Name: "lodash", + StoragePath: "_metadata/npm/lodash/metadata", + ETag: sql.NullString{String: `"v2"`, Valid: true}, + ContentType: sql.NullString{String: "application/json", Valid: true}, + Size: sql.NullInt64{Int64: 200, Valid: true}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + } + if err := db.UpsertMetadataCache(entry2); err != nil { + t.Fatalf("second UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache(testEcosystemNPM, "lodash") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("expected entry after overwrite") + } + if got.ETag.String != `"v2"` { + t.Errorf("etag = %q, want %q", got.ETag.String, `"v2"`) + } + if got.Size.Int64 != 200 { + t.Errorf("size = %d, want 200", got.Size.Int64) + } +} + +func TestUpsertMetadataCacheNullableFields(t *testing.T) { + db := setupMetadataCacheDB(t) + + entry := &MetadataCacheEntry{ + Ecosystem: "pypi", + Name: "requests", + StoragePath: "_metadata/pypi/requests/metadata", + } + + if err := db.UpsertMetadataCache(entry); err != nil { + t.Fatalf("UpsertMetadataCache() error = %v", err) + } + + got, err := db.GetMetadataCache("pypi", "requests") + if err != nil { + t.Fatalf("GetMetadataCache() error = %v", err) + } + if got == nil { + t.Fatal("expected entry") + } + if got.ETag.Valid { + t.Error("expected null etag") + } + if got.ContentType.Valid { + t.Error("expected null content_type") + } + if got.Size.Valid { + t.Error("expected null size") + } +} + +func TestMetadataCacheTableCreatedByMigration(t *testing.T) { + // Create a DB without the metadata_cache table, then migrate + dbPath := filepath.Join(t.TempDir(), "test.db") + db, err := Create(dbPath) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + defer func() { _ = db.Close() }() + + // MigrateSchema should create the metadata_cache table + if err := db.MigrateSchema(); err != nil { + t.Fatalf("MigrateSchema() error = %v", err) + } + + has, err := db.HasTable("metadata_cache") + if err != nil { + t.Fatalf("HasTable() error = %v", err) + } + if !has { + t.Error("metadata_cache table should exist after migration") + } +} diff --git a/internal/database/queries.go b/internal/database/queries.go index fc6a3b3..5d95596 100644 --- a/internal/database/queries.go +++ b/internal/database/queries.go @@ -887,3 +887,66 @@ func (db *DB) CountCachedPackages(ecosystem string) (int64, error) { err = db.Get(&count, query, args...) return count, err } + +// Metadata cache queries + +func (db *DB) GetMetadataCache(ecosystem, name string) (*MetadataCacheEntry, error) { + var entry MetadataCacheEntry + query := db.Rebind(` + SELECT id, ecosystem, name, storage_path, etag, content_type, + size, last_modified, fetched_at, created_at, updated_at + FROM metadata_cache WHERE ecosystem = ? AND name = ? + `) + err := db.Get(&entry, query, ecosystem, name) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &entry, nil +} + +func (db *DB) UpsertMetadataCache(entry *MetadataCacheEntry) error { + now := time.Now() + var query string + + if db.dialect == DialectPostgres { + query = ` + INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, + size, last_modified, fetched_at, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT(ecosystem, name) DO UPDATE SET + storage_path = EXCLUDED.storage_path, + etag = EXCLUDED.etag, + content_type = EXCLUDED.content_type, + size = EXCLUDED.size, + last_modified = EXCLUDED.last_modified, + fetched_at = EXCLUDED.fetched_at, + updated_at = EXCLUDED.updated_at + ` + } else { + query = ` + INSERT INTO metadata_cache (ecosystem, name, storage_path, etag, content_type, + size, last_modified, fetched_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(ecosystem, name) DO UPDATE SET + storage_path = excluded.storage_path, + etag = excluded.etag, + content_type = excluded.content_type, + size = excluded.size, + last_modified = excluded.last_modified, + fetched_at = excluded.fetched_at, + updated_at = excluded.updated_at + ` + } + + _, err := db.Exec(query, + entry.Ecosystem, entry.Name, entry.StoragePath, entry.ETag, + entry.ContentType, entry.Size, entry.LastModified, entry.FetchedAt, now, now, + ) + if err != nil { + return fmt.Errorf("upserting metadata cache: %w", err) + } + return nil +} diff --git a/internal/database/schema.go b/internal/database/schema.go index 233357f..e6f284f 100644 --- a/internal/database/schema.go +++ b/internal/database/schema.go @@ -91,6 +91,21 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); +CREATE TABLE IF NOT EXISTS metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, + last_modified DATETIME, + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + CREATE TABLE IF NOT EXISTS migrations ( name TEXT NOT NULL PRIMARY KEY, applied_at DATETIME NOT NULL @@ -176,6 +191,21 @@ CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE UNIQUE INDEX IF NOT EXISTS idx_vulns_id_pkg ON vulnerabilities(vuln_id, ecosystem, package_name); CREATE INDEX IF NOT EXISTS idx_vulns_ecosystem_pkg ON vulnerabilities(ecosystem, package_name); +CREATE TABLE IF NOT EXISTS metadata_cache ( + id SERIAL PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size BIGINT, + last_modified TIMESTAMP, + fetched_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + CREATE TABLE IF NOT EXISTS migrations ( name TEXT NOT NULL PRIMARY KEY, applied_at TIMESTAMP NOT NULL @@ -324,6 +354,7 @@ var migrations = []migration{ {"002_add_versions_enrichment_columns", migrateAddVersionsEnrichmentColumns}, {"003_ensure_artifacts_table", migrateEnsureArtifactsTable}, {"004_ensure_vulnerabilities_table", migrateEnsureVulnerabilitiesTable}, + {"005_ensure_metadata_cache_table", migrateEnsureMetadataCacheTable}, } // isTableNotFound returns true if the error indicates a missing table. @@ -538,5 +569,62 @@ func migrateEnsureVulnerabilitiesTable(db *DB) error { if _, err := db.Exec(vulnSchema); err != nil { return fmt.Errorf("creating vulnerabilities table: %w", err) } + + return nil +} + +func migrateEnsureMetadataCacheTable(db *DB) error { + return db.EnsureMetadataCacheTable() +} + +// EnsureMetadataCacheTable creates the metadata_cache table if it doesn't exist. +func (db *DB) EnsureMetadataCacheTable() error { + has, err := db.HasTable("metadata_cache") + if err != nil { + return fmt.Errorf("checking metadata_cache table: %w", err) + } + if has { + return nil + } + + var schema string + if db.dialect == DialectPostgres { + schema = ` + CREATE TABLE metadata_cache ( + id SERIAL PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size BIGINT, + last_modified TIMESTAMP, + fetched_at TIMESTAMP, + created_at TIMESTAMP, + updated_at TIMESTAMP + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + ` + } else { + schema = ` + CREATE TABLE metadata_cache ( + id INTEGER PRIMARY KEY, + ecosystem TEXT NOT NULL, + name TEXT NOT NULL, + storage_path TEXT NOT NULL, + etag TEXT, + content_type TEXT, + size INTEGER, + last_modified DATETIME, + fetched_at DATETIME, + created_at DATETIME, + updated_at DATETIME + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_metadata_eco_name ON metadata_cache(ecosystem, name); + ` + } + if _, err := db.Exec(schema); err != nil { + return fmt.Errorf("creating metadata_cache table: %w", err) + } return nil } diff --git a/internal/database/types.go b/internal/database/types.go index f73bfb4..f5b718e 100644 --- a/internal/database/types.go +++ b/internal/database/types.go @@ -76,6 +76,21 @@ func (a *Artifact) IsCached() bool { return a.StoragePath.Valid && a.FetchedAt.Valid } +// MetadataCacheEntry represents a cached metadata blob for offline serving. +type MetadataCacheEntry struct { + ID int64 `db:"id" json:"id"` + Ecosystem string `db:"ecosystem" json:"ecosystem"` + Name string `db:"name" json:"name"` + StoragePath string `db:"storage_path" json:"storage_path"` + ETag sql.NullString `db:"etag" json:"etag,omitempty"` + ContentType sql.NullString `db:"content_type" json:"content_type,omitempty"` + Size sql.NullInt64 `db:"size" json:"size,omitempty"` + LastModified sql.NullTime `db:"last_modified" json:"last_modified,omitempty"` + FetchedAt sql.NullTime `db:"fetched_at" json:"fetched_at,omitempty"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + // Vulnerability represents a cached vulnerability record. type Vulnerability struct { ID int64 `db:"id" json:"id"` diff --git a/internal/handler/cargo.go b/internal/handler/cargo.go index 9602fe6..5d7810c 100644 --- a/internal/handler/cargo.go +++ b/internal/handler/cargo.go @@ -3,8 +3,8 @@ package handler import ( "bufio" "encoding/json" + "errors" "fmt" - "io" "net/http" "strings" "time" @@ -88,44 +88,27 @@ func (h *CargoHandler) handleIndex(w http.ResponseWriter, r *http.Request) { h.proxy.Logger.Info("cargo index request", "crate", name) - // Build the index path indexPath := h.buildIndexPath(name) upstreamURL := fmt.Sprintf("%s/%s", h.indexURL, indexPath) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "internal error", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, contentType, err := h.proxy.FetchOrCacheMetadata(r.Context(), "cargo", name, upstreamURL, "text/plain") if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("failed to fetch upstream index", "error", err) http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode == http.StatusNotFound { - http.Error(w, "not found", http.StatusNotFound) - return + if contentType == "" { + contentType = "text/plain; charset=utf-8" } - if resp.StatusCode != http.StatusOK { - http.Error(w, fmt.Sprintf("upstream returned %d", resp.StatusCode), http.StatusBadGateway) - return - } - - // Copy headers and body - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - if etag := resp.Header.Get("ETag"); etag != "" { - w.Header().Set("ETag", etag) - } - if lastMod := resp.Header.Get("Last-Modified"); lastMod != "" { - w.Header().Set("Last-Modified", lastMod) - } - - h.applyCooldownFiltering(w, resp.Body) + w.Header().Set("Content-Type", contentType) + w.WriteHeader(http.StatusOK) + h.applyCooldownFiltering(w, body) } type crateIndexEntry struct { @@ -134,56 +117,45 @@ type crateIndexEntry struct { PublishTime string `json:"pubtime,omitempty"` } -func (h *CargoHandler) applyCooldownFiltering(downstreamResponse io.Writer, upstreamBody io.Reader) { +func (h *CargoHandler) applyCooldownFiltering(downstreamResponse http.ResponseWriter, body []byte) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - // not using cooldowns, just copy the upstream to the downstream - _, _ = io.Copy(downstreamResponse, upstreamBody) + _, _ = downstreamResponse.Write(body) return } - // create a scanner on the body of the http response - requestScanner := bufio.NewScanner(upstreamBody) + scanner := bufio.NewScanner(strings.NewReader(string(body))) - // the response is newline-delimited JSON, loop through each line - for requestScanner.Scan() { - line := requestScanner.Text() + for scanner.Scan() { + line := scanner.Text() - // decode the line var crate crateIndexEntry err := json.Unmarshal([]byte(line), &crate) if err != nil { - // if there is an error parsing this line then exclude it and move to the next entry h.proxy.Logger.Error("failed to parse json entry in index", "error", err) continue } - // parse publish time publishedAt, err := time.Parse(time.RFC3339, crate.PublishTime) if crate.PublishTime == "" || err != nil { - // publish time is empty/missing/invalid, presumably was published before pubtime was added as a field - // write line to response _, _ = downstreamResponse.Write([]byte(line + "\n")) continue } - // make PURL cratePURL := purl.MakePURLString("cargo", crate.Name, "") if !h.proxy.Cooldown.IsAllowed("cargo", cratePURL, publishedAt) { - // crate is not allowed, move to next crate h.proxy.Logger.Info("cooldown: filtering cargo version", "crate", crate.Name, "version", crate.Version, "published", crate.PublishTime) continue } - // crate passes, write to response _, _ = downstreamResponse.Write([]byte(line + "\n")) } - if err := requestScanner.Err(); err != nil { + if err := scanner.Err(); err != nil { h.proxy.Logger.Error("error reading index response", "error", err) } } diff --git a/internal/handler/cargo_test.go b/internal/handler/cargo_test.go index 5e7f2e4..5ce81b6 100644 --- a/internal/handler/cargo_test.go +++ b/internal/handler/cargo_test.go @@ -1,7 +1,6 @@ package handler import ( - "bytes" "encoding/json" "log/slog" "net/http" @@ -196,9 +195,9 @@ func TestCargoCooldown(t *testing.T) { proxyURL: "http://localhost:8080", } - var outputBuffer bytes.Buffer - h.applyCooldownFiltering(&outputBuffer, strings.NewReader(testInput.String())) - output := outputBuffer.String() + recorder := httptest.NewRecorder() + h.applyCooldownFiltering(recorder, []byte(testInput.String())) + output := recorder.Body.String() if output != expectedOutput.String() { t.Errorf("output = %q, want %q", output, expectedOutput.String()) diff --git a/internal/handler/composer.go b/internal/handler/composer.go index d47a0f2..7936401 100644 --- a/internal/handler/composer.go +++ b/internal/handler/composer.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -87,34 +88,18 @@ func (h *ComposerHandler) handlePackageMetadata(w http.ResponseWriter, r *http.R h.proxy.Logger.Info("composer metadata request", "package", packageName) - // Fetch from repo.packagist.org (Composer v2 metadata) upstreamURL := fmt.Sprintf("%s/p2/%s/%s.json", h.repoURL, vendor, pkg) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "composer", packageName, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteMetadata(body) if err != nil { diff --git a/internal/handler/conan.go b/internal/handler/conan.go index b91f7c3..53f6428 100644 --- a/internal/handler/conan.go +++ b/internal/handler/conan.go @@ -43,8 +43,8 @@ func (h *ConanHandler) Routes() http.Handler { mux.HandleFunc("GET /v1/files/{name}/{version}/{user}/{channel}/{revision}/package/{pkgref}/{pkgrev}/{filename}", h.handlePackageFile) mux.HandleFunc("GET /v2/files/{name}/{version}/{user}/{channel}/{revision}/package/{pkgref}/{pkgrev}/{filename}", h.handlePackageFile) - // Proxy all other endpoints (metadata, search, etc.) - mux.HandleFunc("GET /", h.proxyUpstream) + // Proxy all other endpoints (metadata, search, etc.) with caching + mux.HandleFunc("GET /", h.proxyCached) return mux } @@ -147,6 +147,20 @@ func (h *ConanHandler) shouldCacheFile(filename string) bool { return false } +// proxyCached forwards a request with metadata caching. +func (h *ConanHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + if r.URL.RawQuery != "" { + cacheKey += "_" + r.URL.RawQuery + } + upstreamURL := h.upstreamURL + r.URL.Path + if r.URL.RawQuery != "" { + upstreamURL += "?" + r.URL.RawQuery + } + h.proxy.ProxyCached(w, r, upstreamURL, "conan", cacheKey, "*/*") +} + // proxyUpstream forwards a request to conan center without caching. func (h *ConanHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { upstreamURL := h.upstreamURL + r.URL.Path diff --git a/internal/handler/conda.go b/internal/handler/conda.go index 46d8704..a986f01 100644 --- a/internal/handler/conda.go +++ b/internal/handler/conda.go @@ -37,7 +37,7 @@ func (h *CondaHandler) Routes() http.Handler { // Channel index (repodata) mux.HandleFunc("GET /{channel}/{arch}/repodata.json", h.handleRepodata) - mux.HandleFunc("GET /{channel}/{arch}/repodata.json.bz2", h.proxyUpstream) + mux.HandleFunc("GET /{channel}/{arch}/repodata.json.bz2", h.proxyCached) mux.HandleFunc("GET /{channel}/{arch}/current_repodata.json", h.handleRepodata) // Package downloads (cache these) @@ -127,7 +127,7 @@ func (h *CondaHandler) parseFilename(filename string) (name, version string) { // handleRepodata proxies repodata.json, applying cooldown filtering when enabled. func (h *CondaHandler) handleRepodata(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -232,6 +232,13 @@ func (h *CondaHandler) applyCooldownFiltering(body []byte) ([]byte, error) { return json.Marshal(repodata) } +// proxyCached forwards a metadata request with caching. +func (h *CondaHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "conda", cacheKey, "*/*") +} + // proxyUpstream forwards a request to Anaconda without caching. func (h *CondaHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) diff --git a/internal/handler/cran.go b/internal/handler/cran.go index b8c3c3a..246fcaa 100644 --- a/internal/handler/cran.go +++ b/internal/handler/cran.go @@ -30,14 +30,14 @@ func (h *CRANHandler) Routes() http.Handler { mux := http.NewServeMux() // Package indexes - mux.HandleFunc("GET /src/contrib/PACKAGES", h.proxyUpstream) - mux.HandleFunc("GET /src/contrib/PACKAGES.gz", h.proxyUpstream) - mux.HandleFunc("GET /src/contrib/PACKAGES.rds", h.proxyUpstream) + mux.HandleFunc("GET /src/contrib/PACKAGES", h.proxyCached) + mux.HandleFunc("GET /src/contrib/PACKAGES.gz", h.proxyCached) + mux.HandleFunc("GET /src/contrib/PACKAGES.rds", h.proxyCached) // Binary package indexes - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES", h.proxyUpstream) - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.gz", h.proxyUpstream) - mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.rds", h.proxyUpstream) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES", h.proxyCached) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.gz", h.proxyCached) + mux.HandleFunc("GET /bin/{platform}/contrib/{rversion}/PACKAGES.rds", h.proxyCached) // Source package downloads mux.HandleFunc("GET /src/contrib/{filename}", h.handleSourceDownload) @@ -150,6 +150,13 @@ func (h *CRANHandler) isBinaryPackage(filename string) bool { return strings.HasSuffix(filename, ".zip") || strings.HasSuffix(filename, ".tgz") } +// proxyCached forwards a metadata request with caching. +func (h *CRANHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + cacheKey = strings.ReplaceAll(cacheKey, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "cran", cacheKey, "*/*") +} + // proxyUpstream forwards a request to CRAN without caching. func (h *CRANHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept-Encoding"}) diff --git a/internal/handler/debian.go b/internal/handler/debian.go index 11a1979..b767f6d 100644 --- a/internal/handler/debian.go +++ b/internal/handler/debian.go @@ -93,7 +93,8 @@ func (h *DebianHandler) handlePackageDownload(w http.ResponseWriter, r *http.Req // handleMetadata proxies repository metadata files. // These change frequently so we don't cache them. func (h *DebianHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian") + cacheKey := strings.ReplaceAll(path, "/", "_") + h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "debian", cacheKey, "*/*") } // proxyFile proxies any file directly without caching. diff --git a/internal/handler/download_test.go b/internal/handler/download_test.go index d560b82..639e976 100644 --- a/internal/handler/download_test.go +++ b/internal/handler/download_test.go @@ -197,9 +197,8 @@ func TestGemHandler_UpstreamProxy(t *testing.T) { if string(body) != "upstream specs data" { t.Errorf("body = %q, want %q", body, "upstream specs data") } - if resp.Header.Get("X-Test") != "upstream" { - t.Errorf("missing upstream header") - } + // Metadata caching reads the response body into storage and serves it back, + // so arbitrary upstream headers are not forwarded. Content-Type is preserved. } func TestGemHandler_CacheMiss(t *testing.T) { diff --git a/internal/handler/gem.go b/internal/handler/gem.go index 8faae54..bdb4bb9 100644 --- a/internal/handler/gem.go +++ b/internal/handler/gem.go @@ -40,12 +40,12 @@ func (h *GemHandler) Routes() http.Handler { mux.HandleFunc("GET /gems/{filename}", h.handleDownload) // Specs indexes (compressed Ruby Marshal format) - mux.HandleFunc("GET /specs.4.8.gz", h.proxyUpstream) - mux.HandleFunc("GET /latest_specs.4.8.gz", h.proxyUpstream) - mux.HandleFunc("GET /prerelease_specs.4.8.gz", h.proxyUpstream) + mux.HandleFunc("GET /specs.4.8.gz", h.proxyCached) + mux.HandleFunc("GET /latest_specs.4.8.gz", h.proxyCached) + mux.HandleFunc("GET /prerelease_specs.4.8.gz", h.proxyCached) // Compact index (bundler 2.x+) - mux.HandleFunc("GET /versions", h.proxyUpstream) + mux.HandleFunc("GET /versions", h.proxyCached) mux.HandleFunc("GET /info/{name}", h.handleCompactIndex) // Quick index @@ -107,7 +107,7 @@ func (h *GemHandler) parseGemFilename(filename string) (name, version string) { // based on cooldown when enabled. func (h *GemHandler) handleCompactIndex(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -288,6 +288,13 @@ func (h *GemHandler) fetchFilteredVersions(r *http.Request, name string) (map[st return filtered, nil } +// proxyCached forwards a metadata request with caching. +func (h *GemHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + upstreamURL := h.upstreamURL + r.URL.Path + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + h.proxy.ProxyCached(w, r, upstreamURL, "gem", cacheKey, "*/*") +} + // proxyUpstream forwards a request to rubygems.org without caching. func (h *GemHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { upstreamURL := h.upstreamURL + r.URL.Path diff --git a/internal/handler/go.go b/internal/handler/go.go index dd4e17d..955a89c 100644 --- a/internal/handler/go.go +++ b/internal/handler/go.go @@ -54,18 +54,19 @@ func (h *GoHandler) handleRequest(w http.ResponseWriter, r *http.Request) { module := path[:idx] rest := path[idx+4:] // after "/@v/" + decodedMod := decodeGoModule(module) switch { case rest == "list": // GET /{module}/@v/list - list versions - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/list") case strings.HasSuffix(rest, ".info"): // GET /{module}/@v/{version}.info - version metadata - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/"+rest) case strings.HasSuffix(rest, ".mod"): // GET /{module}/@v/{version}.mod - go.mod file - h.proxyUpstream(w, r) + h.proxyCached(w, r, decodedMod+"/@v/"+rest) case strings.HasSuffix(rest, ".zip"): // GET /{module}/@v/{version}.zip - source archive (cache this) @@ -80,7 +81,8 @@ func (h *GoHandler) handleRequest(w http.ResponseWriter, r *http.Request) { // Check for @latest if strings.HasSuffix(path, "/@latest") { - h.proxyUpstream(w, r) + module := strings.TrimSuffix(path, "/@latest") + h.proxyCached(w, r, decodeGoModule(module)+"/@latest") return } @@ -111,6 +113,11 @@ func (h *GoHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, nil) } +// proxyCached forwards a request with metadata caching. +func (h *GoHandler) proxyCached(w http.ResponseWriter, r *http.Request, cacheKey string) { + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "golang", cacheKey, "*/*") +} + // decodeGoModule decodes an encoded module path. // In the encoding, uppercase letters are represented as "!" followed by lowercase. func decodeGoModule(encoded string) string { diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 109eacd..3f3291e 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -2,12 +2,15 @@ package handler import ( + "bytes" "context" "database/sql" + "errors" "fmt" "io" "log/slog" "net/http" + "strconv" "strings" "time" @@ -32,6 +35,8 @@ func containsPathTraversal(path string) bool { const defaultHTTPTimeout = 30 * time.Second +const contentTypeJSON = "application/json" + // maxMetadataSize is the maximum size of upstream metadata responses (50 MB). // Package metadata (e.g. npm with many versions) can be large, but unbounded // reads risk OOM if an upstream misbehaves. @@ -45,13 +50,15 @@ func ReadMetadata(r io.Reader) ([]byte, error) { // Proxy provides shared functionality for protocol handlers. type Proxy struct { - DB *database.DB - Storage storage.Storage - Fetcher fetch.FetcherInterface - Resolver *fetch.Resolver - Logger *slog.Logger - Cooldown *cooldown.Config - HTTPClient *http.Client + DB *database.DB + Storage storage.Storage + Fetcher fetch.FetcherInterface + Resolver *fetch.Resolver + Logger *slog.Logger + Cooldown *cooldown.Config + CacheMetadata bool + MetadataTTL time.Duration + HTTPClient *http.Client } // NewProxy creates a new Proxy with the given dependencies. @@ -311,33 +318,24 @@ func (p *Proxy) ProxyUpstream(w http.ResponseWriter, r *http.Request, upstreamUR _, _ = io.Copy(w, resp.Body) } -// ProxyMetadata forwards a metadata request to upstream, copying only specific response headers. -func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamURL string, logLabel string) { - p.Logger.Debug(logLabel+" metadata request", "url", upstreamURL) - +// ProxyFile forwards a file request to upstream, copying all response headers. +func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) { req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) if err != nil { http.Error(w, "failed to create request", http.StatusInternalServerError) return } - for _, header := range []string{"Accept", "Accept-Encoding", "If-Modified-Since", "If-None-Match"} { - if v := r.Header.Get(header); v != "" { - req.Header.Set(header, v) - } - } - resp, err := p.HTTPClient.Do(req) if err != nil { - p.Logger.Error("failed to fetch upstream metadata", "error", err) http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) return } defer func() { _ = resp.Body.Close() }() - for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { - if v := resp.Header.Get(header); v != "" { - w.Header().Set(header, v) + for key, values := range resp.Header { + for _, v := range values { + w.Header().Add(key, v) } } @@ -345,14 +343,308 @@ func (p *Proxy) ProxyMetadata(w http.ResponseWriter, r *http.Request, upstreamUR _, _ = io.Copy(w, resp.Body) } -// ProxyFile forwards a file request to upstream, copying all response headers. -func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL string) { - req, err := http.NewRequestWithContext(r.Context(), r.Method, upstreamURL, nil) +// JSONError writes a JSON error response. +func JSONError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", contentTypeJSON) + w.WriteHeader(status) + _, _ = fmt.Fprintf(w, `{"error":%q}`, message) +} + +// ErrUpstreamNotFound indicates the upstream returned 404. +var ErrUpstreamNotFound = fmt.Errorf("upstream: not found") + +// errStale304 is returned when upstream sends 304 but the cached file is missing. +var errStale304 = fmt.Errorf("upstream returned 304 but cached file is missing") + +// metadataStoragePath builds a storage path for cached metadata. +func metadataStoragePath(ecosystem, cacheKey string) string { + return "_metadata/" + ecosystem + "/" + cacheKey + "/metadata" +} + +// FetchOrCacheMetadata fetches metadata from upstream with caching. +// On success it returns the raw response bytes and content type. +// If upstream fails and a cached copy exists, the cached version is returned. +// cacheKey is typically the package name but can include subpath components. +// Optional acceptHeaders specify the Accept header(s) to send; defaults to application/json. +func (p *Proxy) FetchOrCacheMetadata(ctx context.Context, ecosystem, cacheKey, upstreamURL string, acceptHeaders ...string) ([]byte, string, error) { + if containsPathTraversal(cacheKey) { + return nil, "", fmt.Errorf("invalid cache key: %q", cacheKey) + } + + storagePath := metadataStoragePath(ecosystem, cacheKey) + + // Check for existing cache entry (for ETag revalidation and TTL) + var entry *database.MetadataCacheEntry + if p.CacheMetadata && p.DB != nil { + entry, _ = p.DB.GetMetadataCache(ecosystem, cacheKey) + } + + // Serve from cache if within TTL (skip upstream entirely) + if entry != nil && p.MetadataTTL > 0 && entry.FetchedAt.Valid { + if time.Since(entry.FetchedAt.Time) < p.MetadataTTL { + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr == nil { + defer func() { _ = cached.Close() }() + data, readErr := ReadMetadata(cached) + if readErr == nil { + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + return data, ct, nil + } + } + // Cache file missing/unreadable, fall through to upstream + } + } + + accept := contentTypeJSON + if len(acceptHeaders) > 0 && acceptHeaders[0] != "" { + accept = acceptHeaders[0] + } + + // Try upstream + body, contentType, etag, lastModified, err := p.fetchUpstreamMetadata(ctx, upstreamURL, entry, accept) + if errors.Is(err, errStale304) { + // 304 but cached file is gone; retry without ETag + body, contentType, etag, lastModified, err = p.fetchUpstreamMetadata(ctx, upstreamURL, nil, accept) + } + if err == nil { + if p.CacheMetadata { + p.cacheMetadataBlob(ctx, ecosystem, cacheKey, storagePath, body, contentType, etag, lastModified) + } + return body, contentType, nil + } + + // Upstream failed -- fall back to cache if available + if !p.CacheMetadata || entry == nil { + return nil, "", fmt.Errorf("upstream failed and no cached metadata: %w", err) + } + + p.Logger.Warn("upstream metadata fetch failed, checking cache", + "ecosystem", ecosystem, "key", cacheKey, "error", err) + + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr != nil { + return nil, "", fmt.Errorf("upstream failed and cached file missing: %w", err) + } + defer func() { _ = cached.Close() }() + + data, readErr := ReadMetadata(cached) + if readErr != nil { + return nil, "", fmt.Errorf("upstream failed and cached read error: %w", err) + } + + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + p.Logger.Info("serving metadata from cache", + "ecosystem", ecosystem, "key", cacheKey) + return data, ct, nil +} + +// fetchUpstreamMetadata fetches metadata from upstream, using ETag for conditional revalidation. +// Returns the body, content type, ETag, upstream Last-Modified time, and any error. +func (p *Proxy) fetchUpstreamMetadata(ctx context.Context, upstreamURL string, entry *database.MetadataCacheEntry, accept string) ([]byte, string, string, time.Time, error) { + var zeroTime time.Time + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, upstreamURL, nil) + if err != nil { + return nil, "", "", zeroTime, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Accept", accept) + + if entry != nil && entry.ETag.Valid { + req.Header.Set("If-None-Match", entry.ETag.String) + } + + resp, err := p.HTTPClient.Do(req) + if err != nil { + return nil, "", "", zeroTime, fmt.Errorf("fetching metadata: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 304 Not Modified -- our cached copy is still good + if resp.StatusCode == http.StatusNotModified && entry != nil { + cached, readErr := p.Storage.Open(ctx, entry.StoragePath) + if readErr != nil { + return nil, "", "", zeroTime, errStale304 + } + defer func() { _ = cached.Close() }() + data, readErr := ReadMetadata(cached) + if readErr != nil { + return nil, "", "", zeroTime, errStale304 + } + ct := contentTypeJSON + if entry.ContentType.Valid { + ct = entry.ContentType.String + } + lm := zeroTime + if entry.LastModified.Valid { + lm = entry.LastModified.Time + } + return data, ct, entry.ETag.String, lm, nil + } + + if resp.StatusCode == http.StatusNotFound { + return nil, "", "", zeroTime, ErrUpstreamNotFound + } + if resp.StatusCode != http.StatusOK { + return nil, "", "", zeroTime, fmt.Errorf("upstream returned %d", resp.StatusCode) + } + + body, err := ReadMetadata(resp.Body) + if err != nil { + return nil, "", "", zeroTime, fmt.Errorf("reading response: %w", err) + } + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = contentTypeJSON + } + + etag := resp.Header.Get("ETag") + + var lastModified time.Time + if lm := resp.Header.Get("Last-Modified"); lm != "" { + lastModified, _ = http.ParseTime(lm) + } + + return body, contentType, etag, lastModified, nil +} + +// cacheMetadataBlob stores metadata bytes in storage and updates the database. +func (p *Proxy) cacheMetadataBlob(ctx context.Context, ecosystem, cacheKey, storagePath string, data []byte, contentType, etag string, lastModified time.Time) { + if p.DB == nil || p.Storage == nil { + return + } + + size, _, err := p.Storage.Store(ctx, storagePath, bytes.NewReader(data)) + if err != nil { + p.Logger.Warn("failed to cache metadata", "ecosystem", ecosystem, "key", cacheKey, "error", err) + return + } + + _ = p.DB.UpsertMetadataCache(&database.MetadataCacheEntry{ + Ecosystem: ecosystem, + Name: cacheKey, + StoragePath: storagePath, + ETag: sql.NullString{String: etag, Valid: etag != ""}, + ContentType: sql.NullString{String: contentType, Valid: contentType != ""}, + Size: sql.NullInt64{Int64: size, Valid: true}, + LastModified: sql.NullTime{Time: lastModified, Valid: !lastModified.IsZero()}, + FetchedAt: sql.NullTime{Time: time.Now(), Valid: true}, + }) +} + +// cachedMeta holds cache validators and freshness state from a metadata cache entry. +type cachedMeta struct { + etag string + lastModified time.Time + stale bool +} + +// lookupCachedMeta retrieves cache validators for a metadata entry. +func (p *Proxy) lookupCachedMeta(ecosystem, cacheKey string) cachedMeta { + if p.DB == nil { + return cachedMeta{} + } + entry, err := p.DB.GetMetadataCache(ecosystem, cacheKey) + if err != nil || entry == nil { + return cachedMeta{} + } + var cm cachedMeta + if entry.ETag.Valid { + cm.etag = entry.ETag.String + } + if entry.LastModified.Valid { + cm.lastModified = entry.LastModified.Time + } + // If FetchedAt is older than TTL, upstream must have failed and + // we served from stale cache (successful fetches update FetchedAt). + if p.MetadataTTL > 0 && entry.FetchedAt.Valid && time.Since(entry.FetchedAt.Time) > p.MetadataTTL { + cm.stale = true + } + return cm +} + +// ProxyCached fetches metadata from upstream (with optional caching for offline fallback) +// and writes it to the response. Optional acceptHeaders specify the Accept header to send. +// When metadata caching is disabled, the response is streamed directly to avoid buffering +// large metadata responses (e.g. npm packages with many versions) in memory. +func (p *Proxy) ProxyCached(w http.ResponseWriter, r *http.Request, upstreamURL, ecosystem, cacheKey string, acceptHeaders ...string) { + if !p.CacheMetadata { + // Stream directly without buffering when caching is off. + p.proxyMetadataStream(w, r, upstreamURL, acceptHeaders...) + return + } + + body, contentType, err := p.FetchOrCacheMetadata(r.Context(), ecosystem, cacheKey, upstreamURL, acceptHeaders...) + if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } + p.Logger.Error("metadata fetch failed", "error", err) + http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) + return + } + + cm := p.lookupCachedMeta(ecosystem, cacheKey) + + // Honor client conditional request headers + if cm.etag != "" { + if match := r.Header.Get("If-None-Match"); match != "" && match == cm.etag { + w.WriteHeader(http.StatusNotModified) + return + } + } + if !cm.lastModified.IsZero() { + if ims := r.Header.Get("If-Modified-Since"); ims != "" { + if t, err := http.ParseTime(ims); err == nil && !cm.lastModified.After(t) { + w.WriteHeader(http.StatusNotModified) + return + } + } + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + if cm.etag != "" { + w.Header().Set("ETag", cm.etag) + } + if !cm.lastModified.IsZero() { + w.Header().Set("Last-Modified", cm.lastModified.UTC().Format(http.TimeFormat)) + } + if cm.stale { + w.Header().Set("Warning", `110 - "Response is Stale"`) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write(body) +} + +// proxyMetadataStream forwards an upstream metadata response by streaming it to the client +// without buffering the full body in memory. +func (p *Proxy) proxyMetadataStream(w http.ResponseWriter, r *http.Request, upstreamURL string, acceptHeaders ...string) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) if err != nil { http.Error(w, "failed to create request", http.StatusInternalServerError) return } + accept := contentTypeJSON + if len(acceptHeaders) > 0 && acceptHeaders[0] != "" { + accept = acceptHeaders[0] + } + req.Header.Set("Accept", accept) + + for _, header := range []string{"Accept-Encoding", "If-Modified-Since", "If-None-Match"} { + if v := r.Header.Get(header); v != "" { + req.Header.Set(header, v) + } + } + resp, err := p.HTTPClient.Do(req) if err != nil { http.Error(w, "failed to fetch from upstream", http.StatusBadGateway) @@ -360,9 +652,9 @@ func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL st } defer func() { _ = resp.Body.Close() }() - for key, values := range resp.Header { - for _, v := range values { - w.Header().Add(key, v) + for _, header := range []string{"Content-Type", "Content-Length", "Last-Modified", "ETag"} { + if v := resp.Header.Get(header); v != "" { + w.Header().Set(header, v) } } @@ -370,13 +662,6 @@ func (p *Proxy) ProxyFile(w http.ResponseWriter, r *http.Request, upstreamURL st _, _ = io.Copy(w, resp.Body) } -// JSONError writes a JSON error response. -func JSONError(w http.ResponseWriter, status int, message string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _, _ = fmt.Fprintf(w, `{"error":%q}`, message) -} - // GetOrFetchArtifactFromURL retrieves an artifact from cache or fetches from a specific URL. // This is useful for registries where download URLs are determined from metadata. func (p *Proxy) GetOrFetchArtifactFromURL(ctx context.Context, ecosystem, name, version, filename, downloadURL string) (*CacheResult, error) { diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 4c71319..78ed415 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -486,3 +486,302 @@ func TestNewProxy_NilLogger(t *testing.T) { t.Error("Logger should be set to default when nil is passed") } } + +const testLastModified = "Wed, 01 Jan 2025 12:00:00 GMT" + +// setupCachedProxy creates a Proxy with CacheMetadata enabled and an upstream +// test server that returns JSON with ETag and Last-Modified headers. +func setupCachedProxy(t *testing.T, upstreamETag, upstreamLastModified string) (*Proxy, *httptest.Server) { + t.Helper() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upstreamETag != "" { + w.Header().Set("ETag", upstreamETag) + } + if upstreamLastModified != "" { + w.Header().Set("Last-Modified", upstreamLastModified) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.HTTPClient = upstream.Client() + + return proxy, upstream +} + +func TestProxyCached_SetsETagAndLastModified(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, `"abc123"`, lm) + + // First request populates the cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "test-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("ETag"); got != `"abc123"` { + t.Errorf("ETag = %q, want %q", got, `"abc123"`) + } + if got := w.Header().Get("Last-Modified"); got != lm { + t.Errorf("Last-Modified = %q, want %q", got, lm) + } + if got := w.Header().Get("Content-Length"); got != "11" { + t.Errorf("Content-Length = %q, want %q", got, "11") + } + if w.Body.String() != `{"ok":true}` { + t.Errorf("body = %q, want %q", w.Body.String(), `{"ok":true}`) + } +} + +func TestProxyCached_IfNoneMatch_Returns304(t *testing.T) { + proxy, upstream := setupCachedProxy(t, `"abc123"`, "") + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with matching ETag + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-None-Match", `"abc123"`) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-key") + + if w.Code != http.StatusNotModified { + t.Errorf("conditional request: status = %d, want 304", w.Code) + } + if w.Body.Len() != 0 { + t.Errorf("304 response should have empty body, got %d bytes", w.Body.Len()) + } +} + +func TestProxyCached_IfNoneMatch_NonMatching_Returns200(t *testing.T) { + proxy, upstream := setupCachedProxy(t, `"abc123"`, "") + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-nm-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with non-matching ETag + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-None-Match", `"different"`) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "etag-nm-key") + + if w.Code != http.StatusOK { + t.Errorf("non-matching ETag: status = %d, want 200", w.Code) + } +} + +func TestProxyCached_IfModifiedSince_Returns304(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, "", lm) + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with If-Modified-Since equal to Last-Modified + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-Modified-Since", lm) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-key") + + if w.Code != http.StatusNotModified { + t.Errorf("conditional request: status = %d, want 304", w.Code) + } +} + +func TestProxyCached_IfModifiedSince_OlderDate_Returns200(t *testing.T) { + lm := testLastModified + proxy, upstream := setupCachedProxy(t, "", lm) + + // Populate cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-old-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Conditional request with If-Modified-Since older than Last-Modified + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("If-Modified-Since", "Mon, 01 Dec 2024 12:00:00 GMT") + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "lm-old-key") + + if w.Code != http.StatusOK { + t.Errorf("older If-Modified-Since: status = %d, want 200", w.Code) + } +} + +func TestProxyCached_NoValidators_OmitsHeaders(t *testing.T) { + proxy, upstream := setupCachedProxy(t, "", "") + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "no-val-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("ETag"); got != "" { + t.Errorf("ETag should be empty when upstream has none, got %q", got) + } + if got := w.Header().Get("Last-Modified"); got != "" { + t.Errorf("Last-Modified should be empty when upstream has none, got %q", got) + } +} + +func TestFetchOrCacheMetadata_TTL_ServesFreshFromCache(t *testing.T) { + upstreamHits := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"v":1}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 1 * time.Hour + proxy.HTTPClient = upstream.Client() + + ctx := context.Background() + + // First request populates cache + body, _, err := proxy.FetchOrCacheMetadata(ctx, "test", "ttl-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("first fetch: %v", err) + } + if string(body) != `{"v":1}` { + t.Errorf("body = %q, want %q", body, `{"v":1}`) + } + if upstreamHits != 1 { + t.Fatalf("expected 1 upstream hit, got %d", upstreamHits) + } + + // Second request within TTL should serve from cache without hitting upstream + body, _, err = proxy.FetchOrCacheMetadata(ctx, "test", "ttl-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("second fetch: %v", err) + } + if string(body) != `{"v":1}` { + t.Errorf("body = %q, want %q", body, `{"v":1}`) + } + if upstreamHits != 1 { + t.Errorf("expected upstream to still be hit only once, got %d", upstreamHits) + } +} + +func TestFetchOrCacheMetadata_TTL_Zero_AlwaysRevalidates(t *testing.T) { + upstreamHits := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamHits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"v":1}`)) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 0 // always revalidate + proxy.HTTPClient = upstream.Client() + + ctx := context.Background() + + _, _, err := proxy.FetchOrCacheMetadata(ctx, "test", "ttl0-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("first fetch: %v", err) + } + + _, _, err = proxy.FetchOrCacheMetadata(ctx, "test", "ttl0-pkg", upstream.URL+"/pkg") + if err != nil { + t.Fatalf("second fetch: %v", err) + } + + if upstreamHits != 2 { + t.Errorf("expected 2 upstream hits with TTL=0, got %d", upstreamHits) + } +} + +func TestProxyCached_StaleWarningHeader(t *testing.T) { + requestCount := 0 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + if requestCount == 1 { + // First request succeeds to populate cache + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cached":true}`)) + return + } + // Subsequent requests fail to simulate upstream outage + w.WriteHeader(http.StatusBadGateway) + })) + t.Cleanup(upstream.Close) + + proxy, _, _, _ := setupTestProxy(t) + proxy.CacheMetadata = true + proxy.MetadataTTL = 1 * time.Millisecond // very short TTL so it expires immediately + proxy.HTTPClient = upstream.Client() + + // First request populates cache + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "stale-key") + if w.Code != http.StatusOK { + t.Fatalf("initial request: status = %d, want 200", w.Code) + } + + // Wait for TTL to expire + time.Sleep(5 * time.Millisecond) + + // Second request: upstream fails, should serve stale cache with Warning header + req = httptest.NewRequest(http.MethodGet, "/test", nil) + w = httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "stale-key") + + if w.Code != http.StatusOK { + t.Fatalf("stale request: status = %d, want 200", w.Code) + } + if w.Body.String() != `{"cached":true}` { + t.Errorf("body = %q, want %q", w.Body.String(), `{"cached":true}`) + } + if got := w.Header().Get("Warning"); got != `110 - "Response is Stale"` { + t.Errorf("Warning = %q, want %q", got, `110 - "Response is Stale"`) + } +} + +func TestProxyCached_FreshResponse_NoWarningHeader(t *testing.T) { + proxy, upstream := setupCachedProxy(t, "", "") + proxy.MetadataTTL = 1 * time.Hour + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + proxy.ProxyCached(w, req, upstream.URL+"/test", "test-eco", "fresh-key") + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("Warning"); got != "" { + t.Errorf("Warning should be empty for fresh response, got %q", got) + } +} diff --git a/internal/handler/hex.go b/internal/handler/hex.go index 990fb55..0f0c72e 100644 --- a/internal/handler/hex.go +++ b/internal/handler/hex.go @@ -41,9 +41,9 @@ func (h *HexHandler) Routes() http.Handler { // Package tarballs (cache these) mux.HandleFunc("GET /tarballs/{filename}", h.handleDownload) - // Registry resources (proxy without caching) - mux.HandleFunc("GET /names", h.proxyUpstream) - mux.HandleFunc("GET /versions", h.proxyUpstream) + // Registry resources (cached for offline) + mux.HandleFunc("GET /names", h.proxyCached) + mux.HandleFunc("GET /versions", h.proxyCached) mux.HandleFunc("GET /packages/{name}", h.handlePackages) // Public keys @@ -102,13 +102,13 @@ const hexAPIURL = "https://hex.pm" // Hex HTTP API concurrently. func (h *HexHandler) handlePackages(w http.ResponseWriter, r *http.Request) { if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } name := r.PathValue("name") if name == "" { - h.proxyUpstream(w, r) + h.proxyCached(w, r) return } @@ -417,6 +417,12 @@ func extractProtobufBytes(data []byte, fieldNum protowire.Number) ([]byte, error return nil, fmt.Errorf("field %d not found", fieldNum) } +// proxyCached forwards a request with metadata caching. +func (h *HexHandler) proxyCached(w http.ResponseWriter, r *http.Request) { + cacheKey := strings.TrimPrefix(r.URL.Path, "/") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "hex", cacheKey, "*/*") +} + // proxyUpstream forwards a request to hex.pm without caching. func (h *HexHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { h.proxy.ProxyUpstream(w, r, h.upstreamURL+r.URL.Path, []string{"Accept"}) diff --git a/internal/handler/maven.go b/internal/handler/maven.go index 79da0c0..86664a2 100644 --- a/internal/handler/maven.go +++ b/internal/handler/maven.go @@ -51,8 +51,8 @@ func (h *MavenHandler) handleRequest(w http.ResponseWriter, r *http.Request) { filename := path.Base(urlPath) if h.isMetadataFile(filename) { - // Proxy metadata without caching - h.proxyUpstream(w, r) + cacheKey := strings.ReplaceAll(urlPath, "/", "_") + h.proxy.ProxyCached(w, r, h.upstreamURL+r.URL.Path, "maven", cacheKey, "*/*") return } diff --git a/internal/handler/npm.go b/internal/handler/npm.go index e0b0566..f15d626 100644 --- a/internal/handler/npm.go +++ b/internal/handler/npm.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -65,39 +66,18 @@ func (h *NPMHandler) handlePackageMetadata(w http.ResponseWriter, r *http.Reques h.proxy.Logger.Info("npm metadata request", "package", packageName) - // Fetch metadata from upstream upstreamURL := fmt.Sprintf("%s/%s", h.upstreamURL, url.PathEscape(packageName)) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "npm", packageName, upstreamURL) if err != nil { - JSONError(w, http.StatusInternalServerError, "failed to create request") - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) - if err != nil { - h.proxy.Logger.Error("failed to fetch upstream metadata", "error", err) + if errors.Is(err, ErrUpstreamNotFound) { + JSONError(w, http.StatusNotFound, "package not found") + return + } + h.proxy.Logger.Error("failed to fetch npm metadata", "error", err) JSONError(w, http.StatusBadGateway, "failed to fetch from upstream") return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode == http.StatusNotFound { - JSONError(w, http.StatusNotFound, "package not found") - return - } - if resp.StatusCode != http.StatusOK { - JSONError(w, http.StatusBadGateway, fmt.Sprintf("upstream returned %d", resp.StatusCode)) - return - } - - // Parse and rewrite tarball URLs - body, err := ReadMetadata(resp.Body) - if err != nil { - JSONError(w, http.StatusInternalServerError, "failed to read response") - return - } rewritten, err := h.rewriteMetadata(packageName, body) if err != nil { diff --git a/internal/handler/nuget.go b/internal/handler/nuget.go index 21bcc46..615b0d2 100644 --- a/internal/handler/nuget.go +++ b/internal/handler/nuget.go @@ -2,6 +2,7 @@ package handler import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -60,31 +61,16 @@ func (h *NuGetHandler) handleServiceIndex(w http.ResponseWriter, r *http.Request upstreamURL := h.upstreamURL + "/v3/index.json" - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "nuget", "_service_index", upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteServiceIndex(body) if err != nil { diff --git a/internal/handler/nuget_test.go b/internal/handler/nuget_test.go index 43d6f60..5dbb242 100644 --- a/internal/handler/nuget_test.go +++ b/internal/handler/nuget_test.go @@ -230,8 +230,9 @@ func TestNuGetHandleServiceIndexUpstreamError(t *testing.T) { w := httptest.NewRecorder() h.handleServiceIndex(w, req) - if w.Code != http.StatusInternalServerError { - t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) + // With metadata caching, upstream 500 is reported as 502 (bad gateway) + if w.Code != http.StatusBadGateway { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadGateway) } } diff --git a/internal/handler/pub.go b/internal/handler/pub.go index b8f6207..60bbbad 100644 --- a/internal/handler/pub.go +++ b/internal/handler/pub.go @@ -2,8 +2,8 @@ package handler import ( "encoding/json" + "errors" "fmt" - "io" "net/http" "strings" "time" @@ -89,32 +89,16 @@ func (h *PubHandler) handlePackageMetadata(w http.ResponseWriter, r *http.Reques upstreamURL := fmt.Sprintf("%s/api/packages/%s", h.upstreamURL, name) - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pub", name, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteMetadata(name, body) if err != nil { diff --git a/internal/handler/pypi.go b/internal/handler/pypi.go index aac33a7..954adbf 100644 --- a/internal/handler/pypi.go +++ b/internal/handler/pypi.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -74,33 +75,18 @@ func (h *PyPIHandler) handleSimplePackage(w http.ResponseWriter, r *http.Request h.proxy.Logger.Info("pypi simple request", "package", name) upstreamURL := fmt.Sprintf("%s/simple/%s/", h.upstreamURL, name) + cacheKey := name + "/simple" - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "text/html") - - resp, err := h.proxy.HTTPClient.Do(req) + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL, "text/html") if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } // When cooldown is enabled, fetch JSON metadata to get version timestamps var filteredVersions map[string]bool @@ -221,7 +207,7 @@ func (h *PyPIHandler) handleJSON(w http.ResponseWriter, r *http.Request) { h.proxy.Logger.Info("pypi json request", "package", name) upstreamURL := fmt.Sprintf("%s/pypi/%s/json", h.upstreamURL, name) - h.proxyAndRewriteJSON(w, r, upstreamURL) + h.proxyAndRewriteJSON(w, r, upstreamURL, name+"/json") } // handleVersionJSON serves the JSON API version metadata. @@ -237,37 +223,21 @@ func (h *PyPIHandler) handleVersionJSON(w http.ResponseWriter, r *http.Request) h.proxy.Logger.Info("pypi version json request", "package", name, "version", version) upstreamURL := fmt.Sprintf("%s/pypi/%s/%s/json", h.upstreamURL, name, version) - h.proxyAndRewriteJSON(w, r, upstreamURL) + h.proxyAndRewriteJSON(w, r, upstreamURL, name+"/"+version) } // proxyAndRewriteJSON fetches JSON metadata and rewrites download URLs. -func (h *PyPIHandler) proxyAndRewriteJSON(w http.ResponseWriter, r *http.Request, upstreamURL string) { - req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstreamURL, nil) - if err != nil { - http.Error(w, "failed to create request", http.StatusInternalServerError) - return - } - req.Header.Set("Accept", "application/json") - - resp, err := h.proxy.HTTPClient.Do(req) +func (h *PyPIHandler) proxyAndRewriteJSON(w http.ResponseWriter, r *http.Request, upstreamURL, cacheKey string) { + body, _, err := h.proxy.FetchOrCacheMetadata(r.Context(), "pypi", cacheKey, upstreamURL) if err != nil { + if errors.Is(err, ErrUpstreamNotFound) { + http.Error(w, "not found", http.StatusNotFound) + return + } h.proxy.Logger.Error("upstream request failed", "error", err) http.Error(w, "upstream request failed", http.StatusBadGateway) return } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - w.WriteHeader(resp.StatusCode) - _, _ = io.Copy(w, resp.Body) - return - } - - body, err := ReadMetadata(resp.Body) - if err != nil { - http.Error(w, "failed to read response", http.StatusInternalServerError) - return - } rewritten, err := h.rewriteJSONMetadata(body) if err != nil { diff --git a/internal/handler/rpm.go b/internal/handler/rpm.go index 92da8b6..6440d0f 100644 --- a/internal/handler/rpm.go +++ b/internal/handler/rpm.go @@ -95,7 +95,8 @@ func (h *RPMHandler) handlePackageDownload(w http.ResponseWriter, r *http.Reques // handleMetadata proxies repository metadata files (repomd.xml, primary.xml.gz, etc.). // These change frequently so we don't cache them. func (h *RPMHandler) handleMetadata(w http.ResponseWriter, r *http.Request, path string) { - h.proxy.ProxyMetadata(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm") + cacheKey := strings.ReplaceAll(path, "/", "_") + h.proxy.ProxyCached(w, r, fmt.Sprintf("%s/%s", h.upstreamURL, path), "rpm", cacheKey, "*/*") } // proxyFile proxies any file directly without caching. diff --git a/internal/mirror/job.go b/internal/mirror/job.go new file mode 100644 index 0000000..8915d4a --- /dev/null +++ b/internal/mirror/job.go @@ -0,0 +1,207 @@ +package mirror + +import ( + "context" + "crypto/rand" + "fmt" + "sync" + "time" +) + +// JobState represents the current state of a mirror job. +type JobState string + +const ( + JobStatePending JobState = "pending" + JobStateRunning JobState = "running" + JobStateComplete JobState = "complete" + JobStateFailed JobState = "failed" + JobStateCanceled JobState = "canceled" +) + +const jobTTL = 1 * time.Hour +const cleanupInterval = 5 * time.Minute //nolint:mnd // cleanup ticker + +// Job represents an async mirror operation. +type Job struct { + ID string `json:"id"` + State JobState `json:"state"` + Progress Progress `json:"progress"` + CreatedAt time.Time `json:"created_at"` + Error string `json:"error,omitempty"` + + cancel context.CancelFunc +} + +// JobRequest is the JSON body for starting a mirror job via the API. +type JobRequest struct { + PURLs []string `json:"purls,omitempty"` + Registry string `json:"registry,omitempty"` +} + +// JobStore manages in-memory mirror jobs. +type JobStore struct { + mu sync.RWMutex + jobs map[string]*Job + mirror *Mirror + parentCtx context.Context +} + +// NewJobStore creates a new job store. The parent context is used as the base +// for all job contexts so that jobs are canceled when the server shuts down. +func NewJobStore(ctx context.Context, m *Mirror) *JobStore { + return &JobStore{ + jobs: make(map[string]*Job), + mirror: m, + parentCtx: ctx, + } +} + +// Create starts a new mirror job and returns its ID. +func (js *JobStore) Create(req JobRequest) (string, error) { + source, err := js.sourceFromRequest(req) + if err != nil { + return "", err + } + + id := newJobID() + ctx, cancel := context.WithCancel(js.parentCtx) + + job := &Job{ + ID: id, + State: JobStatePending, + CreatedAt: time.Now(), + cancel: cancel, + } + + js.mu.Lock() + js.jobs[id] = job + js.mu.Unlock() + + go js.runJob(ctx, cancel, job, source) + + return id, nil +} + +// Get returns a snapshot of a job by ID. The returned copy is safe to +// serialize without holding the lock. +func (js *JobStore) Get(id string) *Job { + js.mu.RLock() + defer js.mu.RUnlock() + job := js.jobs[id] + if job == nil { + return nil + } + snapshot := *job + snapshot.cancel = nil // don't leak cancel func + if len(job.Progress.Errors) > 0 { + snapshot.Progress.Errors = make([]MirrorError, len(job.Progress.Errors)) + copy(snapshot.Progress.Errors, job.Progress.Errors) + } + return &snapshot +} + +// Cancel cancels a running job. +func (js *JobStore) Cancel(id string) bool { + js.mu.Lock() + defer js.mu.Unlock() + + job := js.jobs[id] + if job == nil || job.cancel == nil { + return false + } + + if job.State != JobStatePending && job.State != JobStateRunning { + return false + } + + job.cancel() + job.State = JobStateCanceled + return true +} + +// Cleanup removes completed/failed/canceled jobs older than jobTTL. +func (js *JobStore) Cleanup() { + js.mu.Lock() + defer js.mu.Unlock() + for id, job := range js.jobs { + if job.State == JobStateComplete || job.State == JobStateFailed || job.State == JobStateCanceled { + if time.Since(job.CreatedAt) > jobTTL { + delete(js.jobs, id) + } + } + } +} + +// StartCleanup runs periodic cleanup of old jobs until the context is canceled. +func (js *JobStore) StartCleanup(ctx context.Context) { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + js.Cleanup() + } + } +} + +func (js *JobStore) runJob(ctx context.Context, cancel context.CancelFunc, job *Job, source Source) { + defer cancel() + + js.mu.Lock() + if job.State == JobStateCanceled { + js.mu.Unlock() + return + } + job.State = JobStateRunning + js.mu.Unlock() + + progress, err := js.mirror.Run(ctx, source, func(p Progress) { + js.mu.Lock() + defer js.mu.Unlock() + if job.State == JobStateRunning { + job.Progress = p + } + }) + + js.mu.Lock() + defer js.mu.Unlock() + + // Cancel() may have already set the state; don't overwrite it. + if job.State == JobStateCanceled { + return + } + + if err != nil { + job.State = JobStateFailed + job.Error = err.Error() + return + } + + job.Progress = *progress + if progress.Failed > 0 && progress.Completed == 0 { + job.State = JobStateFailed + } else { + job.State = JobStateComplete + } +} + +func (js *JobStore) sourceFromRequest(req JobRequest) (Source, error) { //nolint:ireturn // interface return is the design + switch { + case len(req.PURLs) > 0: + return &PURLSource{PURLs: req.PURLs}, nil + case req.Registry != "": + return nil, fmt.Errorf("registry mirroring is not yet implemented; use purls instead") + default: + return nil, fmt.Errorf("request must include purls") + } +} + +// newJobID generates a random hex job ID. +func newJobID() string { + b := make([]byte, 16) //nolint:mnd // 128-bit ID + _, _ = rand.Read(b) + return fmt.Sprintf("%x", b) +} diff --git a/internal/mirror/job_test.go b/internal/mirror/job_test.go new file mode 100644 index 0000000..f7f2f1c --- /dev/null +++ b/internal/mirror/job_test.go @@ -0,0 +1,183 @@ +package mirror + +import ( + "context" + "testing" + "time" +) + +func TestJobStoreCreateAndGet(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + id, err := js.Create(JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if id == "" { + t.Fatal("expected non-empty job ID") + } + + // Wait for the job to start (it runs async) + time.Sleep(100 * time.Millisecond) + + job := js.Get(id) + if job == nil { + t.Fatal("Get() returned nil") + } + if job.ID != id { + t.Errorf("job ID = %q, want %q", job.ID, id) + } +} + +func TestJobStoreGetNotFound(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + job := js.Get("nonexistent") + if job != nil { + t.Errorf("expected nil for nonexistent job, got %v", job) + } +} + +func TestJobStoreCancelNotFound(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + if js.Cancel("nonexistent") { + t.Error("expected Cancel to return false for nonexistent job") + } +} + +func TestJobStoreCreateInvalidRequest(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + _, err := js.Create(JobRequest{}) + if err == nil { + t.Fatal("expected error for empty request") + } +} + +func TestJobStoreMultipleJobs(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + id1, err := js.Create(JobRequest{PURLs: []string{"pkg:npm/lodash@4.17.21"}}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + id2, err := js.Create(JobRequest{PURLs: []string{"pkg:cargo/serde@1.0.0"}}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + if id1 == id2 { + t.Error("expected different job IDs") + } + + job1 := js.Get(id1) + job2 := js.Get(id2) + if job1 == nil || job2 == nil { + t.Fatal("expected both jobs to exist") + } +} + +func TestSourceFromRequestPURLs(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + source, err := js.sourceFromRequest(JobRequest{PURLs: []string{"pkg:npm/lodash@1.0.0"}}) + if err != nil { + t.Fatalf("sourceFromRequest() error = %v", err) + } + if _, ok := source.(*PURLSource); !ok { + t.Errorf("expected *PURLSource, got %T", source) + } +} + +func TestSourceFromRequestRegistryRejected(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + _, err := js.sourceFromRequest(JobRequest{Registry: "npm"}) + if err == nil { + t.Fatal("expected error for registry request") + } +} + +func TestJobStoreCleanup(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + // Add a completed job with old CreatedAt + js.mu.Lock() + js.jobs["old-job"] = &Job{ + ID: "old-job", + State: JobStateComplete, + CreatedAt: time.Now().Add(-2 * time.Hour), + } + js.jobs["recent-job"] = &Job{ + ID: "recent-job", + State: JobStateComplete, + CreatedAt: time.Now(), + } + js.jobs["running-job"] = &Job{ + ID: "running-job", + State: JobStateRunning, + CreatedAt: time.Now().Add(-2 * time.Hour), + } + js.mu.Unlock() + + js.Cleanup() + + if js.Get("old-job") != nil { + t.Error("expected old completed job to be cleaned up") + } + if js.Get("recent-job") == nil { + t.Error("expected recent completed job to be kept") + } + if js.Get("running-job") == nil { + t.Error("expected running job to be kept regardless of age") + } +} + +func TestJobStoreCancelPreservesStateAfterRunJob(t *testing.T) { + m := setupTestMirror(t, 1) + js := NewJobStore(context.Background(), m) + + // Create a job with a PURL that will fail (no real upstream in test) + id, err := js.Create(JobRequest{PURLs: []string{"pkg:npm/nonexistent-pkg@0.0.0"}}) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + // Cancel immediately -- the job may already be running + js.Cancel(id) + + // Wait for runJob goroutine to finish + time.Sleep(200 * time.Millisecond) + + job := js.Get(id) + if job == nil { + t.Fatal("Get() returned nil") + } + if job.State != JobStateCanceled { + t.Errorf("state = %q, want %q (cancel should not be overwritten by runJob)", job.State, JobStateCanceled) + } +} + +func TestNewJobIDUnique(t *testing.T) { + ids := make(map[string]bool) + for range 100 { + id := newJobID() + if ids[id] { + t.Fatalf("duplicate job ID: %s", id) + } + ids[id] = true + } +} diff --git a/internal/mirror/mirror.go b/internal/mirror/mirror.go new file mode 100644 index 0000000..06b496f --- /dev/null +++ b/internal/mirror/mirror.go @@ -0,0 +1,228 @@ +// Package mirror provides selective package mirroring for pre-populating the proxy cache. +package mirror + +import ( + "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/storage" + "golang.org/x/sync/errgroup" +) + +// Mirror pre-populates the proxy cache from various input sources. +type Mirror struct { + proxy *handler.Proxy + db *database.DB + storage storage.Storage + logger *slog.Logger + workers int +} + +// New creates a new Mirror with the given dependencies. +func New(proxy *handler.Proxy, db *database.DB, store storage.Storage, logger *slog.Logger, workers int) *Mirror { + if workers < 1 { + workers = 1 + } + return &Mirror{ + proxy: proxy, + db: db, + storage: store, + logger: logger, + workers: workers, + } +} + +// Progress tracks the state of a mirror operation. +type Progress struct { + Total int64 `json:"total"` + Completed int64 `json:"completed"` + Skipped int64 `json:"skipped"` + Failed int64 `json:"failed"` + Bytes int64 `json:"bytes"` + Errors []MirrorError `json:"errors,omitempty"` + StartedAt time.Time `json:"started_at"` + Phase string `json:"phase"` +} + +// MirrorError records a single failed mirror attempt. +type MirrorError struct { + Ecosystem string `json:"ecosystem"` + Name string `json:"name"` + Version string `json:"version"` + Error string `json:"error"` +} + +type progressTracker struct { + total atomic.Int64 + completed atomic.Int64 + skipped atomic.Int64 + failed atomic.Int64 + bytes atomic.Int64 + mu sync.Mutex + errors []MirrorError + startedAt time.Time + phase atomic.Value // string +} + +func newProgressTracker() *progressTracker { + pt := &progressTracker{ + startedAt: time.Now(), + } + pt.phase.Store("resolving") + return pt +} + +const maxTrackedErrors = 1000 +const progressReportInterval = 500 * time.Millisecond //nolint:mnd // progress update frequency + +func (pt *progressTracker) addError(eco, name, version, err string) { + pt.mu.Lock() + if len(pt.errors) < maxTrackedErrors { + pt.errors = append(pt.errors, MirrorError{ + Ecosystem: eco, + Name: name, + Version: version, + Error: err, + }) + } + pt.mu.Unlock() +} + +func (pt *progressTracker) snapshot() Progress { + pt.mu.Lock() + errs := make([]MirrorError, len(pt.errors)) + copy(errs, pt.errors) + pt.mu.Unlock() + + phase, _ := pt.phase.Load().(string) + return Progress{ + Total: pt.total.Load(), + Completed: pt.completed.Load(), + Skipped: pt.skipped.Load(), + Failed: pt.failed.Load(), + Bytes: pt.bytes.Load(), + Errors: errs, + StartedAt: pt.startedAt, + Phase: phase, + } +} + +// ProgressFunc is called periodically with a snapshot of the current progress. +type ProgressFunc func(Progress) + +// Run mirrors all packages from the source using a bounded worker pool. +// It returns the final progress when complete. If onProgress is non-nil, +// it is called with progress snapshots as work proceeds. +func (m *Mirror) Run(ctx context.Context, source Source, onProgress ...ProgressFunc) (*Progress, error) { + tracker := newProgressTracker() + + // Collect items from source + var items []PackageVersion + tracker.phase.Store("resolving") + err := source.Enumerate(ctx, func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + return nil, fmt.Errorf("enumerating packages: %w", err) + } + + tracker.total.Store(int64(len(items))) + tracker.phase.Store("downloading") + + // Start periodic progress reporting if a callback was provided + var progressFn ProgressFunc + if len(onProgress) > 0 && onProgress[0] != nil { + progressFn = onProgress[0] + } + progressDone := make(chan struct{}) + if progressFn != nil { + progressFn(tracker.snapshot()) // initial snapshot with total set + go func() { + ticker := time.NewTicker(progressReportInterval) + defer ticker.Stop() + for { + select { + case <-progressDone: + return + case <-ticker.C: + progressFn(tracker.snapshot()) + } + } + }() + } + + // Process items with bounded concurrency + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(m.workers) + + for _, item := range items { + g.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + m.logger.Error("panic in mirror worker", "recover", r, + "ecosystem", item.Ecosystem, "name", item.Name, "version", item.Version) + tracker.failed.Add(1) + tracker.addError(item.Ecosystem, item.Name, item.Version, fmt.Sprintf("panic: %v", r)) + } + }() + m.mirrorOne(gctx, item, tracker) + return nil // never fail the group; errors are tracked + }) + } + + _ = g.Wait() + + close(progressDone) // stop the progress reporter goroutine + + tracker.phase.Store("complete") + p := tracker.snapshot() + + // Send final snapshot + if progressFn != nil { + progressFn(p) + } + + return &p, nil +} + +// RunDryRun enumerates what would be mirrored without downloading. +func (m *Mirror) RunDryRun(ctx context.Context, source Source) ([]PackageVersion, error) { + var items []PackageVersion + err := source.Enumerate(ctx, func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + return items, err +} + +func (m *Mirror) mirrorOne(ctx context.Context, pv PackageVersion, tracker *progressTracker) { + result, err := m.proxy.GetOrFetchArtifact(ctx, pv.Ecosystem, pv.Name, pv.Version, "") + if err != nil { + tracker.failed.Add(1) + tracker.addError(pv.Ecosystem, pv.Name, pv.Version, err.Error()) + m.logger.Warn("mirror failed", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version, "error", err) + return + } + + _ = result.Reader.Close() + + if result.Cached { + tracker.skipped.Add(1) + m.logger.Debug("already cached", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version) + } else { + tracker.completed.Add(1) + tracker.bytes.Add(result.Size) + m.logger.Info("mirrored", + "ecosystem", pv.Ecosystem, "name", pv.Name, "version", pv.Version, + "size", result.Size) + } +} diff --git a/internal/mirror/mirror_test.go b/internal/mirror/mirror_test.go new file mode 100644 index 0000000..1d7d30d --- /dev/null +++ b/internal/mirror/mirror_test.go @@ -0,0 +1,195 @@ +package mirror + +import ( + "context" + "log/slog" + "os" + "testing" + "time" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" +) + +// setupTestMirror creates a Mirror with real DB and filesystem storage for integration tests. +func setupTestMirror(t *testing.T, workers int) *Mirror { + t.Helper() + + dbPath := t.TempDir() + "/test.db" + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("creating database: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("migrating schema: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + storeDir := t.TempDir() + store, err := storage.OpenBucket(context.Background(), "file://"+storeDir) + if err != nil { + t.Fatalf("opening storage: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + + return New(proxy, db, store, logger, workers) +} + +const testPackageLodash = "lodash" + +func TestMirrorRunEmptySource(t *testing.T) { + m := setupTestMirror(t, 2) + + source := &PURLSource{PURLs: []string{}} + progress, err := m.Run(context.Background(), source) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + + if progress.Total != 0 { + t.Errorf("total = %d, want 0", progress.Total) + } + if progress.Phase != "complete" { + t.Errorf("phase = %q, want %q", progress.Phase, "complete") + } +} + +func TestMirrorRunDryRun(t *testing.T) { + m := setupTestMirror(t, 1) + + source := &PURLSource{ + PURLs: []string{ + "pkg:npm/lodash@4.17.21", + "pkg:cargo/serde@1.0.0", + }, + } + + items, err := m.RunDryRun(context.Background(), source) + if err != nil { + t.Fatalf("RunDryRun() error = %v", err) + } + + if len(items) != 2 { + t.Fatalf("got %d items, want 2", len(items)) + } + + // Dry run should not modify the database + stats, err := m.db.GetCacheStats() + if err != nil { + t.Fatalf("GetCacheStats() error = %v", err) + } + if stats.TotalArtifacts != 0 { + t.Errorf("artifacts = %d, want 0 (dry run should not cache)", stats.TotalArtifacts) + } +} + +func TestMirrorRunCanceled(t *testing.T) { + m := setupTestMirror(t, 1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // Use a source that produces items but they'll all fail due to canceled context + source := &PURLSource{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + } + + progress, err := m.Run(ctx, source) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + + // With a canceled context, the fetch should fail + if progress.Failed != 1 { + t.Errorf("failed = %d, want 1", progress.Failed) + } +} + +func TestProgressTrackerSnapshot(t *testing.T) { + pt := newProgressTracker() + pt.total.Store(10) + pt.completed.Store(5) + pt.skipped.Store(3) + pt.failed.Store(2) + pt.bytes.Store(1024) + pt.phase.Store("downloading") + pt.addError("npm", testPackageLodash, "4.17.21", "fetch failed") + + snap := pt.snapshot() + if snap.Total != 10 { + t.Errorf("total = %d, want 10", snap.Total) + } + if snap.Completed != 5 { + t.Errorf("completed = %d, want 5", snap.Completed) + } + if snap.Skipped != 3 { + t.Errorf("skipped = %d, want 3", snap.Skipped) + } + if snap.Failed != 2 { + t.Errorf("failed = %d, want 2", snap.Failed) + } + if snap.Bytes != 1024 { + t.Errorf("bytes = %d, want 1024", snap.Bytes) + } + if snap.Phase != "downloading" { + t.Errorf("phase = %q, want %q", snap.Phase, "downloading") + } + if len(snap.Errors) != 1 { + t.Fatalf("errors = %d, want 1", len(snap.Errors)) + } + if snap.Errors[0].Name != testPackageLodash { + t.Errorf("error name = %q, want %q", snap.Errors[0].Name, testPackageLodash) + } + if snap.StartedAt.IsZero() { + t.Error("started_at should not be zero") + } +} + +func TestProgressTrackerConcurrentAccess(t *testing.T) { + pt := newProgressTracker() + done := make(chan struct{}) + + for range 10 { + go func() { + pt.completed.Add(1) + pt.addError("npm", "test", "1.0.0", "error") + _ = pt.snapshot() + done <- struct{}{} + }() + } + + timeout := time.After(5 * time.Second) + for range 10 { + select { + case <-done: + case <-timeout: + t.Fatal("timed out waiting for goroutines") + } + } + + snap := pt.snapshot() + if snap.Completed != 10 { + t.Errorf("completed = %d, want 10", snap.Completed) + } + if len(snap.Errors) != 10 { + t.Errorf("errors = %d, want 10", len(snap.Errors)) + } +} + +func TestNewMirrorDefaultWorkers(t *testing.T) { + m := New(nil, nil, nil, slog.Default(), 0) + if m.workers != 1 { + t.Errorf("workers = %d, want 1 (minimum)", m.workers) + } + + m = New(nil, nil, nil, slog.Default(), -5) + if m.workers != 1 { + t.Errorf("workers = %d, want 1 (minimum)", m.workers) + } +} diff --git a/internal/mirror/registry.go b/internal/mirror/registry.go new file mode 100644 index 0000000..6b2c449 --- /dev/null +++ b/internal/mirror/registry.go @@ -0,0 +1,16 @@ +package mirror + +import ( + "context" + "fmt" +) + +// RegistrySource enumerates all packages in a registry for full mirroring. +// Registry enumeration is not yet implemented for any ecosystem. +type RegistrySource struct { + Ecosystem string +} + +func (s *RegistrySource) Enumerate(_ context.Context, _ func(PackageVersion) error) error { + return fmt.Errorf("registry enumeration is not yet implemented for ecosystem %q", s.Ecosystem) +} diff --git a/internal/mirror/registry_test.go b/internal/mirror/registry_test.go new file mode 100644 index 0000000..363bfea --- /dev/null +++ b/internal/mirror/registry_test.go @@ -0,0 +1,46 @@ +package mirror + +import ( + "context" + "testing" +) + +func TestRegistrySourceUnsupported(t *testing.T) { + source := &RegistrySource{Ecosystem: "golang"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for unsupported ecosystem") + } +} + +func TestRegistrySourceNPMNotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "npm"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} + +func TestRegistrySourcePyPINotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "pypi"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} + +func TestRegistrySourceCargoNotImplemented(t *testing.T) { + source := &RegistrySource{Ecosystem: "cargo"} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected not-implemented error") + } +} diff --git a/internal/mirror/source.go b/internal/mirror/source.go new file mode 100644 index 0000000..a6fa364 --- /dev/null +++ b/internal/mirror/source.go @@ -0,0 +1,190 @@ +package mirror + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + + cdx "github.com/CycloneDX/cyclonedx-go" + "github.com/git-pkgs/purl" + "github.com/git-pkgs/registries" + _ "github.com/git-pkgs/registries/all" + "github.com/spdx/tools-golang/spdx" + spdxjson "github.com/spdx/tools-golang/json" + spdxtv "github.com/spdx/tools-golang/tagvalue" +) + +// PackageVersion identifies a specific package version to mirror. +type PackageVersion struct { + Ecosystem string + Name string + Version string +} + +func (pv PackageVersion) String() string { + return fmt.Sprintf("pkg:%s/%s@%s", pv.Ecosystem, pv.Name, pv.Version) +} + +// Source produces PackageVersion items for mirroring. +type Source interface { + Enumerate(ctx context.Context, fn func(PackageVersion) error) error +} + +// PURLSource yields packages from PURL strings. +// Versioned PURLs produce a single item. Unversioned PURLs look up all versions from the registry. +type PURLSource struct { + PURLs []string + RegClient *registries.Client +} + +func (s *PURLSource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { + client := s.RegClient + if client == nil { + client = registries.DefaultClient() + } + + for _, purlStr := range s.PURLs { + p, err := purl.Parse(purlStr) + if err != nil { + return fmt.Errorf("parsing PURL %q: %w", purlStr, err) + } + + ecosystem := purl.PURLTypeToEcosystem(p.Type) + name := p.Name + if p.Namespace != "" { + name = p.Namespace + "/" + p.Name + } + + if p.Version != "" { + if err := fn(PackageVersion{Ecosystem: ecosystem, Name: name, Version: p.Version}); err != nil { + return err + } + continue + } + + // Unversioned: enumerate all versions + versions, err := s.fetchVersions(ctx, client, ecosystem, name) + if err != nil { + return fmt.Errorf("fetching versions for %s/%s: %w", ecosystem, name, err) + } + for _, v := range versions { + if err := fn(PackageVersion{Ecosystem: ecosystem, Name: name, Version: v}); err != nil { + return err + } + } + } + return nil +} + +func (s *PURLSource) fetchVersions(ctx context.Context, client *registries.Client, ecosystem, name string) ([]string, error) { + reg, err := registries.New(purl.EcosystemToPURLType(ecosystem), "", client) + if err != nil { + return nil, err + } + versions, err := reg.FetchVersions(ctx, name) + if err != nil { + return nil, err + } + result := make([]string, len(versions)) + for i, v := range versions { + result[i] = v.Number + } + return result, nil +} + +// SBOMSource extracts package versions from a CycloneDX or SPDX SBOM file. +type SBOMSource struct { + Path string + RegClient *registries.Client +} + +func (s *SBOMSource) Enumerate(ctx context.Context, fn func(PackageVersion) error) error { + purls, err := s.extractPURLs() + if err != nil { + return fmt.Errorf("reading SBOM %s: %w", s.Path, err) + } + + inner := &PURLSource{PURLs: purls, RegClient: s.RegClient} + return inner.Enumerate(ctx, fn) +} + +func (s *SBOMSource) extractPURLs() ([]string, error) { + data, err := os.ReadFile(s.Path) + if err != nil { + return nil, err + } + + // Try CycloneDX first + if purls, err := extractCycloneDXPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + // Try SPDX JSON + if purls, err := extractSPDXJSONPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + // Try SPDX tag-value + if purls, err := extractSPDXTVPURLs(data); err == nil && len(purls) > 0 { + return purls, nil + } + + return nil, fmt.Errorf("could not parse SBOM as CycloneDX or SPDX") +} + +func extractCycloneDXPURLs(data []byte) ([]string, error) { + bom := new(cdx.BOM) + if err := json.Unmarshal(data, bom); err != nil { + // Try XML + decoder := cdx.NewBOMDecoder(bytes.NewReader(data), cdx.BOMFileFormatXML) + bom = new(cdx.BOM) + if err := decoder.Decode(bom); err != nil { + return nil, err + } + } + + if bom.Components == nil { + return nil, nil + } + + var purls []string + for _, c := range *bom.Components { + if c.PackageURL != "" { + purls = append(purls, c.PackageURL) + } + } + return purls, nil +} + +func extractSPDXJSONPURLs(data []byte) ([]string, error) { + doc, err := spdxjson.Read(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return extractSPDXDocPURLs(doc), nil +} + +func extractSPDXTVPURLs(data []byte) ([]string, error) { + doc, err := spdxtv.Read(bytes.NewReader(data)) + if err != nil { + return nil, err + } + return extractSPDXDocPURLs(doc), nil +} + +func extractSPDXDocPURLs(doc *spdx.Document) []string { + if doc == nil { + return nil + } + var purls []string + for _, pkg := range doc.Packages { + for _, ref := range pkg.PackageExternalReferences { + if ref.RefType == "purl" { + purls = append(purls, ref.Locator) + } + } + } + return purls +} diff --git a/internal/mirror/source_test.go b/internal/mirror/source_test.go new file mode 100644 index 0000000..ce53acf --- /dev/null +++ b/internal/mirror/source_test.go @@ -0,0 +1,243 @@ +package mirror + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestPURLSourceVersioned(t *testing.T) { + source := &PURLSource{ + PURLs: []string{ + "pkg:npm/lodash@4.17.21", + "pkg:cargo/serde@1.0.0", + "pkg:pypi/requests@2.31.0", + }, + } + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 3 { + t.Fatalf("got %d items, want 3", len(items)) + } + + expected := []PackageVersion{ + {Ecosystem: "npm", Name: "lodash", Version: "4.17.21"}, + {Ecosystem: "cargo", Name: "serde", Version: "1.0.0"}, + {Ecosystem: "pypi", Name: "requests", Version: "2.31.0"}, + } + + for i, want := range expected { + got := items[i] + if got.Ecosystem != want.Ecosystem || got.Name != want.Name || got.Version != want.Version { + t.Errorf("items[%d] = %v, want %v", i, got, want) + } + } +} + +func TestPURLSourceScopedPackage(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"pkg:npm/%40babel/core@7.23.0"}, + } + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 1 { + t.Fatalf("got %d items, want 1", len(items)) + } + + if items[0].Name != "@babel/core" { + t.Errorf("name = %q, want %q", items[0].Name, "@babel/core") + } + if items[0].Version != "7.23.0" { + t.Errorf("version = %q, want %q", items[0].Version, "7.23.0") + } +} + +func TestPURLSourceInvalid(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"not-a-purl"}, + } + + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for invalid PURL") + } +} + +func TestPURLSourceCallbackError(t *testing.T) { + source := &PURLSource{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + } + + wantErr := context.Canceled + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return wantErr + }) + if err != wantErr { + t.Fatalf("got error %v, want %v", err, wantErr) + } +} + +func TestPackageVersionString(t *testing.T) { + pv := PackageVersion{Ecosystem: "npm", Name: "lodash", Version: "4.17.21"} + got := pv.String() + want := "pkg:npm/lodash@4.17.21" + if got != want { + t.Errorf("String() = %q, want %q", got, want) + } +} + +func TestSBOMSourceCycloneDXJSON(t *testing.T) { + bom := map[string]any{ + "bomFormat": "CycloneDX", + "specVersion": "1.4", + "components": []map[string]any{ + {"type": "library", "name": "lodash", "version": "4.17.21", "purl": "pkg:npm/lodash@4.17.21"}, + {"type": "library", "name": "serde", "version": "1.0.0", "purl": "pkg:cargo/serde@1.0.0"}, + }, + } + + path := writeTempJSON(t, bom) + source := &SBOMSource{Path: path} + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 2 { + t.Fatalf("got %d items, want 2", len(items)) + } + + if items[0].Ecosystem != "npm" || items[0].Name != "lodash" || items[0].Version != "4.17.21" { + t.Errorf("items[0] = %v", items[0]) + } + if items[1].Ecosystem != "cargo" || items[1].Name != "serde" || items[1].Version != "1.0.0" { + t.Errorf("items[1] = %v", items[1]) + } +} + +func TestSBOMSourceSPDXJSON(t *testing.T) { + doc := map[string]any{ + "spdxVersion": "SPDX-2.3", + "dataLicense": "CC0-1.0", + "SPDXID": "SPDXRef-DOCUMENT", + "name": "test", + "documentNamespace": "https://example.com/test", + "packages": []map[string]any{ + { + "SPDXID": "SPDXRef-Package", + "name": "lodash", + "version": "4.17.21", + "downloadLocation": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "externalRefs": []map[string]any{ + { + "referenceCategory": "PACKAGE-MANAGER", + "referenceType": "purl", + "referenceLocator": "pkg:npm/lodash@4.17.21", + }, + }, + }, + }, + } + + path := writeTempJSON(t, doc) + source := &SBOMSource{Path: path} + + var items []PackageVersion + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + items = append(items, pv) + return nil + }) + if err != nil { + t.Fatalf("Enumerate() error = %v", err) + } + + if len(items) != 1 { + t.Fatalf("got %d items, want 1", len(items)) + } + + if items[0].Name != "lodash" || items[0].Version != "4.17.21" { + t.Errorf("items[0] = %v", items[0]) + } +} + +func TestSBOMSourceNonexistentFile(t *testing.T) { + source := &SBOMSource{Path: "/nonexistent/sbom.json"} + + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} + +func TestSBOMSourceInvalidFormat(t *testing.T) { + path := filepath.Join(t.TempDir(), "invalid.txt") + if err := os.WriteFile(path, []byte("this is not an SBOM"), 0644); err != nil { + t.Fatal(err) + } + + source := &SBOMSource{Path: path} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for invalid SBOM") + } +} + +func TestSBOMSourceEmptyCycloneDX(t *testing.T) { + bom := map[string]any{ + "bomFormat": "CycloneDX", + "specVersion": "1.4", + } + path := writeTempJSON(t, bom) + + // This should fall through to SPDX parsing, which will also fail, + // resulting in an error about not being able to parse + source := &SBOMSource{Path: path} + err := source.Enumerate(context.Background(), func(pv PackageVersion) error { + return nil + }) + if err == nil { + t.Fatal("expected error for empty SBOM") + } +} + +func writeTempJSON(t *testing.T, v any) string { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + path := filepath.Join(t.TempDir(), "sbom.json") + if err := os.WriteFile(path, data, 0644); err != nil { + t.Fatal(err) + } + return path +} diff --git a/internal/server/mirror_api.go b/internal/server/mirror_api.go new file mode 100644 index 0000000..6a6a6ca --- /dev/null +++ b/internal/server/mirror_api.go @@ -0,0 +1,70 @@ +package server + +import ( + "encoding/json" + "net/http" + + "github.com/git-pkgs/proxy/internal/mirror" + "github.com/go-chi/chi/v5" +) + +// MirrorAPIHandler handles mirror API requests. +type MirrorAPIHandler struct { + jobs *mirror.JobStore +} + +// NewMirrorAPIHandler creates a new mirror API handler. +func NewMirrorAPIHandler(jobs *mirror.JobStore) *MirrorAPIHandler { + return &MirrorAPIHandler{jobs: jobs} +} + +// HandleCreate starts a new mirror job. +func (h *MirrorAPIHandler) HandleCreate(w http.ResponseWriter, r *http.Request) { + var req mirror.JobRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + writeJSON(w, map[string]string{"error": "invalid request body"}) + return + } + + id, err := h.jobs.Create(req) + if err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + writeJSON(w, map[string]string{"error": err.Error()}) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + writeJSON(w, map[string]string{"id": id}) +} + +// HandleGet returns the status of a mirror job. +func (h *MirrorAPIHandler) HandleGet(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + job := h.jobs.Get(id) + if job == nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "job not found"}) + return + } + + w.Header().Set("Content-Type", "application/json") + writeJSON(w, job) +} + +// HandleCancel cancels a running mirror job. +func (h *MirrorAPIHandler) HandleCancel(w http.ResponseWriter, r *http.Request) { + id := chi.URLParam(r, "id") + if h.jobs.Cancel(id) { + w.Header().Set("Content-Type", "application/json") + writeJSON(w, map[string]string{"status": "canceled"}) + } else { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "job not found or not running"}) + } +} diff --git a/internal/server/mirror_api_test.go b/internal/server/mirror_api_test.go new file mode 100644 index 0000000..0e84da1 --- /dev/null +++ b/internal/server/mirror_api_test.go @@ -0,0 +1,163 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/git-pkgs/proxy/internal/database" + "github.com/git-pkgs/proxy/internal/handler" + "github.com/git-pkgs/proxy/internal/mirror" + "github.com/git-pkgs/proxy/internal/storage" + "github.com/git-pkgs/registries/fetch" + "github.com/go-chi/chi/v5" +) + +func setupMirrorAPI(t *testing.T) *MirrorAPIHandler { + t.Helper() + + dbPath := t.TempDir() + "/test.db" + db, err := database.Create(dbPath) + if err != nil { + t.Fatalf("creating database: %v", err) + } + if err := db.MigrateSchema(); err != nil { + t.Fatalf("migrating schema: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + storeDir := t.TempDir() + store, err := storage.OpenBucket(context.Background(), "file://"+storeDir) + if err != nil { + t.Fatalf("opening storage: %v", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn})) + fetcher := fetch.NewFetcher() + resolver := fetch.NewResolver() + proxy := handler.NewProxy(db, store, fetcher, resolver, logger) + + m := mirror.New(proxy, db, store, logger, 1) + js := mirror.NewJobStore(context.Background(), m) + return NewMirrorAPIHandler(js) +} + +func TestMirrorAPICreateJob(t *testing.T) { + h := setupMirrorAPI(t) + + body, _ := json.Marshal(mirror.JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusAccepted { + t.Errorf("status = %d, want %d", w.Code, http.StatusAccepted) + } + + var resp map[string]string + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + if resp["id"] == "" { + t.Error("expected non-empty job ID") + } +} + +func TestMirrorAPICreateInvalidBody(t *testing.T) { + h := setupMirrorAPI(t) + + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader([]byte("not json"))) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestMirrorAPICreateEmptyRequest(t *testing.T) { + h := setupMirrorAPI(t) + + body, _ := json.Marshal(mirror.JobRequest{}) + req := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + w := httptest.NewRecorder() + h.HandleCreate(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } +} + +func TestMirrorAPIGetNotFound(t *testing.T) { + h := setupMirrorAPI(t) + + r := chi.NewRouter() + r.Get("/api/mirror/{id}", h.HandleGet) + + req := httptest.NewRequest("GET", "/api/mirror/nonexistent", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestMirrorAPICancelNotFound(t *testing.T) { + h := setupMirrorAPI(t) + + r := chi.NewRouter() + r.Delete("/api/mirror/{id}", h.HandleCancel) + + req := httptest.NewRequest("DELETE", "/api/mirror/nonexistent", nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestMirrorAPICreateAndGetJob(t *testing.T) { + h := setupMirrorAPI(t) + + // Create a job + body, _ := json.Marshal(mirror.JobRequest{ + PURLs: []string{"pkg:npm/lodash@4.17.21"}, + }) + createReq := httptest.NewRequest("POST", "/api/mirror", bytes.NewReader(body)) + createW := httptest.NewRecorder() + h.HandleCreate(createW, createReq) + + var createResp map[string]string + _ = json.NewDecoder(createW.Body).Decode(&createResp) + jobID := createResp["id"] + + // Get the job + r := chi.NewRouter() + r.Get("/api/mirror/{id}", h.HandleGet) + + getReq := httptest.NewRequest("GET", "/api/mirror/"+jobID, nil) + getW := httptest.NewRecorder() + r.ServeHTTP(getW, getReq) + + if getW.Code != http.StatusOK { + t.Errorf("status = %d, want %d", getW.Code, http.StatusOK) + } + + var job mirror.Job + if err := json.NewDecoder(getW.Body).Decode(&job); err != nil { + t.Fatalf("decoding job: %v", err) + } + if job.ID != jobID { + t.Errorf("job ID = %q, want %q", job.ID, jobID) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 429c988..5d544a2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,6 +53,7 @@ import ( "github.com/git-pkgs/proxy/internal/enrichment" "github.com/git-pkgs/proxy/internal/handler" "github.com/git-pkgs/proxy/internal/metrics" + "github.com/git-pkgs/proxy/internal/mirror" "github.com/git-pkgs/proxy/internal/storage" "github.com/git-pkgs/purl" "github.com/git-pkgs/registries/fetch" @@ -77,6 +78,7 @@ type Server struct { logger *slog.Logger http *http.Server templates *Templates + cancel context.CancelFunc } // New creates a new Server with the given configuration. @@ -144,6 +146,8 @@ func (s *Server) Start() error { } proxy := handler.NewProxy(s.db, s.storage, fetcher, resolver, s.logger) proxy.Cooldown = cd + proxy.CacheMetadata = s.cfg.CacheMetadata + proxy.MetadataTTL = s.cfg.ParseMetadataTTL() // Create router with Chi r := chi.NewRouter() @@ -228,6 +232,21 @@ func (s *Server) Start() error { r.Get("/api/browse/{ecosystem}/*", s.handleBrowsePath) r.Get("/api/compare/{ecosystem}/*", s.handleComparePath) + // Start background context (used by mirror jobs and cleanup) + bgCtx, bgCancel := context.WithCancel(context.Background()) + s.cancel = bgCancel + + // Mirror API endpoints (opt-in via mirror_api config or PROXY_MIRROR_API env) + if s.cfg.MirrorAPI { + mirrorSvc := mirror.New(proxy, s.db, s.storage, s.logger, 4) //nolint:mnd // default concurrency + jobStore := mirror.NewJobStore(bgCtx, mirrorSvc) + mirrorAPI := NewMirrorAPIHandler(jobStore) + r.Post("/api/mirror", mirrorAPI.HandleCreate) + r.Get("/api/mirror/{id}", mirrorAPI.HandleGet) + r.Delete("/api/mirror/{id}", mirrorAPI.HandleCancel) + go jobStore.StartCleanup(bgCtx) + } + s.http = &http.Server{ Addr: s.cfg.Listen, Handler: r, @@ -241,8 +260,6 @@ func (s *Server) Start() error { "base_url", s.cfg.BaseURL, "storage", s.storage.URL(), "database", s.cfg.Database.Path) - - // Start background goroutine to update cache stats metrics go s.updateCacheStatsMetrics() return s.http.ListenAndServe() @@ -274,6 +291,10 @@ func (s *Server) updateCacheStats() { func (s *Server) Shutdown(ctx context.Context) error { s.logger.Info("shutting down server") + if s.cancel != nil { + s.cancel() + } + var errs []error if s.http != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 4c49035..be88bf6 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -320,10 +320,14 @@ func TestGoList(t *testing.T) { w := httptest.NewRecorder() ts.handler.ServeHTTP(w, req) - // The handler is mounted if we get a Go proxy error (not a generic 404) - body := w.Body.String() - if w.Code == http.StatusNotFound && !strings.Contains(body, "example.com") { - t.Errorf("go handler should be mounted, got status %d, body: %s", w.Code, body) + // The handler is mounted if we get a response from the proxy (404 from upstream + // or 502 from connection failure), not a chi router 404. + // With metadata caching, upstream 404 is cleanly returned as our own 404. + if w.Code == http.StatusNotFound { + body := w.Body.String() + if !strings.Contains(body, "not found") { + t.Errorf("go handler should be mounted, got status %d, body: %s", w.Code, body) + } } }