diff --git a/RLTest/__main__.py b/RLTest/__main__.py index 9bb7143..05b3889 100644 --- a/RLTest/__main__.py +++ b/RLTest/__main__.py @@ -218,6 +218,11 @@ def do_normal_conn(self, line): '--use-slaves', action='store_const', const=True, default=False, help='run env with slaves enabled') +parser.add_argument( + '--replicas-per-shard', default=1, type=int, + help='Number of replicas per shard when --use-slaves is set ' + '(cluster mode only). Default: 1.') + parser.add_argument( '--shards-count', default=1, type=int, help='Number shards in bdb') @@ -516,6 +521,9 @@ def __init__(self): Defaults.logdir = self.args.log_dir Defaults.loglevel = self.args.log_level Defaults.use_slaves = self.args.use_slaves + if self.args.replicas_per_shard < 1: + raise Exception('--replicas-per-shard must be >= 1') + Defaults.replicas_per_shard = self.args.replicas_per_shard Defaults.num_shards = self.args.shards_count Defaults.shards_ports = self.args.shards_ports.split(',') if self.args.shards_ports is not None else None Defaults.cluster_address = self.args.cluster_address diff --git a/RLTest/env.py b/RLTest/env.py index eb630cc..99265ed 100644 --- a/RLTest/env.py +++ b/RLTest/env.py @@ -136,6 +136,7 @@ class Defaults: logdir = None loglevel = None use_slaves = False + replicas_per_shard = 1 num_shards = 1 external_addr = 'localhost:6379' use_unix = False @@ -161,6 +162,7 @@ def getKwargs(self): 'moduleArgs': self.module_args, 'port': self.port, 'useSlaves': self.use_slaves, + 'replicasPerShard': self.replicas_per_shard, 'useAof': self.use_aof, 'useRdbPreamble': self.use_rdb_preamble, 'dbDirPath': self.logdir, @@ -184,7 +186,8 @@ def getKwargs(self): class Env: RTestInstance = None - EnvCompareParams = ['module', 'moduleArgs', 'env', 'useSlaves', 'shardsCount', 'useAof', + EnvCompareParams = ['module', 'moduleArgs', 'env', 'useSlaves', 'replicasPerShard', + 'shardsCount', 'useAof', 'useRdbPreamble', 'forceTcp', 'enableDebugCommand', 'enableProtectedConfigs', 'enableModuleCommand', 'protocol', 'password'] @@ -203,7 +206,7 @@ def __init__(self, testName=None, testDescription=None, module=None, redisEnterpriseBinaryPath=None, noDefaultModuleArgs=False, clusterNodeTimeout = None, freshEnv=False, enableDebugCommand=None, enableModuleCommand=None, enableProtectedConfigs=None, protocol=None, terminateRetries=None, terminateRetrySecs=None, redisConfigFile=None, dualTLS=False, - startupGraceSecs=None): + startupGraceSecs=None, replicasPerShard=None): self.testName = testName if testName else Defaults.curr_test_name if self.testName is None: @@ -220,6 +223,11 @@ def __init__(self, testName=None, testDescription=None, module=None, self.moduleArgs = fix_modulesArgs(self.module, moduleArgs, Defaults.module_args) self.env = env if env else Defaults.env self.useSlaves = useSlaves if useSlaves else Defaults.use_slaves + # Per-test override is rare; default falls back to the global value set + # via --replicas-per-shard. Only meaningful for cluster mode with + # useSlaves; standalone always uses 1 slave. + self.replicasPerShard = (replicasPerShard if replicasPerShard is not None + else Defaults.replicas_per_shard) self.shardsCount = shardsCount if shardsCount else Defaults.num_shards self.decodeResponses = decodeResponses if decodeResponses else Defaults.decode_responses self.useAof = useAof if useAof else Defaults.use_aof @@ -351,6 +359,7 @@ def getEnvKwargs(self): 'modulePath': self.module, 'moduleArgs': self.moduleArgs, 'useSlaves': self.useSlaves, + 'replicasPerShard': self.replicasPerShard, 'decodeResponses': self.decodeResponses, 'useAof': self.useAof, 'useRdbPreamble': self.useRdbPreamble, diff --git a/RLTest/redis_cluster.py b/RLTest/redis_cluster.py index ab2e975..3bfaae9 100644 --- a/RLTest/redis_cluster.py +++ b/RLTest/redis_cluster.py @@ -1,6 +1,6 @@ from __future__ import print_function -from .redis_std import StandardEnv +from .redis_std import StandardEnv, MASTER, SLAVE from redis.cluster import ClusterNode import redis import time @@ -9,6 +9,14 @@ # Interval in seconds between status updates during cluster wait CLUSTER_STATUS_INTERVAL_SEC = 5 +# Brief pause to let master-to-master gossip settle before/after slave MEET. +GOSSIP_SETTLE_SEC = 0.5 +# Max attempts when issuing CLUSTER REPLICATE on a freshly-MEET'd slave; the +# slave may not yet have learned the master's node id via gossip. +REPLICATE_RETRY_MAX = 20 +# Sleep between CLUSTER REPLICATE retries. +REPLICATE_RETRY_INTERVAL_SEC = 0.25 + class ClusterEnv(object): def __init__(self, **kwargs): @@ -20,6 +28,10 @@ def __init__(self, **kwargs): self.password = kwargs['password'] self.shardsCount = kwargs.pop('shardsCount') useSlaves = kwargs.get('useSlaves', False) + # replicasPerShard is the number of replicas attached to each shard's + # master. Default 1 preserves prior single-replica behavior. Only + # honored when useSlaves is True; otherwise treated as 0. + self.replicasPerShard = kwargs.get('replicasPerShard', 1) if useSlaves else 0 self.useTLS = kwargs['useTLS'] self.decodeResponses = kwargs.get('decodeResponses', False) self.tlsPassphrase = kwargs.get('tlsPassphrase', None) @@ -29,14 +41,23 @@ def __init__(self, **kwargs): self.verbose = kwargs.get('verbose', False) self.clusterStartTimeout = kwargs.pop('clusterStartTimeout', 40) startPort = kwargs.pop('port', 10000) - totalRedises = self.shardsCount * (2 if useSlaves else 1) + # Each shard owns one master plus N replicas, so allocate + # shardsCount * (1 + replicasPerShard) total redises when useSlaves is + # set. With replicasPerShard == 1 this equals the previous 2x layout. + instancesPerShard = (1 + self.replicasPerShard) if useSlaves else 1 + totalRedises = self.shardsCount * instancesPerShard randomizePorts = kwargs.pop('randomizePorts', False) - for i in range(0, totalRedises, (2 if useSlaves else 1)): + # Per-shard port stride: must leave room for the master plus all of + # its replicas. The historical stride was 2 (matching 1 master + 1 + # slave); with N replicas it grows to 1 + N. When useSlaves is off the + # stride stays at 2 to match the pre-feature default. + portStride = (1 + self.replicasPerShard) if useSlaves else 2 + for i in range(0, totalRedises, instancesPerShard): port = 0 if randomizePorts else startPort shard = StandardEnv(port=port, serverId=(i + 1), clusterEnabled=True, **kwargs) self.shards.append(shard) - startPort += 2 + startPort += portStride def printEnvData(self, prefix=''): print(Colors.Yellow(prefix + 'info:')) @@ -69,6 +90,32 @@ def _countOk(self): ok += 1 return ok + @staticmethod + def _normalizeSlotsView(slots_view): + """Returns a hashable representation of CLUSTER SLOTS that is invariant + to per-shard replica ordering. + + Each CLUSTER SLOTS row is + [start, end, master_entry, replica_entry, replica_entry, ...] + Different masters can report the replica entries in different orders + when there is more than one replica per shard; that is not a real + disagreement on the topology. We sort replicas within each row before + comparing so multi-replica clusters can converge. + """ + if slots_view is None: + return None + normalized = [] + for row in slots_view: + if len(row) < 3: + normalized.append(tuple(row)) + continue + start, end, master_entry = row[0], row[1], row[2] + replicas = sorted((tuple(r) for r in row[3:]), + key=lambda r: tuple(repr(x) for x in r)) + normalized.append((start, end, tuple(master_entry)) + tuple(replicas)) + normalized.sort() + return tuple(normalized) + def _countAgreeSlots(self): """Returns count of shards that agree on slots view""" ok = 0 @@ -80,26 +127,73 @@ def _countAgreeSlots(self): except Exception as e: print('got error on cluster slots, will try again, %s' % str(e)) continue + normalized = self._normalizeSlotsView(slots_view) if first_view is None: - first_view = slots_view - if slots_view == first_view: + first_view = normalized + if normalized == first_view: ok += 1 return ok + def _expectedReplicasInSlots(self): + """Return total live slaves across all shards. + + Counts from shard.slaveProcesses (process-side), not from CLUSTER SLOTS. + This is robust to non-contiguous slot distributions during partial + migration -- the process count is authoritative regardless of how + CLUSTER SLOTS rows are partitioned. + + Each shard reports one slot-range row, and that row contains one + replica entry per running slave attached to that shard's master. + Returns 0 when no slaves are configured. + """ + total = 0 + for shard in self.shards: + if not getattr(shard, 'useSlaves', False): + continue + for proc in getattr(shard, 'slaveProcesses', []): + if proc is not None: + total += 1 + return total + + def _countReplicasInSlots(self): + """Returns the number of replica entries reported across all shards. + + Queries CLUSTER SLOTS from a master and counts the replica entries + (positions after the first master entry in each slot row). Only the + first master is queried because all masters should agree post-gossip. + """ + if not self.shards: + return 0 + con = self.shards[0].getConnection() + try: + slots_view = con.execute_command('CLUSTER', 'SLOTS') + except Exception as e: + print('got error on cluster slots, will try again, %s' % str(e)) + return 0 + count = 0 + for row in slots_view: + # row = [start, end, master_entry, replica_entry, replica_entry, ...] + if len(row) > 3: + count += len(row) - 3 + return count + def waitCluster(self, timeout_sec=40, verbose=True): st = time.time() last_status_time = st total_shards = len(self.shards) + expected_replicas = self._expectedReplicasInSlots() if verbose: - print(Colors.Yellow('Waiting for cluster to be ready (timeout: %d seconds, %d shards)...' % - (timeout_sec, total_shards))) + print(Colors.Yellow('Waiting for cluster to be ready (timeout: %d seconds, %d shards, %d replicas)...' % + (timeout_sec, total_shards, expected_replicas))) while st + timeout_sec > time.time(): ok_count = self._countOk() slots_count = self._countAgreeSlots() + replicas_count = self._countReplicasInSlots() if expected_replicas else 0 - if ok_count == total_shards and slots_count == total_shards: + if (ok_count == total_shards and slots_count == total_shards + and replicas_count >= expected_replicas): elapsed = time.time() - st if verbose: print(Colors.Green('Cluster is ready after %.1f seconds' % elapsed)) @@ -114,14 +208,96 @@ def waitCluster(self, timeout_sec=40, verbose=True): now = time.time() if verbose and (now - last_status_time) >= CLUSTER_STATUS_INTERVAL_SEC: elapsed = now - st - print(Colors.Yellow(' Cluster wait: %.1fs elapsed - %d/%d shards OK, %d/%d agree on slots...' % - (elapsed, ok_count, total_shards, slots_count, total_shards))) + if expected_replicas: + print(Colors.Yellow( + ' Cluster wait: %.1fs elapsed - %d/%d shards OK, %d/%d agree on slots, %d/%d replicas visible...' + % (elapsed, ok_count, total_shards, slots_count, total_shards, + replicas_count, expected_replicas))) + else: + print(Colors.Yellow(' Cluster wait: %.1fs elapsed - %d/%d shards OK, %d/%d agree on slots...' % + (elapsed, ok_count, total_shards, slots_count, total_shards))) last_status_time = now time.sleep(0.1) raise RuntimeError( "Cluster OK wait loop timed out after %s seconds" % timeout_sec) + def _attachSlavesToCluster(self): + """Attach slaves to their masters via CLUSTER MEET + CLUSTER REPLICATE. + + Each StandardEnv shard owns one master and (when useSlaves is True) + one or more slaves (controlled by --replicas-per-shard). Slaves were + booted with --cluster-enabled but no master link (see redis_std.py). + Now that masters have MEET'd each other and slots are assigned, MEET + every slave from its master so the slave joins gossip, then issue + CLUSTER REPLICATE on each slave's connection to attach it to the + master. + """ + total_shards = len(self.shards) + # Total live slaves across all shards. + total_slaves = sum( + sum(1 for p in getattr(s, 'slaveProcesses', []) if p is not None) + for s in self.shards + if getattr(s, 'useSlaves', False) + ) + if total_slaves == 0: + return + + if self.verbose: + print(Colors.Yellow('Attaching %d slave(s) to cluster...' % total_slaves)) + + # Briefly wait for masters to finish gossiping with each other. + time.sleep(GOSSIP_SETTLE_SEC) + + # Phase 1: MEET each slave from its master so the slave joins gossip. + master_node_ids = {} + for i, shard in enumerate(self.shards): + if not getattr(shard, 'useSlaves', False): + continue + live_slave_indices = [j for j, p in enumerate(shard.slaveProcesses) if p is not None] + if not live_slave_indices: + continue + master_conn = shard.getConnection() + master_node_id = master_conn.execute_command('CLUSTER', 'MYID') + if isinstance(master_node_id, bytes): + master_node_id = master_node_id.decode() + master_node_ids[i] = (master_node_id, live_slave_indices) + for j in live_slave_indices: + slave_port = shard.getSlavePort(j) + master_conn.execute_command('CLUSTER', 'MEET', '127.0.0.1', slave_port) + + # Allow gossip to propagate so each slave sees the master it will replicate. + time.sleep(GOSSIP_SETTLE_SEC) + + # Phase 2: CLUSTER REPLICATE on each slave connection. + for i, shard in enumerate(self.shards): + if i not in master_node_ids: + continue + master_node_id, live_slave_indices = master_node_ids[i] + for j in live_slave_indices: + slave_conn = shard.getSlaveConnection(j) + # Retry briefly to handle the race where the slave has not yet + # learned the master node id via gossip. + attached = False + last_err = None + for _ in range(REPLICATE_RETRY_MAX): + try: + slave_conn.execute_command('CLUSTER', 'REPLICATE', master_node_id) + attached = True + break + except Exception as e: + last_err = e + time.sleep(REPLICATE_RETRY_INTERVAL_SEC) + if not attached: + raise RuntimeError( + 'CLUSTER REPLICATE failed for shard %d/%d slave[%d]: %s' + % (i + 1, total_shards, j, last_err)) + if self.verbose: + label = ('slave' if len(live_slave_indices) == 1 + else 'slave[%d]' % j) + print(Colors.Yellow(' Attached %s for shard %d/%d (replicate %s)' % + (label, i + 1, total_shards, master_node_id[:8]))) + def startEnv(self, masters=True, slaves=True): if self.envIsUp == True: return # env is already up @@ -145,29 +321,40 @@ def startEnv(self, masters=True, slaves=True): if self.verbose: print(Colors.Yellow('Configuring cluster topology...')) slots_per_node = int(16384 / len(self.shards)) + 1 - for i, shard in enumerate(self.shards): - con = shard.getConnection() - for s in self.shards: - con.execute_command('CLUSTER', 'MEET', - '127.0.0.1', s.getMasterPort()) - - start_slot = i * slots_per_node - end_slot = start_slot + slots_per_node - if end_slot > 16384: - end_slot = 16384 + try: + for i, shard in enumerate(self.shards): + con = shard.getConnection() + for s in self.shards: + con.execute_command('CLUSTER', 'MEET', + '127.0.0.1', s.getMasterPort()) + + start_slot = i * slots_per_node + end_slot = start_slot + slots_per_node + if end_slot > 16384: + end_slot = 16384 + + try: + con.execute_command('CLUSTER', 'ADDSLOTS', *(str(x) + for x in range(start_slot, end_slot))) + except Exception as e: + print(Colors.Bred(' Error assigning slots %d-%d to shard %d: %s' % + (start_slot, end_slot - 1, i + 1, str(e)))) - try: - con.execute_command('CLUSTER', 'ADDSLOTS', *(str(x) - for x in range(start_slot, end_slot))) - except Exception as e: - print(Colors.Bred(' Error assigning slots %d-%d to shard %d: %s' % - (start_slot, end_slot - 1, i + 1, str(e)))) + if self.verbose: + print(Colors.Yellow(' Configured shard %d/%d (slots %d-%d)' % + (i + 1, total_shards, start_slot, min(end_slot - 1, 16383)))) - if self.verbose: - print(Colors.Yellow(' Configured shard %d/%d (slots %d-%d)' % - (i + 1, total_shards, start_slot, min(end_slot - 1, 16383)))) + # Attach slaves (if any) before waiting for cluster_state:ok so the + # final waitCluster call also covers replica readiness. + self._attachSlavesToCluster() - self.waitCluster(timeout_sec=self.clusterStartTimeout, verbose=self.verbose) + self.waitCluster(timeout_sec=self.clusterStartTimeout, verbose=self.verbose) + except Exception: + # Topology phase failures (waitCluster timeout, REPLICATE retry + # exhaustion, etc.) would otherwise leak every redis-server we + # already booted in the shard loop above. Tear them down. + self.stopEnv() + raise self.envIsUp = True self.envIsHealthy = True @@ -235,7 +422,11 @@ def getConnectionByKey(self, key, command): def addShardToCluster(self, redisBinaryPath, output_files_format, **kwargs): kwargs.pop('port') - port = self.shards[-1].port + 2 # use a fresh port + # Skip past the previous shard's master and all of its replicas so we + # land on a free port. With replicasPerShard==1 the stride is 2, + # matching the prior +2 hop. + port_stride = 1 + self.replicasPerShard if self.replicasPerShard else 2 + port = self.shards[-1].port + port_stride # use a fresh port self.shardsCount += 1 new_shard = StandardEnv(redisBinaryPath, port, outputFilesFormat=output_files_format, serverId=self.shardsCount, clusterEnabled=True, **kwargs) diff --git a/RLTest/redis_std.py b/RLTest/redis_std.py index 6ad0065..c0d775e 100644 --- a/RLTest/redis_std.py +++ b/RLTest/redis_std.py @@ -23,7 +23,7 @@ def __init__(self, redisBinaryPath, port=6379, modulePath=None, moduleArgs=None, useAof=False, useRdbPreamble=True, debugger=None, sanitizer=None, noCatch=False, noLog=False, unix=False, verbose=False, useTLS=False, tlsCertFile=None, tlsKeyFile=None, tlsCaCertFile=None, clusterNodeTimeout=None, tlsPassphrase=None, enableDebugCommand=False, protocol=2, terminateRetries=None, terminateRetrySecs=None, enableProtectedConfigs=False, enableModuleCommand=False, loglevel=None, - redisConfigFile=None, dualTLS=False, startupGraceSecs=0.1 + redisConfigFile=None, dualTLS=False, startupGraceSecs=0.1, replicasPerShard=1 ): self.uuid = uuid.uuid4().hex self.redisBinaryPath = os.path.expanduser(redisBinaryPath) if redisBinaryPath.startswith( @@ -33,6 +33,19 @@ def __init__(self, redisBinaryPath, port=6379, modulePath=None, moduleArgs=None, self.moduleArgs = fix_modulesArgs(self.modulePath, moduleArgs, haveSeqs=False) self.outputFilesFormat = self.uuid + '.' + outputFilesFormat self.useSlaves = useSlaves + # Number of replicas attached to this shard's master. Default 1 + # preserves the historical single-replica behavior. Multi-replica + # support (>1) is only meaningful for cluster mode; in standalone we + # cap at 1 to keep behavior unchanged. + if useSlaves: + if not clusterEnabled and replicasPerShard != 1: + # Standalone replication keeps a single replica; per design + # there is no clear use case for multiple chained replicas at + # this level. + replicasPerShard = 1 + self.replicasPerShard = replicasPerShard + else: + self.replicasPerShard = 0 self.masterServerId = serverId self.password = password self.clusterEnabled = clusterEnabled @@ -52,10 +65,23 @@ def __init__(self, redisBinaryPath, port=6379, modulePath=None, moduleArgs=None, self.masterStdout = None self.masterStderr = None self.masterExitCode = None - self.slaveProcess = None - self.slaveStdout = None - self.slaveStderr = None - self.slaveExitCode = None + # Slave state is stored as per-replica lists so a single shard can own + # more than one replica (controlled by replicasPerShard). The legacy + # scalar attributes (slaveProcess, slavePort, slaveStdout, slaveStderr, + # slaveExitCode, slaveServerId) remain accessible via @property + # accessors below for backward compatibility; they map to index 0. + n_slaves = self.replicasPerShard if self.useSlaves else 0 + self.slaveProcesses = [None] * n_slaves + self.slaveStdouts = [None] * n_slaves + self.slaveStderrs = [None] * n_slaves + self.slaveExitCodes = [None] * n_slaves + # Slave indices (0-based, into the per-shard lists above) that a test + # has deliberately shut down via SHUTDOWN/DEBUG SLEEP/etc. The + # corresponding slaveProcesses entry is cleared so that stopEnv() will + # not try to terminate an already-gone process, and the corresponding + # slaveExitCodes entry is forced to 0 so that checkExitCode() does not + # flag a missing exit code as a crash. See markSlaveDeadByTest(). + self._expectedDeadSlaves = set() self.verbose = verbose self.role = MASTER self.useTLS = useTLS @@ -76,13 +102,16 @@ def __init__(self, redisBinaryPath, port=6379, modulePath=None, moduleArgs=None, if port > 0: self.port = port - self.slavePort = port + 1 if self.useSlaves else 0 + # Allocate sequential ports right after the master port, one per + # replica. With replicasPerShard==1 this matches historical + # behavior (slavePort == port + 1). + self.slavePorts = [port + 1 + i for i in range(n_slaves)] if self.useSlaves else [] elif port == 0: self.port = get_random_port() - self.slavePort = get_random_port() if self.useSlaves else 0 + self.slavePorts = [get_random_port() for _ in range(n_slaves)] if self.useSlaves else [] else: self.port = -1 - self.slavePort = -1 + self.slavePorts = [-1] * n_slaves if self.useSlaves else [] if self.has_interactive_debugger and serverId > 1: assert self.noCatch and not self.useSlaves and not self.clusterEnabled @@ -123,16 +152,123 @@ def __init__(self, redisBinaryPath, port=6379, modulePath=None, moduleArgs=None, self.masterCmdArgs = self.createCmdArgs(MASTER) self.masterOSEnv = self.createCmdOSEnv(MASTER) + # Per-replica command args / env / server ids. The current slave index + # used by createCmdArgs/createCmdOSEnv/_getFileName is tracked via the + # transient self._slaveIdx attribute set during the loop below; this + # avoids changing the createCmdArgs/_getFileName signatures and + # therefore preserves backward compatibility. + self.slaveServerIds = [] + self.slaveCmdArgsList = [] + self.slaveOSEnvList = [] if self.useSlaves: - self.slaveServerId = serverId + 1 - self.slaveCmdArgs = self.createCmdArgs(SLAVE) - self.slaveOSEnv = self.createCmdOSEnv(SLAVE) + for i in range(self.replicasPerShard): + self._slaveIdx = i + self.slaveServerIds.append(serverId + 1 + i) + self.slaveCmdArgsList.append(self.createCmdArgs(SLAVE)) + self.slaveOSEnvList.append(self.createCmdOSEnv(SLAVE)) + self._slaveIdx = 0 self.envIsHealthy = True + # ---- Backward-compatibility scalar accessors ---- + # External callers (and the existing test suite) read/write scalar + # attributes such as self.slaveProcess, self.slavePort, self.slaveServerId, + # self.slaveStdout, self.slaveStderr, self.slaveExitCode, self.slaveCmdArgs + # and self.slaveOSEnv. These properties expose index 0 of the underlying + # per-replica lists so that pre-existing code continues to work unchanged + # whenever replicasPerShard == 1 (the default). + @property + def slaveProcess(self): + return self.slaveProcesses[0] if self.slaveProcesses else None + + @slaveProcess.setter + def slaveProcess(self, value): + if self.slaveProcesses: + self.slaveProcesses[0] = value + else: + self.slaveProcesses = [value] + + @property + def slavePort(self): + return self.slavePorts[0] if self.slavePorts else 0 + + @slavePort.setter + def slavePort(self, value): + if self.slavePorts: + self.slavePorts[0] = value + else: + self.slavePorts = [value] + + @property + def slaveServerId(self): + return self.slaveServerIds[0] if self.slaveServerIds else None + + @slaveServerId.setter + def slaveServerId(self, value): + if self.slaveServerIds: + self.slaveServerIds[0] = value + else: + self.slaveServerIds = [value] + + @property + def slaveStdout(self): + return self.slaveStdouts[0] if self.slaveStdouts else None + + @slaveStdout.setter + def slaveStdout(self, value): + if self.slaveStdouts: + self.slaveStdouts[0] = value + else: + self.slaveStdouts = [value] + + @property + def slaveStderr(self): + return self.slaveStderrs[0] if self.slaveStderrs else None + + @slaveStderr.setter + def slaveStderr(self, value): + if self.slaveStderrs: + self.slaveStderrs[0] = value + else: + self.slaveStderrs = [value] + + @property + def slaveExitCode(self): + return self.slaveExitCodes[0] if self.slaveExitCodes else None + + @slaveExitCode.setter + def slaveExitCode(self, value): + if self.slaveExitCodes: + self.slaveExitCodes[0] = value + else: + self.slaveExitCodes = [value] + + @property + def slaveCmdArgs(self): + return self.slaveCmdArgsList[0] if self.slaveCmdArgsList else None + + @property + def slaveOSEnv(self): + return self.slaveOSEnvList[0] if self.slaveOSEnvList else None + + def getNumSlaves(self): + """Returns the number of slaves configured for this shard.""" + return len(self.slaveProcesses) + def _getFileName(self, role, suffix): - return (self.outputFilesFormat + suffix) % ( - 'master-%d' % self.masterServerId if role == MASTER else 'slave-%d' % self.slaveServerId) + if role == MASTER: + tag = 'master-%d' % self.masterServerId + else: + # When createCmdArgs is called once per replica during __init__, + # self._slaveIdx points to the replica currently being built. After + # construction, callers that pass role=SLAVE expect the legacy + # single-slave tag (index 0). + idx = getattr(self, '_slaveIdx', 0) + server_id = (self.slaveServerIds[idx] + if self.slaveServerIds and idx < len(self.slaveServerIds) + else self.masterServerId + 1) + tag = 'slave-%d' % server_id + return (self.outputFilesFormat + suffix) % tag def _getValgrindFilePath(self, role): return os.path.join(self.dbDirPath, self._getFileName(role, '.valgrind.log')) @@ -221,13 +357,21 @@ def createCmdArgs(self, role): cmdArgs += ['--loglevel', self.loglevel] if self.outputFilesFormat is not None: cmdArgs += ['--dbfilename', self._getFileName(role, '.rdb')] - if role == SLAVE: + if role == SLAVE and not self.clusterEnabled: + # Standalone replication: tie the slave to its master at boot. cmdArgs += ['--slaveof', 'localhost', str(self.port)] if self.password: cmdArgs += ['--masterauth', self.password] + elif role == SLAVE and self.clusterEnabled: + # Cluster mode: do NOT use --slaveof. The slave will be attached + # to its master via CLUSTER REPLICATE after the cluster is formed + # (see redis_cluster.py::startEnv). Boot it as an empty + # cluster-enabled node so it can join gossip. + if self.password: + cmdArgs += ['--masterauth', self.password] if self.password: cmdArgs += ['--requirepass', self.password] - if self.clusterEnabled and role is not SLAVE: + if self.clusterEnabled: # creating .cluster.conf in /tmp as lock fails on NFS cmdArgs += ['--cluster-enabled', 'yes', '--cluster-config-file', '/tmp/' + self._getFileName(role, '.cluster.conf'), '--cluster-node-timeout', '5000' if self.clusterNodeTimeout is None else str(self.clusterNodeTimeout)] @@ -269,14 +413,26 @@ def waitForRedisToStart(self, con, proc): wait_for_conn(con, proc, retries=1000 if self.debugger else 200) self._waitForAOFChild(con) - def getPid(self, role): - return self.masterProcess.pid if role == MASTER else self.slaveProcess.pid - - def getPort(self, role): - return self.port if role == MASTER else self.slavePort - - def getServerId(self, role): - return self.masterServerId if role == MASTER else self.slaveServerId + def getPid(self, role, slaveIdx=0): + if role == MASTER: + return self.masterProcess.pid + idx = getattr(self, '_slaveIdx', slaveIdx) + return self.slaveProcesses[idx].pid + + def getPort(self, role, slaveIdx=0): + if role == MASTER: + return self.port + # During __init__, createCmdArgs(SLAVE) is invoked once per replica + # while self._slaveIdx walks through the index range; honor it so each + # replica gets its own port wired into its command line. + idx = getattr(self, '_slaveIdx', slaveIdx) + return self.slavePorts[idx] if self.slavePorts else 0 + + def getServerId(self, role, slaveIdx=0): + if role == MASTER: + return self.masterServerId + idx = getattr(self, '_slaveIdx', slaveIdx) + return self.slaveServerIds[idx] if self.slaveServerIds else None def _printEnvData(self, prefix='', role=MASTER): print(Colors.Yellow(prefix + 'pid: %d' % (self.getPid(role)))) @@ -307,14 +463,28 @@ def printEnvData(self, prefix=''): print(Colors.Yellow(prefix + 'master:')) self._printEnvData(prefix + '\t', MASTER) if self.useSlaves: - print(Colors.Yellow(prefix + 'slave:')) - self._printEnvData(prefix + '\t', SLAVE) + for i in range(self.getNumSlaves()): + label = 'slave:' if self.getNumSlaves() == 1 else 'slave[%d]:' % i + print(Colors.Yellow(prefix + label)) + # _printEnvData reads role-keyed scalars; switch self._slaveIdx + # so the helpers (getPid/getPort/getServerId/_getFileName) + # return the i-th replica's values for this print iteration. + old_idx = getattr(self, '_slaveIdx', 0) + self._slaveIdx = i + try: + self._printEnvData(prefix + '\t', SLAVE) + finally: + self._slaveIdx = old_idx def getInformationBeforeDispose(self): res = {} instances = [(MASTER, self.getConnection(), self.masterProcess)] if self.useSlaves: - instances.append((SLAVE, self.getSlaveConnection(), self.slaveProcess)) + for i in range(self.getNumSlaves()): + instances.append((SLAVE if i == 0 and self.getNumSlaves() == 1 + else '%s-%d' % (SLAVE, i), + self.getSlaveConnection(i), + self.slaveProcesses[i])) for role, conn, proc in instances: info = None try: @@ -330,7 +500,11 @@ def getInformationAfterDispose(self): res = {} instances = [(MASTER, self.masterStdout, self.masterStderr)] if self.useSlaves: - instances.append((SLAVE, self.slaveStdout, self.slaveStderr)) + for i in range(self.getNumSlaves()): + instances.append((SLAVE if i == 0 and self.getNumSlaves() == 1 + else '%s-%d' % (SLAVE, i), + self.slaveStdouts[i], + self.slaveStderrs[i])) for role, stdout, stderr in instances: stdoutStr = None stderrStr = None @@ -388,20 +562,28 @@ def startEnv(self, masters = True, slaves = True): self.waitForRedisToStart(con, self.masterProcess) else: self.masterProcess = None - if self.useSlaves and slaves and self.slaveProcess is None: - if self.verbose: - print(Colors.Green("Redis slave command: " + ' '.join(self.slaveCmdArgs))) - self.slaveProcess = subprocess.Popen(args=self.slaveCmdArgs, env=self.slaveOSEnv, cwd=self.dbDirPath, - **options) - time.sleep(self.startupGraceSecs) - if self._isAlive(self.slaveProcess): - con = self.getSlaveConnection() - self.waitForRedisToStart(con, self.slaveProcess) - else: - self.slaveProcess = None - - self.envIsUp = self.masterProcess is not None or self.slaveProcess is not None - self.envIsHealthy = self.masterProcess is not None and (self.slaveProcess is not None if self.useSlaves else True) + if self.useSlaves and slaves: + for i in range(self.getNumSlaves()): + if self.slaveProcesses[i] is not None: + continue + if self.verbose: + label = "Redis slave command" if self.getNumSlaves() == 1 \ + else "Redis slave[%d] command" % i + print(Colors.Green("%s: %s" % (label, ' '.join(self.slaveCmdArgsList[i])))) + self.slaveProcesses[i] = subprocess.Popen( + args=self.slaveCmdArgsList[i], env=self.slaveOSEnvList[i], + cwd=self.dbDirPath, **options) + time.sleep(self.startupGraceSecs) + if self._isAlive(self.slaveProcesses[i]): + con = self.getSlaveConnection(i) + self.waitForRedisToStart(con, self.slaveProcesses[i]) + else: + self.slaveProcesses[i] = None + + any_slave_alive = any(p is not None for p in self.slaveProcesses) + all_slaves_alive = all(p is not None for p in self.slaveProcesses) if self.slaveProcesses else True + self.envIsUp = self.masterProcess is not None or any_slave_alive + self.envIsHealthy = self.masterProcess is not None and (all_slaves_alive if self.useSlaves else True) # self.masterStdout = self.masterProcess.stdout if self.masterProcess else None # self.masterStderr = self.masterProcess.stderr if self.masterProcess else None @@ -418,8 +600,11 @@ def _isAlive(self, process): return True return False - def _segfault(self, role, retries=3): - process = self.masterProcess if role == MASTER else self.slaveProcess + def _segfault(self, role, retries=3, slaveIdx=0): + if role == MASTER: + process = self.masterProcess + else: + process = self.slaveProcesses[slaveIdx] if not self._isAlive(process): return for _ in range(retries): @@ -440,14 +625,32 @@ def _segfault(self, role, retries=3): def stopEnvWithSegFault(self, masters = True, slaves = True): if self.masterProcess is not None and masters is True: self._segfault(MASTER) - if self.useSlaves and self.slaveProcess is not None and slaves is True: - self._segfault(SLAVE) + if self.useSlaves and slaves is True: + for i in range(self.getNumSlaves()): + if self.slaveProcesses[i] is not None: + self._segfault(SLAVE, slaveIdx=i) + + def _stopProcess(self, role, slaveIdx=0): + if role == MASTER: + process = self.masterProcess + serverId = self.masterServerId + else: + process = self.slaveProcesses[slaveIdx] + serverId = self.slaveServerIds[slaveIdx] + # _getFileName(SLAVE, ...) reads self._slaveIdx to pick the correct + # log file when verbose_analyse_server_log is invoked below. + old_idx = getattr(self, '_slaveIdx', 0) + self._slaveIdx = slaveIdx + try: + return self.__stopProcessImpl(process, role, serverId, slaveIdx) + finally: + self._slaveIdx = old_idx - def _stopProcess(self, role): - process = self.masterProcess if role == MASTER else self.slaveProcess - serverId = self.masterServerId if role == MASTER else self.slaveServerId + def __stopProcessImpl(self, process, role, serverId, slaveIdx): if not self._isAlive(process): - if not self.has_interactive_debugger: + expected_dead = (role == SLAVE and + slaveIdx in self._expectedDeadSlaves) + if not self.has_interactive_debugger and not expected_dead: # on interactive debugger its expected that then process will not be alive print('\t' + Colors.Bred('process is not alive, might have crash durring test execution, ' 'check this out. server id : %s' % str(serverId))) @@ -494,7 +697,7 @@ def _stopProcess(self, role): if role == MASTER: self.masterExitCode = process.poll() else: - self.slaveExitCode = process.poll() + self.slaveExitCodes[slaveIdx] = process.poll() except OSError as e: print('\t' + Colors.Bred( 'OSError caught while waiting for {0} process to end: {1}'.format(role, e.__str__()))) @@ -520,31 +723,42 @@ def stopEnv(self, masters = True, slaves = True): if self.masterProcess is not None and masters is True: self._stopProcess(MASTER) self.masterProcess = None - if self.useSlaves and self.slaveProcess is not None and slaves is True: - self._stopProcess(SLAVE) - self.slaveProcess = None - self.envIsUp = self.masterProcess is not None or self.slaveProcess is not None - self.envIsHealthy = self.masterProcess is not None and (self.slaveProcess is not None if self.useSlaves else True) - - def _getConnection(self, role): - if self.useUnix: - return redis.StrictRedis(unix_socket_path=self.getUnixPath(role), - password=self.password, decode_responses=self.decodeResponses, protocol=self.protocol) - elif self.useTLS: - return redis.StrictRedis('localhost', self.getPort(role), - password=self.password, - ssl=True, - ssl_password=self.tlsPassphrase, - ssl_keyfile=self.getTLSKeyFile(), - ssl_certfile=self.getTLSCertFile(), - ssl_cert_reqs=None, - ssl_ca_certs=self.getTLSCACertFile(), - decode_responses=self.decodeResponses, - protocol=self.protocol - ) - else: - return redis.StrictRedis('localhost', self.getPort(role), - password=self.password, decode_responses=self.decodeResponses, protocol=self.protocol) + if self.useSlaves and slaves is True: + for i in range(self.getNumSlaves()): + if self.slaveProcesses[i] is not None: + self._stopProcess(SLAVE, slaveIdx=i) + self.slaveProcesses[i] = None + any_slave_alive = any(p is not None for p in self.slaveProcesses) + all_slaves_alive = all(p is not None for p in self.slaveProcesses) if self.slaveProcesses else True + self.envIsUp = self.masterProcess is not None or any_slave_alive + self.envIsHealthy = self.masterProcess is not None and (all_slaves_alive if self.useSlaves else True) + + def _getConnection(self, role, slaveIdx=0): + # When fetching a slave connection, temporarily steer the role-aware + # helpers (getPort/getUnixPath) at the requested replica index. + old_idx = getattr(self, '_slaveIdx', 0) + self._slaveIdx = slaveIdx + try: + if self.useUnix: + return redis.StrictRedis(unix_socket_path=self.getUnixPath(role), + password=self.password, decode_responses=self.decodeResponses, protocol=self.protocol) + elif self.useTLS: + return redis.StrictRedis('localhost', self.getPort(role), + password=self.password, + ssl=True, + ssl_password=self.tlsPassphrase, + ssl_keyfile=self.getTLSKeyFile(), + ssl_certfile=self.getTLSCertFile(), + ssl_cert_reqs=None, + ssl_ca_certs=self.getTLSCACertFile(), + decode_responses=self.decodeResponses, + protocol=self.protocol + ) + else: + return redis.StrictRedis('localhost', self.getPort(role), + password=self.password, decode_responses=self.decodeResponses, protocol=self.protocol) + finally: + self._slaveIdx = old_idx def getConnection(self, shardId=1): return self._getConnection(MASTER) @@ -553,11 +767,17 @@ def getConnection(self, shardId=1): def getOSSMasterNodesConnectionList(self): return [self.getConnection()] - def getSlaveConnection(self): + def getSlaveConnection(self, slaveIdx=0): if self.useSlaves: - return self._getConnection(SLAVE) + return self._getConnection(SLAVE, slaveIdx=slaveIdx) raise Exception('asked for slave connection but no slave exists') + def getSlavePort(self, slaveIdx=0): + """Returns the port for the slave at the given index (default 0).""" + if not self.useSlaves or not self.slavePorts: + raise Exception('asked for slave port but no slave exists') + return self.slavePorts[slaveIdx] + # List of nodes that initial bootstrapping can be done from def getMasterNodesList(self): node_info = {"host": None, "port": None, "unix_socket_path": None, "password": None} @@ -590,7 +810,8 @@ def dumpAndReload(self, restart=False, shardId=None, timeout_sec=0): conns = [] conns.append(self.getConnection()) if self.useSlaves: - conns.append(self.getSlaveConnection()) + for i in range(self.getNumSlaves()): + conns.append(self.getSlaveConnection(i)) if restart: for con in conns: self._waitForAOFChild(con) @@ -613,15 +834,58 @@ def broadcast(self, *cmd): except Exception as e: print(e) + def markSlaveDeadByTest(self, slave_idx): + """Mark a slave as deliberately shut down by the test. + + Tests that intentionally terminate a replica (for example by sending + SHUTDOWN NOSAVE to exercise a failover code path) must call this so + that RLTest's teardown does not subsequently flag the missing process + as a crash. After this call: + + * ``checkExitCode`` will treat ``slaveExitCodes[slave_idx]`` as 0 + rather than ``None``. + * ``stopEnv`` will skip the slave because its process entry is + cleared. + * If ``__stopProcessImpl`` is ever invoked for this slave (e.g. via + a direct ``_stopProcess`` call), the "process is not alive" + warning is suppressed. + + Parameters + ---------- + slave_idx : int + Zero-based index into the per-shard slave lists + (``slaveProcesses`` / ``slaveExitCodes`` / ``slaveServerIds``). + Note that this is NOT the absolute ``serverId`` printed by + RLTest in failure output; convert from the displayed id via + ``serverId - masterServerId - 1`` if needed, or look up the + index by port via ``slavePorts``. + """ + if not self.useSlaves: + raise ValueError('markSlaveDeadByTest called on a shard ' + 'without replicas') + slave_idx = int(slave_idx) + if slave_idx < 0 or slave_idx >= self.getNumSlaves(): + raise IndexError('slave_idx %d out of range [0, %d)' % + (slave_idx, self.getNumSlaves())) + self._expectedDeadSlaves.add(slave_idx) + self.slaveExitCodes[slave_idx] = 0 + self.slaveProcesses[slave_idx] = None + print('\t' + Colors.Yellow( + 'slave-%d marked dead by test (expected)' % slave_idx)) + def checkExitCode(self): ret = True if self.masterExitCode != 0: print('\t' + Colors.Bred('bad exit code for serverId %s' % str(self.masterServerId))) ret = False - if self.useSlaves and (self.slaveExitCode is None or self.slaveExitCode != 0): - print('\t' + Colors.Bred('bad exit code for serverId %s' % str(self.slaveServerId))) - ret = False + if self.useSlaves: + for i in range(self.getNumSlaves()): + exit_code = self.slaveExitCodes[i] + if exit_code is None or exit_code != 0: + print('\t' + Colors.Bred('bad exit code for serverId %s' % + str(self.slaveServerIds[i]))) + ret = False return ret def isUp(self): diff --git a/tests/unit/test_replicas_per_shard.py b/tests/unit/test_replicas_per_shard.py new file mode 100644 index 0000000..0b6150f --- /dev/null +++ b/tests/unit/test_replicas_per_shard.py @@ -0,0 +1,154 @@ +"""Tests for the --replicas-per-shard / replicasPerShard feature. + +The feature lets a cluster carry more than one replica per shard. It is +cluster-mode only; standalone mode is intentionally capped at a single +replica regardless of the requested value. +""" + +import os +import shutil +import tempfile +from unittest import TestCase + +from RLTest.env import Defaults +from RLTest.redis_cluster import ClusterEnv +from RLTest.redis_std import StandardEnv +from tests.unit.test_common import REDIS_BINARY + + +class TestReplicasPerShardStandardEnv(TestCase): + """StandardEnv-level checks for the per-replica list bookkeeping.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_default_replicas_per_shard_is_one(self): + env = StandardEnv(redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + dbDirPath=self.test_dir, + useSlaves=True) + assert env.replicasPerShard == 1 + assert env.getNumSlaves() == 1 + assert len(env.slaveProcesses) == 1 + assert len(env.slavePorts) == 1 + # Legacy scalar accessor still works. + assert env.slavePort == env.slavePorts[0] + + def test_standalone_caps_replicas_at_one(self): + # Standalone (non-cluster) must stay single-slave; the request for 3 + # replicas is clamped back to 1. + env = StandardEnv(redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + dbDirPath=self.test_dir, + useSlaves=True, + replicasPerShard=3) + assert env.replicasPerShard == 1 + assert env.getNumSlaves() == 1 + + def test_no_slaves_zeroes_replicas(self): + env = StandardEnv(redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + dbDirPath=self.test_dir, + useSlaves=False, + replicasPerShard=4) + assert env.replicasPerShard == 0 + assert env.getNumSlaves() == 0 + assert env.slavePorts == [] + + def test_cluster_shard_multi_replicas_allocates_ports(self): + # In cluster mode multi-replica is allowed; expect sequential ports + # immediately after the master port. + env = StandardEnv(redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + dbDirPath=self.test_dir, + useSlaves=True, + clusterEnabled=True, + replicasPerShard=3, + port=20000) + assert env.replicasPerShard == 3 + assert env.getNumSlaves() == 3 + assert env.slavePorts == [20001, 20002, 20003] + # Each replica gets its own server id, command line, and process slot. + assert len(set(env.slaveServerIds)) == 3 + assert len(env.slaveCmdArgsList) == 3 + assert len(env.slaveProcesses) == 3 + # getSlavePort returns the right port per index. + assert env.getSlavePort(0) == 20001 + assert env.getSlavePort(2) == 20003 + + +class TestReplicasPerShardClusterEnv(TestCase): + """End-to-end ClusterEnv tests that actually launch redis-server.""" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def _build_default_args(self): + default_args = Defaults().getKwargs() + default_args['dbDirPath'] = self.test_dir + return default_args + + def test_total_redises_three_shards_two_replicas(self): + # Construction-time check: 3 shards * (1 master + 2 replicas) = 9 redises. + default_args = self._build_default_args() + default_args['useSlaves'] = True + default_args['replicasPerShard'] = 2 + cluster_env = ClusterEnv(shardsCount=3, redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + randomizePorts=True, **default_args) + try: + assert len(cluster_env.shards) == 3 + assert cluster_env.replicasPerShard == 2 + total_redises = sum(1 + s.getNumSlaves() for s in cluster_env.shards) + assert total_redises == 9 + for shard in cluster_env.shards: + assert shard.getNumSlaves() == 2 + assert shard.replicasPerShard == 2 + finally: + cluster_env.stopEnv() + + def test_start_three_shards_two_replicas_and_cluster_slots(self): + # Run an actual cluster and verify CLUSTER SLOTS exposes 2 replica + # entries per slot-range row. + default_args = self._build_default_args() + default_args['useSlaves'] = True + default_args['replicasPerShard'] = 2 + cluster_env = ClusterEnv(shardsCount=3, redisBinaryPath=REDIS_BINARY, + outputFilesFormat='%s-test', + randomizePorts=True, **default_args) + try: + cluster_env.startEnv() + # All 9 processes alive. + for shard in cluster_env.shards: + assert shard.masterProcess is not None + assert shard.masterProcess.poll() is None + assert shard.getNumSlaves() == 2 + for proc in shard.slaveProcesses: + assert proc is not None + assert proc.poll() is None + # CLUSTER SLOTS should report 2 replicas per slot range. + master_conn = cluster_env.shards[0].getConnection() + slots_view = master_conn.execute_command('CLUSTER', 'SLOTS') + assert len(slots_view) == 3 + for row in slots_view: + # row = [start, end, master_entry, replica_entry, replica_entry] + assert len(row) == 5, ( + 'expected 2 replica entries per slot row, got row=%r' % (row,)) + # Each replica should be cluster-enabled. + for shard in cluster_env.shards: + for i in range(shard.getNumSlaves()): + replica_conn = shard.getSlaveConnection(i) + info = replica_conn.execute_command('INFO', 'cluster') + if isinstance(info, dict): + # decoded response client + assert info.get('cluster_enabled') in (1, '1', True) + else: + assert b'cluster_enabled:1' in info or 'cluster_enabled:1' in str(info) + finally: + cluster_env.stopEnv()