Skip to content

Commit aa9edb8

Browse files
Signal Handling with Graceful Shutdown with metrics (#325)
* Signal Handling with Graceful Shutdown with metrics. * Fixes per PR review: Interrupt logic is now encapsulated in the client_group class where it belongs.
1 parent 09052d3 commit aa9edb8

File tree

6 files changed

+161
-0
lines changed

6 files changed

+161
-0
lines changed

client.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,32 @@ void client_group::run(void)
646646
event_base_dispatch(m_base);
647647
}
648648

649+
void client_group::interrupt(void)
650+
{
651+
// Mark all clients as interrupted
652+
set_all_clients_interrupted();
653+
// Break the event loop to stop processing
654+
event_base_loopbreak(m_base);
655+
// Set end time for all clients as close as possible to the loop break
656+
finalize_all_clients();
657+
}
658+
659+
void client_group::finalize_all_clients(void)
660+
{
661+
for (std::vector<client*>::iterator i = m_clients.begin(); i != m_clients.end(); i++) {
662+
client* c = *i;
663+
c->set_end_time();
664+
}
665+
}
666+
667+
void client_group::set_all_clients_interrupted(void)
668+
{
669+
for (std::vector<client*>::iterator i = m_clients.begin(); i != m_clients.end(); i++) {
670+
client* c = *i;
671+
c->get_stats()->set_interrupted(true);
672+
}
673+
}
674+
649675
unsigned long int client_group::get_total_bytes(void)
650676
{
651677
unsigned long int total_bytes = 0;

client.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ class client_group {
210210
int create_clients(int count);
211211
int prepare(void);
212212
void run(void);
213+
void interrupt(void);
214+
void finalize_all_clients(void);
215+
void set_all_clients_interrupted(void);
213216

214217
void write_client_stats(const char *prefix);
215218

memtier_benchmark.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <errno.h>
3333
#include <sys/time.h>
3434
#include <sys/resource.h>
35+
#include <signal.h>
3536

3637
#ifdef USE_TLS
3738
#include <openssl/crypto.h>
@@ -65,6 +66,16 @@
6566

6667

6768
static int log_level = 0;
69+
70+
// Global flag for signal handling
71+
static volatile sig_atomic_t g_interrupted = 0;
72+
73+
// Signal handler for Ctrl+C
74+
static void sigint_handler(int signum)
75+
{
76+
(void)signum; // unused parameter
77+
g_interrupted = 1;
78+
}
6879
void benchmark_log_file_line(int level, const char *filename, unsigned int line, const char *fmt, ...)
6980
{
7081
if (level > log_level)
@@ -1329,6 +1340,25 @@ run_stats run_benchmark(int run_id, benchmark_config* cfg, object_generator* obj
13291340
active_threads = 0;
13301341
sleep(1);
13311342

1343+
// Check for Ctrl+C interrupt
1344+
if (g_interrupted) {
1345+
// Calculate elapsed time before interrupting
1346+
unsigned long int elapsed_duration = 0;
1347+
unsigned int thread_counter = 0;
1348+
for (std::vector<cg_thread*>::iterator i = threads.begin(); i != threads.end(); i++) {
1349+
thread_counter++;
1350+
float factor = ((float)(thread_counter - 1) / thread_counter);
1351+
elapsed_duration = factor * elapsed_duration + (float)(*i)->m_cg->get_duration_usec() / thread_counter;
1352+
}
1353+
fprintf(stderr, "\n[RUN #%u] Interrupted by user (Ctrl+C) after %.1f secs, stopping threads...\n",
1354+
run_id, (float)elapsed_duration / 1000000);
1355+
// Interrupt all threads (marks clients as interrupted, breaks event loops, and finalizes stats)
1356+
for (std::vector<cg_thread*>::iterator i = threads.begin(); i != threads.end(); i++) {
1357+
(*i)->m_cg->interrupt();
1358+
}
1359+
break;
1360+
}
1361+
13321362
unsigned long int total_ops = 0;
13331363
unsigned long int total_bytes = 0;
13341364
unsigned long int duration = 0;
@@ -1496,6 +1526,9 @@ static void cleanup_openssl(void)
14961526

14971527
int main(int argc, char *argv[])
14981528
{
1529+
// Install signal handler for Ctrl+C
1530+
signal(SIGINT, sigint_handler);
1531+
14991532
benchmark_config cfg = benchmark_config();
15001533
cfg.arbitrary_commands = new arbitrary_command_list();
15011534

run_stats.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ inline timeval timeval_factorial_average(timeval a, timeval b, unsigned int weig
112112

113113
run_stats::run_stats(benchmark_config *config) :
114114
m_config(config),
115+
m_interrupted(false),
115116
m_totals(),
116117
m_cur_stats(0)
117118
{
@@ -792,6 +793,11 @@ void run_stats::merge(const run_stats& other, int iteration)
792793
m_start_time = timeval_factorial_average( m_start_time, other.m_start_time, iteration );
793794
m_end_time = timeval_factorial_average( m_end_time, other.m_end_time, iteration );
794795

796+
// If any run was interrupted, mark the merged result as interrupted
797+
if (other.m_interrupted) {
798+
m_interrupted = true;
799+
}
800+
795801
// aggregate the one_second_stats vectors. this is not efficient
796802
// but it's not really important (small numbers, not realtime)
797803
for (std::list<one_second_stats>::const_iterator other_i = other.m_stats.begin();
@@ -1221,6 +1227,7 @@ void run_stats::print_json(json_handler *jsonhandler, arbitrary_command_list& co
12211227
jsonhandler->write_obj("Finish time","%lld", end_time_ms);
12221228
jsonhandler->write_obj("Total duration","%lld", end_time_ms-start_time_ms);
12231229
jsonhandler->write_obj("Time unit","\"%s\"","MILLISECONDS");
1230+
jsonhandler->write_obj("Interrupted","\"%s\"", m_interrupted ? "true" : "false");
12241231
jsonhandler->close_nesting();
12251232
}
12261233
std::vector<unsigned int> timestamps = get_one_sec_cmd_stats_timestamp();

run_stats.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class run_stats {
9393

9494
struct timeval m_start_time;
9595
struct timeval m_end_time;
96+
bool m_interrupted;
9697

9798
totals m_totals;
9899

@@ -122,6 +123,8 @@ class run_stats {
122123
void setup_arbitrary_commands(size_t n_arbitrary_commands);
123124
void set_start_time(struct timeval* start_time);
124125
void set_end_time(struct timeval* end_time);
126+
void set_interrupted(bool interrupted) { m_interrupted = interrupted; }
127+
bool get_interrupted() const { return m_interrupted; }
125128

126129
void update_get_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency, unsigned int hits, unsigned int misses);
127130
void update_set_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency);

tests/tests_oss_simple_flow.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import tempfile
22
import json
3+
import time
4+
import signal
5+
import subprocess
6+
import os
37
from include import *
48
from mb import Benchmark, RunConfig
59

@@ -907,3 +911,88 @@ def test_uri_invalid_database(env):
907911
# benchmark.run() should return False for invalid database number
908912
memtier_ok = benchmark.run()
909913
env.assertFalse(memtier_ok)
914+
915+
916+
def test_interrupt_signal_handling(env):
917+
"""Test that Ctrl+C (SIGINT) properly stops the benchmark and outputs correct statistics"""
918+
# Use a large number of requests so the test doesn't finish before we interrupt it
919+
benchmark_specs = {"name": env.testName, "args": ['--requests=1000000', '--hide-histogram']}
920+
addTLSArgs(benchmark_specs, env)
921+
config = get_default_memtier_config(threads=4, clients=50, requests=1000000)
922+
master_nodes_list = env.getMasterNodesList()
923+
924+
add_required_env_arguments(benchmark_specs, config, env, master_nodes_list)
925+
926+
# Create a temporary directory
927+
test_dir = tempfile.mkdtemp()
928+
config = RunConfig(test_dir, env.testName, config, {})
929+
ensure_clean_benchmark_folder(config.results_dir)
930+
931+
benchmark = Benchmark.from_json(config, benchmark_specs)
932+
933+
# Start the benchmark process manually so we can send SIGINT
934+
import logging
935+
logging.debug(' Command: %s', ' '.join(benchmark.args))
936+
937+
stderr_file = open(os.path.join(config.results_dir, 'mb.stderr'), 'wb')
938+
process = subprocess.Popen(
939+
stdin=None, stdout=subprocess.PIPE, stderr=stderr_file,
940+
executable=benchmark.binary, args=benchmark.args)
941+
942+
# Wait 3 seconds then send SIGINT
943+
time.sleep(3)
944+
process.send_signal(signal.SIGINT)
945+
946+
# Wait for process to finish
947+
_stdout, _ = process.communicate()
948+
stderr_file.close()
949+
950+
# Write stdout to file
951+
benchmark.write_file('mb.stdout', _stdout)
952+
953+
# Read stderr to check for interrupt message
954+
with open(os.path.join(config.results_dir, 'mb.stderr'), 'r') as stderr:
955+
stderr_content = stderr.read()
956+
# Check that the interrupt message is present and shows elapsed time
957+
env.assertTrue("Interrupted by user (Ctrl+C) after" in stderr_content)
958+
env.assertTrue("secs, stopping threads..." in stderr_content)
959+
960+
# Check JSON output
961+
json_filename = '{0}/mb.json'.format(config.results_dir)
962+
env.assertTrue(os.path.isfile(json_filename))
963+
964+
with open(json_filename) as results_json:
965+
results_dict = json.load(results_json)
966+
967+
# Check that Runtime section exists and has Interrupted flag
968+
env.assertTrue("ALL STATS" in results_dict)
969+
env.assertTrue("Runtime" in results_dict["ALL STATS"])
970+
runtime = results_dict["ALL STATS"]["Runtime"]
971+
972+
# Verify interrupted flag is set to "true"
973+
env.assertTrue("Interrupted" in runtime)
974+
env.assertEqual(runtime["Interrupted"], "true")
975+
976+
# Verify duration is reasonable (should be around 3 seconds, give or take)
977+
env.assertTrue("Total duration" in runtime)
978+
duration_ms = runtime["Total duration"]
979+
env.assertTrue(duration_ms >= 2000) # At least 2 seconds
980+
env.assertTrue(duration_ms <= 5000) # At most 5 seconds
981+
982+
# Verify that throughput metrics are NOT zero
983+
totals_metrics = results_dict["ALL STATS"]["Totals"]
984+
985+
# Check ops/sec is not zero
986+
env.assertTrue("Ops/sec" in totals_metrics)
987+
total_ops_sec = totals_metrics["Ops/sec"]
988+
env.assertTrue(total_ops_sec > 0)
989+
990+
# Check latency metrics are not zero
991+
env.assertTrue("Latency" in totals_metrics)
992+
total_latency = totals_metrics["Latency"]
993+
env.assertTrue(total_latency > 0)
994+
995+
# Check that we actually processed some operations
996+
env.assertTrue("Count" in totals_metrics)
997+
total_count = totals_metrics["Count"]
998+
env.assertTrue(total_count > 0)

0 commit comments

Comments
 (0)