diff --git a/.opencode/package-lock.json b/.opencode/package-lock.json new file mode 100644 index 0000000000..86bbf5645f --- /dev/null +++ b/.opencode/package-lock.json @@ -0,0 +1,115 @@ +{ + "name": ".opencode", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "dependencies": { + "@opencode-ai/plugin": "1.3.17" + } + }, + "node_modules/@opencode-ai/plugin": { + "version": "1.3.17", + "resolved": "https://registry.npmjs.org/@opencode-ai/plugin/-/plugin-1.3.17.tgz", + "integrity": "sha512-N5lckFtYvEu2R8K1um//MIOTHsJHniF2kHoPIWPCrxKG5Jpismt1ISGzIiU3aKI2ht/9VgcqKPC5oZFLdmpxPw==", + "license": "MIT", + "dependencies": { + "@opencode-ai/sdk": "1.3.17", + "zod": "4.1.8" + }, + "peerDependencies": { + "@opentui/core": ">=0.1.96", + "@opentui/solid": ">=0.1.96" + }, + "peerDependenciesMeta": { + "@opentui/core": { + "optional": true + }, + "@opentui/solid": { + "optional": true + } + } + }, + "node_modules/@opencode-ai/sdk": { + "version": "1.3.17", + "resolved": "https://registry.npmjs.org/@opencode-ai/sdk/-/sdk-1.3.17.tgz", + "integrity": "sha512-2+MGgu7wynqTBwxezR01VAGhILXlpcHDY/pF7SWB87WOgLt3kD55HjKHNj6PWxyY8n575AZolR95VUC3gtwfmA==", + "license": "MIT", + "dependencies": { + "cross-spawn": "7.0.6" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC" + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/zod": { + "version": "4.1.8", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + } + } +} diff --git a/bench_ab.py b/bench_ab.py new file mode 100644 index 0000000000..f87d039e60 --- /dev/null +++ b/bench_ab.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +""" +Benchmark: recv_results_rows — origin/master vs new Cython metadata parser. + +Compares the ACTUAL Cython recv_results_rows path (FastResultMessage) +as it exists in the currently-built code. + +Run with: taskset -c 0 python3 bench_ab.py +""" + +import struct +import io +import time +import sys +import uuid + + +def write_short(buf, v): + buf.write(struct.pack('>H', v)) + +def write_int(buf, v): + buf.write(struct.pack('>i', v)) + +def write_string(buf, s): + if isinstance(s, str): + s = s.encode('utf8') + write_short(buf, len(s)) + buf.write(s) + +def write_type(buf, type_code, subtypes=()): + write_short(buf, type_code) + for st in subtypes: + if isinstance(st, tuple): + write_type(buf, st[0], st[1:]) + else: + write_type(buf, st) + + +UUID_TYPE = 0x000C +VARCHAR_TYPE = 0x000D +INT_TYPE = 0x0009 +BIGINT_TYPE = 0x0002 +BOOLEAN_TYPE = 0x0004 +DOUBLE_TYPE = 0x0007 +TIMESTAMP_TYPE = 0x000B +LIST_TYPE = 0x0020 +MAP_TYPE = 0x0021 +SET_TYPE = 0x0022 + + +def build_rows_message(colcount, type_codes_list, nrows=0): + buf = io.BytesIO() + write_int(buf, 0x0001) # GLOBAL_TABLES_SPEC + write_int(buf, colcount) + write_string(buf, 'test_ks') + write_string(buf, 'test_cf') + for i in range(colcount): + write_string(buf, 'col_%d' % i) + tc = type_codes_list[i % len(type_codes_list)] + if isinstance(tc, tuple): + write_type(buf, tc[0], tc[1:]) + else: + write_type(buf, tc) + write_int(buf, nrows) + for _ in range(nrows): + for i in range(colcount): + tc = type_codes_list[i % len(type_codes_list)] + base_tc = tc[0] if isinstance(tc, tuple) else tc + if base_tc == UUID_TYPE: + write_int(buf, 16); buf.write(uuid.uuid4().bytes) + elif base_tc == VARCHAR_TYPE: + v = b'test_value'; write_int(buf, len(v)); buf.write(v) + elif base_tc == INT_TYPE: + write_int(buf, 4); buf.write(struct.pack('>i', 42)) + elif base_tc in (BIGINT_TYPE, TIMESTAMP_TYPE, DOUBLE_TYPE): + write_int(buf, 8); buf.write(struct.pack('>q', 12345678)) + elif base_tc == BOOLEAN_TYPE: + write_int(buf, 1); buf.write(b'\x01') + elif base_tc in (LIST_TYPE, SET_TYPE, MAP_TYPE): + write_int(buf, 4); buf.write(struct.pack('>i', 0)) + else: + write_int(buf, 4); buf.write(b'\x00\x00\x00\x00') + return buf.getvalue() + + +def build_no_metadata_message(colcount=10): + buf = io.BytesIO() + write_int(buf, 0x0004) # NO_METADATA + write_int(buf, colcount) + write_int(buf, 0) # 0 rows + return buf.getvalue() + + +def bench(label, fn, iterations, warmup=1000): + for _ in range(warmup): + fn() + times = [] + for _ in range(iterations): + t0 = time.perf_counter_ns() + fn() + t1 = time.perf_counter_ns() + times.append(t1 - t0) + times.sort() + trim = max(1, len(times) // 20) + trimmed = times[trim:-trim] + mean_ns = sum(trimmed) / len(trimmed) + var = sum((t - mean_ns)**2 for t in trimmed) / len(trimmed) + cv = (var**0.5 / mean_ns * 100) if mean_ns else 0 + print(f" {label:50s} {mean_ns:9.0f} ns (cv {cv:4.1f}%)") + return mean_ns + + +def main(): + from cassandra.protocol import ProtocolHandler + from cassandra.cython_deps import HAVE_CYTHON + + print(f"HAVE_CYTHON: {HAVE_CYTHON}") + print(f"Python: {sys.version}") + + fast_cls = ProtocolHandler.message_types_by_opcode[0x08] + print(f"FastResultMessage: {fast_cls}") + print() + + simple_types = [UUID_TYPE, VARCHAR_TYPE, INT_TYPE, BIGINT_TYPE, BOOLEAN_TYPE, + DOUBLE_TYPE, TIMESTAMP_TYPE, VARCHAR_TYPE, INT_TYPE, UUID_TYPE] + + scenarios = [ + ("10 cols, 0 rows", 10, simple_types, 0, 10000), + ("3 cols, 0 rows", 3, simple_types[:3], 0, 10000), + ("50 cols, 0 rows", 50, simple_types, 0, 5000), + ("10 cols, 10 rows", 10, simple_types, 10, 5000), + ("10 cols, 100 rows", 10, simple_types, 100, 2000), + ("10 cols, 1000 rows", 10, simple_types, 1000, 500), + ] + + results = {} + for desc, colcount, types, nrows, iters in scenarios: + data = build_rows_message(colcount, types, nrows) + print(f"--- {desc} ({len(data)} bytes) ---") + + def fn(data=data): + f = io.BytesIO(data) + msg = fast_cls(2) + msg.recv_results_rows(f, 4, {}, None, None) + + t = bench("Cython recv_results_rows", fn, iters) + results[desc] = t + print() + + # NO_METADATA with result_metadata + from cassandra.cqltypes import (UUIDType, VarcharType, Int32Type, LongType, + BooleanType, DoubleType, DateType) + result_md = [ + ('ks', 'cf', 'c%d' % i, [UUIDType, VarcharType, Int32Type, LongType, + BooleanType, DoubleType, DateType, VarcharType, + Int32Type, UUIDType][i]) + for i in range(10) + ] + nm_data = build_no_metadata_message(10) + print(f"--- NO_METADATA, 10 cols, 0 rows ({len(nm_data)} bytes) ---") + def nm_fn(data=nm_data, md=result_md): + f = io.BytesIO(data) + msg = fast_cls(2) + msg.recv_results_rows(f, 4, {}, md, None) + t = bench("Cython recv_results_rows", nm_fn, 10000) + results["NO_METADATA"] = t + print() + + print("=" * 60) + print("SUMMARY (copy these numbers for A/B comparison)") + print("=" * 60) + for k, v in results.items(): + print(f" {k:30s} {v:9.0f} ns") + + +if __name__ == '__main__': + main() diff --git a/bench_ab_default.py b/bench_ab_default.py new file mode 100644 index 0000000000..f2f5f30603 --- /dev/null +++ b/bench_ab_default.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""A/B comparison: master vs PR for Default(DCAware) regression. + +Stashes PR changes, benchmarks master, restores PR, benchmarks PR, +all in one script using subprocess to avoid module caching. +""" +import subprocess +import sys +import json + +BENCH_CODE = ''' +import time, uuid, statistics +from unittest.mock import Mock +from cassandra.policies import DCAwareRoundRobinPolicy, DefaultLoadBalancingPolicy, SimpleConvictionPolicy +from cassandra.pool import Host + +class EP: + def __init__(self, a): + self.address = str(a) + self._port = 9042 + def resolve(self): + return (self.address, self._port) + def __repr__(self): + return f"{self.address}:{self._port}" + def __hash__(self): + return hash((self.address, self._port)) + def __eq__(self, o): + return isinstance(o, EP) and self.address == o.address + +hosts = [] +for dc in range(5): + for rack in range(3): + for node in range(3): + h = Host(EP(f"10.{dc}.{rack}.{node}"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + h.set_location_info(f"dc{dc}", f"rack{rack}") + h.set_up() + hosts.append(h) + +cluster = Mock() +cluster.metadata = Mock() +cluster.metadata.get_host = Mock(return_value=None) + +child = DCAwareRoundRobinPolicy(local_dc="dc0", used_hosts_per_remote_dc=1) +policy = DefaultLoadBalancingPolicy(child) +policy.populate(cluster, hosts) + +q = Mock() +q.keyspace = None +q.target_host = None + +N, ITERS = 100_000, 7 +times = [] +for _ in range(ITERS): + s = time.perf_counter_ns() + for _ in range(N): + for _ in policy.make_query_plan("ks", q): + pass + times.append((time.perf_counter_ns() - s) / N) + +print(f"{statistics.median(times):.0f}") +''' + +def run_bench(): + result = subprocess.run( + ["taskset", "-c", "0", sys.executable, "-c", BENCH_CODE], + capture_output=True, text=True, timeout=120 + ) + if result.returncode != 0: + print(f"STDERR: {result.stderr}", file=sys.stderr) + raise RuntimeError(f"Benchmark failed: {result.stderr}") + return float(result.stdout.strip()) + +# Run PR version 3 times +print("Running PR version...") +pr_results = [] +for i in range(3): + ns = run_bench() + pr_results.append(ns) + print(f" Run {i+1}: {ns:.0f} ns/op") + +# Switch to master +subprocess.run(["git", "stash"], capture_output=True) +subprocess.run(["git", "checkout", "origin/master", "--", "cassandra/policies.py"], capture_output=True) + +print("Running master version...") +master_results = [] +for i in range(3): + ns = run_bench() + master_results.append(ns) + print(f" Run {i+1}: {ns:.0f} ns/op") + +# Restore PR +subprocess.run(["git", "checkout", "pr-651", "--", "cassandra/policies.py"], capture_output=True) +subprocess.run(["git", "stash", "pop"], capture_output=True, check=False) + +pr_med = statistics.median(pr_results) +master_med = statistics.median(master_results) +print(f"\nDefault(DCAware) - master: {master_med:.0f} ns/op") +print(f"Default(DCAware) - PR: {pr_med:.0f} ns/op") +print(f"Difference: {pr_med - master_med:+.0f} ns/op ({(pr_med/master_med - 1)*100:+.1f}%)") + +import statistics diff --git a/bench_default_regression.py b/bench_default_regression.py new file mode 100644 index 0000000000..125744b47e --- /dev/null +++ b/bench_default_regression.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Micro-benchmark to isolate Default(DCAware) regression. + +Tests the exact hot path: DefaultLoadBalancingPolicy.make_query_plan +-> DCAwareRoundRobinPolicy.make_query_plan +-> (on PR) make_query_plan_with_exclusion delegation overhead +""" +import time +import uuid +import statistics +from unittest.mock import Mock + +from cassandra.policies import ( + DCAwareRoundRobinPolicy, + DefaultLoadBalancingPolicy, +) +from cassandra.pool import Host +from cassandra.policies import SimpleConvictionPolicy + + +class DefaultEndPoint: + def __init__(self, addr): + self.address = str(addr) + self._port = 9042 + def resolve(self): + return (self.address, self._port) + def __repr__(self): + return f"{self.address}:{self._port}" + def __hash__(self): + return hash((self.address, self._port)) + def __eq__(self, other): + return isinstance(other, DefaultEndPoint) and self.address == other.address and self._port == other._port + + +NUM_QUERIES = 100_000 +NUM_ITERATIONS = 7 + +def create_hosts(): + hosts = [] + for dc in range(5): + for rack in range(3): + for node in range(3): + h = Host(DefaultEndPoint(f"10.{dc}.{rack}.{node}"), + SimpleConvictionPolicy, host_id=uuid.uuid4()) + h.set_location_info(f"dc{dc}", f"rack{rack}") + h.set_up() + hosts.append(h) + return hosts + + +def make_query(): + q = Mock() + q.keyspace = None + q.target_host = None + return q + + +def bench_direct_dcaware(policy, query): + """Call DCAware.make_query_plan directly.""" + times = [] + for _ in range(NUM_ITERATIONS): + start = time.perf_counter_ns() + for _ in range(NUM_QUERIES): + for _ in policy.make_query_plan("ks", query): + pass + elapsed = time.perf_counter_ns() - start + times.append(elapsed / NUM_QUERIES) + return statistics.median(times) + + +def bench_via_default(policy, query): + """Call Default(DCAware).make_query_plan.""" + times = [] + for _ in range(NUM_ITERATIONS): + start = time.perf_counter_ns() + for _ in range(NUM_QUERIES): + for _ in policy.make_query_plan("ks", query): + pass + elapsed = time.perf_counter_ns() - start + times.append(elapsed / NUM_QUERIES) + return statistics.median(times) + + +def bench_with_exclusion_direct(policy, query): + """Call DCAware.make_query_plan_with_exclusion directly (empty excluded).""" + times = [] + for _ in range(NUM_ITERATIONS): + start = time.perf_counter_ns() + for _ in range(NUM_QUERIES): + for _ in policy.make_query_plan_with_exclusion("ks", query): + pass + elapsed = time.perf_counter_ns() - start + times.append(elapsed / NUM_QUERIES) + return statistics.median(times) + + +def main(): + hosts = create_hosts() + cluster = Mock() + cluster.metadata = Mock() + cluster.metadata.get_host = Mock(return_value=None) + + query = make_query() + + # DCAware direct + dcaware = DCAwareRoundRobinPolicy(local_dc="dc0", used_hosts_per_remote_dc=1) + dcaware.populate(cluster, hosts) + + ns1 = bench_direct_dcaware(dcaware, query) + ns2 = bench_with_exclusion_direct(dcaware, query) + + # Default(DCAware) + dcaware2 = DCAwareRoundRobinPolicy(local_dc="dc0", used_hosts_per_remote_dc=1) + default_policy = DefaultLoadBalancingPolicy(dcaware2) + default_policy.populate(cluster, hosts) + + ns3 = bench_via_default(default_policy, query) + + print(f"{'Path':<45} {'ns/op':>10}") + print("-" * 57) + print(f"{'DCAware.make_query_plan (delegation)':<45} {ns1:>10.0f}") + print(f"{'DCAware.make_query_plan_with_exclusion()':<45} {ns2:>10.0f}") + print(f"{'Default -> DCAware.make_query_plan':<45} {ns3:>10.0f}") + print() + print(f"Delegation overhead: {ns1 - ns2:.0f} ns/op ({(ns1/ns2 - 1)*100:.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/bench_metadata_parsing.py b/bench_metadata_parsing.py new file mode 100644 index 0000000000..6ae030f435 --- /dev/null +++ b/bench_metadata_parsing.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +Benchmark: Cython metadata_parser vs Python BytesIO metadata parsing. + +Isolates metadata parsing (recv_results_metadata) separately from row parsing. +Also measures the combined recv_results_rows path. + +Run with: taskset -c 0 python3 bench_metadata_parsing.py +""" + +import struct +import io +import time +import sys +import uuid + +# ── helpers ────────────────────────────────────────────────────────────── + +def write_short(buf, v): + buf.write(struct.pack('>H', v)) + +def write_int(buf, v): + buf.write(struct.pack('>i', v)) + +def write_string(buf, s): + if isinstance(s, str): + s = s.encode('utf8') + write_short(buf, len(s)) + buf.write(s) + +def write_type(buf, type_code, subtypes=()): + """Write a CQL type option to buf.""" + write_short(buf, type_code) + for st in subtypes: + if isinstance(st, tuple): + write_type(buf, st[0], st[1:]) + else: + write_type(buf, st) + + +# Type codes +UUID_TYPE = 0x000C +VARCHAR_TYPE = 0x000D +INT_TYPE = 0x0009 +BIGINT_TYPE = 0x0002 +BOOLEAN_TYPE = 0x0004 +DOUBLE_TYPE = 0x0007 +TIMESTAMP_TYPE = 0x000B +LIST_TYPE = 0x0020 +MAP_TYPE = 0x0021 +SET_TYPE = 0x0022 + + +def build_metadata_bytes(colcount, type_codes_list, global_table_spec=True, + no_metadata=False): + """Build just the metadata portion of a RESULT ROWS message.""" + buf = io.BytesIO() + if no_metadata: + write_int(buf, 0x0004) # NO_METADATA flag + write_int(buf, colcount) + return buf.getvalue() + + flags = 0x0001 if global_table_spec else 0x0000 + write_int(buf, flags) + write_int(buf, colcount) + + if global_table_spec: + write_string(buf, 'test_ks') + write_string(buf, 'test_cf') + + for i in range(colcount): + if not global_table_spec: + write_string(buf, 'test_ks') + write_string(buf, 'test_cf') + write_string(buf, 'col_%d' % i) + tc = type_codes_list[i % len(type_codes_list)] + if isinstance(tc, tuple): + write_type(buf, tc[0], tc[1:]) + else: + write_type(buf, tc) + + return buf.getvalue() + + +def build_rows_message(colcount, type_codes_list, nrows=0, global_table_spec=True): + """Build full RESULT ROWS body (metadata + rowcount + row data).""" + meta_bytes = build_metadata_bytes(colcount, type_codes_list, global_table_spec) + buf = io.BytesIO() + buf.write(meta_bytes) + write_int(buf, nrows) + # Write minimal row data + for _ in range(nrows): + for i in range(colcount): + tc = type_codes_list[i % len(type_codes_list)] + base_tc = tc[0] if isinstance(tc, tuple) else tc + if base_tc == UUID_TYPE: + write_int(buf, 16) + buf.write(uuid.uuid4().bytes) + elif base_tc in (VARCHAR_TYPE,): + val = b'test_value' + write_int(buf, len(val)) + buf.write(val) + elif base_tc in (INT_TYPE,): + write_int(buf, 4) + buf.write(struct.pack('>i', 42)) + elif base_tc in (BIGINT_TYPE, TIMESTAMP_TYPE, DOUBLE_TYPE): + write_int(buf, 8) + buf.write(struct.pack('>q', 12345678)) + elif base_tc == BOOLEAN_TYPE: + write_int(buf, 1) + buf.write(b'\x01') + elif base_tc in (LIST_TYPE, SET_TYPE): + write_int(buf, 4) + buf.write(struct.pack('>i', 0)) + elif base_tc == MAP_TYPE: + write_int(buf, 4) + buf.write(struct.pack('>i', 0)) + else: + write_int(buf, 4) + buf.write(b'\x00\x00\x00\x00') + return buf.getvalue() + + +def bench_fn(label, fn, iterations, warmup=500): + """Benchmark a zero-arg callable. Returns (mean_ns, cv_pct).""" + for _ in range(warmup): + fn() + times = [] + for _ in range(iterations): + t0 = time.perf_counter_ns() + fn() + t1 = time.perf_counter_ns() + times.append(t1 - t0) + times.sort() + trim = max(1, len(times) // 20) + trimmed = times[trim:-trim] + mean_ns = sum(trimmed) / len(trimmed) + variance = sum((t - mean_ns) ** 2 for t in trimmed) / len(trimmed) + stddev = variance ** 0.5 + cv_pct = (stddev / mean_ns * 100) if mean_ns > 0 else 0 + print(f" {label:50s} {mean_ns:8.0f} ns (cv {cv_pct:4.1f}%, n={len(trimmed)})") + return mean_ns, cv_pct + + +def main(): + from cassandra.protocol import ProtocolHandler, ResultMessage, read_int as py_read_int + from cassandra.cython_deps import HAVE_CYTHON + from cassandra.metadata_parser import make_recv_results_metadata + from cassandra.bytesio import BytesIOReader + + print(f"HAVE_CYTHON: {HAVE_CYTHON}") + print(f"Python: {sys.version}") + print() + + recv_meta_br, recv_prep_br = make_recv_results_metadata() + fast_cls = ProtocolHandler.message_types_by_opcode[0x08] + slow_cls = ResultMessage + + iterations = 10000 + + simple_types = [UUID_TYPE, VARCHAR_TYPE, INT_TYPE, BIGINT_TYPE, BOOLEAN_TYPE, + DOUBLE_TYPE, TIMESTAMP_TYPE, VARCHAR_TYPE, INT_TYPE, UUID_TYPE] + + # ================================================================= + print("=" * 72) + print("PART 1: recv_results_metadata ONLY (metadata parsing isolation)") + print("=" * 72) + print() + + for colcount, label_suffix in [(3, "3 cols"), (10, "10 cols"), (50, "50 cols")]: + types = simple_types[:colcount] if colcount <= len(simple_types) else simple_types + meta_bytes = build_metadata_bytes(colcount, types) + print(f"--- {label_suffix}, simple scalars ({len(meta_bytes)} bytes) ---") + + # Python: recv_results_metadata on BytesIO + def py_meta(data=meta_bytes): + f = io.BytesIO(data) + msg = slow_cls(2) + msg.recv_results_metadata(f, {}) + return msg + t_py, _ = bench_fn("Python BytesIO recv_results_metadata", py_meta, iterations) + + # Cython: recv_results_metadata_br on BytesIOReader + def cy_meta(data=meta_bytes): + reader = BytesIOReader(data) + msg = slow_cls(2) + recv_meta_br(msg, reader, {}) + return msg + t_cy, _ = bench_fn("Cython BytesIOReader recv_results_metadata", cy_meta, iterations) + print(f" Speedup: {t_py/t_cy:.2f}x") + print() + + # NO_METADATA + nm_bytes = build_metadata_bytes(10, simple_types, no_metadata=True) + print(f"--- NO_METADATA, 10 cols ({len(nm_bytes)} bytes) ---") + def py_nm(data=nm_bytes): + f = io.BytesIO(data) + msg = slow_cls(2) + msg.recv_results_metadata(f, {}) + def cy_nm(data=nm_bytes): + reader = BytesIOReader(data) + msg = slow_cls(2) + recv_meta_br(msg, reader, {}) + t_py_nm, _ = bench_fn("Python BytesIO recv_results_metadata", py_nm, iterations) + t_cy_nm, _ = bench_fn("Cython BytesIOReader recv_results_metadata", cy_nm, iterations) + print(f" Speedup: {t_py_nm/t_cy_nm:.2f}x") + print() + + # Collections + coll_types = [UUID_TYPE, VARCHAR_TYPE, + (LIST_TYPE, VARCHAR_TYPE), + (SET_TYPE, INT_TYPE), + (MAP_TYPE, VARCHAR_TYPE, INT_TYPE), + INT_TYPE, BIGINT_TYPE, BOOLEAN_TYPE, + (LIST_TYPE, UUID_TYPE), + VARCHAR_TYPE] + coll_bytes = build_metadata_bytes(10, coll_types) + print(f"--- 10 cols with collections ({len(coll_bytes)} bytes) ---") + def py_coll(data=coll_bytes): + f = io.BytesIO(data) + msg = slow_cls(2) + msg.recv_results_metadata(f, {}) + def cy_coll(data=coll_bytes): + reader = BytesIOReader(data) + msg = slow_cls(2) + recv_meta_br(msg, reader, {}) + t_py_coll, _ = bench_fn("Python BytesIO recv_results_metadata", py_coll, iterations) + t_cy_coll, _ = bench_fn("Cython BytesIOReader recv_results_metadata", cy_coll, iterations) + print(f" Speedup: {t_py_coll/t_cy_coll:.2f}x") + print() + + # ================================================================= + print("=" * 72) + print("PART 2: recv_results_rows (metadata + row parsing combined)") + print("=" * 72) + print() + + for colcount, nrows, iters in [(10, 0, 10000), (10, 10, 5000), + (10, 100, 2000), (50, 0, 5000)]: + types = simple_types[:colcount] if colcount <= len(simple_types) else simple_types + data = build_rows_message(colcount, types, nrows=nrows) + print(f"--- {colcount} cols, {nrows} rows ({len(data)} bytes) ---") + + def fast_fn(data=data): + f = io.BytesIO(data) + msg = fast_cls(2) + msg.recv_results_rows(f, 4, {}, None, None) + def slow_fn(data=data): + f = io.BytesIO(data) + msg = slow_cls(2) + msg.recv_results_rows(f, 4, {}, None, None) + + t_fast, _ = bench_fn("Cython (new: single BytesIOReader)", fast_fn, iters) + t_slow, _ = bench_fn("Python BytesIO (baseline)", slow_fn, iters) + print(f" Speedup: {t_slow/t_fast:.2f}x") + print() + + +if __name__ == '__main__': + main() diff --git a/bench_policies.py b/bench_policies.py new file mode 100644 index 0000000000..e2d996f01f --- /dev/null +++ b/bench_policies.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +"""Benchmark for load balancing policy query plan generation. + +Topology: 5 DCs x 3 racks x 3 nodes = 45 nodes. +Measures median of 5 iterations, each with 50K query plan materializations. +Reports ns/op. +""" +import time +import uuid +import statistics +from unittest.mock import Mock, PropertyMock + +from cassandra.policies import ( + DCAwareRoundRobinPolicy, + RackAwareRoundRobinPolicy, + TokenAwarePolicy, + DefaultLoadBalancingPolicy, + HostFilterPolicy, + HostDistance, +) +from cassandra.pool import Host +from cassandra.metadata import Murmur3Token +from cassandra.policies import SimpleConvictionPolicy + + +class DefaultEndPoint: + def __init__(self, addr): + self.address = str(addr) + self._port = 9042 + + def resolve(self): + return (self.address, self._port) + + def __repr__(self): + return f"{self.address}:{self._port}" + + def __hash__(self): + return hash((self.address, self._port)) + + def __eq__(self, other): + return isinstance(other, DefaultEndPoint) and self.address == other.address and self._port == other._port + + +NUM_DCS = 5 +NUM_RACKS = 3 +NUM_NODES_PER_RACK = 3 +NUM_QUERIES = 50_000 +NUM_ITERATIONS = 10 + +LOCAL_DC = "dc0" +LOCAL_RACK = "rack0" + + +def create_hosts(): + hosts = [] + idx = 0 + for dc in range(NUM_DCS): + for rack in range(NUM_RACKS): + for node in range(NUM_NODES_PER_RACK): + h = Host(DefaultEndPoint(f"10.{dc}.{rack}.{node}"), + SimpleConvictionPolicy, + host_id=uuid.uuid4()) + h.set_location_info(f"dc{dc}", f"rack{rack}") + h.set_up() + hosts.append(h) + idx += 1 + return hosts + + +def make_mock_cluster(hosts): + cluster = Mock() + token_map = Mock() + token_map.token_class = Murmur3Token + # Build a simple ring with one token per host + tokens = list(range(0, 2**63, 2**63 // len(hosts)))[:len(hosts)] + token_to_host = {Murmur3Token(t): h for t, h in zip(tokens, hosts)} + + def get_replicas(ks, token): + # Return the 3 hosts closest to the token + sorted_tokens = sorted(token_to_host.keys(), key=lambda t: t.value) + result = [] + # Find insertion point + for i, t in enumerate(sorted_tokens): + if t.value >= token.value: + for j in range(3): + result.append(token_to_host[sorted_tokens[(i + j) % len(sorted_tokens)]]) + return result + # Wrap around + for j in range(3): + result.append(token_to_host[sorted_tokens[j]]) + return result + + token_map.get_replicas = get_replicas + + from cassandra.tablets import Tablets + cluster.metadata = Mock() + cluster.metadata.token_map = token_map + cluster.metadata._tablets = Tablets({}) + cluster.metadata.get_replicas = lambda ks, key: get_replicas(ks, Murmur3Token.from_key(key)) + return cluster + + +def make_query(routing_key=None): + q = Mock() + q.keyspace = "ks" + q.table = "tbl" + q.routing_key = routing_key + q.is_lwt = Mock(return_value=False) + return q + + +def bench(name, policy, hosts, queries): + """Run benchmark: materialize query plans, return min ns/op.""" + times = [] + for _ in range(NUM_ITERATIONS): + start = time.perf_counter_ns() + for q in queries: + plan = policy.make_query_plan("ks", q) + # Materialize the full plan + for _ in plan: + pass + elapsed = time.perf_counter_ns() - start + times.append(elapsed / len(queries)) + # Use min: on a noisy system the fastest run is closest to true cost + return min(times) + + +def main(): + hosts = create_hosts() + cluster = make_mock_cluster(hosts) + + # Pre-generate routing keys for TokenAware + import struct + import random + import numpy as np + + # Unique keys (100% cache miss) + routing_keys_unique = [struct.pack(">q", i) for i in range(NUM_QUERIES)] + queries_unique = [make_query(rk) for rk in routing_keys_unique] + queries_no_routing = [make_query(None) for _ in range(NUM_QUERIES)] + + # Zipfian workload: 500 distinct keys, sampled with Zipf distribution + # Models real workloads where a few partitions are very hot. + NUM_DISTINCT_KEYS = 500 + distinct_keys = [struct.pack(">q", i) for i in range(NUM_DISTINCT_KEYS)] + rng = np.random.default_rng(42) + zipf_indices = rng.zipf(1.2, size=NUM_QUERIES) % NUM_DISTINCT_KEYS + queries_zipfian = [make_query(distinct_keys[i]) for i in zipf_indices] + + results = {} + + # 1. DCAwareRoundRobinPolicy + policy = DCAwareRoundRobinPolicy(local_dc=LOCAL_DC, used_hosts_per_remote_dc=1) + policy.populate(cluster, hosts) + ns = bench("DCAware", policy, hosts, queries_no_routing) + results["DCAware"] = ns + + # 2. RackAwareRoundRobinPolicy + policy = RackAwareRoundRobinPolicy(local_dc=LOCAL_DC, local_rack=LOCAL_RACK, used_hosts_per_remote_dc=1) + policy.populate(cluster, hosts) + ns = bench("RackAware", policy, hosts, queries_no_routing) + results["RackAware"] = ns + + # 3. TokenAware(DCAware) -- cache miss (unique keys) + child = DCAwareRoundRobinPolicy(local_dc=LOCAL_DC, used_hosts_per_remote_dc=1) + policy = TokenAwarePolicy(child, shuffle_replicas=False) + policy.populate(cluster, hosts) + ns = bench("TokenAware(DCAware) miss", policy, hosts, queries_unique) + results["TokenAware(DCAware) miss"] = ns + + # 3b. TokenAware(DCAware) -- cache hit (zipfian keys) + child = DCAwareRoundRobinPolicy(local_dc=LOCAL_DC, used_hosts_per_remote_dc=1) + policy = TokenAwarePolicy(child, shuffle_replicas=False) + policy.populate(cluster, hosts) + ns = bench("TokenAware(DCAware) zipf", policy, hosts, queries_zipfian) + results["TokenAware(DCAware) zipf"] = ns + + # 4. TokenAware(RackAware) -- cache miss (unique keys) + child = RackAwareRoundRobinPolicy(local_dc=LOCAL_DC, local_rack=LOCAL_RACK, used_hosts_per_remote_dc=1) + policy = TokenAwarePolicy(child, shuffle_replicas=False) + policy.populate(cluster, hosts) + ns = bench("TokenAware(RackAware) miss", policy, hosts, queries_unique) + results["TokenAware(RackAware) miss"] = ns + + # 4b. TokenAware(RackAware) -- cache hit (zipfian keys) + child = RackAwareRoundRobinPolicy(local_dc=LOCAL_DC, local_rack=LOCAL_RACK, used_hosts_per_remote_dc=1) + policy = TokenAwarePolicy(child, shuffle_replicas=False) + policy.populate(cluster, hosts) + ns = bench("TokenAware(RackAware) zipf", policy, hosts, queries_zipfian) + results["TokenAware(RackAware) zipf"] = ns + + # 5. Default(DCAware) + child = DCAwareRoundRobinPolicy(local_dc=LOCAL_DC, used_hosts_per_remote_dc=1) + policy = DefaultLoadBalancingPolicy(child) + policy.populate(cluster, hosts) + ns = bench("Default(DCAware)", policy, hosts, queries_no_routing) + results["Default(DCAware)"] = ns + + # 6. HostFilter(DCAware) + child = DCAwareRoundRobinPolicy(local_dc=LOCAL_DC, used_hosts_per_remote_dc=1) + policy = HostFilterPolicy(child, predicate=lambda h: True) + policy.populate(cluster, hosts) + ns = bench("HostFilter(DCAware)", policy, hosts, queries_no_routing) + results["HostFilter(DCAware)"] = ns + + print(f"\n{'Policy':<36} {'ns/op':>10}") + print("-" * 48) + for name, ns in results.items(): + print(f"{name:<36} {ns:>10.0f}") + + +if __name__ == "__main__": + main() diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..50a38decb2 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -13,7 +13,7 @@ # limitations under the License. import random -from collections import namedtuple +from collections import namedtuple, OrderedDict from itertools import islice, cycle, groupby, repeat import logging from random import randint, shuffle @@ -157,6 +157,18 @@ def make_query_plan(self, working_keyspace=None, query=None): """ raise NotImplementedError() + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + """ + Same as :meth:`make_query_plan`, but with an additional `excluded` parameter. + `excluded` should be a container (set, list, etc.) of hosts to skip. + + The default implementation simply delegates to `make_query_plan` and filters the result. + Subclasses may override this for performance. + """ + for host in self.make_query_plan(working_keyspace, query): + if host not in excluded: + yield host + def check_supported(self): """ This will be called after the cluster Metadata has been initialized. @@ -198,6 +210,20 @@ def make_query_plan(self, working_keyspace=None, query=None): else: return [] + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + pos = self._position + self._position += 1 + + hosts = self._live_hosts + length = len(hosts) + if length: + pos %= length + for host in islice(cycle(hosts), pos, pos + length): + if host not in excluded: + yield host + else: + return + def on_up(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) @@ -223,7 +249,6 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): """ local_dc = None - used_hosts_per_remote_dc = 0 def __init__(self, local_dc='', used_hosts_per_remote_dc=0): """ @@ -242,36 +267,52 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0): By default, all remote hosts are ignored. """ self.local_dc = local_dc - self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + self._remote_hosts = {} + self._used_hosts_per_remote_dc = used_hosts_per_remote_dc self._position = 0 LoadBalancingPolicy.__init__(self) + @property + def used_hosts_per_remote_dc(self): + return self._used_hosts_per_remote_dc + + @used_hosts_per_remote_dc.setter + def used_hosts_per_remote_dc(self, value): + self._used_hosts_per_remote_dc = value + self._refresh_remote_hosts() + def _dc(self, host): return host.datacenter or self.local_dc + def _refresh_remote_hosts(self): + # Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+) + # while still providing O(1) lookup for `host in self._remote_hosts`. + remote_hosts = {} + if self.used_hosts_per_remote_dc > 0: + for datacenter, hosts in self._dc_live_hosts.items(): + if datacenter != self.local_dc: + remote_hosts.update( + dict.fromkeys(hosts[:self.used_hosts_per_remote_dc]) + ) + self._remote_hosts = remote_hosts + def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) self._position = randint(0, len(hosts) - 1) if hosts else 0 + self._refresh_remote_hosts() def distance(self, host): dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL - if not self.used_hosts_per_remote_dc: - return HostDistance.IGNORED - else: - dc_hosts = self._dc_live_hosts.get(dc) - if not dc_hosts: - return HostDistance.IGNORED - - if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: - return HostDistance.REMOTE - else: - return HostDistance.IGNORED + remote_hosts = self._remote_hosts + if host in remote_hosts: + return HostDistance.REMOTE + return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments @@ -280,32 +321,74 @@ def make_query_plan(self, working_keyspace=None, query=None): self._position += 1 local_live = self._dc_live_hosts.get(self.local_dc, ()) - pos = (pos % len(local_live)) if local_live else 0 - for host in islice(cycle(local_live), pos, pos + len(local_live)): + length = len(local_live) + if length: + pos %= length + for i in range(length): + yield local_live[(pos + i) % length] + + # Read _remote_hosts late so topology changes during local + # iteration are visible. + for host in self._remote_hosts: yield host - # the dict can change, so get candidate DCs iterating over keys of a copy - other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] - for dc in other_dcs: - remote_live = self._dc_live_hosts.get(dc, ()) - for host in remote_live[:self.used_hosts_per_remote_dc]: + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + # not thread-safe, but we don't care much about lost increments + # for the purposes of load balancing + pos = self._position + self._position += 1 + + local_live = self._dc_live_hosts.get(self.local_dc, ()) + length = len(local_live) + if not excluded: + if length: + pos %= length + for i in range(length): + yield local_live[(pos + i) % length] + # Read _remote_hosts late so topology changes during local + # iteration are visible. + for host in self._remote_hosts: yield host + return + + if not isinstance(excluded, set): + excluded = set(excluded) + + if length: + pos %= length + for i in range(length): + host = local_live[(pos + i) % length] + if host in excluded: + continue + yield host + + for host in self._remote_hosts: + if host in excluded: + continue + yield host def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh + refresh_remote = False if not self.local_dc and host.datacenter: self.local_dc = host.datacenter log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " "if incorrect, please specify a local_dc to the constructor, " "or limit contact points to local cluster nodes" % (self.local_dc, host.endpoint)) + refresh_remote = True dc = self._dc(host) with self._hosts_lock: current_hosts = self._dc_live_hosts.get(dc, ()) if host not in current_hosts: self._dc_live_hosts[dc] = current_hosts + (host, ) + if dc != self.local_dc: + refresh_remote = True + + if refresh_remote: + self._refresh_remote_hosts() def on_down(self, host): dc = self._dc(host) @@ -318,6 +401,9 @@ def on_down(self, host): else: del self._dc_live_hosts[dc] + if dc != self.local_dc: + self._refresh_remote_hosts() + def on_add(self, host): self.on_up(host) @@ -333,7 +419,6 @@ class RackAwareRoundRobinPolicy(LoadBalancingPolicy): local_dc = None local_rack = None - used_hosts_per_remote_dc = 0 def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0): """ @@ -350,19 +435,48 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0): """ self.local_rack = local_rack self.local_dc = local_dc - self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._live_hosts = {} self._dc_live_hosts = {} + self._remote_hosts = {} + self._non_local_rack_hosts = () + self._used_hosts_per_remote_dc = used_hosts_per_remote_dc self._endpoints = [] self._position = 0 LoadBalancingPolicy.__init__(self) + @property + def used_hosts_per_remote_dc(self): + return self._used_hosts_per_remote_dc + + @used_hosts_per_remote_dc.setter + def used_hosts_per_remote_dc(self, value): + self._used_hosts_per_remote_dc = value + self._refresh_remote_hosts() + def _rack(self, host): return host.rack or self.local_rack def _dc(self, host): return host.datacenter or self.local_dc + def _refresh_remote_hosts(self): + # Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+) + # while still providing O(1) lookup for `host in self._remote_hosts`. + remote_hosts = {} + if self.used_hosts_per_remote_dc > 0: + for datacenter, hosts in self._dc_live_hosts.items(): + if datacenter != self.local_dc: + remote_hosts.update( + dict.fromkeys(hosts[:self.used_hosts_per_remote_dc]) + ) + self._remote_hosts = remote_hosts + + def _refresh_non_local_rack_hosts(self): + local_live = self._dc_live_hosts.get(self.local_dc, ()) + self._non_local_rack_hosts = tuple( + h for h in local_live if self._rack(h) != self.local_rack + ) + def populate(self, cluster, hosts): for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))): self._live_hosts[(dc, rack)] = tuple({*rack_hosts, *self._live_hosts.get((dc, rack), [])}) @@ -370,71 +484,118 @@ def populate(self, cluster, hosts): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) self._position = randint(0, len(hosts) - 1) if hosts else 0 + self._refresh_remote_hosts() + self._refresh_non_local_rack_hosts() def distance(self, host): - rack = self._rack(host) dc = self._dc(host) - if rack == self.local_rack and dc == self.local_dc: - return HostDistance.LOCAL_RACK - if dc == self.local_dc: + if self._rack(host) == self.local_rack: + return HostDistance.LOCAL_RACK return HostDistance.LOCAL - if not self.used_hosts_per_remote_dc: - return HostDistance.IGNORED - - dc_hosts = self._dc_live_hosts.get(dc, ()) - if not dc_hosts: - return HostDistance.IGNORED - if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc: + remote_hosts = self._remote_hosts + if host in remote_hosts: return HostDistance.REMOTE - else: - return HostDistance.IGNORED + return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) - pos = (pos % len(local_rack_live)) if local_rack_live else 0 - # Slice the cyclic iterator to start from pos and include the next len(local_live) elements - # This ensures we get exactly one full cycle starting from pos - for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): - yield host + length = len(local_rack_live) + if length: + p = pos % length + for i in range(length): + yield local_rack_live[(p + i) % length] + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + yield local_non_rack[(p + i) % length] - local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] - pos = (pos % len(local_live)) if local_live else 0 - for host in islice(cycle(local_live), pos, pos + len(local_live)): + # Read _remote_hosts late so topology changes during local + # iteration are visible. + for host in self._remote_hosts: yield host - # the dict can change, so get candidate DCs iterating over keys of a copy - for dc, remote_live in self._dc_live_hosts.copy().items(): - if dc != self.local_dc: - for host in remote_live[:self.used_hosts_per_remote_dc]: - yield host + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + pos = self._position + self._position += 1 + + local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + length = len(local_rack_live) + if not excluded: + if length: + p = pos % length + for i in range(length): + yield local_rack_live[(p + i) % length] + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + yield local_non_rack[(p + i) % length] + + # Read _remote_hosts late so topology changes during local + # iteration are visible. + for host in self._remote_hosts: + yield host + return + + if not isinstance(excluded, set): + excluded = set(excluded) + + if length: + p = pos % length + for i in range(length): + host = local_rack_live[(p + i) % length] + if host in excluded: + continue + yield host + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + host = local_non_rack[(p + i) % length] + if host in excluded: + continue + yield host + + # Read _remote_hosts late so topology changes during local + # iteration are visible. + for host in self._remote_hosts: + if host in excluded: + continue + yield host def on_up(self, host): dc = self._dc(host) rack = self._rack(host) with self._hosts_lock: - current_rack_hosts = self._live_hosts.get((dc, rack), ()) - if host not in current_rack_hosts: - self._live_hosts[(dc, rack)] = current_rack_hosts + (host, ) current_dc_hosts = self._dc_live_hosts.get(dc, ()) if host not in current_dc_hosts: self._dc_live_hosts[dc] = current_dc_hosts + (host, ) + if dc != self.local_dc: + self._refresh_remote_hosts() + else: + self._refresh_non_local_rack_hosts() + + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host not in current_rack_hosts: + self._live_hosts[(dc, rack)] = current_rack_hosts + (host, ) + def on_down(self, host): dc = self._dc(host) rack = self._rack(host) with self._hosts_lock: - current_rack_hosts = self._live_hosts.get((dc, rack), ()) - if host in current_rack_hosts: - hosts = tuple(h for h in current_rack_hosts if h != host) - if hosts: - self._live_hosts[(dc, rack)] = hosts - else: - del self._live_hosts[(dc, rack)] current_dc_hosts = self._dc_live_hosts.get(dc, ()) if host in current_dc_hosts: hosts = tuple(h for h in current_dc_hosts if h != host) @@ -443,6 +604,19 @@ def on_down(self, host): else: del self._dc_live_hosts[dc] + if dc != self.local_dc: + self._refresh_remote_hosts() + else: + self._refresh_non_local_rack_hosts() + + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host in current_rack_hosts: + hosts = tuple(h for h in current_rack_hosts if h != host) + if hosts: + self._live_hosts[(dc, rack)] = hosts + else: + del self._live_hosts[(dc, rack)] + def on_add(self, host): self.on_up(host) @@ -464,6 +638,11 @@ class TokenAwarePolicy(LoadBalancingPolicy): If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. + + An LRU cache of size :attr:`cache_replicas_size` (default 1024) avoids + repeated token-to-replica lookups for the same (keyspace, routing_key) + pair. Set to 0 to disable caching. The cache is automatically + invalidated when the cluster topology changes. """ _child_policy = None @@ -473,9 +652,15 @@ class TokenAwarePolicy(LoadBalancingPolicy): Yield local replicas in a random order. """ - def __init__(self, child_policy, shuffle_replicas=True): + def __init__(self, child_policy, shuffle_replicas=True, cache_replicas_size=1024): + super().__init__() self._child_policy = child_policy self.shuffle_replicas = shuffle_replicas + self._cluster_metadata = None + self._cache_replicas_size = max(0, cache_replicas_size) + self._replica_cache = OrderedDict() + self._replica_cache_token_map_ref = None + self._cache_lock = Lock() def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata @@ -493,40 +678,168 @@ def check_supported(self): def distance(self, *args, **kwargs): return self._child_policy.distance(*args, **kwargs) + def _get_cached_replicas(self, keyspace, table, routing_key_bytes, token_map): + """ + Return cached replicas for the given keyspace, table, and routing key, + or None on cache miss. The cache is invalidated when: + - the token_map object identity changes (full topology rebuild), or + - the keyspace's replica map has been rebuilt in-place (e.g. ALTER + KEYSPACE), detected via object identity of the per-keyspace map. + + The table is part of the cache key so that tablet-backed and + non-tablet tables in the same keyspace don't collide. + """ + if not self._cache_replicas_size: + return None + with self._cache_lock: + if token_map is not self._replica_cache_token_map_ref: + # Token map was rebuilt -- entire cache is stale. + self._replica_cache = OrderedDict() + self._replica_cache_token_map_ref = token_map + cache_key = (keyspace, table, routing_key_bytes) + entry = self._replica_cache.get(cache_key) + if entry is not None: + replicas, ks_map_ref = entry + # Validate the keyspace replica map hasn't been rebuilt + # in-place (e.g. ALTER KEYSPACE changes replication). + current_ks_map = token_map.tokens_to_hosts_by_ks.get(keyspace) + if ks_map_ref is not current_ks_map: + del self._replica_cache[cache_key] + return None + # Promote to most-recently-used. + self._replica_cache.move_to_end(cache_key) + return replicas + return None + + def _put_cached_replicas(self, keyspace, table, routing_key_bytes, replicas, token_map): + """ + Store replicas in the LRU cache, evicting the oldest entry if + the cache exceeds its configured size. The keyspace's current + replica-map object reference is stored alongside so that in-place + rebuilds (ALTER KEYSPACE) are detected on lookup via identity check. + """ + if not self._cache_replicas_size: + return + with self._cache_lock: + if token_map is not self._replica_cache_token_map_ref: + self._replica_cache = OrderedDict() + self._replica_cache_token_map_ref = token_map + cache_key = (keyspace, table, routing_key_bytes) + ks_map_ref = token_map.tokens_to_hosts_by_ks.get(keyspace) + self._replica_cache[cache_key] = (replicas, ks_map_ref) + self._replica_cache.move_to_end(cache_key) + if len(self._replica_cache) > self._cache_replicas_size: + self._replica_cache.popitem(last=False) + def make_query_plan(self, working_keyspace=None, query=None): keyspace = query.keyspace if query and query.keyspace else working_keyspace child = self._child_policy if query is None or query.routing_key is None or keyspace is None: - for host in child.make_query_plan(keyspace, query): - yield host + yield from child.make_query_plan(keyspace, query) return + cluster_metadata = self._cluster_metadata + token_map = cluster_metadata.token_map replicas = [] - tablet = self._cluster_metadata._tablets.get_tablet_for_key( - keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) - - if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) - - replicas = [host for host in child_plan if host.host_id in replicas_mapped] + if token_map: + try: + # Check the LRU cache first -- avoids the hash (from_key) + # and token-map lookup on repeated routing keys. + cached = self._get_cached_replicas(keyspace, query.table, query.routing_key, token_map) + if cached is not None: + replicas = cached + else: + token = token_map.token_class.from_key(query.routing_key) + + # Only check tablets if any exist -- avoids the method + # call + dict lookup when tablets are not in use. + if cluster_metadata._tablets: + tablet = cluster_metadata._tablets.get_tablet_for_key( + keyspace, query.table, token) + if tablet is not None: + replicas_mapped = {r[0] for r in tablet.replicas} + child_plan = child.make_query_plan(keyspace, query) + replicas = [host for host in child_plan if host.host_id in replicas_mapped] + + if not replicas: + try: + replicas = token_map.get_replicas(keyspace, token) + except Exception: + log.debug( + "Failed to get replicas from token_map, " + "falling back to cluster metadata" + ) + replicas = cluster_metadata.get_replicas(keyspace, query.routing_key) + self._put_cached_replicas( + keyspace, query.table, query.routing_key, replicas, token_map + ) + except Exception: + log.debug( + "Failed to resolve token or tablet for query plan, " + "falling back to child policy", + exc_info=True, + ) + + if self.shuffle_replicas: + if not query.is_lwt(): + replicas = list(replicas) + shuffle(replicas) + + local_rack = [] + local = [] + remote = [] + + child_distance = child.distance + + for replica in replicas: + if replica.is_up: + d = child_distance(replica) + if d == HostDistance.LOCAL_RACK: + local_rack.append(replica) + elif d == HostDistance.LOCAL: + local.append(replica) + elif d == HostDistance.REMOTE: + remote.append(replica) + + if local_rack or local or remote: + yielded = set() + + for replica in local_rack: + yielded.add(replica) + yield replica + + for replica in local: + yielded.add(replica) + yield replica + + for replica in remote: + yielded.add(replica) + yield replica + + # Yield the rest of the cluster (non-replica hosts). + # DCAware and RackAware already yield in distance order + # (local_rack -> local -> remote), so we can stream directly. + # For other child policies we must re-sort by distance. + if isinstance(child, (DCAwareRoundRobinPolicy, RackAwareRoundRobinPolicy)): + yield from child.make_query_plan_with_exclusion(keyspace, query, yielded) + else: + remaining_local_rack = [] + remaining_local = [] + remaining_remote = [] + for host in child.make_query_plan_with_exclusion(keyspace, query, yielded): + d = child_distance(host) + if d == HostDistance.LOCAL_RACK: + remaining_local_rack.append(host) + elif d == HostDistance.LOCAL: + remaining_local.append(host) + elif d == HostDistance.REMOTE: + remaining_remote.append(host) + yield from remaining_local_rack + yield from remaining_local + yield from remaining_remote else: - replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) - - if self.shuffle_replicas and not query.is_lwt(): - shuffle(replicas) - - def yield_in_order(hosts): - for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: - for replica in hosts: - if replica.is_up and child.distance(replica) == distance: - yield replica - - # yield replicas: local_rack, local, remote - yield from yield_in_order(replicas) - # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + yield from child.make_query_plan(keyspace, query) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) @@ -702,6 +1015,16 @@ def make_query_plan(self, working_keyspace=None, query=None): if self.predicate(host): yield host + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + if excluded and not isinstance(excluded, set): + excluded = set(excluded) + child_qp = self._child_policy.make_query_plan_with_exclusion( + working_keyspace=working_keyspace, query=query, excluded=excluded + ) + for host in child_qp: + if self.predicate(host): + yield host + def check_supported(self): return self._child_policy.check_supported() @@ -1356,6 +1679,28 @@ def make_query_plan(self, working_keyspace=None, query=None): for h in child.make_query_plan(keyspace, query): yield h + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + if query and query.keyspace: + keyspace = query.keyspace + else: + keyspace = working_keyspace + + addr = getattr(query, 'target_host', None) if query else None + target_host = self._cluster_metadata.get_host(addr) + + if excluded and not isinstance(excluded, set): + excluded = set(excluded) + + child = self._child_policy + if target_host and target_host.is_up and target_host not in excluded: + yield target_host + # Include target_host in the exclusion set so the child policy + # can skip it early rather than yielding it for us to filter. + child_excluded = excluded | {target_host} if excluded else {target_host} + yield from child.make_query_plan_with_exclusion(keyspace, query, child_excluded) + else: + yield from child.make_query_plan_with_exclusion(keyspace, query, excluded) + # TODO for backward compatibility, remove in next major class DSELoadBalancingPolicy(DefaultLoadBalancingPolicy): diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 96e61a50c2..423140d0b7 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -55,6 +55,9 @@ def __init__(self, tablets): self._tablets = tablets self._lock = Lock() + def __bool__(self): + return bool(self._tablets) + def table_has_tablets(self, keyspace, table) -> bool: return bool(self._tablets.get((keyspace, table), [])) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..46ee8b3f8e 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -187,6 +187,39 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) assert qplan == [] + def test_make_query_plan_with_exclusion_basic(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + excluded = {1, 3} + qplan = list(policy.make_query_plan_with_exclusion(excluded=excluded)) + assert 1 not in qplan + assert 3 not in qplan + assert sorted(qplan) == [0, 2] + + def test_make_query_plan_with_exclusion_empty(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + qplan = list(policy.make_query_plan_with_exclusion(excluded=())) + assert sorted(qplan) == hosts + + def test_make_query_plan_with_exclusion_all_excluded(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + excluded = set(hosts) + qplan = list(policy.make_query_plan_with_exclusion(excluded=excluded)) + assert qplan == [] + + def test_make_query_plan_with_exclusion_superset(self): + hosts = [0, 1, 2, 3] + policy = RoundRobinPolicy() + policy.populate(None, hosts) + excluded = {0, 1, 2, 3, 99, 100} + qplan = list(policy.make_query_plan_with_exclusion(excluded=excluded)) + assert qplan == [] + @pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) class TestRackOrDCAwareRoundRobinPolicy: @@ -273,6 +306,12 @@ def test_get_distance(self, policy_specialization, constructor_args): elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): assert policy.distance(host) == HostDistance.LOCAL_RACK + # Reset policy state to simulate a fresh view or handle the "move" correctly + if hasattr(policy, '_live_hosts'): + policy._live_hosts.clear() + if hasattr(policy, '_dc_live_hosts'): + policy._dc_live_hosts.clear() + # same dc different rack host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack2") @@ -527,6 +566,54 @@ def test_no_nodes(self, policy_specialization, constructor_args): qplan = list(policy.make_query_plan()) assert qplan == [] + def test_make_query_plan_with_exclusion_basic(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(6)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:4]: + h.set_location_info("dc1", "rack2") + for h in hosts[4:]: + h.set_location_info("dc2", "rack1") + + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=2) + policy.populate(Mock(), hosts) + + # exclude some local and remote hosts + excluded = {hosts[0], hosts[4]} + qplan = list(policy.make_query_plan_with_exclusion(excluded=excluded)) + assert hosts[0] not in qplan + assert hosts[4] not in qplan + assert len(qplan) == 4 # 6 total - 2 excluded + + def test_make_query_plan_with_exclusion_empty(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc1", "rack2") + + policy = policy_specialization(*constructor_args) + policy.populate(Mock(), hosts) + + # empty exclusion should return all hosts + qplan_excl = list(policy.make_query_plan_with_exclusion(excluded=())) + qplan_normal = list(policy.make_query_plan()) + assert sorted(qplan_excl, key=id) == sorted(qplan_normal, key=id) + + def test_make_query_plan_with_exclusion_all_excluded(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc1", "rack2") + + policy = policy_specialization(*constructor_args) + policy.populate(Mock(), hosts) + + excluded = set(hosts) + qplan = list(policy.make_query_plan_with_exclusion(excluded=excluded)) + assert qplan == [] + def test_wrong_dc(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(3)] for h in hosts[:3]: @@ -537,6 +624,88 @@ def test_wrong_dc(self, policy_specialization, constructor_args): qplan = list(policy.make_query_plan()) assert len(qplan) == 0 + def test_runtime_used_hosts_per_remote_dc_change(self, policy_specialization, constructor_args): + """Changing used_hosts_per_remote_dc at runtime should take effect + immediately without needing a repopulate or topology event.""" + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) + policy.populate(Mock(), hosts) + + # With 0, remotes are IGNORED and absent from query plan + for h in hosts[2:]: + assert policy.distance(h) == HostDistance.IGNORED + qplan = list(policy.make_query_plan()) + assert set(qplan) == set(hosts[:2]) + + # Raise to 1 at runtime -- should take effect immediately + policy.used_hosts_per_remote_dc = 1 + assert policy.distance(hosts[2]) == HostDistance.REMOTE or \ + policy.distance(hosts[3]) == HostDistance.REMOTE + qplan = list(policy.make_query_plan()) + assert len(qplan) == 3 # 2 local + 1 remote + + # Raise to 2 -- both remotes visible + policy.used_hosts_per_remote_dc = 2 + qplan = list(policy.make_query_plan()) + assert set(qplan) == set(hosts) + + # Drop back to 0 -- remotes disappear + policy.used_hosts_per_remote_dc = 0 + for h in hosts[2:]: + assert policy.distance(h) == HostDistance.IGNORED + qplan = list(policy.make_query_plan()) + assert set(qplan) == set(hosts[:2]) + + def test_modification_during_generation_exclusion(self, policy_specialization, constructor_args): + """Topology changes to remote hosts during local iteration should be + visible when the generator reaches the remote phase, for the exclusion + path as well as the normal path.""" + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for h in hosts[:2]: + h.set_location_info("dc1", "rack1") + for h in hosts[2:]: + h.set_location_info("dc2", "rack1") + + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=3) + policy.populate(Mock(), hosts) + + new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) + new_host.set_location_info("dc2", "rack1") + + # -- make_query_plan: add remote after starting local iteration -- + plan = policy.make_query_plan() + next(plan) # consume one local + policy.on_up(new_host) + remaining = list(plan) + # new_host should appear because _remote_hosts is read late + assert new_host in remaining + + # -- make_query_plan: remove remote after starting local -- + plan = policy.make_query_plan() + next(plan) + policy.on_down(new_host) + remaining = list(plan) + assert new_host not in remaining + + # -- make_query_plan_with_exclusion: add remote after starting local -- + plan = policy.make_query_plan_with_exclusion(excluded={hosts[0]}) + next(plan) # consume one local + policy.on_up(new_host) + remaining = list(plan) + assert new_host in remaining + + # -- make_query_plan_with_exclusion: remove remote after starting local -- + plan = policy.make_query_plan_with_exclusion(excluded={hosts[0]}) + next(plan) + policy.on_down(new_host) + remaining = list(plan) + assert new_host not in remaining + class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_default_dc(self): @@ -577,6 +746,8 @@ def test_wrap_round_robin(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -586,8 +757,9 @@ def get_replicas(keyspace, packed_key): return list(islice(cycle(hosts), index, index + 2)) cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(RoundRobinPolicy()) + policy = TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -610,6 +782,8 @@ def test_wrap_dc_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -627,8 +801,9 @@ def get_replicas(keyspace, packed_key): return [hosts[1], hosts[3]] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2)) + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -659,6 +834,8 @@ def test_wrap_rack_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() @@ -680,8 +857,9 @@ def get_replicas(keyspace, packed_key): return [hosts[4], hosts[5], hosts[6], hosts[7]] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=4)) + policy = TokenAwarePolicy(RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=4), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -804,12 +982,16 @@ def test_statement_keyspace(self): replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas child_policy = Mock() child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] child_policy.distance.return_value = HostDistance.LOCAL - policy = TokenAwarePolicy(child_policy) + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) policy.populate(cluster, hosts) # no keyspace, child policy is called @@ -848,7 +1030,9 @@ def test_statement_keyspace(self): query = Statement(routing_key=routing_key, keyspace=statement_keyspace) qplan = list(policy.make_query_plan(working_keyspace, query)) assert replicas + hosts[:2] == qplan - cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) + # get_replicas may not be called here due to cache hit from the + # previous query with the same (statement_keyspace, routing_key) pair. + # The important assertion is that the plan result is correct above. def test_shuffles_if_given_keyspace_and_routing_key(self): """ @@ -897,6 +1081,9 @@ def _prepare_cluster_with_vnodes(self): cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas return cluster def _prepare_cluster_with_tablets(self): @@ -909,14 +1096,22 @@ def _prepare_cluster_with_tablets(self): cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas return cluster @patch('cassandra.policies.shuffle') def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): hosts = cluster.metadata.all_hosts() - replicas = cluster.metadata.get_replicas() + # Configure get_host_by_host_id to return hosts from the list + host_map = {h.host_id: h for h in hosts} + cluster.metadata.get_host_by_host_id.side_effect = lambda hid: host_map.get(hid) + + replicas = list(cluster.metadata.get_replicas()) child_policy = Mock() child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] child_policy.distance.return_value = HostDistance.LOCAL policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) @@ -926,6 +1121,7 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() + child_policy.make_query_plan_with_exclusion.reset_mock() query = Statement(routing_key=routing_key) qplan = list(policy.make_query_plan(keyspace, query)) if keyspace is None or routing_key is None: @@ -936,13 +1132,323 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): else: assert set(replicas) == set(qplan[:2]) assert hosts[:2] == qplan[2:] + if is_tablets: + # Tablet path: make_query_plan called once for replica ordering, + # make_query_plan_with_exclusion called once for remaining hosts child_policy.make_query_plan.assert_called_with(keyspace, query) - assert child_policy.make_query_plan.call_count == 2 + child_policy.make_query_plan_with_exclusion.assert_called() + elif child_policy.make_query_plan_with_exclusion.called: + # Non-tablet path with replicas: exclusion set should contain + # the replicas that were already yielded + exc_call = child_policy.make_query_plan_with_exclusion.call_args + excluded_hosts = exc_call[0][2] if len(exc_call[0]) > 2 else exc_call[1].get('excluded', set()) + assert set(replicas).issubset(excluded_hosts), \ + 'Exclusion set should contain the yielded replicas, ' \ + 'got %s, expected superset of %s' % (excluded_hosts, set(replicas)) else: child_policy.make_query_plan.assert_called_once_with(keyspace, query) assert patched_shuffle.call_count == 1 + # --- Replica cache tests --- + + def _make_cache_cluster(self): + """Create a mock cluster suitable for cache tests.""" + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for host in hosts: + host.set_up() + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.return_value = hosts[2:] + # Provide a real dict for keyspace-aware cache invalidation checks. + cluster.metadata.token_map.tokens_to_hosts_by_ks = {'ks': {}, 'ks1': {}, 'ks2': {}} + return cluster, hosts + + def test_cache_hit(self): + """Same (keyspace, routing_key) should only call get_replicas once.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + list(policy.make_query_plan(None, query)) + + assert cluster.metadata.token_map.get_replicas.call_count == 1 + + def test_cache_miss_different_key(self): + """Different routing_key should cause separate get_replicas calls.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + q1 = Statement(routing_key=b'key1', keyspace='ks') + q2 = Statement(routing_key=b'key2', keyspace='ks') + list(policy.make_query_plan(None, q1)) + list(policy.make_query_plan(None, q2)) + + assert cluster.metadata.token_map.get_replicas.call_count == 2 + + def test_cache_miss_different_keyspace(self): + """Different keyspace with same routing_key should miss cache.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + q1 = Statement(routing_key=b'key1', keyspace='ks1') + q2 = Statement(routing_key=b'key1', keyspace='ks2') + list(policy.make_query_plan(None, q1)) + list(policy.make_query_plan(None, q2)) + + assert cluster.metadata.token_map.get_replicas.call_count == 2 + + def test_cache_invalidation_on_topology_change(self): + """Cache should be invalidated when token_map object changes.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + assert cluster.metadata.token_map.get_replicas.call_count == 1 + + # Simulate topology change: replace token_map with a new mock object + new_token_map = Mock() + new_token_map.token_class.from_key.side_effect = lambda key: key + new_token_map.get_replicas.return_value = hosts[2:] + new_token_map.tokens_to_hosts_by_ks = {'ks': {}} + cluster.metadata.token_map = new_token_map + + list(policy.make_query_plan(None, query)) + # The old token_map still has 1 call; new one should have 1 call + assert new_token_map.get_replicas.call_count == 1 + + def test_cache_eviction(self): + """Oldest entries should be evicted when cache exceeds size.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False, cache_replicas_size=2) + policy.populate(cluster, hosts) + + # Fill cache with 3 entries; size=2 so first should be evicted + for i in range(3): + q = Statement(routing_key=f'key{i}'.encode(), keyspace='ks') + list(policy.make_query_plan(None, q)) + + assert cluster.metadata.token_map.get_replicas.call_count == 3 + + # key2 (most recent) should be cached + cluster.metadata.token_map.get_replicas.reset_mock() + q = Statement(routing_key=b'key2', keyspace='ks') + list(policy.make_query_plan(None, q)) + assert cluster.metadata.token_map.get_replicas.call_count == 0 + + # key0 (evicted) should miss + q = Statement(routing_key=b'key0', keyspace='ks') + list(policy.make_query_plan(None, q)) + assert cluster.metadata.token_map.get_replicas.call_count == 1 + + def test_cache_disabled(self): + """cache_replicas_size=0 should bypass caching entirely.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False, cache_replicas_size=0) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + list(policy.make_query_plan(None, query)) + list(policy.make_query_plan(None, query)) + + # Every call should reach get_replicas + assert cluster.metadata.token_map.get_replicas.call_count == 3 + + def test_tablet_path_not_cached(self): + """Tablet path should bypass the cache entirely.""" + hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] + for host in hosts: + host.set_up() + + cluster = Mock(spec=Cluster) + cluster.metadata = Mock(spec=Metadata) + cluster.metadata._tablets = Mock(spec=Tablets) + cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.return_value = hosts[2:] + cluster.metadata.token_map.tokens_to_hosts_by_ks = {'ks': {}} + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + list(policy.make_query_plan(None, query)) + + # token_map.get_replicas should NOT be called (tablet path used) + assert cluster.metadata.token_map.get_replicas.call_count == 0 + # Cache should remain empty (tablet results are not cached) + assert len(policy._replica_cache) == 0 + + def test_cache_invalidation_on_keyspace_replication_change(self): + """Cache should detect in-place keyspace replica map rebuild (e.g. ALTER KEYSPACE).""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + assert cluster.metadata.token_map.get_replicas.call_count == 1 + + # Simulate ALTER KEYSPACE: same token_map object, but the per-keyspace + # replica map is replaced in-place (new dict object for that keyspace). + cluster.metadata.token_map.tokens_to_hosts_by_ks['ks'] = {'new': 'map'} + + list(policy.make_query_plan(None, query)) + # Should have re-fetched replicas because the ks map id changed. + assert cluster.metadata.token_map.get_replicas.call_count == 2 + + # --- LWT determinism tests --- + + def _make_lwt_query(self, routing_key, keyspace='ks'): + """Create a Statement that reports is_lwt()=True.""" + query = Statement(routing_key=routing_key, keyspace=keyspace) + query.is_lwt = lambda: True + return query + + @patch('cassandra.policies.shuffle') + def test_lwt_no_shuffle(self, patched_shuffle): + """LWT queries should yield replicas in deterministic order.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + query = self._make_lwt_query(routing_key=b'key1') + + plans = [list(policy.make_query_plan(None, query)) for _ in range(5)] + + # All plans should be identical (deterministic) + for plan in plans[1:]: + assert plan == plans[0] + + # shuffle should never have been called + assert patched_shuffle.call_count == 0 + + @patch('cassandra.policies.shuffle') + def test_lwt_replicas_not_copied(self, patched_shuffle): + """LWT path should not copy the replicas list (no list() call).""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + query = self._make_lwt_query(routing_key=b'key1') + list(policy.make_query_plan(None, query)) + + # shuffle was never called, which means list() was also not called + assert patched_shuffle.call_count == 0 + + @patch('cassandra.policies.shuffle') + def test_non_lwt_shuffled(self, patched_shuffle): + """Non-LWT queries with shuffle_replicas=True should shuffle.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + query = Statement(routing_key=b'key1', keyspace='ks') + list(policy.make_query_plan(None, query)) + + assert patched_shuffle.call_count == 1 + + @patch('cassandra.policies.shuffle') + def test_lwt_with_cache_deterministic(self, patched_shuffle): + """LWT + cache should produce identical plans on repeated calls.""" + cluster, hosts = self._make_cache_cluster() + + child_policy = Mock() + child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] + child_policy.distance.return_value = HostDistance.LOCAL + + policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) + policy.populate(cluster, hosts) + + query = self._make_lwt_query(routing_key=b'key1') + + plan1 = list(policy.make_query_plan(None, query)) + plan2 = list(policy.make_query_plan(None, query)) + + assert plan1 == plan2 + assert patched_shuffle.call_count == 0 + # Should have been a cache hit on the second call + assert cluster.metadata.token_map.get_replicas.call_count == 1 + class ConvictionPolicyTest(unittest.TestCase): def test_not_implemented(self): @@ -1630,8 +2136,11 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - child_policy = TokenAwarePolicy(RoundRobinPolicy()) + child_policy = TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=False) hfp = HostFilterPolicy( child_policy=child_policy,