diff --git a/AUTHORS b/AUTHORS index 2c067dd23..11d04e331 100644 --- a/AUTHORS +++ b/AUTHORS @@ -74,3 +74,4 @@ List of contributors, in chronological order: * JupiterRider (https://github.com/JupiterRider) * Agustin Henze (https://github.com/agustinhenze) * Tobias Assarsson (https://github.com/daedaluz) +* Ato Araki (https://github.com/atotto) diff --git a/go.mod b/go.mod index 53c5e78cb..6ef363050 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/aptly-dev/aptly -go 1.24 +go 1.24.0 require ( github.com/AlekSi/pointer v1.1.0 @@ -41,6 +41,7 @@ require ( ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/PuerkitoBio/purell v1.1.1 // indirect @@ -128,5 +129,6 @@ require ( github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/swag v1.16.3 go.etcd.io/etcd/client/v3 v3.5.15 + golang.org/x/oauth2 v0.33.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 502f4b216..e53d552af 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/AlekSi/pointer v1.1.0 h1:SSDMPcXD9jSl8FPy9cRzoRaMJtm9g9ggGTxecRUbQoI= github.com/AlekSi/pointer v1.1.0/go.mod h1:y7BvfRI3wXPWKXEBhU71nbnIEEZX0QTSB2Bj48UJIZE= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= @@ -348,6 +350,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= +golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/http/download.go b/http/download.go index 887f9b3bf..02b9275e4 100644 --- a/http/download.go +++ b/http/download.go @@ -44,6 +44,7 @@ func NewDownloader(downLimit int64, maxTries int, progress aptly.Progress) aptly transport.DisableCompression = true initTransport(&transport) transport.RegisterProtocol("ftp", &protocol.FTPRoundTripper{}) + transport.RegisterProtocol("ar+https", NewGCPRoundTripper(&transport)) downloader := &downloaderImpl{ progress: progress, diff --git a/http/gcp_auth.go b/http/gcp_auth.go new file mode 100644 index 000000000..d547da532 --- /dev/null +++ b/http/gcp_auth.go @@ -0,0 +1,64 @@ +package http + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +// gcpRoundTripper wraps http.RoundTripper to add Google Cloud authentication. +// It delays GCP authentication initialization until the first actual request is made. +// This avoids unnecessary credential loading when ar+https protocol is not actually used. +// +// It uses Application Default Credentials (ADC) which checks: +// 1. GOOGLE_APPLICATION_CREDENTIALS environment variable +// 2. gcloud auth application-default credentials +// 3. GCE/GKE metadata server +// See https://cloud.google.com/docs/authentication/application-default-credentials for usage details. +type gcpRoundTripper struct { + base http.RoundTripper + initOnce sync.Once + tokenSrc oauth2.TokenSource + initErr error +} + +func (t *gcpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Lazy initialization: only initialize GCP credentials on first request + t.initOnce.Do(func() { + creds, err := google.FindDefaultCredentials(context.Background(), + "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + t.initErr = fmt.Errorf("failed to find default credentials: %w", err) + return + } + t.tokenSrc = creds.TokenSource + }) + + reqCopy := req.Clone(req.Context()) + reqCopy.URL.Scheme = strings.TrimPrefix(reqCopy.URL.Scheme, "ar+") + + // Fall back to base transport if GCP auth initialization failed + if t.initErr != nil { + return t.base.RoundTrip(reqCopy) + } + + token, err := t.tokenSrc.Token() + if err != nil { + return nil, fmt.Errorf("failed to get OAuth2 token: %w", err) + } + token.SetAuthHeader(reqCopy) + + return t.base.RoundTrip(reqCopy) +} + +// NewGCPRoundTripper creates a new RoundTripper that handles GCP authentication for ar+https protocol. +func NewGCPRoundTripper(base http.RoundTripper) http.RoundTripper { + return &gcpRoundTripper{ + base: base, + } +} diff --git a/http/gcp_auth_test.go b/http/gcp_auth_test.go new file mode 100644 index 000000000..a706bdfe5 --- /dev/null +++ b/http/gcp_auth_test.go @@ -0,0 +1,110 @@ +package http + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" + + "golang.org/x/oauth2" +) + +func TestGCPAuthTransport_RoundTrip(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" { + t.Error("Expected Authorization header, got none") + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + transport := NewGCPRoundTripper(http.DefaultTransport) + + if os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") == "" { + t.Skip("Skipping test: GOOGLE_APPLICATION_CREDENTIALS not set") + } + + client := &http.Client{Transport: transport} + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestGCPAuthTransport_RoundTrip_with_dummy_tokenSource(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer dummy-token" { + t.Errorf("Expected Authorization header 'Bearer dummy-token', got '%s'", auth) + } + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + // Use a dummy token source for testing + transport := &gcpRoundTripper{ + base: http.DefaultTransport, + tokenSrc: oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "dummy-token", + }), + } + transport.initOnce.Do(func() {}) // Mark as initialized for testing + + client := &http.Client{Transport: transport} + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestGCPAuthTransport_RoundTrip_with_InvalidCredentials(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer ts.Close() + + // Create a temporary invalid credentials file + tmpFile, err := os.CreateTemp("", "invalid_credentials.json") + if err != nil { + t.Fatalf("Failed to create temp file: %s", err) + } + defer os.Remove(tmpFile.Name()) + if _, err := tmpFile.WriteString(`{"invalid": "data"}`); err != nil { + t.Fatalf("Failed to write to temp file: %s", err) + } + tmpFile.Close() + + defaultEnv := os.Getenv("GOOGLE_APPLICATION_CREDENTIALS") + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", tmpFile.Name()) + defer os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", defaultEnv) + + transport := &gcpRoundTripper{ + base: http.DefaultTransport, + } + + client := &http.Client{Transport: transport} + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("Failed to make request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusForbidden { + t.Errorf("Expected status 403, got %d", resp.StatusCode) + } + + if transport.initErr == nil { + t.Error("Expected init error due to invalid credentials, got none") + } +}