diff --git a/cmd/up/client/client.go b/cmd/up/client/client.go index 28935dc..6608e24 100644 --- a/cmd/up/client/client.go +++ b/cmd/up/client/client.go @@ -77,6 +77,12 @@ func validateDNSIP(s, field string) error { return nil } +func applyTLSClientCert(client *api.Client, certPath string) { + if client != nil && client.HTTPClient != nil { + client.HTTPClient.TLSClientCert = certPath + } +} + func ClientUpCmd() *cobra.Command { opts := ClientUpCmdOpts{} @@ -141,6 +147,7 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string) apiClient := api.FromContext(cmd.Context()) accountStore := config.AccountStoreFromContext(cmd.Context()) cfg := config.ConfigFromContext(cmd.Context()) + applyTLSClientCert(apiClient, opts.TlsClientCert) if runtime.GOOS == "windows" { err := errors.New("this command is currently unsupported on Windows") @@ -199,6 +206,7 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string) return err } } + applyTLSClientCert(healthClient, opts.TlsClientCert) healthOk, healthErr := healthClient.CheckHealth() if healthErr != nil || !healthOk { @@ -624,11 +632,11 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string) if enableAPI { _ = olm.StartApi() } - + // Run StartTunnel in a goroutine so org switching can restart it // without causing the CLI process to exit go olm.StartTunnel(tunnelConfig) - + // Block on context to keep process alive <-ctx.Done() logger.Info("Received shutdown signal, stopping tunnel") diff --git a/internal/api/client.go b/internal/api/client.go index 6ee4f8f..be74a80 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -2,18 +2,22 @@ package api import ( "bytes" + "crypto/tls" + "crypto/x509" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" + "os" "strconv" "strings" "time" "github.com/fosrl/cli/internal/version" yaml "go.yaml.in/yaml/v3" + "software.sslmate.com/src/go-pkcs12" ) // ClientConfig holds configuration for creating a new client @@ -127,8 +131,9 @@ func (c *Client) request(method, endpoint string, payload interface{}, result in } // Create HTTP client and execute request - httpClient := &http.Client{ - Timeout: c.HTTPClient.Timeout, + httpClient, err := createHTTPClient(c.HTTPClient.Timeout, c.HTTPClient.TLSClientCert) + if err != nil { + return fmt.Errorf("failed to create HTTP client: %w", err) } resp, err := httpClient.Do(req) @@ -548,11 +553,56 @@ func setJSONResponseHeaders(req *http.Request, userAgent string) { req.Header.Set("User-Agent", userAgent) } -// createHTTPClient creates an HTTP client with the specified timeout -func createHTTPClient(timeout time.Duration) *http.Client { +// createHTTPClient creates an HTTP client with the specified timeout and optional mTLS client cert. +func createHTTPClient(timeout time.Duration, tlsClientCert string) (*http.Client, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + if tlsClientCert != "" { + tlsConfig, err := loadTLSClientConfig(tlsClientCert) + if err != nil { + return nil, err + } + transport.TLSClientConfig = tlsConfig + } + return &http.Client{ - Timeout: timeout, + Timeout: timeout, + Transport: transport, + }, nil +} + +func loadTLSClientConfig(path string) (*tls.Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read TLS client certificate %q: %w", path, err) } + + if cert, err := tls.X509KeyPair(data, data); err == nil { + if len(cert.Certificate) > 0 && cert.Leaf == nil { + if leaf, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { + cert.Leaf = leaf + } + } + return &tls.Config{Certificates: []tls.Certificate{cert}}, nil + } + + privateKey, certificate, caCerts, err := pkcs12.DecodeChain(data, "") + if err != nil { + return nil, fmt.Errorf("load TLS client certificate %q: %w", path, err) + } + + clientCert := tls.Certificate{ + Certificate: make([][]byte, 0, 1+len(caCerts)), + PrivateKey: privateKey, + Leaf: certificate, + } + if certificate != nil { + clientCert.Certificate = append(clientCert.Certificate, certificate.Raw) + } + for _, caCert := range caCerts { + clientCert.Certificate = append(clientCert.Certificate, caCert.Raw) + } + + return &tls.Config{Certificates: []tls.Certificate{clientCert}}, nil } // parseAPIResponseBody parses the response body into an APIResponse struct @@ -647,7 +697,10 @@ func LoginWithCookie(client *Client, req LoginRequest) (*LoginResponse, string, client.Session.ApplyToRequest(httpReq) // Execute request - httpClient := createHTTPClient(client.HTTPClient.Timeout) + httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert) + if err != nil { + return nil, "", fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := httpClient.Do(httpReq) if err != nil { return nil, "", fmt.Errorf("request failed: %w", err) @@ -734,7 +787,10 @@ func StartDeviceWebAuth(client *Client, req DeviceWebAuthStartRequest) (*DeviceW client.Session.ApplyToRequest(httpReq) // Execute request - httpClient := createHTTPClient(client.HTTPClient.Timeout) + httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("request failed: %w", err) @@ -792,7 +848,10 @@ func PollDeviceWebAuth(client *Client, code string) (*DeviceWebAuthPollResponse, client.Session.ApplyToRequest(httpReq) // Execute request - httpClient := createHTTPClient(client.HTTPClient.Timeout) + httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert) + if err != nil { + return nil, "", fmt.Errorf("failed to create HTTP client: %w", err) + } resp, err := httpClient.Do(httpReq) if err != nil { return nil, "", fmt.Errorf("request failed: %w", err) @@ -870,4 +929,3 @@ func (c *Client) ApplyBlueprint(orgID string, name string, blueprint string) (*A } return &response, nil } - diff --git a/internal/api/client_tls_test.go b/internal/api/client_tls_test.go new file mode 100644 index 0000000..0351f7f --- /dev/null +++ b/internal/api/client_tls_test.go @@ -0,0 +1,117 @@ +package api + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "software.sslmate.com/src/go-pkcs12" +) + +func TestCreateHTTPClientWithTLSClientCert(t *testing.T) { + pemPath, p12Path := writeTestClientCerts(t) + + t.Run("pem", func(t *testing.T) { + client, err := createHTTPClient(time.Second, pemPath) + if err != nil { + t.Fatalf("createHTTPClient() error = %v", err) + } + assertTLSClientCert(t, client, "pangolin-client") + }) + + t.Run("pkcs12", func(t *testing.T) { + client, err := createHTTPClient(time.Second, p12Path) + if err != nil { + t.Fatalf("createHTTPClient() error = %v", err) + } + assertTLSClientCert(t, client, "pangolin-client") + }) +} + +func assertTLSClientCert(t *testing.T, client *http.Client, wantCommonName string) { + t.Helper() + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport) + } + if transport.TLSClientConfig == nil { + t.Fatalf("transport.TLSClientConfig = nil") + } + if got := len(transport.TLSClientConfig.Certificates); got != 1 { + t.Fatalf("len(Certificates) = %d, want 1", got) + } + leaf := transport.TLSClientConfig.Certificates[0].Leaf + if leaf == nil { + t.Fatalf("certificate Leaf = nil") + } + if got := leaf.Subject.CommonName; got != wantCommonName { + t.Fatalf("certificate CommonName = %q, want %q", got, wantCommonName) + } +} + +func writeTestClientCerts(t *testing.T) (string, string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "pangolin-client", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate() error = %v", err) + } + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatalf("ParseCertificate() error = %v", err) + } + + dir := t.TempDir() + + pemPath := filepath.Join(dir, "client.pem") + pemFile, err := os.Create(pemPath) + if err != nil { + t.Fatalf("Create(%q) error = %v", pemPath, err) + } + if err := pem.Encode(pemFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatalf("encode certificate PEM: %v", err) + } + if err := pem.Encode(pemFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}); err != nil { + t.Fatalf("encode key PEM: %v", err) + } + if err := pemFile.Close(); err != nil { + t.Fatalf("close PEM file: %v", err) + } + + p12Path := filepath.Join(dir, "client.p12") + p12Data, err := pkcs12.Encode(rand.Reader, key, cert, nil, "") + if err != nil { + t.Fatalf("pkcs12.Encode() error = %v", err) + } + if err := os.WriteFile(p12Path, p12Data, 0o600); err != nil { + t.Fatalf("WriteFile(%q) error = %v", p12Path, err) + } + + return pemPath, p12Path +} diff --git a/internal/api/types.go b/internal/api/types.go index 9319f60..ee00008 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -17,7 +17,8 @@ type Client struct { // HTTPClient wraps the standard http.Client with additional configuration type HTTPClient struct { - Timeout time.Duration + Timeout time.Duration + TLSClientCert string } // RequestOptions contains optional parameters for API requests