diff --git a/tailscale.go b/tailscale.go index 13d8ef9..605506d 100644 --- a/tailscale.go +++ b/tailscale.go @@ -40,15 +40,10 @@ type server struct { lastErr string } -func getServer(sd C.int) (*server, error) { +func getServer(sd C.int) *server { servers.mu.Lock() defer servers.mu.Unlock() - - s := servers.m[sd] - if s == nil { - return nil, fmt.Errorf("tsnetc: unknown server descriptors %d (of %d servers)", sd, len(servers.m)) - } - return s, nil + return servers.m[sd] } // listeners tracks all the tsnet_listener objects allocated via tsnet_listen. @@ -107,20 +102,20 @@ func TsnetNewServer() C.int { //export TsnetStart func TsnetStart(sd C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } return s.recErr(s.s.Start()) } //export TsnetUp func TsnetUp(sd C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } - _, err = s.s.Up(context.Background()) // cancellation is via TsnetClose + _, err := s.s.Up(context.Background()) // cancellation is via TsnetClose return s.recErr(err) } @@ -169,7 +164,7 @@ func TsnetGetIps(sd C.int, buf *C.char, buflen C.size_t) C.int { ip4, ip6 := s.s.TailscaleIPs() joined := strings.Join([]string{ip4.String(), ip6.String()}, ",") n := copy(out, joined) - if len(out) < len(joined)-1 { + if n >= len(out) { out[len(out)-1] = '\x00' // always NUL-terminate return C.ERANGE } @@ -195,7 +190,7 @@ func TsnetErrmsg(sd C.int, buf *C.char, buflen C.size_t) C.int { return C.EBADF } n := copy(out, s.lastErr) - if len(out) < len(s.lastErr)-1 { + if n >= len(out) { out[len(out)-1] = '\x00' // always NUL-terminate return C.ERANGE } @@ -205,9 +200,9 @@ func TsnetErrmsg(sd C.int, buf *C.char, buflen C.size_t) C.int { //export TsnetListen func TsnetListen(sd C.int, network, addr *C.char, listenerOut *C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } ln, err := s.s.Listen(C.GoString(network), C.GoString(addr)) @@ -380,7 +375,7 @@ func TsnetGetRemoteAddr(listener C.int, conn C.int, buf *C.char, buflen C.size_t ip := extractIP(addr.String()) n := copy(out, ip) - if len(out) < len(ip)-1 { + if n >= len(out) { out[len(out)-1] = '\x00' // always NUL-terminate return C.ERANGE } @@ -397,15 +392,15 @@ func extractIP(ipWithPort string) string { //export TsnetDial func TsnetDial(sd C.int, network, addr *C.char, connOut *C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } netConn, err := s.s.Dial(context.Background(), C.GoString(network), C.GoString(addr)) if err != nil { return s.recErr(err) } - if newConn(s, netConn, connOut); err != nil { + if err := newConn(s, netConn, connOut); err != nil { return s.recErr(err) } return 0 @@ -413,9 +408,9 @@ func TsnetDial(sd C.int, network, addr *C.char, connOut *C.int) C.int { //export TsnetSetDir func TsnetSetDir(sd C.int, str *C.char) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } s.s.Dir = C.GoString(str) return 0 @@ -423,9 +418,9 @@ func TsnetSetDir(sd C.int, str *C.char) C.int { //export TsnetSetHostname func TsnetSetHostname(sd C.int, str *C.char) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } s.s.Hostname = C.GoString(str) return 0 @@ -433,9 +428,9 @@ func TsnetSetHostname(sd C.int, str *C.char) C.int { //export TsnetSetAuthKey func TsnetSetAuthKey(sd C.int, str *C.char) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } s.s.AuthKey = C.GoString(str) return 0 @@ -443,9 +438,9 @@ func TsnetSetAuthKey(sd C.int, str *C.char) C.int { //export TsnetSetControlURL func TsnetSetControlURL(sd C.int, str *C.char) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } s.s.ControlURL = C.GoString(str) return 0 @@ -453,9 +448,9 @@ func TsnetSetControlURL(sd C.int, str *C.char) C.int { //export TsnetSetEphemeral func TsnetSetEphemeral(sd C.int, e int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } if e == 0 { s.s.Ephemeral = false @@ -467,9 +462,9 @@ func TsnetSetEphemeral(sd C.int, e int) C.int { //export TsnetSetLogFD func TsnetSetLogFD(sd, fd C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } if fd == -1 { s.s.Logf = logger.Discard @@ -501,9 +496,9 @@ func TsnetLoopback(sd C.int, addrOut *C.char, addrLen C.size_t, proxyOut *C.char *localOut = '\x00' *proxyOut = '\x00' - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } addr, proxyCred, localAPICred, err := s.s.Loopback() if err != nil { @@ -515,11 +510,13 @@ func TsnetLoopback(sd C.int, addrOut *C.char, addrLen C.size_t, proxyOut *C.char if len(localAPICred) != 32 { return s.recErr(fmt.Errorf("libtailscale: len(localAPICred)=%d, want 32", len(localAPICred))) } - if len(addr)+1 > int(addrLen) { - return s.recErr(fmt.Errorf("libtailscale: loopback addr of %d bytes is too long for addrlen %d", len(addr), addrLen)) - } + out := unsafe.Slice((*byte)(unsafe.Pointer(addrOut)), addrLen) n := copy(out, addr) + if n >= len(out) { + out[len(out)-1] = '\x00' // always NUL-terminate + return C.ERANGE + } out[n] = '\x00' // proxyOut and localOut are non-nil and 33 bytes long because @@ -536,9 +533,9 @@ func TsnetLoopback(sd C.int, addrOut *C.char, addrLen C.size_t, proxyOut *C.char //export TsnetEnableFunnelToLocalhostPlaintextHttp1 func TsnetEnableFunnelToLocalhostPlaintextHttp1(sd C.int, localhostPort C.int) C.int { - s, err := getServer(sd) - if err != nil { - return s.recErr(err) + s := getServer(sd) + if s == nil { + return C.EBADF } ctx := context.Background()