Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,11 @@ Available commands include:
Server information.
Returns multiple fields giving a brief overview of server state.

* *shed_connections/shed* :
Attempts to shed approximately the specified percentage of connections.
Requires "percentage": (int)
Returns "connections_shed" (int) and "percentage_requested" (int)

* *snapshot/snap* :
Takes a snapshot of the current server in the datadir and stream out data.
Optional query parameter:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public enum DisconnectReason {
AUTH_PROVIDER_NOT_FOUND("auth provider not found"),
FAILED_HANDSHAKE("Unsuccessful handshake"),
CLIENT_RATE_LIMIT("Client hits rate limiting threshold"),
CLIENT_CNX_LIMIT("Client hits connection limiting threshold");
CLIENT_CNX_LIMIT("Client hits connection limiting threshold"),
SHED_CONNECTIONS_COMMAND("shed_connections_command");

String disconnectReason;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import javax.management.JMException;
import javax.security.auth.callback.CallbackHandler;
Expand Down Expand Up @@ -160,6 +161,56 @@ public final void setZooKeeperServer(ZooKeeperServer zks) {

public abstract void closeAll(ServerCnxn.DisconnectReason reason);

/**
* Attempts to shed approximately the specified percentage of connections.
*
* @param percentage [0-100] percentage of connections to shed
* @return actual number of connections successfully closed (may vary due to randomness)
* @throws IllegalArgumentException if percentage not in [0, 100]
*/
public int shedConnections(final int percentage) {
if (percentage < 0 || percentage > 100) {
throw new IllegalArgumentException("percentage must be between 0 and 100, got: " + percentage);
}

final int totalConnections = cnxns.size();
if (percentage == 0 || totalConnections == 0) {
return 0;
}

int actualShedCount = 0;
// For 100%, close all connections deterministically
if (percentage == 100) {
for (final ServerCnxn cnxn : cnxns) {
try {
cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);
actualShedCount++;
} catch (final Exception e) {
LOG.warn("Failed to close connection for session 0x{}: {}",
Long.toHexString(cnxn.getSessionId()), e.getMessage());
}
}
} else {
// For other percentages, use probabilistic approach
final ThreadLocalRandom random = ThreadLocalRandom.current();
final double probability = percentage / 100.0;

for (final ServerCnxn cnxn : cnxns) {
if (random.nextDouble() < probability) {
try {
cnxn.close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);
actualShedCount++;
} catch (final Exception e) {
LOG.warn("Failed to close connection for session 0x{}: {}",
Long.toHexString(cnxn.getSessionId()), e.getMessage());
}
}
}
}
LOG.info("Shed {} out of {} connections ({}%)", actualShedCount, totalConnections, percentage);
return actualShedCount;
}

public static ServerCnxnFactory createFactory() throws IOException {
String serverCnxnFactoryName = System.getProperty(ZOOKEEPER_SERVER_CNXN_FACTORY);
if (serverCnxnFactoryName == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

import static org.apache.zookeeper.server.persistence.FileSnap.SNAPSHOT_FILE_PREFIX;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -292,6 +295,7 @@ public static Command getCommand(String cmdName) {
registerCommand(new RestoreCommand());
registerCommand(new RuokCommand());
registerCommand(new SetTraceMaskCommand());
registerCommand(new ShedConnectionsCommand());
registerCommand(new SnapshotCommand());
registerCommand(new SrvrCommand());
registerCommand(new StatCommand());
Expand Down Expand Up @@ -863,6 +867,77 @@ public CommandResponse runGet(ZooKeeperServer zkServer, Map<String, String> kwar

}

/**
* Attempts to shed approximately the specified percentage of connections.
*
* Request: JSON input stream containing the following required field:
* - "percentage": Integer [0-100] - percentage of connections to attempt shedding
* value must be between 0 (no connections) and 100 (all connections).
*
* Response: JSON output stream containing:
* - "connections_shed": Integer - actual number of connections successfully closed
* may vary due to randomness.
* - "percentage_requested": Integer - the percentage that was requested
*/
public static class ShedConnectionsCommand extends PostCommand {
private static final String FIELD_PERCENTAGE = "percentage";

public ShedConnectionsCommand() {
super(Arrays.asList("shed_connections", "shed"), true, new AuthRequest(ZooDefs.Perms.ALL, ROOT_PATH));
}

@Override
public CommandResponse runPost(final ZooKeeperServer zkServer, final InputStream inputStream) {
final CommandResponse response = initializeResponse();

if (inputStream == null) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Request body is required");
return response;
}

try {
final ObjectMapper mapper = new ObjectMapper();
final JsonNode jsonNode = mapper.readTree(inputStream);

if (!jsonNode.has(FIELD_PERCENTAGE)) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Missing required field: " + FIELD_PERCENTAGE);
return response;
}

final int percentage = jsonNode.get(FIELD_PERCENTAGE).asInt();
if (percentage < 0 || percentage > 100) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Percentage must be between 0 and 100");
return response;
}

final ServerCnxnFactory factory = zkServer.getServerCnxnFactory();
final ServerCnxnFactory secureFactory = zkServer.getSecureServerCnxnFactory();

int connectionsShed = 0;
if (percentage > 0) {
if (factory != null) {
connectionsShed += factory.shedConnections(percentage);
}
if (secureFactory != null) {
connectionsShed += secureFactory.shedConnections(percentage);
}
}

response.put("connections_shed", connectionsShed);
response.put("percentage_requested", percentage);

LOG.info("Shed {} connections ({}%)", connectionsShed, percentage);
} catch (final IOException e) {
response.setStatusCode(HttpServletResponse.SC_BAD_REQUEST);
response.put("error", "Invalid JSON or failed to read request body: " + e.getMessage());
}
return response;
}
}

/**
* Same as SrvrCommand but has extra "connections" entry.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.zookeeper.server;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockingDetails;
import java.util.Arrays;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

public class ServerCnxnFactoryTest {
public enum FactoryType {
NIO, NETTY
}

private ServerCnxnFactory factory;

@AfterEach
public void tearDown() {
if (factory != null) {
try {
factory.shutdown();
} catch (Exception e) {
// Ignore all shutdown exceptions in tests since factory may not be fully initialized
}
}
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_InvalidPercentage(final FactoryType factoryType) {
factory = createFactory(factoryType);
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(-1));
assertThrows(IllegalArgumentException.class, () -> factory.shedConnections(101));
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_ValidPercentages(final FactoryType factoryType) {
factory = createFactory(factoryType);

assertEquals(0, factory.shedConnections(0));
assertEquals(0, factory.shedConnections(50));
assertEquals(0, factory.shedConnections(100));
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_DeterministicBehavior(final FactoryType factoryType) {
factory = createFactory(factoryType);

// Create 4 mock connections for testing deterministic edge cases
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
for (int i = 0; i < 4; i++) {
mockCnxns[i] = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxns[i]);
}

// Test 0% shedding - should shed exactly 0 connections (deterministic)
int shedCount = factory.shedConnections(0);
assertEquals(0, shedCount, "0% shedding should shed exactly 0 connections");

// Verify no connections were actually closed
int actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(0, actualClosedCount, "No connections should be closed for 0% shedding");

// Test 100% shedding - should shed exactly all connections (deterministic)
shedCount = factory.shedConnections(100);
assertEquals(4, shedCount, "100% shedding should shed exactly all 4 connections");

// Verify all connections were actually closed with correct reason
actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(4, actualClosedCount, "All 4 connections should be closed for 100% shedding");
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_SmallPercentageRoundsToZero(final FactoryType factoryType) {
factory = createFactory(factoryType);

// Add single mock connection
final ServerCnxn mockCnxn = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxn);

// small percentage rounds to 0
assertEquals(0, factory.shedConnections(1), "1% of 1 connection should round to 0");
}

@ParameterizedTest
@EnumSource(FactoryType.class)
public void testShedConnections_ErrorHandling(final FactoryType factoryType) {
factory = createFactory(factoryType);

// Create mock connections where one will fail to close
final ServerCnxn[] mockCnxns = new ServerCnxn[4];
for (int i = 0; i < 4; i++) {
mockCnxns[i] = mock(ServerCnxn.class);
factory.cnxns.add(mockCnxns[i]);
}

// Make the second connection throw an exception when closed
doThrow(new RuntimeException("Connection close failed"))
.when(mockCnxns[1]).close(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND);

// Test 100% shedding to ensure error handling works deterministically
final int shedCount = factory.shedConnections(100);

// Since one connection throws an exception, only 3 should be successfully closed
assertEquals(3, shedCount, "Should successfully close 3 connections, 1 should fail");
int actualClosedCount = countConnectionsShed(mockCnxns);
assertEquals(4, actualClosedCount, "All 4 connections should have close() called, even if one throws exception");
}

private ServerCnxnFactory createFactory(final FactoryType type) {
switch (type) {
case NIO:
return new NIOServerCnxnFactory();
case NETTY:
return new NettyServerCnxnFactory();
default:
throw new IllegalArgumentException("Unknown factory type: " + type);
}
}

private int countConnectionsShed(final ServerCnxn[] connections) {
return (int) Arrays.stream(connections)
.filter(cnxn -> mockingDetails(cnxn).getInvocations().stream()
.anyMatch(invocation ->
invocation.getMethod().getName().equals("close")
&& invocation.getArguments().length == 1
&& invocation.getArguments()[0].equals(ServerCnxn.DisconnectReason.SHED_CONNECTIONS_COMMAND)
))
.count();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,16 @@ public void testStatCommandSecureOnly() {
assertThat(response.toMap().containsKey("secure_connections"), is(true));
}

@Test
public void testShedConnections() throws IOException, InterruptedException {
final Map<String, String> kwargs = new HashMap<>();
final InputStream inputStream = new ByteArrayInputStream("{\"percentage\": 25}".getBytes());
final String authInfo = CommandAuthTest.buildAuthorizationForDigest();
testCommand("shed_connections", kwargs, inputStream, authInfo, new HashMap<>(), HttpServletResponse.SC_OK,
new Field("percentage_requested", Integer.class),
new Field("connections_shed", Integer.class));
}

private void testSnapshot(final boolean streaming) throws IOException, InterruptedException {
System.setProperty(ADMIN_SNAPSHOT_ENABLED, "true");
System.setProperty(ADMIN_RATE_LIMITER_INTERVAL, "0");
Expand Down
Loading