diff --git a/cmd/src/proxy.go b/cmd/src/proxy.go new file mode 100644 index 0000000000..4e6d895c65 --- /dev/null +++ b/cmd/src/proxy.go @@ -0,0 +1,104 @@ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/src-cli/internal/cmderrors" + "github.com/sourcegraph/src-cli/internal/srcproxy" +) + +func init() { + // NOTE: this is an experimental command. It isn't advertised in -help + + flagSet := flag.NewFlagSet("proxy", flag.ExitOnError) + usageFunc := func() { + fmt.Fprintf(flag.CommandLine.Output(), `'src proxy' starts a local reverse proxy to your Sourcegraph instance. + +USAGE + src [-v] proxy [-addr :7777] [-insecure-skip-verify] [-server-cert cert.pem -server-key key.pem] [-log-file path] [client-ca.pem] + +By default, proxied requests use SRC_ACCESS_TOKEN via: + Authorization: token SRC_ACCESS_TOKEN + +If a client CA certificate path is provided, proxy runs in mTLS sudo mode: + 1. Serves HTTPS and requires a client certificate signed by the provided CA. + 2. Reads the first email SAN from the presented client certificate. + 3. Looks up the Sourcegraph user by that email. + 4. Proxies requests with: + Authorization: token-sudo token="TOKEN",user="USERNAME" + +Server certificate options: + -server-cert and -server-key can be used to provide the TLS certificate + and key used by the local proxy server. If omitted in cert mode, an + ephemeral self-signed server certificate is generated. +`) + } + + var ( + addrFlag = flagSet.String("addr", ":7777", "Address on which to serve") + insecureSkipVerifyFlag = flagSet.Bool("insecure-skip-verify", false, "Skip validation of TLS certificates against trusted chains") + serverCertFlag = flagSet.String("server-cert", "", "Path to TLS server certificate for local proxy listener") + serverKeyFlag = flagSet.String("server-key", "", "Path to TLS server private key for local proxy listener") + logFileFlag = flagSet.String("log-file", "", "Path to log file. If not set, logs are written to stderr") + ) + + handler := func(args []string) error { + if err := flagSet.Parse(args); err != nil { + return err + } + + var clientCAPath string + switch flagSet.NArg() { + case 0: + case 1: + clientCAPath = flagSet.Arg(0) + default: + return cmderrors.Usage("requires zero or one positional argument: path to client CA certificate") + } + if (*serverCertFlag == "") != (*serverKeyFlag == "") { + return cmderrors.Usage("both -server-cert and -server-key must be provided together") + } + + logOutput := io.Writer(os.Stderr) + var logF *os.File + if *logFileFlag != "" { + var err error + logF, err = os.Create(*logFileFlag) + if err != nil { + return errors.Wrap(err, "open log file") + } + defer func() { _ = logF.Close() }() + logOutput = logF + } + + dbug := log.New(io.Discard, "", log.LstdFlags) + if *verbose { + dbug = log.New(logOutput, "DBUG proxy: ", log.LstdFlags) + } + + s := &srcproxy.Serve{ + Addr: *addrFlag, + Endpoint: cfg.Endpoint, + AccessToken: cfg.AccessToken, + ClientCAPath: clientCAPath, + ServerCertPath: *serverCertFlag, + ServerKeyPath: *serverKeyFlag, + InsecureSkipVerify: *insecureSkipVerifyFlag, + AdditionalHeaders: cfg.AdditionalHeaders, + Info: log.New(logOutput, "proxy: ", log.LstdFlags), + Debug: dbug, + } + return s.Start() + } + + commands = append(commands, &command{ + flagSet: flagSet, + handler: handler, + usageFunc: usageFunc, + }) +} diff --git a/internal/srcproxy/README.md b/internal/srcproxy/README.md new file mode 100644 index 0000000000..0e1c305c51 --- /dev/null +++ b/internal/srcproxy/README.md @@ -0,0 +1,71 @@ +# srcproxy + +`src proxy` is a local reverse proxy for Sourcegraph with two auth modes. + +## Auth Modes + +- Default mode (no CA arg): forwards requests with `Authorization: token `. +- mTLS sudo mode (with CA arg): requires client certificate, extracts first email SAN, resolves Sourcegraph user by email, then forwards with `token-sudo`. + +## Run + +```bash +# default mode +src proxy + +# mTLS sudo mode +src -v proxy \ + -server-cert ./internal/srcproxy/test-certs/server.pem \ + -server-key ./internal/srcproxy/test-certs/server.key \ + ./internal/srcproxy/test-certs/ca.pem +``` + +## Logging + +- `-v` enables request-level debug logging. +- `-log-file ` writes logs to a file. +- Without `-log-file`, logs go to stderr. + +Example: + +```bash +src -v proxy -log-file ./proxy.log ./internal/srcproxy/test-certs/ca.pem +``` + +## Request Format + +GraphQL requests should use JSON: + +```bash +curl -k \ + -H 'Content-Type: application/json' \ + --cert ./internal/srcproxy/test-certs/client.pem \ + --key ./internal/srcproxy/test-certs/client.key \ + https://localhost:7777/.api/graphql \ + -d '{"query":"{ currentUser { username } }"}' +``` + +If `Content-Type` is omitted with `curl -d`, curl sends `application/x-www-form-urlencoded`, which Sourcegraph GraphQL rejects. + +## Important Routing Behavior + +The proxy rewrites upstream `Host` to `SRC_ENDPOINT` host. + +This is required for name-based routing (for example Caddy virtual hosts). If `Host` is forwarded as `localhost:`, some upstream setups return `200` with empty body from a default vhost instead of Sourcegraph GraphQL. + +## mTLS Certificate Requirements + +- Client cert must chain to the CA file passed as positional arg. +- Client cert must include an email SAN. +- The email SAN must map to an existing Sourcegraph user. + +## Troubleshooting + +- `HTTP 200` with empty body: + upstream host routing mismatch. Confirm proxy is current and `Host` rewrite is in place. +- `no client certificate presented`: + client did not send cert/key or CA trust does not match. +- `client certificate does not contain an email SAN`: + regenerate client cert with email SAN. +- `no Sourcegraph user found for certificate email`: + cert email is not a Sourcegraph user email. diff --git a/internal/srcproxy/gen-test-certs.sh b/internal/srcproxy/gen-test-certs.sh new file mode 100755 index 0000000000..0f569793ff --- /dev/null +++ b/internal/srcproxy/gen-test-certs.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# gen-test-certs.sh — generate certs for testing `src proxy` mTLS mode +# +# Usage: ./gen-test-certs.sh [email] [output-dir] +# email: email SAN to embed in client cert (default: alice@example.com) +# output-dir: where to write files (default: ./test-certs) +set -euo pipefail + +EMAIL="${1:-alice@example.com}" +DIR="${2:-./test-certs}" +mkdir -p "$DIR" + +echo "==> Generating certs in $DIR (email: $EMAIL)" + +# ── 1. CA ──────────────────────────────────────────────────────────────────── +openssl genrsa -out "$DIR/ca.key" 2048 2>/dev/null + +openssl req -new -x509 -days 1 \ + -key "$DIR/ca.key" \ + -out "$DIR/ca.pem" \ + -subj "/CN=Test Client CA" 2>/dev/null + +echo " ca.pem / ca.key" + +# ── 2. Server cert (so you can pass it to the proxy and trust it in curl) ──── +openssl genrsa -out "$DIR/server.key" 2048 2>/dev/null + +openssl req -new \ + -key "$DIR/server.key" \ + -out "$DIR/server.csr" \ + -subj "/CN=localhost" 2>/dev/null + +openssl x509 -req -days 1 \ + -in "$DIR/server.csr" \ + -signkey "$DIR/server.key" \ + -out "$DIR/server.pem" \ + -extfile <(printf 'subjectAltName=DNS:localhost,IP:127.0.0.1') 2>/dev/null + +echo " server.pem / server.key" + +# ── 3. Client cert with email SAN signed by the CA ─────────────────────────── +openssl genrsa -out "$DIR/client.key" 2048 2>/dev/null + +openssl req -new \ + -key "$DIR/client.key" \ + -out "$DIR/client.csr" \ + -subj "/CN=test-client" 2>/dev/null + +openssl x509 -req -days 1 \ + -in "$DIR/client.csr" \ + -CA "$DIR/ca.pem" \ + -CAkey "$DIR/ca.key" \ + -CAcreateserial \ + -out "$DIR/client.pem" \ + -extfile <(printf "subjectAltName=email:%s" "$EMAIL") 2>/dev/null + +echo " client.pem / client.key (email SAN: $EMAIL)" + +# Confirm the SAN is present +echo "" +echo "==> Verifying email SAN in client cert:" +openssl x509 -in "$DIR/client.pem" -noout -text \ + | grep -A1 "Subject Alternative Name" + +echo "" +echo "==> Done. Next steps:" +echo "" +echo " # 1. Start the proxy (in another terminal):" +echo " export SRC_ENDPOINT=https://sourcegraph.example.com" +echo " export SRC_ACCESS_TOKEN=" +echo " go run ./cmd/src proxy \\" +echo " -server-cert $DIR/server.pem \\" +echo " -server-key $DIR/server.key \\" +echo " $DIR/ca.pem" +echo "" +echo " # 2. Send a request via curl using the client cert:" +echo " curl --cacert $DIR/server.pem \\" +echo " --cert $DIR/client.pem \\" +echo " --key $DIR/client.key \\" +echo " https://localhost:7777/.api/graphql \\" +echo " -d '{\"query\":\"{ currentUser { username } }\"}'" +echo "" +echo " # Or skip server cert verification with -k:" +echo " curl -k --cert $DIR/client.pem --key $DIR/client.key \\" +echo " https://localhost:7777/.api/graphql \\" +echo " -d '{\"query\":\"{ currentUser { username } }\"}'" diff --git a/internal/srcproxy/serve.go b/internal/srcproxy/serve.go new file mode 100644 index 0000000000..ab2f151192 --- /dev/null +++ b/internal/srcproxy/serve.go @@ -0,0 +1,371 @@ +package srcproxy + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "log" + "math/big" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +type Serve struct { + Addr string + Endpoint string + AccessToken string + ClientCAPath string + ServerCertPath string + ServerKeyPath string + InsecureSkipVerify bool + AdditionalHeaders map[string]string + HTTPClient *http.Client + Info *log.Logger + Debug *log.Logger + + mu sync.Mutex + userByEmail map[string]string + httpClient *http.Client + baseAuthMode string +} + +func (s *Serve) Start() error { + if s.AccessToken == "" { + return errors.New("SRC_ACCESS_TOKEN must be set") + } + if s.Endpoint == "" { + return errors.New("SRC_ENDPOINT must be set") + } + + s.httpClient = s.HTTPClient + if s.httpClient == nil { + s.httpClient = &http.Client{Transport: http.DefaultTransport.(*http.Transport).Clone()} + if s.InsecureSkipVerify { + s.httpClient.Transport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + } + + s.userByEmail = map[string]string{} + + if s.ClientCAPath == "" { + s.baseAuthMode = "access-token" + } else { + if err := s.verifySudoCapability(context.Background()); err != nil { + return err + } + s.baseAuthMode = "mtls-sudo" + } + + endpointURL, err := url.Parse(strings.TrimRight(s.Endpoint, "/")) + if err != nil { + return errors.Wrap(err, "parse endpoint") + } + + proxy := s.newReverseProxy(endpointURL) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Debug.Printf("incoming request method=%s host=%s path=%s remote=%s", r.Method, r.Host, r.URL.RequestURI(), r.RemoteAddr) + authHeader, err := s.authorizationForRequest(r) + if err != nil { + s.Debug.Printf("authorization failed method=%s path=%s err=%s", r.Method, r.URL.Path, err) + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + r.Header.Set("Authorization", authHeader) + s.Debug.Printf("proxying request method=%s path=%s upstream_host=%s", r.Method, r.URL.Path, endpointURL.Host) + proxy.ServeHTTP(w, r) + }) + + ln, err := net.Listen("tcp", s.Addr) + if err != nil { + return errors.Wrap(err, "listen") + } + + if s.ClientCAPath != "" { + tlsCfg, err := s.mtlsServerTLSConfig() + if err != nil { + return err + } + ln = tls.NewListener(ln, tlsCfg) + } + + s.Addr = ln.Addr().String() + if s.ClientCAPath == "" { + s.Info.Printf("listening on http://%s", s.Addr) + } else { + s.Info.Printf("listening on https://%s", s.Addr) + s.Info.Printf("mTLS client CA: %s", s.ClientCAPath) + } + s.Info.Printf("proxying requests to %s", s.Endpoint) + s.Info.Printf("auth mode: %s", s.baseAuthMode) + + if err := (&http.Server{Handler: handler}).Serve(ln); err != nil { + return errors.Wrap(err, "serve") + } + + return nil +} + +func (s *Serve) newReverseProxy(endpointURL *url.URL) *httputil.ReverseProxy { + transport := http.DefaultTransport.(*http.Transport).Clone() + if s.InsecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + proxy := httputil.NewSingleHostReverseProxy(endpointURL) + proxy.Transport = transport + + upstreamDirector := proxy.Director + proxy.Director = func(r *http.Request) { + upstreamDirector(r) + // Rewrite Host for name-based routing upstream (e.g. Caddy vhosts). + r.Host = endpointURL.Host + for key, value := range s.AdditionalHeaders { + r.Header.Set(key, value) + } + } + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + s.Debug.Printf("proxy request failed method=%s path=%s err=%s", r.Method, r.URL.Path, err) + http.Error(w, "proxy request failed: "+err.Error(), http.StatusBadGateway) + } + return proxy +} + +func (s *Serve) authorizationForRequest(r *http.Request) (string, error) { + if s.ClientCAPath == "" { + return "token " + s.AccessToken, nil + } + + email, err := emailFromPeerCertificates(r.TLS) + if err != nil { + return "", err + } + + username, err := s.lookupUsernameByEmail(r.Context(), email) + if err != nil { + return "", err + } + + return fmt.Sprintf(`token-sudo token=%q,user=%q`, s.AccessToken, username), nil +} + +func (s *Serve) verifySudoCapability(ctx context.Context) error { + const queryCurrentUser = `query CurrentUser { currentUser { username } }` + + var current struct { + CurrentUser *struct { + Username string `json:"username"` + } `json:"currentUser"` + } + if err := doGraphQL(ctx, s.httpClient, s.Endpoint, "token "+s.AccessToken, queryCurrentUser, nil, ¤t); err != nil { + return errors.Wrap(err, "verify base access token") + } + if current.CurrentUser == nil || current.CurrentUser.Username == "" { + return errors.New("unable to resolve current user from access token") + } + + sudoAuth := fmt.Sprintf(`token-sudo token=%q,user=%q`, s.AccessToken, current.CurrentUser.Username) + if err := doGraphQL(ctx, s.httpClient, s.Endpoint, sudoAuth, queryCurrentUser, nil, ¤t); err != nil { + return errors.Wrap(err, "verify token has site-admin:sudo scope") + } + + return nil +} + +func (s *Serve) lookupUsernameByEmail(ctx context.Context, email string) (string, error) { + s.mu.Lock() + if username, ok := s.userByEmail[email]; ok { + s.mu.Unlock() + return username, nil + } + s.mu.Unlock() + + username, err := lookupUserByEmail(ctx, s.httpClient, s.Endpoint, s.AccessToken, email) + if err != nil { + return "", err + } + + s.mu.Lock() + s.userByEmail[email] = username + s.mu.Unlock() + return username, nil +} + +func (s *Serve) mtlsServerTLSConfig() (*tls.Config, error) { + caData, err := os.ReadFile(s.ClientCAPath) + if err != nil { + return nil, errors.Wrap(err, "read mTLS client CA certificate") + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caData) { + return nil, errors.New("failed to parse mTLS client CA certificate") + } + + serverCert, err := s.loadOrGenerateServerCert() + if err != nil { + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: pool, + MinVersion: tls.VersionTLS12, + }, nil +} + +func (s *Serve) loadOrGenerateServerCert() (tls.Certificate, error) { + if s.ServerCertPath != "" || s.ServerKeyPath != "" { + cert, err := tls.LoadX509KeyPair(s.ServerCertPath, s.ServerKeyPath) + if err != nil { + return tls.Certificate{}, errors.Wrap(err, "load server TLS certificate/key") + } + return cert, nil + } + + cert, err := generateEphemeralServerCert() + if err != nil { + return tls.Certificate{}, errors.Wrap(err, "generate server TLS certificate") + } + return cert, nil +} + +func generateEphemeralServerCert() (tls.Certificate, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now().Add(-1 * time.Hour) + notAfter := time.Now().Add(24 * time.Hour) + + tpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: "src-proxy", + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +func emailFromPeerCertificates(tlsState *tls.ConnectionState) (string, error) { + if tlsState == nil || len(tlsState.PeerCertificates) == 0 { + return "", errors.New("no client certificate presented") + } + cert := tlsState.PeerCertificates[0] + if len(cert.EmailAddresses) == 0 { + return "", errors.New("client certificate does not contain an email SAN") + } + return cert.EmailAddresses[0], nil +} + +func lookupUserByEmail(ctx context.Context, client *http.Client, endpoint, accessToken, email string) (string, error) { + const query = `query LookupUserByEmail($email: String) { + user(email: $email) { + username + } +}` + + var result struct { + User *struct { + Username string `json:"username"` + } `json:"user"` + } + if err := doGraphQL(ctx, client, endpoint, "token "+accessToken, query, map[string]any{"email": email}, &result); err != nil { + return "", err + } + if result.User == nil || result.User.Username == "" { + return "", errors.New("no Sourcegraph user found for certificate email") + } + return result.User.Username, nil +} + +func doGraphQL(ctx context.Context, client *http.Client, endpoint, authorizationHeader, query string, variables map[string]any, result any) error { + payload, err := json.Marshal(map[string]any{ + "query": query, + "variables": variables, + }) + if err != nil { + return errors.Wrap(err, "marshal GraphQL request") + } + + req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(endpoint, "/")+"/.api/graphql", bytes.NewReader(payload)) + if err != nil { + return errors.Wrap(err, "create GraphQL request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", authorizationHeader) + + resp, err := client.Do(req) + if err != nil { + return errors.Wrap(err, "perform GraphQL request") + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return errors.Wrap(err, "read GraphQL response") + } + + if resp.StatusCode != http.StatusOK { + return errors.Newf("GraphQL request failed with status %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + + var envelope struct { + Data json.RawMessage `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + } + if err := json.Unmarshal(body, &envelope); err != nil { + return errors.Wrap(err, "decode GraphQL response envelope") + } + + if len(envelope.Errors) > 0 { + var messages []string + for _, graphqlErr := range envelope.Errors { + messages = append(messages, graphqlErr.Message) + } + return errors.New(strings.Join(messages, "; ")) + } + + if result == nil { + return nil + } + if err := json.Unmarshal(envelope.Data, result); err != nil { + return errors.Wrap(err, "decode GraphQL response data") + } + return nil +} diff --git a/internal/srcproxy/serve_test.go b/internal/srcproxy/serve_test.go new file mode 100644 index 0000000000..9b151f0aec --- /dev/null +++ b/internal/srcproxy/serve_test.go @@ -0,0 +1,257 @@ +package srcproxy + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "io" + "log" + "math/big" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestAuthorizationForRequest_DefaultMode(t *testing.T) { + t.Parallel() + + s := &Serve{ + AccessToken: "test-token", + Info: log.New(ioDiscard{}, "", 0), + Debug: log.New(ioDiscard{}, "", 0), + } + + authHeader, err := s.authorizationForRequest(&http.Request{}) + if err != nil { + t.Fatalf("authorizationForRequest() error = %v", err) + } + if got, want := authHeader, "token test-token"; got != want { + t.Fatalf("auth header = %q, want %q", got, want) + } +} + +func TestAuthorizationForRequest_MTLSSuccess(t *testing.T) { + t.Parallel() + + client := &http.Client{Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + var payload struct { + Query string `json:"query"` + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal request body: %v", err) + } + + if !strings.Contains(payload.Query, "LookupUserByEmail") { + t.Fatalf("unexpected query: %s", payload.Query) + } + if got, want := r.Header.Get("Authorization"), "token token-1"; got != want { + t.Fatalf("auth header = %q, want %q", got, want) + } + + return testResponse(`{"data":{"user":{"username":"alice"}}}`), nil + })} + + s := &Serve{ + Endpoint: "https://example.com", + AccessToken: "token-1", + ClientCAPath: "ca.pem", + httpClient: client, + userByEmail: map[string]string{}, + Info: log.New(ioDiscard{}, "", 0), + Debug: log.New(ioDiscard{}, "", 0), + } + + req := &http.Request{TLS: tlsStateWithEmail("alice@example.com")} + authHeader, err := s.authorizationForRequest(req) + if err != nil { + t.Fatalf("authorizationForRequest() error = %v", err) + } + if got, want := authHeader, `token-sudo token="token-1",user="alice"`; got != want { + t.Fatalf("auth header = %q, want %q", got, want) + } +} + +func TestVerifySudoCapability(t *testing.T) { + t.Parallel() + + var step int + client := &http.Client{Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + switch step { + case 0: + if got, want := r.Header.Get("Authorization"), "token token-1"; got != want { + t.Fatalf("step 0 auth = %q, want %q", got, want) + } + step++ + return testResponse(`{"data":{"currentUser":{"username":"site-admin"}}}`), nil + case 1: + if got, want := r.Header.Get("Authorization"), `token-sudo token="token-1",user="site-admin"`; got != want { + t.Fatalf("step 1 auth = %q, want %q", got, want) + } + step++ + return testResponse(`{"data":{"currentUser":{"username":"site-admin"}}}`), nil + default: + t.Fatalf("unexpected extra call") + return nil, nil + } + })} + + s := &Serve{ + Endpoint: "https://example.com", + AccessToken: "token-1", + httpClient: client, + Info: log.New(ioDiscard{}, "", 0), + Debug: log.New(ioDiscard{}, "", 0), + } + + if err := s.verifySudoCapability(context.Background()); err != nil { + t.Fatalf("verifySudoCapability() error = %v", err) + } + if got, want := step, 2; got != want { + t.Fatalf("steps = %d, want %d", got, want) + } +} + +func TestLoadOrGenerateServerCert_FromFiles(t *testing.T) { + t.Parallel() + + certPath, keyPath := writeServerCertKeyPair(t) + s := &Serve{ + ServerCertPath: certPath, + ServerKeyPath: keyPath, + } + + cert, err := s.loadOrGenerateServerCert() + if err != nil { + t.Fatalf("loadOrGenerateServerCert() error = %v", err) + } + if len(cert.Certificate) == 0 { + t.Fatal("expected loaded certificate chain") + } +} + +func TestEmailFromPeerCertificates(t *testing.T) { + t.Parallel() + + tlsState := tlsStateWithEmail("alice@example.com") + email, err := emailFromPeerCertificates(tlsState) + if err != nil { + t.Fatalf("emailFromPeerCertificates() error = %v", err) + } + if got, want := email, "alice@example.com"; got != want { + t.Fatalf("email = %q, want %q", got, want) + } +} + +func TestNewReverseProxy_RewritesHostAndHeaders(t *testing.T) { + t.Parallel() + + endpointURL, err := url.Parse("https://sourcegraph.test:3443") + if err != nil { + t.Fatalf("parse endpoint URL: %v", err) + } + + s := &Serve{ + AdditionalHeaders: map[string]string{"X-Test": "1"}, + Info: log.New(ioDiscard{}, "", 0), + Debug: log.New(ioDiscard{}, "", 0), + } + proxy := s.newReverseProxy(endpointURL) + + req, err := http.NewRequest("POST", "https://localhost:7777/.api/graphql", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Host = "localhost:7777" + + proxy.Director(req) + + if got, want := req.URL.Host, "sourcegraph.test:3443"; got != want { + t.Fatalf("URL host = %q, want %q", got, want) + } + if got, want := req.Host, "sourcegraph.test:3443"; got != want { + t.Fatalf("host header = %q, want %q", got, want) + } + if got, want := req.Header.Get("X-Test"), "1"; got != want { + t.Fatalf("X-Test header = %q, want %q", got, want) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +func testResponse(body string) *http.Response { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + } +} + +type ioDiscard struct{} + +func (ioDiscard) Write(p []byte) (int, error) { + return len(p), nil +} + +func tlsStateWithEmail(email string) *tls.ConnectionState { + return &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{{EmailAddresses: []string{email}}}, + } +} + +func writeServerCertKeyPair(t *testing.T) (certPath, keyPath string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate key: %v", err) + } + + tpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + dir := t.TempDir() + certPath = filepath.Join(dir, "server-cert.pem") + keyPath = filepath.Join(dir, "server-key.pem") + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + t.Fatalf("write cert: %v", err) + } + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + t.Fatalf("write key: %v", err) + } + + return certPath, keyPath +}