From aeee33981839eca5aaf027d4f9080e11f62ba942 Mon Sep 17 00:00:00 2001 From: liwang Date: Wed, 28 Jan 2026 10:23:21 -0800 Subject: [PATCH] ZOOKEEPER-5015: Add admin server command to shed client connections by percentage Author: Li Wang --- .../main/resources/markdown/zookeeperAdmin.md | 5 + .../apache/zookeeper/server/ServerCnxn.java | 3 +- .../zookeeper/server/ServerCnxnFactory.java | 51 ++++ .../zookeeper/server/admin/Commands.java | 75 +++++ .../server/ServerCnxnFactoryTest.java | 156 +++++++++++ .../zookeeper/server/admin/CommandsTest.java | 10 + .../admin/ShedConnectionsCommandTest.java | 265 ++++++++++++++++++ 7 files changed, 564 insertions(+), 1 deletion(-) create mode 100644 zookeeper-server/src/test/java/org/apache/zookeeper/server/ServerCnxnFactoryTest.java create mode 100644 zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/ShedConnectionsCommandTest.java diff --git a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md index c1e1684aea7..808448d1034 100644 --- a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md +++ b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md @@ -2771,6 +2771,11 @@ Available commands include: Server information. Returns multiple fields giving a brief overview of server state. +* *shed_connections/shed* : + Attempts to shed approximately the specified percentage of connections. + Requires "percentage": (int) + Returns "connections_shed" (int) and "percentage_requested" (int) + * *snapshot/snap* : Takes a snapshot of the current server in the datadir and stream out data. Optional query parameter: diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxn.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxn.java index ebfd32afa4a..eb31b5e9254 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxn.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxn.java @@ -98,7 +98,8 @@ public enum DisconnectReason { AUTH_PROVIDER_NOT_FOUND("auth provider not found"), FAILED_HANDSHAKE("Unsuccessful handshake"), CLIENT_RATE_LIMIT("Client hits rate limiting threshold"), - CLIENT_CNX_LIMIT("Client hits connection limiting threshold"); + CLIENT_CNX_LIMIT("Client hits connection limiting threshold"), + SHED_CONNECTIONS_COMMAND("shed_connections_command"); String disconnectReason; diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxnFactory.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxnFactory.java index 85ad981e6ca..f63c1eec4ba 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxnFactory.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerCnxnFactory.java @@ -26,6 +26,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.Supplier; import javax.management.JMException; import javax.security.auth.callback.CallbackHandler; @@ -160,6 +161,56 @@ public final void setZooKeeperServer(ZooKeeperServer zks) { public abstract void closeAll(ServerCnxn.DisconnectReason reason); + /** + * Attempts to shed approximately the specified percentage of connections. + * + * @param percentage [0-100] percentage of connections to shed + * @return actual number of connections successfully closed (may vary due to randomness) + * @throws IllegalArgumentException if percentage not in [0, 100] + */ + public int shedConnections(final int percentage) { + if (percentage < 0 || percentage > 100) { + throw new IllegalArgumentException("percentage must be between 0 and 100, got: " + percentage); + } + + final int totalConnections = cnxns.size(); + if (percentage == 0 || totalConnections == 0) { + return 0; + } + + int actualShedCount = 0; + // For 100%, close all connections deterministically + if (percentage == 100) { + for (final ServerCnxn cnxn : cnxns) { + try { + cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND); + actualShedCount++; + } catch (final Exception e) { + LOG.warn("Failed to close connection for session 0x{}: {}", + Long.toHexString(cnxn.getSessionId()), e.getMessage()); + } + } + } else { + // For other percentages, use probabilistic approach + final ThreadLocalRandom random = ThreadLocalRandom.current(); + final double probability = percentage / 100.0; + + for (final ServerCnxn cnxn : cnxns) { + if (random.nextDouble() < probability) { + try { + cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND); + actualShedCount++; + } catch (final Exception e) { + LOG.warn("Failed to close connection for session 0x{}: {}", + Long.toHexString(cnxn.getSessionId()), e.getMessage()); + } + } + } + } + LOG.info("Shed {} out of {} connections ({}%)", actualShedCount, totalConnections, percentage); + return actualShedCount; + } + public static ServerCnxnFactory createFactory() throws IOException { String serverCnxnFactoryName = System.getProperty(ZOOKEEPER_SERVER_CNXN_FACTORY); if (serverCnxnFactoryName == null) { diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/admin/Commands.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/admin/Commands.java index ae7c4369137..af69008986c 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/admin/Commands.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/admin/Commands.java @@ -20,9 +20,12 @@ import static org.apache.zookeeper.server.persistence.FileSnap.SNAPSHOT_FILE_PREFIX; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.File; import java.io.FileInputStream; +import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; @@ -292,6 +295,7 @@ public static Command getCommand(String cmdName) { registerCommand(new RestoreCommand()); registerCommand(new RuokCommand()); registerCommand(new SetTraceMaskCommand()); + registerCommand(new ShedConnectionsCommand()); registerCommand(new SnapshotCommand()); registerCommand(new SrvrCommand()); registerCommand(new StatCommand()); @@ -863,6 +867,77 @@ public CommandResponse runGet(ZooKeeperServer zkServer, Map kwar } + /** + * Attempts to shed approximately the specified percentage of connections. + * + * Request: JSON input stream containing the following required field: + * - "percentage": Integer [0-100] - percentage of connections to attempt shedding + * value must be between 0 (no connections) and 100 (all connections). + * + * Response: JSON output stream containing: + * - "connections_shed": Integer - actual number of connections successfully closed + * may vary due to randomness. + * - "percentage_requested": Integer - the percentage that was requested + */ + public static class ShedConnectionsCommand extends PostCommand { + private static final String FIELD_PERCENTAGE = "percentage"; + + public ShedConnectionsCommand() { + super(Arrays.asList("shed_connections", "shed"), true, new AuthRequest(ZooDefs.Perms.ALL, ROOT_PATH)); + } + + @Override + public CommandResponse runPost(final ZooKeeperServer zkServer, final InputStream inputStream) { + final CommandResponse response = initializeResponse(); + + if (inputStream == null) { + response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST); + response.put("error", "Request body is required"); + return response; + } + + try { + final ObjectMapper mapper = new ObjectMapper(); + final JsonNode jsonNode = mapper.readTree(inputStream); + + if (!jsonNode.has(FIELD_PERCENTAGE)) { + response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST); + response.put("error", "Missing required field: " + FIELD_PERCENTAGE); + return response; + } + + final int percentage = jsonNode.get(FIELD_PERCENTAGE).asInt(); + if (percentage < 0 || percentage > 100) { + response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST); + response.put("error", "Percentage must be between 0 and 100"); + return response; + } + + final ServerCnxnFactory factory = zkServer.getServerCnxnFactory(); + final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory(); + + int connectionsShed = 0; + if (percentage > 0) { + if (factory != null) { + connectionsShed += factory.shedConnections(percentage); + } + if (secureFactory != null) { + connectionsShed += secureFactory.shedConnections(percentage); + } + } + + response.put("connections_shed", connectionsShed); + response.put("percentage_requested", percentage); + + LOG.info("Shed {} connections ({}%)", connectionsShed, percentage); + } catch (final IOException e) { + response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST); + response.put("error", "Invalid JSON or failed to read request body: " + e.getMessage()); + } + return response; + } + } + /** * Same as SrvrCommand but has extra "connections" entry. */ diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/ServerCnxnFactoryTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/ServerCnxnFactoryTest.java new file mode 100644 index 00000000000..f8fb125ca4e --- /dev/null +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/ServerCnxnFactoryTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper.server; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import java.util.Arrays; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +public class ServerCnxnFactoryTest { + public enum FactoryType { + NIO, NETTY + } + + private ServerCnxnFactory factory; + + @AfterEach + public void tearDown() { + if (factory != null) { + try { + factory.shutdown(); + } catch (Exception e) { + // Ignore all shutdown exceptions in tests since factory may not be fully initialized + } + } + } + + @ParameterizedTest + @EnumSource(FactoryType.class) + public void testShedConnections_InvalidPercentage(final FactoryType factoryType) { + factory = createFactory(factoryType); + assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(-1)); + assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(101)); + } + + @ParameterizedTest + @EnumSource(FactoryType.class) + public void testShedConnections_ValidPercentages(final FactoryType factoryType) { + factory = createFactory(factoryType); + + assertEquals(0, factory.shedConnections(0)); + assertEquals(0, factory.shedConnections(50)); + assertEquals(0, factory.shedConnections(100)); + } + + @ParameterizedTest + @EnumSource(FactoryType.class) + public void testShedConnections_DeterministicBehavior(final FactoryType factoryType) { + factory = createFactory(factoryType); + + // Create 4 mock connections for testing deterministic edge cases + final ServerCnxn[] mockCnxns = new ServerCnxn[4]; + for (int i = 0; i < 4; i++) { + mockCnxns[i] = mock(ServerCnxn.class); + factory.cnxns.add(mockCnxns[i]); + } + + // Test 0% shedding - should shed exactly 0 connections (deterministic) + int shedCount = factory.shedConnections(0); + assertEquals(0, shedCount, "0% shedding should shed exactly 0 connections"); + + // Verify no connections were actually closed + int actualClosedCount = countConnectionsShed(mockCnxns); + assertEquals(0, actualClosedCount, "No connections should be closed for 0% shedding"); + + // Test 100% shedding - should shed exactly all connections (deterministic) + shedCount = factory.shedConnections(100); + assertEquals(4, shedCount, "100% shedding should shed exactly all 4 connections"); + + // Verify all connections were actually closed with correct reason + actualClosedCount = countConnectionsShed(mockCnxns); + assertEquals(4, actualClosedCount, "All 4 connections should be closed for 100% shedding"); + } + + @ParameterizedTest + @EnumSource(FactoryType.class) + public void testShedConnections_SmallPercentageRoundsToZero(final FactoryType factoryType) { + factory = createFactory(factoryType); + + // Add single mock connection + final ServerCnxn mockCnxn = mock(ServerCnxn.class); + factory.cnxns.add(mockCnxn); + + // small percentage rounds to 0 + assertEquals(0, factory.shedConnections(1), "1% of 1 connection should round to 0"); + } + + @ParameterizedTest + @EnumSource(FactoryType.class) + public void testShedConnections_ErrorHandling(final FactoryType factoryType) { + factory = createFactory(factoryType); + + // Create mock connections where one will fail to close + final ServerCnxn[] mockCnxns = new ServerCnxn[4]; + for (int i = 0; i < 4; i++) { + mockCnxns[i] = mock(ServerCnxn.class); + factory.cnxns.add(mockCnxns[i]); + } + + // Make the second connection throw an exception when closed + doThrow(new RuntimeException("Connection close failed")) + .when(mockCnxns[1]).close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND); + + // Test 100% shedding to ensure error handling works deterministically + final int shedCount = factory.shedConnections(100); + + // Since one connection throws an exception, only 3 should be successfully closed + assertEquals(3, shedCount, "Should successfully close 3 connections, 1 should fail"); + int actualClosedCount = countConnectionsShed(mockCnxns); + assertEquals(4, actualClosedCount, "All 4 connections should have close() called, even if one throws exception"); + } + + private ServerCnxnFactory createFactory(final FactoryType type) { + switch (type) { + case NIO: + return new NIOServerCnxnFactory(); + case NETTY: + return new NettyServerCnxnFactory(); + default: + throw new IllegalArgumentException("Unknown factory type: " + type); + } + } + + private int countConnectionsShed(final ServerCnxn[] connections) { + return (int) Arrays.stream(connections) + .filter(cnxn -> mockingDetails(cnxn).getInvocations().stream() + .anyMatch(invocation -> + invocation.getMethod().getName().equals("close") + && invocation.getArguments().length == 1 + && invocation.getArguments()[0].equals(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND) + )) + .count(); + } +} + diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java index ef8448dd418..ef80e7778e4 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java @@ -359,6 +359,16 @@ public void testStatCommandSecureOnly() { assertThat(response.toMap().containsKey("secure_connections"), is(true)); } + @Test + public void testShedConnections() throws IOException, InterruptedException { + final Map kwargs = new HashMap<>(); + final InputStream inputStream = new ByteArrayInputStream("{\"percentage\": 25}".getBytes()); + final String authInfo = CommandAuthTest.buildAuthorizationForDigest(); + testCommand("shed_connections", kwargs, inputStream, authInfo, new HashMap<>(), HttpServletResponse.SC_OK, + new Field("percentage_requested", Integer.class), + new Field("connections_shed", Integer.class)); + } + private void testSnapshot(final boolean streaming) throws IOException, InterruptedException { System.setProperty(ADMIN_SNAPSHOT_ENABLED, "true"); System.setProperty(ADMIN_RATE_LIMITER_INTERVAL, "0"); diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/ShedConnectionsCommandTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/ShedConnectionsCommandTest.java new file mode 100644 index 00000000000..cc02d3cf8bf --- /dev/null +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/ShedConnectionsCommandTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper.server.admin; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Map; +import javax.servlet.http.HttpServletResponse; +import org.apache.zookeeper.server.ServerCnxnFactory; +import org.apache.zookeeper.server.ZooKeeperServer; +import org.junit.jupiter.api.Test; + +public class ShedConnectionsCommandTest { + + private static final String VALID_JSON_25_PERCENT = "{\"percentage\": 25}"; + private static final String VALID_JSON_100_PERCENT = "{\"percentage\": 100}"; + private static final String VALID_JSON_1_PERCENT = "{\"percentage\": 1}"; + private static final String VALID_JSON_0_PERCENT = "{\"percentage\": 0}"; + + private static final String INVALID_JSON_OVER_100_PERCENT = "{\"percentage\": 101}"; + + private static final String INVALID_JSON_MISSING_FIELD = "{\"other\": 25}"; + private static final String INVALID_JSON_MALFORMED = "{\"percentage\": }"; + private static final String INVALID_JSON_EMPTY = "{}"; + + @Test + public void testValidPercentage25() { + validateSuccessfulShedCommand(25, 50, 30, VALID_JSON_25_PERCENT, true, true); + } + + @Test + public void testValidPercentage100() { + validateSuccessfulShedCommand(100, 20, 10, VALID_JSON_100_PERCENT, true, true); + } + + @Test + public void testValidPercentage1() { + validateSuccessfulShedCommand(1, 100, 0, VALID_JSON_1_PERCENT, true, false); + } + + @Test + public void testValidPercentage0() { + validateSuccessfulShedCommand(0, 100, 50, VALID_JSON_0_PERCENT, false, false); + } + + @Test + public void testInvalidPercentage101() { + validateFailedShedCommand(INVALID_JSON_OVER_100_PERCENT, "Percentage must be between 0 and 100", true); + } + + @Test + public void testInvalidNullInputStream() { + validateFailedShedCommand(null, "Request body is required", true); + } + + @Test + public void testEmptyJson() { + validateFailedShedCommand(INVALID_JSON_EMPTY, "Missing required field: percentage", true); + } + + @Test + public void testMissingPercentageParameter() { + validateFailedShedCommand(INVALID_JSON_MISSING_FIELD, "Missing required field: percentage", true); + } + + @Test + public void testMalformedJson() { + validateFailedShedCommand(INVALID_JSON_MALFORMED, "Invalid JSON or failed to read request body", false); + } + + @Test + public void testOnlyInsecureConnections() { + validateSuccessfulShedCommand(25, 40, 0, VALID_JSON_25_PERCENT, true, false); + } + + @Test + public void testOnlySecureConnections() { + validateSuccessfulShedCommand(25, 0, 60, VALID_JSON_25_PERCENT, false, true); + } + + @Test + public void testNoConnections() { + validateSuccessfulShedCommand(25, 0, 0, VALID_JSON_25_PERCENT, false, false); + } + + @Test + public void testMixedConnections() { + validateSuccessfulShedCommand(25, 30, 20, VALID_JSON_25_PERCENT, true, true); + } + + @Test + public void testCommandNames() { + final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand(); + assertEquals(2, command.getNames().size()); + assertTrue(command.getNames().contains("shed")); + assertTrue(command.getNames().contains("shed_connections")); + } + + @Test + public void testAuthorizationRequired() { + final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand(); + final AuthRequest authRequest = command.getAuthRequest(); + + assertNotNull(authRequest); + assertEquals(org.apache.zookeeper.ZooDefs.Perms.ALL, authRequest.getPermission()); + assertEquals(Commands.ROOT_PATH, authRequest.getPath()); + } + + private void validateSuccessfulShedCommand( + final int expectedPercentage, + final int insecureConnections, + final int secureConnections, + final String jsonInput, + final boolean shouldCallInsecureFactory, + final boolean shouldCallSecureFactory) { + + final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand(); + final ZooKeeperServer zkServer = createMockZooKeeperServer(insecureConnections, secureConnections); + final InputStream inputStream = new ByteArrayInputStream(jsonInput.getBytes()); + final int totalConnections = insecureConnections + secureConnections; + + final CommandResponse response = command.runPost(zkServer, inputStream); + assertSuccessfulResponse(response, expectedPercentage, totalConnections); + assertFactoryCalls(zkServer, expectedPercentage, shouldCallInsecureFactory, shouldCallSecureFactory); + } + + private void validateFailedShedCommand( + final String jsonInput, + final String expectedError, + final boolean exactMatch) { + + final Commands.ShedConnectionsCommand command = new Commands.ShedConnectionsCommand(); + final ZooKeeperServer zkServer = createMockZooKeeperServer(10, 10); + final InputStream inputStream = jsonInput != null ? new ByteArrayInputStream(jsonInput.getBytes()) : null; + + final CommandResponse response = command.runPost(zkServer, inputStream); + + assertNotNull(response); + assertEquals(HttpServletResponse.SC_BAD_REQUEST, response.getStatusCode()); + + final Map result = response.toMap(); + final String actualError = (String) result.get("error"); + + if (exactMatch) { + assertEquals(expectedError, actualError); + } else { + assertTrue(actualError.contains(expectedError), + String.format("Expected error message to contain '%s', but was '%s'", expectedError, actualError)); + } + } + + private void assertSuccessfulResponse( + final CommandResponse response, + final int expectedPercentage, + final int totalConnections) { + + assertNotNull(response); + assertEquals(HttpServletResponse.SC_OK, response.getStatusCode()); + + final Map result = response.toMap(); + assertEquals(expectedPercentage, result.get("percentage_requested")); + + assertTrue(result.containsKey("connections_shed")); + final int actualShed = (Integer) result.get("connections_shed"); + assertTrue(actualShed >= 0, "Shed count should be non-negative"); + assertTrue(actualShed <= totalConnections, "Cannot shed more than total connections"); + + // For 0% and 100%, we can make exact assertions + if (expectedPercentage == 0) { + assertEquals(0, actualShed, "0% should shed exactly 0 connections"); + } else if (expectedPercentage == 100) { + assertEquals(totalConnections, actualShed, "100% should shed all connections"); + } + } + + private void assertFactoryCalls( + final ZooKeeperServer zkServer, + final int percentage, + final boolean shouldCallInsecureFactory, + final boolean shouldCallSecureFactory) { + + final ServerCnxnFactory factory = zkServer.getServerCnxnFactory(); + final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory(); + + if (factory != null) { + if (shouldCallInsecureFactory) { + verify(factory, times(1)).shedConnections(percentage); + } else { + verify(factory, never()).shedConnections(anyInt()); + } + } + + if (secureFactory != null) { + if (shouldCallSecureFactory) { + verify(secureFactory, times(1)).shedConnections(percentage); + } else { + verify(secureFactory, never()).shedConnections(anyInt()); + } + } + } + + private ZooKeeperServer createMockZooKeeperServer(int insecureConnections, int secureConnections) { + final ZooKeeperServer zkServer = mock(ZooKeeperServer.class); + final int totalConnections = insecureConnections + secureConnections; + + when(zkServer.getNumAliveConnections()).thenReturn(totalConnections); + + // Mock insecure factory + ServerCnxnFactory factory = null; + if (insecureConnections > 0) { + factory = createMockServerCnxnFactory(insecureConnections); + } + when(zkServer.getServerCnxnFactory()).thenReturn(factory); + + // Mock secure factory + ServerCnxnFactory secureFactory = null; + if (secureConnections > 0) { + secureFactory = createMockServerCnxnFactory(secureConnections); + } + when(zkServer.getSecureServerCnxnFactory()).thenReturn(secureFactory); + + return zkServer; + } + + private ServerCnxnFactory createMockServerCnxnFactory(int connections) { + final ServerCnxnFactory factory = mock(ServerCnxnFactory.class); + when(factory.getNumAliveConnections()).thenReturn(connections); + when(factory.shedConnections(anyInt())).thenAnswer(invocation -> { + int percentage = invocation.getArgument(0); + if (percentage == 0) { + return 0; + } + if (percentage == 100) { + return connections; + } + return (int) Math.ceil(connections * percentage / 100.0); + }); + return factory; + } +}