Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 160 additions & 9 deletions protocol/http/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,45 @@ package http

import (
std_bufio "bufio"
"bytes"
"context"
"encoding/base64"
"io"
"net"
"net/http"
"sort"
"strings"
"time"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)

const defaultProxyAuthRetryTimeout = 5 * time.Second

type HTTPServerOptions struct {
ProxyAuthRetryTimeout time.Duration
Logger logger.ContextLogger
}

func normalizeHTTPServerOptions(options HTTPServerOptions) HTTPServerOptions {
if options.ProxyAuthRetryTimeout <= 0 {
options.ProxyAuthRetryTimeout = defaultProxyAuthRetryTimeout
}
if options.Logger == nil {
options.Logger = logger.NOP()
}
return options
}

func HandleConnectionEx(
ctx context.Context,
conn net.Conn,
Expand All @@ -29,11 +50,36 @@ func HandleConnectionEx(
source M.Socksaddr,
onClose N.CloseHandlerFunc,
) error {
return HandleConnectionExWithOptions(ctx, conn, reader, authenticator, handler, source, onClose, HTTPServerOptions{})
}

func HandleConnectionExWithOptions(
ctx context.Context,
conn net.Conn,
reader *std_bufio.Reader,
authenticator *auth.Authenticator,
handler N.TCPConnectionHandlerEx,
source M.Socksaddr,
onClose N.CloseHandlerFunc,
options HTTPServerOptions,
) error {
options = normalizeHTTPServerOptions(options)
missingProxyAuthorizationRetried := false
waitingRetryProxyAuthentication := false
for {
request, err := ReadRequest(reader)
if err != nil {
return E.Cause(err, "read http request")
}
printRequestHeaders(ctx, options.Logger, request)
if waitingRetryProxyAuthentication {
waitingRetryProxyAuthentication = false
err = conn.SetReadDeadline(time.Time{})
if err != nil {
return E.Cause(err, "clear retry-proxy-authentication timeout")
}
}
retryMissingProxyAuthorization := shouldRetryMissingProxyAuthorization(request)
if authenticator != nil {
var (
username string
Expand All @@ -55,18 +101,29 @@ func HandleConnectionEx(
}
if !authOk {
// Since no one else is using the library, use a fixed realm until rewritten
err = responseWith(
proxyAuthRequiredResponse := responseWith(
request, http.StatusProxyAuthRequired,
"Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`,
).Write(conn)
"Proxy-Authenticate", `Basic realm="sing-box", charset="UTF-8"`,
)
printResponseHeaders(ctx, options.Logger, proxyAuthRequiredResponse)
err = writeResponseBuffered(conn, proxyAuthRequiredResponse)
if err != nil {
return err
return E.Cause(err, "write proxy authentication required response")
}
if username != "" {
return E.New("http: authentication failed, username=", username, ", password=", password)
} else if authorization != "" {
return E.New("http: authentication failed, Proxy-Authorization=", authorization)
} else {
if retryMissingProxyAuthorization && !missingProxyAuthorizationRetried {
missingProxyAuthorizationRetried = true
err = conn.SetReadDeadline(time.Now().Add(options.ProxyAuthRetryTimeout))
if err != nil {
return E.Cause(err, "set retry-proxy-authentication timeout")
}
waitingRetryProxyAuthentication = true
continue
}
return E.New("http: authentication failed, no Proxy-Authorization header")
}
}
Expand Down Expand Up @@ -133,7 +190,7 @@ func HandleConnectionEx(
}
return bufio.CopyConn(ctx, conn, serverConn)
} else {
err = handleHTTPConnection(ctx, handler, conn, request, source)
err = handleHTTPConnection(ctx, handler, conn, request, source, options.Logger)
if err != nil {
return err
}
Expand All @@ -145,9 +202,11 @@ func handleHTTPConnection(
ctx context.Context,
handler N.TCPConnectionHandlerEx,
conn net.Conn,
request *http.Request, source M.Socksaddr,
request *http.Request,
source M.Socksaddr,
contextLogger logger.ContextLogger,
) error {
keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
keepAlive := isProxyKeepAlive(request)
request.RequestURI = ""

removeHopByHopHeaders(request.Header)
Expand All @@ -160,7 +219,9 @@ func handleHTTPConnection(
}

if request.URL.Scheme == "" || request.URL.Host == "" {
return responseWith(request, http.StatusBadRequest).Write(conn)
badRequestResponse := responseWith(request, http.StatusBadRequest)
printResponseHeaders(ctx, contextLogger, badRequestResponse)
return badRequestResponse.Write(conn)
}

var innerErr common.TypedValue[error]
Expand All @@ -186,7 +247,9 @@ func handleHTTPConnection(
response, err := httpClient.Do(request.WithContext(requestCtx))
if err != nil {
cancel()
return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
badGatewayResponse := responseWith(request, http.StatusBadGateway)
printResponseHeaders(ctx, contextLogger, badGatewayResponse)
return E.Errors(innerErr.Load(), err, badGatewayResponse.Write(conn))
}

removeHopByHopHeaders(response.Header)
Expand All @@ -198,6 +261,7 @@ func handleHTTPConnection(
}

response.Close = !keepAlive
printResponseHeaders(ctx, contextLogger, response)

err = response.Write(conn)
if err != nil {
Expand All @@ -212,6 +276,77 @@ func handleHTTPConnection(
return nil
}

func isProxyKeepAlive(request *http.Request) bool {
connection := request.Header.Get("Connection")
proxyConnection := request.Header.Get("Proxy-Connection")

if request.ProtoMajor > 1 || (request.ProtoMajor == 1 && request.ProtoMinor >= 1) {
// HTTP/1.1+ connections are persistent unless explicitly closed.
return !hasHeaderToken(connection, "close") && !hasHeaderToken(proxyConnection, "close")
}

if request.ProtoMajor == 1 && request.ProtoMinor == 0 {
// HTTP/1.0 defaults to close unless keep-alive is requested.
if hasHeaderToken(connection, "close") || hasHeaderToken(proxyConnection, "close") {
return false
}
return hasHeaderToken(connection, "keep-alive") || hasHeaderToken(proxyConnection, "keep-alive")
}

return false
}

func hasHeaderToken(headerValue string, token string) bool {
for _, h := range strings.Split(headerValue, ",") {
if strings.EqualFold(strings.TrimSpace(h), token) {
return true
}
}
return false
}

func shouldRetryMissingProxyAuthorization(request *http.Request) bool {
return isProxyKeepAlive(request) &&
!request.Close &&
!hasHeaderToken(request.Header.Get("Connection"), "upgrade") &&
request.ContentLength == 0 &&
len(request.TransferEncoding) == 0
}

func printRequestHeaders(ctx context.Context, contextLogger logger.ContextLogger, request *http.Request) {
contextLogger.TraceContext(ctx, "request protocol: ", request.Proto)
printHeaders(ctx, contextLogger, "request", request.Header)
}

func printResponseHeaders(ctx context.Context, contextLogger logger.ContextLogger, response *http.Response) {
contextLogger.TraceContext(ctx, "response: protocol=", response.Proto, " status=", response.StatusCode)
printHeaders(ctx, contextLogger, "response", response.Header)
}

func printHeaders(ctx context.Context, contextLogger logger.ContextLogger, kind string, header http.Header) {
keys := make([]string, 0, len(header))
for key := range header {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
redacted := shouldRedactHeaderValue(key)
for _, value := range header[key] {
if redacted {
value = "[redacted]"
}
contextLogger.TraceContext(ctx, kind, " header: ", key, ": ", value)
}
}
}

func shouldRedactHeaderValue(headerKey string) bool {
return strings.EqualFold(headerKey, "Authorization") ||
strings.EqualFold(headerKey, "Proxy-Authorization") ||
strings.EqualFold(headerKey, "Cookie") ||
strings.EqualFold(headerKey, "Set-Cookie")
}

func removeHopByHopHeaders(header http.Header) {
// Strip hop-by-hop header based on RFC:
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
Expand Down Expand Up @@ -252,6 +387,22 @@ func removeExtraHTTPHostPort(req *http.Request) {
req.URL.Host = host
}

func writeResponseBuffered(conn net.Conn, response *http.Response) error {
var responseBuffer bytes.Buffer
err := response.Write(&responseBuffer)
if err != nil {
return err
}
n, err := conn.Write(responseBuffer.Bytes())
if err != nil {
return err
}
if n != responseBuffer.Len() {
return io.ErrShortWrite
}
return nil
}

func responseWith(request *http.Request, statusCode int, headers ...string) *http.Response {
var header http.Header
if len(headers) > 0 {
Expand Down
Loading