diff --git a/.gitignore b/.gitignore index 6d1b19f..1fee62f 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,8 @@ Thumbs.db # Coverage Results # #################### coverage.txt +.coverdata/ +pubsub-sub-bench-test # Profiler Results # #################### diff --git a/Makefile b/Makefile index 628f2fd..e85c9cc 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,11 @@ build-race: $(GOBUILDRACE) \ -ldflags=$(LDFLAGS) . +build-cover: + @echo "Building binary with coverage instrumentation..." + $(GOBUILD) -cover \ + -ldflags=$(LDFLAGS) . + checkfmt: @echo 'Checking gofmt';\ bash -c "diff -u <(echo -n) <(go fmt .)";\ @@ -52,9 +57,21 @@ fmt: get: $(GOGET) -t -v ./... -test: get +test: get build-cover $(GOFMT) ./... - $(GOTEST) -race -covermode=atomic ./... + @rm -rf .coverdata + @mkdir -p .coverdata + $(GOTEST) -v -race -covermode=atomic ./... -coverage: get test - $(GOTEST) -race -coverprofile=coverage.txt -covermode=atomic . +coverage: get build-cover + $(GOFMT) ./... + @rm -rf .coverdata + @mkdir -p .coverdata + $(GOTEST) -v -race -covermode=atomic . + @if [ -d .coverdata ] && [ -n "$$(ls -A .coverdata 2>/dev/null)" ]; then \ + echo "Converting coverage data..."; \ + go tool covdata textfmt -i=.coverdata -o coverage.txt; \ + else \ + echo "No coverage data found, creating empty coverage file"; \ + touch coverage.txt; \ + fi diff --git a/README.md b/README.md index 9be80f5..62f8f0c 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Usage of ./pubsub-sub-bench: -cpuprofile string write cpu profile to file -data-size int - Payload size in bytes. In RTT mode, timestamp (13 bytes) + space + padding to reach this size. (default 128) + Payload size in bytes. In RTT mode, timestamp (19 bytes) + space + padding to reach this size. (default 128) -host string redis host. (default "127.0.0.1") -json-out-file string diff --git a/go.mod b/go.mod index 7be750d..2e5b4ea 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,16 @@ module github.com/RedisLabs/pubsub-sub-bench -go 1.23.0 +go 1.24.0 toolchain go1.24.1 require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 - github.com/redis/go-redis/v9 v9.0.5 - golang.org/x/time v0.11.0 + github.com/redis/go-redis/v9 v9.16.0 + golang.org/x/time v0.14.0 ) require ( - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect ) diff --git a/go.sum b/go.sum index b1b423a..e7cdf66 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,12 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= -github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= -github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= -github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -27,8 +27,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o= -github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4= +github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -52,8 +52,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= -golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/subscriber.go b/subscriber.go index 2eabefd..fdc376a 100644 --- a/subscriber.go +++ b/subscriber.go @@ -31,7 +31,7 @@ const ( redisTLSCert = "tls_cert" redisTLSKey = "tls_key" redisTLSInsecureSkipVerify = "tls_insecure_skip_verify" - timestampSize = 13 // UnixMilli() produces 13-digit number + timestampSize = 19 // UnixNano() produces 19-digit number ) const Inf = rate.Limit(math.MaxFloat64) @@ -58,7 +58,7 @@ type testResult struct { Addresses []string `json:"Addresses"` } -func publisherRoutine(clientName string, channels []string, mode string, measureRTT bool, verbose bool, dataSize int, ctx context.Context, wg *sync.WaitGroup, client *redis.Client, useLimiter bool, rateLimiter *rate.Limiter) { +func publisherRoutine(clientName string, channels []string, mode string, measureRTT bool, verbose bool, dataSize int, ctx context.Context, wg *sync.WaitGroup, client *redis.Client, useLimiter bool, rateLimiter *rate.Limiter, publishLatencyChannel chan int64, subscriberCountChannel chan int64) { defer wg.Done() if verbose { @@ -75,7 +75,7 @@ func publisherRoutine(clientName string, channels []string, mode string, measure // Pre-generate payload once per goroutine // For RTT mode: we'll use a template with padding that we'll prepend timestamp to - // Timestamp format: 13 bytes (e.g., "1762249648882") + // Timestamp format: 19 bytes (e.g., "1762259663660769761") // Format: " " to reach dataSize var paddingPayload string if measureRTT && dataSize > timestampSize+1 { @@ -101,7 +101,7 @@ func publisherRoutine(clientName string, channels []string, mode string, measure time.Sleep(r.Delay()) } if measureRTT { - now := time.Now().UnixMilli() + now := time.Now().UnixNano() if dataSize > timestampSize+1 { // Format: " " msg = strconv.FormatInt(int64(now), 10) + " " + paddingPayload @@ -112,15 +112,29 @@ func publisherRoutine(clientName string, channels []string, mode string, measure } else { msg = paddingPayload } + + // Measure publish latency + startPublish := time.Now().UnixNano() + var subscriberCount int64 var err error switch mode { case "spublish": - err = client.SPublish(ctx, ch, msg).Err() + subscriberCount, err = client.SPublish(ctx, ch, msg).Result() default: - err = client.Publish(ctx, ch, msg).Err() + subscriberCount, err = client.Publish(ctx, ch, msg).Result() } + publishLatency := time.Now().UnixNano() - startPublish + if err != nil { - log.Printf("Error publishing to channel %s: %v", ch, err) + log.Printf("Publisher %s: error publishing to channel %s: %v", clientName, ch, err) + // Don't send metrics on error, but still count the message attempt + } else { + // Send metrics to channels + publishLatencyChannel <- publishLatency + subscriberCountChannel <- subscriberCount + if verbose { + log.Printf("Published to %s: %d subscribers, latency: %d ns", ch, subscriberCount, publishLatency) + } } atomic.AddUint64(&totalMessages, 1) } @@ -205,19 +219,42 @@ func subscriberRoutine(clientName, mode string, channels []string, verbose bool, // Handle messages msg, err := pubsub.ReceiveMessage(ctx) if err != nil { - // Handle Redis connection errors, e.g., reconnect immediately + // Handle Redis connection errors if err == redis.Nil || err == context.DeadlineExceeded || err == context.Canceled { continue } - panic(err) + // Connection error (EOF, network error, etc.) - attempt to reconnect + log.Printf("Subscriber %s: connection error: %v - attempting to reconnect\n", clientName, err) + + // Close the bad connection + if pubsub != nil { + pubsub.Close() + atomic.AddInt64(&totalSubscribedChannels, int64(-len(channels))) + } + + // Wait a bit before reconnecting + time.Sleep(100 * time.Millisecond) + + // Resubscribe + switch mode { + case "ssubscribe": + pubsub = client.SSubscribe(ctx, channels...) + default: + pubsub = client.Subscribe(ctx, channels...) + } + atomic.AddInt64(&totalSubscribedChannels, int64(len(channels))) + atomic.AddUint64(&totalConnects, 1) + + log.Printf("Subscriber %s: reconnected successfully\n", clientName) + continue } if verbose { log.Println(fmt.Sprintf("received message in channel %s. Message: %s", msg.Channel, msg.Payload)) } if measureRTT { - now := time.Now().UnixMicro() + now := time.Now().UnixNano() // Extract timestamp from payload (format: " " or just "") - // Timestamp is always 13 bytes (UnixMilli) + // Timestamp is always 19 bytes (UnixNano) timestampStr := msg.Payload if len(msg.Payload) > timestampSize { timestampStr = msg.Payload[:timestampSize] @@ -226,7 +263,7 @@ func subscriberRoutine(clientName, mode string, channels []string, verbose bool, rtt := now - ts rttLatencyChannel <- rtt if verbose { - log.Printf("RTT measured: %d ms\n", rtt/1000) + log.Printf("RTT measured: %d ns\n", rtt) } } else { log.Printf("Invalid timestamp in message: %s, err: %v\n", timestampStr, err) @@ -244,7 +281,7 @@ func main() { rps := flag.Int64("rps", 0, "Max rps for publisher mode. If 0 no limit is applied and the DB is stressed up to maximum.") rpsburst := flag.Int64("rps-burst", 0, "Max rps burst for publisher mode. If 0 the allowed burst will be the amount of clients.") password := flag.String("a", "", "Password for Redis Auth.") - dataSize := flag.Int("data-size", 128, "Payload size in bytes. In RTT mode, timestamp (13 bytes) + space + padding to reach this size.") + dataSize := flag.Int("data-size", 128, "Payload size in bytes. In RTT mode, timestamp (19 bytes) + space + padding to reach this size.") mode := flag.String("mode", "subscribe", "Mode: 'subscribe', 'ssubscribe', 'publish', or 'spublish'.") username := flag.String("user", "", "Used to send ACL style 'AUTH username pass'. Needs -a.") subscribers_placement := flag.String("subscribers-placement-per-channel", "dense", "(dense,sparse) dense - Place all subscribers to channel in a specific shard. sparse- spread the subscribers across as many shards possible, in a round-robin manner.") @@ -414,7 +451,9 @@ func main() { } pprof.StartCPUProfile(f) } - rttLatencyChannel := make(chan int64, 100000) // Channel for RTT measurements. buffer of 100K messages to process + rttLatencyChannel := make(chan int64, 1000000) // Channel for RTT measurements. buffer of 1M messages to process + publishLatencyChannel := make(chan int64, 1000000) // Channel for publish latency measurements + subscriberCountChannel := make(chan int64, 1000000) // Channel for subscriber count tracking totalCreatedClients := 0 if strings.Contains(*mode, "publish") { var requestRate = Inf @@ -472,7 +511,7 @@ func main() { } wg.Add(1) - go publisherRoutine(publisherName, channels, *mode, *measureRTT, *verbose, *dataSize, ctx, &wg, client, useRateLimiter, rateLimiter) + go publisherRoutine(publisherName, channels, *mode, *measureRTT, *verbose, *dataSize, ctx, &wg, client, useRateLimiter, rateLimiter, publishLatencyChannel, subscriberCountChannel) atomic.AddInt64(&totalPublishers, 1) atomic.AddUint64(&totalConnects, 1) } @@ -548,7 +587,7 @@ func main() { w := new(tabwriter.Writer) tick := time.NewTicker(time.Duration(*client_update_tick) * time.Second) - closed, start_time, duration, totalMessages, messageRateTs, rttValues := updateCLI(tick, c, total_messages, w, *test_time, *measureRTT, *mode, rttLatencyChannel, *verbose) + closed, start_time, duration, totalMessages, messageRateTs, rttValues, publishLatencyValues, subscriberCountValues := updateCLI(tick, c, total_messages, w, *test_time, *measureRTT, *mode, rttLatencyChannel, publishLatencyChannel, subscriberCountChannel, *verbose) messageRate := float64(totalMessages) / float64(duration.Seconds()) if *cpuprofile != "" { @@ -558,22 +597,60 @@ func main() { fmt.Fprintf(w, "Mode: %s\n", *mode) fmt.Fprintf(w, "Total Duration: %f Seconds\n", duration.Seconds()) fmt.Fprintf(w, "Message Rate: %f msg/sec\n", messageRate) - if *measureRTT && (*mode != "publish" && *mode != "spublish") { - hist := hdrhistogram.New(1, 10_000_000, 3) // 1us to 10s, 3 sig digits - for _, rtt := range rttValues { - _ = hist.RecordValue(rtt) + + if strings.Contains(*mode, "publish") { + // Publisher mode: show publish latency and subscriber count stats + if len(publishLatencyValues) > 0 { + hist := hdrhistogram.New(1, 10_000_000, 3) // 1ns to 10s, 3 sig digits + for _, latency := range publishLatencyValues { + _ = hist.RecordValue(latency) + } + avg := hist.Mean() + p50 := hist.ValueAtQuantile(50.0) + p95 := hist.ValueAtQuantile(95.0) + p99 := hist.ValueAtQuantile(99.0) + p999 := hist.ValueAtQuantile(99.9) + fmt.Fprintf(w, "Avg Publish Latency %.3f ms\n", avg/1000000.0) + fmt.Fprintf(w, "P50 Publish Latency %.3f ms\n", float64(p50)/1000000.0) + fmt.Fprintf(w, "P95 Publish Latency %.3f ms\n", float64(p95)/1000000.0) + fmt.Fprintf(w, "P99 Publish Latency %.3f ms\n", float64(p99)/1000000.0) + fmt.Fprintf(w, "P999 Publish Latency %.3f ms\n", float64(p999)/1000000.0) + } + + if len(subscriberCountValues) > 0 { + hist := hdrhistogram.New(0, 1_000_000, 3) // 0 to 1M subscribers, 3 sig digits + for _, count := range subscriberCountValues { + _ = hist.RecordValue(count) + } + avg := hist.Mean() + p50 := hist.ValueAtQuantile(50.0) + p95 := hist.ValueAtQuantile(95.0) + p99 := hist.ValueAtQuantile(99.0) + p999 := hist.ValueAtQuantile(99.9) + fmt.Fprintf(w, "Avg Subscribers %.1f (per-node in cluster mode)\n", avg) + fmt.Fprintf(w, "P50 Subscribers %d\n", p50) + fmt.Fprintf(w, "P95 Subscribers %d\n", p95) + fmt.Fprintf(w, "P99 Subscribers %d\n", p99) + fmt.Fprintf(w, "P999 Subscribers %d\n", p999) + } + } else if *measureRTT { + // Subscriber mode with RTT measurement + if len(rttValues) > 0 { + hist := hdrhistogram.New(1, 10_000_000, 3) // 1ns to 10s, 3 sig digits + for _, rtt := range rttValues { + _ = hist.RecordValue(rtt) + } + avg := hist.Mean() + p50 := hist.ValueAtQuantile(50.0) + p95 := hist.ValueAtQuantile(95.0) + p99 := hist.ValueAtQuantile(99.0) + p999 := hist.ValueAtQuantile(99.9) + fmt.Fprintf(w, "Avg RTT %.3f ms\n", avg/1000000.0) + fmt.Fprintf(w, "P50 RTT %.3f ms\n", float64(p50)/1000000.0) + fmt.Fprintf(w, "P95 RTT %.3f ms\n", float64(p95)/1000000.0) + fmt.Fprintf(w, "P99 RTT %.3f ms\n", float64(p99)/1000000.0) + fmt.Fprintf(w, "P999 RTT %.3f ms\n", float64(p999)/1000000.0) } - avg := hist.Mean() - p50 := hist.ValueAtQuantile(50.0) - p95 := hist.ValueAtQuantile(95.0) - p99 := hist.ValueAtQuantile(99.0) - p999 := hist.ValueAtQuantile(99.9) - fmt.Fprintf(w, "Avg RTT %.3f ms\n", avg/1000.0) - fmt.Fprintf(w, "P50 RTT %.3f ms\n", float64(p50)/1000.0) - fmt.Fprintf(w, "P95 RTT %.3f ms\n", float64(p95)/1000.0) - fmt.Fprintf(w, "P99 RTT %.3f ms\n", float64(p99)/1000.0) - fmt.Fprintf(w, "P999 RTT %.3f ms\n", float64(p999)/1000.0) - } else { } fmt.Fprintf(w, "#################################################\n") fmt.Fprint(w, "\r\n") @@ -656,8 +733,10 @@ func updateCLI( measureRTT bool, mode string, rttLatencyChannel chan int64, + publishLatencyChannel chan int64, + subscriberCountChannel chan int64, verbose bool, -) (bool, time.Time, time.Duration, uint64, []float64, []int64) { +) (bool, time.Time, time.Duration, uint64, []float64, []int64, []int64, []int64) { start := time.Now() prevTime := time.Now() @@ -666,27 +745,28 @@ func updateCLI( messageRateTs := []float64{} tickRttValues := []int64{} rttValues := []int64{} + tickPublishLatencyValues := []int64{} + publishLatencyValues := []int64{} + tickSubscriberCountValues := []int64{} + subscriberCountValues := []int64{} w.Init(os.Stdout, 25, 0, 1, ' ', tabwriter.AlignRight) // Header - if measureRTT { - fmt.Fprint(w, "Test Time\tTotal Messages\t Message Rate \tConnect Rate \t") + fmt.Fprint(w, "Test Time\tTotal Messages\t Message Rate \tConnect Rate \t") - if strings.Contains(mode, "subscribe") { - fmt.Fprint(w, "Active subscriptions\t") - } else { - fmt.Fprint(w, "Active publishers\t") + if strings.Contains(mode, "subscribe") { + fmt.Fprint(w, "Active subscriptions\t") + if measureRTT { + fmt.Fprint(w, "Avg RTT (ms)\t") } - fmt.Fprint(w, "Avg RTT (ms)\t\n") } else { - fmt.Fprint(w, "Test Time\tTotal Messages\t Message Rate \tConnect Rate \t") - if strings.Contains(mode, "subscribe") { - fmt.Fprint(w, "Active subscriptions\t\n") - } else { - fmt.Fprint(w, "Active publishers\t\n") - } + // Publisher mode + fmt.Fprint(w, "Active publishers\t") + fmt.Fprint(w, "Pub Latency (ms)\t") + fmt.Fprint(w, "Avg Subs per channel per node\t") } + fmt.Fprint(w, "\n") w.Flush() // Main loop @@ -696,6 +776,14 @@ func updateCLI( rttValues = append(rttValues, rtt) tickRttValues = append(tickRttValues, rtt) + case publishLatency := <-publishLatencyChannel: + publishLatencyValues = append(publishLatencyValues, publishLatency) + tickPublishLatencyValues = append(tickPublishLatencyValues, publishLatency) + + case subscriberCount := <-subscriberCountChannel: + subscriberCountValues = append(subscriberCountValues, subscriberCount) + tickSubscriberCountValues = append(tickSubscriberCountValues, subscriberCount) + case <-tick.C: now := time.Now() took := now.Sub(prevTime) @@ -725,7 +813,7 @@ func updateCLI( if verbose { fmt.Printf("[DEBUG] Test time reached! Stopping after %.2f seconds\n", elapsed.Seconds()) } - return true, start, time.Since(start), totalMessages, messageRateTs, rttValues + return true, start, time.Since(start), totalMessages, messageRateTs, rttValues, publishLatencyValues, subscriberCountValues } } @@ -738,14 +826,43 @@ func updateCLI( fmt.Fprintf(w, "%d\t", atomic.LoadInt64(&totalPublishers)) } - if measureRTT { + // For publisher mode, show publish latency instead of RTT + if strings.Contains(mode, "publish") { + var avgPublishLatencyMs float64 + if len(tickPublishLatencyValues) > 0 { + var total int64 + for _, v := range tickPublishLatencyValues { + total += v + } + avgPublishLatencyMs = float64(total) / float64(len(tickPublishLatencyValues)) / 1000000.0 + tickPublishLatencyValues = tickPublishLatencyValues[:0] + fmt.Fprintf(w, "%.3f\t", avgPublishLatencyMs) + } else { + fmt.Fprintf(w, "--\t") + } + + // Show average subscriber count + var avgSubscriberCount float64 + if len(tickSubscriberCountValues) > 0 { + var total int64 + for _, v := range tickSubscriberCountValues { + total += v + } + avgSubscriberCount = float64(total) / float64(len(tickSubscriberCountValues)) + tickSubscriberCountValues = tickSubscriberCountValues[:0] + fmt.Fprintf(w, "%.1f\t", avgSubscriberCount) + } else { + fmt.Fprintf(w, "--\t") + } + } else if measureRTT { + // For subscriber mode with RTT measurement var avgRTTms float64 if len(tickRttValues) > 0 { var total int64 for _, v := range tickRttValues { total += v } - avgRTTms = float64(total) / float64(len(tickRttValues)) / 1000.0 + avgRTTms = float64(total) / float64(len(tickRttValues)) / 1000000.0 tickRttValues = tickRttValues[:0] fmt.Fprintf(w, "%.3f\t", avgRTTms) } else { @@ -757,12 +874,12 @@ func updateCLI( w.Flush() if message_limit > 0 && totalMessages >= uint64(message_limit) { - return true, start, time.Since(start), totalMessages, messageRateTs, rttValues + return true, start, time.Since(start), totalMessages, messageRateTs, rttValues, publishLatencyValues, subscriberCountValues } case <-c: fmt.Println("received Ctrl-c - shutting down") - return true, start, time.Since(start), totalMessages, messageRateTs, rttValues + return true, start, time.Since(start), totalMessages, messageRateTs, rttValues, publishLatencyValues, subscriberCountValues } } } diff --git a/subscriber_test.go b/subscriber_test.go new file mode 100644 index 0000000..477fc50 --- /dev/null +++ b/subscriber_test.go @@ -0,0 +1,444 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "syscall" + "testing" + "time" + + "github.com/redis/go-redis/v9" +) + +// TestMain sets up the test environment +func TestMain(m *testing.M) { + // Create coverage directory + coverDir := ".coverdata" + os.MkdirAll(coverDir, 0755) + + // Check if binary exists (should be built by make) + if _, err := os.Stat("./pubsub-sub-bench"); err != nil { + fmt.Fprintf(os.Stderr, "Binary ./pubsub-sub-bench not found. Run 'make build' first.\n") + os.Exit(1) + } + + // Run tests + exitCode := m.Run() + + os.Exit(exitCode) +} + +func getBinaryPath() string { + // Use the binary built by make + return "./pubsub-sub-bench" +} + +func getTestConnectionDetails() (string, string) { + value, exists := os.LookupEnv("REDIS_TEST_HOST") + host := "127.0.0.1" + port := "6379" + password := "" + valuePassword, existsPassword := os.LookupEnv("REDIS_TEST_PASSWORD") + if exists && value != "" { + host = value + } + valuePort, existsPort := os.LookupEnv("REDIS_TEST_PORT") + if existsPort && valuePort != "" { + port = valuePort + } + if existsPassword && valuePassword != "" { + password = valuePassword + } + return host + ":" + port, password +} + +func TestSubscriberMode(t *testing.T) { + var tests = []struct { + name string + wantExitCode int + args []string + timeout time.Duration + }{ + { + "simple subscribe", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "subscribe", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + }, + 2 * time.Second, // Just verify it can connect and subscribe + }, + { + "ssubscribe mode", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "ssubscribe", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + }, + 2 * time.Second, + }, + { + "subscribe with RTT", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "subscribe", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--measure-rtt-latency", + }, + 2 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command(getBinaryPath(), tt.args...) + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, "GOCOVERDIR=.coverdata") + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + // Start the command + err := cmd.Start() + if err != nil { + t.Fatalf("Failed to start command: %v", err) + } + + // Wait for timeout, then kill + time.Sleep(tt.timeout) + cmd.Process.Signal(os.Interrupt) + + // Wait for process to finish + err = cmd.Wait() + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + exitCode = ws.ExitStatus() + } + } + + if exitCode != tt.wantExitCode { + t.Errorf("got exit code = %v, want %v\nOutput: %s", exitCode, tt.wantExitCode, out.String()) + } + }) + } +} + +func TestPublisherMode(t *testing.T) { + hostPort, password := getTestConnectionDetails() + + // Create a Redis client for verification + client := redis.NewClient(&redis.Options{ + Addr: hostPort, + Password: password, + DB: 0, + }) + defer client.Close() + + ctx := context.Background() + + // Test connection + if err := client.Ping(ctx).Err(); err != nil { + t.Skipf("Redis not available at %s: %v", hostPort, err) + } + + var tests = []struct { + name string + wantExitCode int + args []string + }{ + { + "simple publish", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "publish", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "1", + "--data-size", "128", + }, + }, + { + "publish with rate limit", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "publish", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "1", + "--rps", "100", + "--data-size", "256", + }, + }, + { + "publish with RTT measurement", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "publish", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "1", + "--measure-rtt-latency", + "--data-size", "512", + }, + }, + { + "spublish mode", + 0, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "spublish", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "1", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command(getBinaryPath(), tt.args...) + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, "GOCOVERDIR=.coverdata") + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + // Run the command and wait for it to complete (--test-time will make it exit) + err := cmd.Run() + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + exitCode = ws.ExitStatus() + } + } + + if exitCode != tt.wantExitCode { + t.Errorf("got exit code = %v, want %v\nOutput: %s", exitCode, tt.wantExitCode, out.String()) + } + }) + } +} + +func TestPublisherSubscriberIntegration(t *testing.T) { + hostPort, password := getTestConnectionDetails() + + // Create a Redis client for verification + client := redis.NewClient(&redis.Options{ + Addr: hostPort, + Password: password, + DB: 0, + }) + defer client.Close() + + ctx := context.Background() + + // Test connection + if err := client.Ping(ctx).Err(); err != nil { + t.Skipf("Redis not available at %s: %v", hostPort, err) + } + + t.Run("publisher and subscriber together", func(t *testing.T) { + // Start subscriber first + subCmd := exec.Command(getBinaryPath(), + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "subscribe", + "--clients", "2", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "2", + ) + subCmd.Env = os.Environ() + subCmd.Env = append(subCmd.Env, "GOCOVERDIR=.coverdata") + var subOut bytes.Buffer + subCmd.Stdout = &subOut + subCmd.Stderr = &subOut + + err := subCmd.Start() + if err != nil { + t.Fatalf("Failed to start subscriber: %v", err) + } + + // Give subscriber time to connect + time.Sleep(500 * time.Millisecond) + + // Start publisher (will run for 1 second and exit) + pubCmd := exec.Command(getBinaryPath(), + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "publish", + "--clients", "1", + "--channel-minimum", "1", + "--channel-maximum", "2", + "--test-time", "1", + "--rps", "50", + "--data-size", "128", + ) + pubCmd.Env = os.Environ() + pubCmd.Env = append(pubCmd.Env, "GOCOVERDIR=.coverdata") + var pubOut bytes.Buffer + pubCmd.Stdout = &pubOut + pubCmd.Stderr = &pubOut + + // Run publisher and wait for it to complete + err = pubCmd.Run() + pubExitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + pubExitCode = ws.ExitStatus() + } + } + + // Stop subscriber + time.Sleep(500 * time.Millisecond) + subCmd.Process.Signal(os.Interrupt) + err = subCmd.Wait() + subExitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + subExitCode = ws.ExitStatus() + } + } + + if pubExitCode != 0 { + t.Errorf("Publisher exit code = %v, want 0\nOutput: %s", pubExitCode, pubOut.String()) + } + if subExitCode != 0 { + t.Errorf("Subscriber exit code = %v, want 0\nOutput: %s", subExitCode, subOut.String()) + } + + t.Logf("Subscriber output:\n%s", subOut.String()) + t.Logf("Publisher output:\n%s", pubOut.String()) + }) +} + +func TestErrorCases(t *testing.T) { + var tests = []struct { + name string + wantExitCode int + args []string + }{ + { + "invalid mode", + 1, + []string{ + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "invalid_mode", + }, + }, + { + "invalid port", + 1, + []string{ + "--host", "127.0.0.1", + "--port", "99999", + "--mode", "subscribe", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command(getBinaryPath(), tt.args...) + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, "GOCOVERDIR=.coverdata") + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + err := cmd.Run() + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + exitCode = ws.ExitStatus() + } + } + + // For error cases, we expect non-zero exit code + if tt.wantExitCode != 0 && exitCode == 0 { + t.Errorf("expected non-zero exit code, got 0\nOutput: %s", out.String()) + } else if tt.wantExitCode == 0 && exitCode != 0 { + t.Errorf("expected exit code 0, got %d\nOutput: %s", exitCode, out.String()) + } + }) + } +} + +func TestDataSizeVariations(t *testing.T) { + var tests = []struct { + name string + dataSize string + wantExitCode int + }{ + {"small payload 64 bytes", "64", 0}, + {"medium payload 512 bytes", "512", 0}, + {"large payload 4096 bytes", "4096", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := exec.Command(getBinaryPath(), + "--host", "127.0.0.1", + "--port", "6379", + "--mode", "publish", + "--clients", "1", + "--channel-minimum", "1", + "--channel-maximum", "1", + "--test-time", "1", + "--data-size", tt.dataSize, + ) + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, "GOCOVERDIR=.coverdata") + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &out + + // Run the command and wait for it to complete (--test-time will make it exit) + err := cmd.Run() + exitCode := 0 + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ws := exitError.Sys().(syscall.WaitStatus) + exitCode = ws.ExitStatus() + } + } + + if exitCode != tt.wantExitCode { + t.Errorf("got exit code = %v, want %v\nOutput: %s", exitCode, tt.wantExitCode, out.String()) + } + }) + } +}