Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cmd/up/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ func validateDNSIP(s, field string) error {
return nil
}

func applyTLSClientCert(client *api.Client, certPath string) {
if client != nil && client.HTTPClient != nil {
client.HTTPClient.TLSClientCert = certPath
}
}

func ClientUpCmd() *cobra.Command {
opts := ClientUpCmdOpts{}

Expand Down Expand Up @@ -141,6 +147,7 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string)
apiClient := api.FromContext(cmd.Context())
accountStore := config.AccountStoreFromContext(cmd.Context())
cfg := config.ConfigFromContext(cmd.Context())
applyTLSClientCert(apiClient, opts.TlsClientCert)

if runtime.GOOS == "windows" {
err := errors.New("this command is currently unsupported on Windows")
Expand Down Expand Up @@ -199,6 +206,7 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string)
return err
}
}
applyTLSClientCert(healthClient, opts.TlsClientCert)

healthOk, healthErr := healthClient.CheckHealth()
if healthErr != nil || !healthOk {
Expand Down Expand Up @@ -624,11 +632,11 @@ func clientUpMain(cmd *cobra.Command, opts *ClientUpCmdOpts, extraArgs []string)
if enableAPI {
_ = olm.StartApi()
}

// Run StartTunnel in a goroutine so org switching can restart it
// without causing the CLI process to exit
go olm.StartTunnel(tunnelConfig)

// Block on context to keep process alive
<-ctx.Done()
logger.Info("Received shutdown signal, stopping tunnel")
Expand Down
76 changes: 67 additions & 9 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ package api

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"

"github.com/fosrl/cli/internal/version"
yaml "go.yaml.in/yaml/v3"
"software.sslmate.com/src/go-pkcs12"
)

// ClientConfig holds configuration for creating a new client
Expand Down Expand Up @@ -127,8 +131,9 @@ func (c *Client) request(method, endpoint string, payload interface{}, result in
}

// Create HTTP client and execute request
httpClient := &http.Client{
Timeout: c.HTTPClient.Timeout,
httpClient, err := createHTTPClient(c.HTTPClient.Timeout, c.HTTPClient.TLSClientCert)
if err != nil {
return fmt.Errorf("failed to create HTTP client: %w", err)
}

resp, err := httpClient.Do(req)
Expand Down Expand Up @@ -548,11 +553,56 @@ func setJSONResponseHeaders(req *http.Request, userAgent string) {
req.Header.Set("User-Agent", userAgent)
}

// createHTTPClient creates an HTTP client with the specified timeout
func createHTTPClient(timeout time.Duration) *http.Client {
// createHTTPClient creates an HTTP client with the specified timeout and optional mTLS client cert.
func createHTTPClient(timeout time.Duration, tlsClientCert string) (*http.Client, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
if tlsClientCert != "" {
tlsConfig, err := loadTLSClientConfig(tlsClientCert)
if err != nil {
return nil, err
}
transport.TLSClientConfig = tlsConfig
}

return &http.Client{
Timeout: timeout,
Timeout: timeout,
Transport: transport,
}, nil
}

func loadTLSClientConfig(path string) (*tls.Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read TLS client certificate %q: %w", path, err)
}

if cert, err := tls.X509KeyPair(data, data); err == nil {
if len(cert.Certificate) > 0 && cert.Leaf == nil {
if leaf, err := x509.ParseCertificate(cert.Certificate[0]); err == nil {
cert.Leaf = leaf
}
}
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
}

privateKey, certificate, caCerts, err := pkcs12.DecodeChain(data, "")
if err != nil {
return nil, fmt.Errorf("load TLS client certificate %q: %w", path, err)
}

clientCert := tls.Certificate{
Certificate: make([][]byte, 0, 1+len(caCerts)),
PrivateKey: privateKey,
Leaf: certificate,
}
if certificate != nil {
clientCert.Certificate = append(clientCert.Certificate, certificate.Raw)
}
for _, caCert := range caCerts {
clientCert.Certificate = append(clientCert.Certificate, caCert.Raw)
}

return &tls.Config{Certificates: []tls.Certificate{clientCert}}, nil
}

// parseAPIResponseBody parses the response body into an APIResponse struct
Expand Down Expand Up @@ -647,7 +697,10 @@ func LoginWithCookie(client *Client, req LoginRequest) (*LoginResponse, string,
client.Session.ApplyToRequest(httpReq)

// Execute request
httpClient := createHTTPClient(client.HTTPClient.Timeout)
httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert)
if err != nil {
return nil, "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, "", fmt.Errorf("request failed: %w", err)
Expand Down Expand Up @@ -734,7 +787,10 @@ func StartDeviceWebAuth(client *Client, req DeviceWebAuthStartRequest) (*DeviceW
client.Session.ApplyToRequest(httpReq)

// Execute request
httpClient := createHTTPClient(client.HTTPClient.Timeout)
httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
Expand Down Expand Up @@ -792,7 +848,10 @@ func PollDeviceWebAuth(client *Client, code string) (*DeviceWebAuthPollResponse,
client.Session.ApplyToRequest(httpReq)

// Execute request
httpClient := createHTTPClient(client.HTTPClient.Timeout)
httpClient, err := createHTTPClient(client.HTTPClient.Timeout, client.HTTPClient.TLSClientCert)
if err != nil {
return nil, "", fmt.Errorf("failed to create HTTP client: %w", err)
}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, "", fmt.Errorf("request failed: %w", err)
Expand Down Expand Up @@ -870,4 +929,3 @@ func (c *Client) ApplyBlueprint(orgID string, name string, blueprint string) (*A
}
return &response, nil
}

117 changes: 117 additions & 0 deletions internal/api/client_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package api

import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"os"
"path/filepath"
"testing"
"time"

"software.sslmate.com/src/go-pkcs12"
)

func TestCreateHTTPClientWithTLSClientCert(t *testing.T) {
pemPath, p12Path := writeTestClientCerts(t)

t.Run("pem", func(t *testing.T) {
client, err := createHTTPClient(time.Second, pemPath)
if err != nil {
t.Fatalf("createHTTPClient() error = %v", err)
}
assertTLSClientCert(t, client, "pangolin-client")
})

t.Run("pkcs12", func(t *testing.T) {
client, err := createHTTPClient(time.Second, p12Path)
if err != nil {
t.Fatalf("createHTTPClient() error = %v", err)
}
assertTLSClientCert(t, client, "pangolin-client")
})
}

func assertTLSClientCert(t *testing.T, client *http.Client, wantCommonName string) {
t.Helper()

transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
}
if transport.TLSClientConfig == nil {
t.Fatalf("transport.TLSClientConfig = nil")
}
if got := len(transport.TLSClientConfig.Certificates); got != 1 {
t.Fatalf("len(Certificates) = %d, want 1", got)
}
leaf := transport.TLSClientConfig.Certificates[0].Leaf
if leaf == nil {
t.Fatalf("certificate Leaf = nil")
}
if got := leaf.Subject.CommonName; got != wantCommonName {
t.Fatalf("certificate CommonName = %q, want %q", got, wantCommonName)
}
}

func writeTestClientCerts(t *testing.T) (string, string) {
t.Helper()

key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}

template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "pangolin-client",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}

certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate() error = %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("ParseCertificate() error = %v", err)
}

dir := t.TempDir()

pemPath := filepath.Join(dir, "client.pem")
pemFile, err := os.Create(pemPath)
if err != nil {
t.Fatalf("Create(%q) error = %v", pemPath, err)
}
if err := pem.Encode(pemFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
t.Fatalf("encode certificate PEM: %v", err)
}
if err := pem.Encode(pemFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}); err != nil {
t.Fatalf("encode key PEM: %v", err)
}
if err := pemFile.Close(); err != nil {
t.Fatalf("close PEM file: %v", err)
}

p12Path := filepath.Join(dir, "client.p12")
p12Data, err := pkcs12.Encode(rand.Reader, key, cert, nil, "")
if err != nil {
t.Fatalf("pkcs12.Encode() error = %v", err)
}
if err := os.WriteFile(p12Path, p12Data, 0o600); err != nil {
t.Fatalf("WriteFile(%q) error = %v", p12Path, err)
}

return pemPath, p12Path
}
3 changes: 2 additions & 1 deletion internal/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ type Client struct {

// HTTPClient wraps the standard http.Client with additional configuration
type HTTPClient struct {
Timeout time.Duration
Timeout time.Duration
TLSClientCert string
}

// RequestOptions contains optional parameters for API requests
Expand Down