From bd99d16b75b03d1c8d7dd7a20ef780413aed3dc4 Mon Sep 17 00:00:00 2001 From: Onkar Date: Thu, 2 Apr 2026 01:48:02 +0530 Subject: [PATCH 1/3] Add unit tests for FilterValve, ProxyErrorReportValve and SemaphoreValve --- .../catalina/valves/TestFilterValve.java | 220 ++++++++++++++ .../valves/TestProxyErrorReportValve.java | 282 ++++++++++++++++++ .../catalina/valves/TestSemaphoreValve.java | 252 ++++++++++++++++ 3 files changed, 754 insertions(+) create mode 100644 test/org/apache/catalina/valves/TestFilterValve.java create mode 100644 test/org/apache/catalina/valves/TestProxyErrorReportValve.java create mode 100644 test/org/apache/catalina/valves/TestSemaphoreValve.java diff --git a/test/org/apache/catalina/valves/TestFilterValve.java b/test/org/apache/catalina/valves/TestFilterValve.java new file mode 100644 index 000000000000..cfc3a5abb780 --- /dev/null +++ b/test/org/apache/catalina/valves/TestFilterValve.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; + +public class TestFilterValve extends TomcatBaseTest { + + + @Test + public void testFilterPassthrough() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "ok", new OkServlet()); + ctx.addServletMappingDecoded("/", "ok"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(PassthroughFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals("OK", res.toString()); + } + + + @Test + public void testFilterBlocks() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "ok", new OkServlet()); + ctx.addServletMappingDecoded("/", "ok"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(BlockingFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_FORBIDDEN, rc); + } + + + @Test + public void testNullFilterClassThrowsOnStart() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + FilterValve valve = new FilterValve(); + // Do NOT set filterClassName + ctx.getPipeline().addValve(valve); + + boolean threw = false; + try { + tomcat.start(); + } catch (LifecycleException e) { + threw = true; + } + + Assert.assertTrue("Should throw LifecycleException for null filter class", threw); + } + + + @Test + public void testInvalidFilterClassThrowsOnStart() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + FilterValve valve = new FilterValve(); + valve.setFilterClass("com.nonexistent.FakeFilter"); + ctx.getPipeline().addValve(valve); + + boolean threw = false; + try { + tomcat.start(); + } catch (LifecycleException e) { + threw = true; + } + + Assert.assertTrue("Should throw LifecycleException for invalid filter class", threw); + } + + + @Test + public void testGetFilterNameReturnsNull() throws Exception { + FilterValve valve = new FilterValve(); + Assert.assertNull(valve.getFilterName()); + } + + + @Test + public void testInitParams() throws Exception { + FilterValve valve = new FilterValve(); + + valve.addInitParam("key1", "value1"); + valve.addInitParam("key2", "value2"); + + Assert.assertEquals("value1", valve.getInitParameter("key1")); + Assert.assertEquals("value2", valve.getInitParameter("key2")); + Assert.assertNull(valve.getInitParameter("nonexistent")); + + java.util.List names = Collections.list(valve.getInitParameterNames()); + Assert.assertEquals(2, names.size()); + Assert.assertTrue(names.contains("key1")); + Assert.assertTrue(names.contains("key2")); + } + + + @Test + public void testInitParamsEmpty() throws Exception { + FilterValve valve = new FilterValve(); + + Assert.assertNull(valve.getInitParameter("anything")); + Assert.assertFalse(valve.getInitParameterNames().hasMoreElements()); + } + + + @Test + public void testGetSetFilterClassName() throws Exception { + FilterValve valve = new FilterValve(); + + Assert.assertNull(valve.getFilterClassName()); + + valve.setFilterClassName("com.example.MyFilter"); + Assert.assertEquals("com.example.MyFilter", valve.getFilterClassName()); + + // setFilterClass is an alias + valve.setFilterClass("com.example.OtherFilter"); + Assert.assertEquals("com.example.OtherFilter", valve.getFilterClassName()); + } + + + /** + * A Filter that passes the request through to the next element in the chain. + */ + public static final class PassthroughFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + chain.doFilter(request, response); + } + } + + + /** + * A Filter that blocks the request by sending a 403 response without + * calling chain.doFilter(). + */ + public static final class BlockingFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + ((HttpServletResponse) response).sendError(HttpServletResponse.SC_FORBIDDEN); + } + } + + + private static final class OkServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.setContentType("text/plain"); + resp.getWriter().print("OK"); + } + } +} diff --git a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java new file mode 100644 index 000000000000..98a92fe84f50 --- /dev/null +++ b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.core.StandardHost; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; +import org.apache.tomcat.util.descriptor.web.ErrorPage; + +public class TestProxyErrorReportValve extends TomcatBaseTest { + + private static final String PROXY_VALVE = + "org.apache.catalina.valves.ProxyErrorReportValve"; + + + @Test + public void testRedirectMode() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Server broke")); + ctx.addServletMappingDecoded("/", "error"); + + // Register an error page that the valve will redirect to + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); + ctx.addServletMappingDecoded("/error-page", "errorPage"); + ErrorPage errorPage = new ErrorPage(); + errorPage.setErrorCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + errorPage.setLocation("/error-page"); + ctx.addErrorPage(errorPage); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + Map> resHead = new HashMap<>(); + // Don't follow redirects + int rc = getUrl("http://localhost:" + getPort(), res, resHead); + + // ProxyErrorReportValve uses error pages from context — but since + // it calls findErrorPage() which uses Host-level error pages, + // the context error page might not be found and it falls back to + // the superclass. The test verifies the valve is loaded correctly. + Assert.assertTrue("Status should indicate an error", + rc >= 400 || rc == 302); + } + + + @Test + public void testNoErrorPageFallsBackToSuper() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "No page configured")); + ctx.addServletMappingDecoded("/", "error"); + + // No error page configured — should fall back to ErrorReportValve's report() + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + // The default ErrorReportValve produces HTML + Assert.assertTrue("Should contain HTML error report", + body.contains("") || body.contains("

")); + } + + + @Test + public void testStatusBelow400Ignored() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "ok", new OkServlet()); + ctx.addServletMappingDecoded("/", "ok"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals("OK", res.toString()); + } + + + @Test + public void testStatusNotFound() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "notFound", new SendErrorServlet( + HttpServletResponse.SC_NOT_FOUND, "Resource not found")); + ctx.addServletMappingDecoded("/", "notFound"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + // Falls back to parent ErrorReportValve HTML + Assert.assertTrue("Should contain error report", + body.contains("404") || body.contains("Not Found")); + } + + + @Test + public void testGetSetProperties() throws Exception { + ProxyErrorReportValve valve = new ProxyErrorReportValve(); + + // Defaults + Assert.assertTrue(valve.getUseRedirect()); + Assert.assertFalse(valve.getUsePropertiesFile()); + + // Setters + valve.setUseRedirect(false); + Assert.assertFalse(valve.getUseRedirect()); + + valve.setUsePropertiesFile(true); + Assert.assertTrue(valve.getUsePropertiesFile()); + } + + + @Test + public void testMessageInErrorReport() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Custom error message")); + ctx.addServletMappingDecoded("/", "error"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + // Falls back to super.report() which includes the message + Assert.assertTrue("Should contain the custom error message", + body.contains("Custom error message")); + } + + + @Test + public void testExceptionErrorReport() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "exception", new ExceptionServlet()); + ctx.addServletMappingDecoded("/", "exception"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + Assert.assertTrue("Should contain exception info", + body.contains("RuntimeException")); + } + + + private static final class SendErrorServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + private final int statusCode; + private final String message; + + private SendErrorServlet(int statusCode, String message) { + this.statusCode = statusCode; + this.message = message; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.sendError(statusCode, message); + } + } + + + private static final class OkServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.setContentType("text/plain"); + resp.getWriter().print("OK"); + } + } + + + private static final class ErrorPageServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.setContentType("text/plain"); + resp.getWriter().print("ERROR_PAGE_OK"); + } + } + + + private static final class ExceptionServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + @Override + public void service(jakarta.servlet.ServletRequest request, + jakarta.servlet.ServletResponse response) throws IOException { + throw new RuntimeException("Test exception"); + } + } +} diff --git a/test/org/apache/catalina/valves/TestSemaphoreValve.java b/test/org/apache/catalina/valves/TestSemaphoreValve.java new file mode 100644 index 000000000000..eefe4d4acb66 --- /dev/null +++ b/test/org/apache/catalina/valves/TestSemaphoreValve.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; + +public class TestSemaphoreValve extends TomcatBaseTest { + + + @Test + public void testBasicConcurrency() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "ok", new OkServlet()); + ctx.addServletMappingDecoded("/", "ok"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(10); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals("OK", res.toString()); + } + + + @Test + public void testNonBlockingDenied() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(false); + valve.setHighConcurrencyStatus(503); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + // First request — should acquire the permit and block inside the servlet + AtomicInteger firstRc = new AtomicInteger(); + Thread firstThread = new Thread(() -> { + try { + ByteChunk r = new ByteChunk(); + r.setCharset(StandardCharsets.UTF_8); + firstRc.set(getUrl("http://localhost:" + getPort(), r, null)); + } catch (IOException e) { + // Ignore + } + }); + firstThread.start(); + + // Wait until the first request is inside the servlet + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Second request — should be denied because concurrency=1 and block=false + ByteChunk res2 = new ByteChunk(); + res2.setCharset(StandardCharsets.UTF_8); + int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, rc2); + + // Release the first request + canReturn.countDown(); + firstThread.join(10000); + Assert.assertFalse(firstThread.isAlive()); + Assert.assertEquals(HttpServletResponse.SC_OK, firstRc.get()); + } + + + @Test + public void testHighConcurrencyStatusNotSet() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(false); + // highConcurrencyStatus is -1 by default (no error sent) + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + // First request holds the permit + Thread firstThread = new Thread(() -> { + try { + ByteChunk r = new ByteChunk(); + getUrl("http://localhost:" + getPort(), r, null); + } catch (IOException e) { + // Ignore + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Second request — denied but no error status is sent + ByteChunk res2 = new ByteChunk(); + int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + + // With no highConcurrencyStatus, response is 200 with no body + Assert.assertEquals(HttpServletResponse.SC_OK, rc2); + + canReturn.countDown(); + firstThread.join(10000); + } + + + @Test + public void testGetSetProperties() throws Exception { + SemaphoreValve valve = new SemaphoreValve(); + + // Defaults + Assert.assertEquals(10, valve.getConcurrency()); + Assert.assertFalse(valve.getFairness()); + Assert.assertTrue(valve.getBlock()); + Assert.assertFalse(valve.getInterruptible()); + Assert.assertEquals(-1, valve.getHighConcurrencyStatus()); + + // Setters + valve.setConcurrency(5); + Assert.assertEquals(5, valve.getConcurrency()); + + valve.setFairness(true); + Assert.assertTrue(valve.getFairness()); + + valve.setBlock(false); + Assert.assertFalse(valve.getBlock()); + + valve.setInterruptible(true); + Assert.assertTrue(valve.getInterruptible()); + + valve.setHighConcurrencyStatus(429); + Assert.assertEquals(429, valve.getHighConcurrencyStatus()); + } + + + @Test + public void testFairSemaphore() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "ok", new OkServlet()); + ctx.addServletMappingDecoded("/", "ok"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(5); + valve.setFairness(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + res.setCharset(StandardCharsets.UTF_8); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals("OK", res.toString()); + } + + + private static final class OkServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + resp.setContentType("text/plain"); + resp.getWriter().print("OK"); + } + } + + + private static final class SlowServlet extends HttpServlet { + + private static final long serialVersionUID = 1L; + private final CountDownLatch insideServlet; + private final CountDownLatch canReturn; + + private SlowServlet(CountDownLatch insideServlet, CountDownLatch canReturn) { + this.insideServlet = insideServlet; + this.canReturn = canReturn; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + insideServlet.countDown(); + try { + canReturn.await(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + // Ignore + } + resp.setContentType("text/plain"); + resp.getWriter().print("OK"); + } + } +} From a61be7d088385e8844020b822f1d6b1a97362df6 Mon Sep 17 00:00:00 2001 From: Dimitris Soumis Date: Tue, 7 Apr 2026 16:30:19 +0300 Subject: [PATCH 2/3] Add docs for FilterValve --- webapps/docs/config/valve.xml | 62 +++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/webapps/docs/config/valve.xml b/webapps/docs/config/valve.xml index 07947471f333..b32cea42acc1 100644 --- a/webapps/docs/config/valve.xml +++ b/webapps/docs/config/valve.xml @@ -2694,6 +2694,68 @@ +
+ + + +

The Filter Valve allows a Servlet Filter to be run as + part of the Valve pipeline. This enables reuse of existing Filter + implementations at the Valve level without duplicating their logic.

+ +

There are several caveats when using this Valve:

+
    +
  • A separate instance of the Filter class is created, distinct + from any instance that may be instantiated within a web application.
  • +
  • Calls to FilterConfig.getFilterName() will return + null.
  • +
  • FilterConfig.getServletContext() will return the proper + ServletContext for a Valve attached to a + <Context>, but will return a + ServletContext of limited use for a Valve specified on an + <Engine> or <Host>.
  • +
  • The Filter MUST NOT wrap the + ServletRequest or ServletResponse objects, or + an IllegalStateException will be thrown.
  • +
+ +
+ + + +

The Filter Valve supports the following + configuration attributes:

+ + + + +

Java class name of the implementation to use. This MUST be set to + org.apache.catalina.valves.FilterValve.

+
+ + +

The fully qualified class name of the Filter + implementation to use. The class must have a no-argument + constructor.

+
+ +
+ +

The Filter Valve also supports nested + <init-param> elements to pass initialization + parameters to the Filter:

+ + + + myParam + myValue + + ]]> + +
+ +
+ From 393de8ee0ae995363ba340bbb374ad418ffd5240 Mon Sep 17 00:00:00 2001 From: Dimitris Soumis Date: Wed, 8 Apr 2026 14:52:53 +0300 Subject: [PATCH 3/3] Add more tests and minor fixes for FilterValve, ProxyErrorReportValve and SemaphoreValve --- .../catalina/valves/TestFilterValve.java | 99 ++++--- .../valves/TestProxyErrorReportValve.java | 121 ++++---- .../catalina/valves/TestSemaphoreValve.java | 258 +++++++++++++++--- 3 files changed, 335 insertions(+), 143 deletions(-) diff --git a/test/org/apache/catalina/valves/TestFilterValve.java b/test/org/apache/catalina/valves/TestFilterValve.java index cfc3a5abb780..dd2d918c5c87 100644 --- a/test/org/apache/catalina/valves/TestFilterValve.java +++ b/test/org/apache/catalina/valves/TestFilterValve.java @@ -17,16 +17,16 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.List; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; -import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; import org.junit.Assert; @@ -47,8 +47,8 @@ public void testFilterPassthrough() throws Exception { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); FilterValve valve = new FilterValve(); valve.setFilterClass(PassthroughFilter.class.getName()); @@ -57,11 +57,10 @@ public void testFilterPassthrough() throws Exception { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -71,8 +70,8 @@ public void testFilterBlocks() throws Exception { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); FilterValve valve = new FilterValve(); valve.setFilterClass(BlockingFilter.class.getName()); @@ -81,14 +80,33 @@ public void testFilterBlocks() throws Exception { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_FORBIDDEN, rc); } - @Test + public void testFilterWrappingRequestThrows() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(WrappingFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + } + + + @Test(expected = LifecycleException.class) public void testNullFilterClassThrowsOnStart() throws Exception { Tomcat tomcat = getTomcatInstance(); @@ -98,18 +116,11 @@ public void testNullFilterClassThrowsOnStart() throws Exception { // Do NOT set filterClassName ctx.getPipeline().addValve(valve); - boolean threw = false; - try { - tomcat.start(); - } catch (LifecycleException e) { - threw = true; - } - - Assert.assertTrue("Should throw LifecycleException for null filter class", threw); + tomcat.start(); } - @Test + @Test(expected = LifecycleException.class) public void testInvalidFilterClassThrowsOnStart() throws Exception { Tomcat tomcat = getTomcatInstance(); @@ -119,26 +130,19 @@ public void testInvalidFilterClassThrowsOnStart() throws Exception { valve.setFilterClass("com.nonexistent.FakeFilter"); ctx.getPipeline().addValve(valve); - boolean threw = false; - try { - tomcat.start(); - } catch (LifecycleException e) { - threw = true; - } - - Assert.assertTrue("Should throw LifecycleException for invalid filter class", threw); + tomcat.start(); } @Test - public void testGetFilterNameReturnsNull() throws Exception { + public void testGetFilterNameReturnsNull() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getFilterName()); } @Test - public void testInitParams() throws Exception { + public void testInitParams() { FilterValve valve = new FilterValve(); valve.addInitParam("key1", "value1"); @@ -148,7 +152,7 @@ public void testInitParams() throws Exception { Assert.assertEquals("value2", valve.getInitParameter("key2")); Assert.assertNull(valve.getInitParameter("nonexistent")); - java.util.List names = Collections.list(valve.getInitParameterNames()); + List names = Collections.list(valve.getInitParameterNames()); Assert.assertEquals(2, names.size()); Assert.assertTrue(names.contains("key1")); Assert.assertTrue(names.contains("key2")); @@ -156,7 +160,7 @@ public void testInitParams() throws Exception { @Test - public void testInitParamsEmpty() throws Exception { + public void testInitParamsEmpty() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getInitParameter("anything")); @@ -165,7 +169,7 @@ public void testInitParamsEmpty() throws Exception { @Test - public void testGetSetFilterClassName() throws Exception { + public void testGetSetFilterClassName() { FilterValve valve = new FilterValve(); Assert.assertNull(valve.getFilterClassName()); @@ -173,11 +177,16 @@ public void testGetSetFilterClassName() throws Exception { valve.setFilterClassName("com.example.MyFilter"); Assert.assertEquals("com.example.MyFilter", valve.getFilterClassName()); - // setFilterClass is an alias valve.setFilterClass("com.example.OtherFilter"); Assert.assertEquals("com.example.OtherFilter", valve.getFilterClassName()); } + @Test(expected = IllegalStateException.class) + public void testGetServletContextThrowsBeforeStart() { + FilterValve valve = new FilterValve(); + valve.getServletContext(); + } + /** * A Filter that passes the request through to the next element in the chain. @@ -186,35 +195,35 @@ public static final class PassthroughFilter implements Filter { @Override public void doFilter(ServletRequest request, ServletResponse response, - FilterChain chain) throws IOException, ServletException { + FilterChain chain) throws IOException, ServletException { chain.doFilter(request, response); } } /** - * A Filter that blocks the request by sending a 403 response without - * calling chain.doFilter(). + * A Filter that blocks the request by sending a 403 response without calling chain.doFilter(). */ public static final class BlockingFilter implements Filter { @Override public void doFilter(ServletRequest request, ServletResponse response, - FilterChain chain) throws IOException, ServletException { + FilterChain chain) throws IOException, ServletException { ((HttpServletResponse) response).sendError(HttpServletResponse.SC_FORBIDDEN); } } - - private static final class OkServlet extends HttpServlet { - - private static final long serialVersionUID = 1L; + /** + * A Filter that wraps the request before calling chain.doFilter(), which FilterValve explicitly forbids. + */ + public static final class WrappingFilter implements Filter { @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + HttpServletRequestWrapper wrapped = new HttpServletRequestWrapper((HttpServletRequest) request); + chain.doFilter(wrapped, response); } } + } diff --git a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java index 98a92fe84f50..8829fa2d6394 100644 --- a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java +++ b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java @@ -17,12 +17,8 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.Serial; -import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -35,7 +31,6 @@ import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; import org.apache.tomcat.util.buf.ByteChunk; -import org.apache.tomcat.util.descriptor.web.ErrorPage; public class TestProxyErrorReportValve extends TomcatBaseTest { @@ -46,7 +41,8 @@ public class TestProxyErrorReportValve extends TomcatBaseTest { @Test public void testRedirectMode() throws Exception { Tomcat tomcat = getTomcatInstance(); - ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); Context ctx = getProgrammaticRootContext(); @@ -54,28 +50,49 @@ public void testRedirectMode() throws Exception { HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Server broke")); ctx.addServletMappingDecoded("/", "error"); - // Register an error page that the valve will redirect to + // Register an error page at the Host's error report valve level + // so findErrorPage() returns a URL for the redirect + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); + ctx.addServletMappingDecoded("/error-page", "errorPage"); + + tomcat.start(); + + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setProperty("errorCode." + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + "http://localhost:" + getPort() + "/error-page"); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), false); + + Assert.assertEquals(HttpServletResponse.SC_FOUND, rc); + } + + @Test + public void testProxyMode() throws Exception { + Tomcat tomcat = getTomcatInstance(); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_NOT_FOUND, "Not found")); + ctx.addServletMappingDecoded("/", "error"); + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); ctx.addServletMappingDecoded("/error-page", "errorPage"); - ErrorPage errorPage = new ErrorPage(); - errorPage.setErrorCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); - errorPage.setLocation("/error-page"); - ctx.addErrorPage(errorPage); tomcat.start(); + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setUseRedirect(false); + valve.setProperty("errorCode." + HttpServletResponse.SC_NOT_FOUND, + "http://localhost:" + getPort() + "/error-page"); + ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); - Map> resHead = new HashMap<>(); - // Don't follow redirects - int rc = getUrl("http://localhost:" + getPort(), res, resHead); - - // ProxyErrorReportValve uses error pages from context — but since - // it calls findErrorPage() which uses Host-level error pages, - // the context error page might not be found and it falls back to - // the superclass. The test verifies the valve is loaded correctly. - Assert.assertTrue("Status should indicate an error", - rc >= 400 || rc == 302); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); + Assert.assertTrue(res.toString().contains("ERROR_PAGE_OK")); } @@ -90,20 +107,18 @@ public void testNoErrorPageFallsBackToSuper() throws Exception { HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "No page configured")); ctx.addServletMappingDecoded("/", "error"); - // No error page configured — should fall back to ErrorReportValve's report() tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); String body = res.toString(); Assert.assertNotNull(body); - // The default ErrorReportValve produces HTML Assert.assertTrue("Should contain HTML error report", - body.contains("") || body.contains("

")); + body.contains("html") && + body.contains(String.valueOf(HttpServletResponse.SC_INTERNAL_SERVER_ERROR))); } @@ -114,17 +129,16 @@ public void testStatusBelow400Ignored() throws Exception { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -142,28 +156,24 @@ public void testStatusNotFound() throws Exception { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); String body = res.toString(); Assert.assertNotNull(body); - // Falls back to parent ErrorReportValve HTML Assert.assertTrue("Should contain error report", - body.contains("404") || body.contains("Not Found")); + body.contains(String.valueOf(HttpServletResponse.SC_NOT_FOUND))); } @Test - public void testGetSetProperties() throws Exception { + public void testGetSetProperties() { ProxyErrorReportValve valve = new ProxyErrorReportValve(); - // Defaults Assert.assertTrue(valve.getUseRedirect()); Assert.assertFalse(valve.getUsePropertiesFile()); - // Setters valve.setUseRedirect(false); Assert.assertFalse(valve.getUseRedirect()); @@ -174,19 +184,19 @@ public void testGetSetProperties() throws Exception { @Test public void testMessageInErrorReport() throws Exception { + final String customErrorMessage = "Custom error message"; Tomcat tomcat = getTomcatInstance(); ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); Context ctx = getProgrammaticRootContext(); Tomcat.addServlet(ctx, "error", new SendErrorServlet( - HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Custom error message")); + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, customErrorMessage)); ctx.addServletMappingDecoded("/", "error"); tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); @@ -194,8 +204,7 @@ public void testMessageInErrorReport() throws Exception { String body = res.toString(); Assert.assertNotNull(body); // Falls back to super.report() which includes the message - Assert.assertTrue("Should contain the custom error message", - body.contains("Custom error message")); + Assert.assertTrue(body.contains(customErrorMessage)); } @@ -212,21 +221,21 @@ public void testExceptionErrorReport() throws Exception { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); String body = res.toString(); Assert.assertNotNull(body); - Assert.assertTrue("Should contain exception info", - body.contains("RuntimeException")); + Assert.assertTrue(body.contains("RuntimeException")); } private static final class SendErrorServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; + private final int statusCode; private final String message; @@ -237,33 +246,19 @@ private SendErrorServlet(int statusCode, String message) { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { + throws IOException { resp.sendError(statusCode, message); } } - - private static final class OkServlet extends HttpServlet { - - private static final long serialVersionUID = 1L; - - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); - } - } - - private static final class ErrorPageServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); + throws IOException { resp.getWriter().print("ERROR_PAGE_OK"); } } @@ -271,11 +266,11 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) private static final class ExceptionServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; @Override - public void service(jakarta.servlet.ServletRequest request, - jakarta.servlet.ServletResponse response) throws IOException { + protected void doGet(HttpServletRequest req, HttpServletResponse resp) { throw new RuntimeException("Test exception"); } } diff --git a/test/org/apache/catalina/valves/TestSemaphoreValve.java b/test/org/apache/catalina/valves/TestSemaphoreValve.java index eefe4d4acb66..cf9ae9d92a36 100644 --- a/test/org/apache/catalina/valves/TestSemaphoreValve.java +++ b/test/org/apache/catalina/valves/TestSemaphoreValve.java @@ -17,12 +17,14 @@ package org.apache.catalina.valves; import java.io.IOException; -import java.nio.charset.StandardCharsets; +import java.io.Serial; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; -import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -31,6 +33,8 @@ import org.junit.Test; import org.apache.catalina.Context; +import org.apache.catalina.connector.Request; +import org.apache.catalina.connector.Response; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; import org.apache.tomcat.util.buf.ByteChunk; @@ -44,8 +48,8 @@ public void testBasicConcurrency() throws Exception { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(10); @@ -54,11 +58,33 @@ public void testBasicConcurrency() throws Exception { tomcat.start(); ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + @Test + public void testInterruptedConcurrency() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(10); + valve.setInterruptible(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } @@ -76,7 +102,7 @@ public void testNonBlockingDenied() throws Exception { SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(1); valve.setBlock(false); - valve.setHighConcurrencyStatus(503); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); ctx.getPipeline().addValve(valve); tomcat.start(); @@ -85,9 +111,7 @@ public void testNonBlockingDenied() throws Exception { AtomicInteger firstRc = new AtomicInteger(); Thread firstThread = new Thread(() -> { try { - ByteChunk r = new ByteChunk(); - r.setCharset(StandardCharsets.UTF_8); - firstRc.set(getUrl("http://localhost:" + getPort(), r, null)); + firstRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); } catch (IOException e) { // Ignore } @@ -99,9 +123,7 @@ public void testNonBlockingDenied() throws Exception { insideServlet.await(10, TimeUnit.SECONDS)); // Second request — should be denied because concurrency=1 and block=false - ByteChunk res2 = new ByteChunk(); - res2.setCharset(StandardCharsets.UTF_8); - int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, rc2); @@ -135,8 +157,7 @@ public void testHighConcurrencyStatusNotSet() throws Exception { // First request holds the permit Thread firstThread = new Thread(() -> { try { - ByteChunk r = new ByteChunk(); - getUrl("http://localhost:" + getPort(), r, null); + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); } catch (IOException e) { // Ignore } @@ -147,10 +168,9 @@ public void testHighConcurrencyStatusNotSet() throws Exception { insideServlet.await(10, TimeUnit.SECONDS)); // Second request — denied but no error status is sent - ByteChunk res2 = new ByteChunk(); - int rc2 = getUrl("http://localhost:" + getPort(), res2, null); + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); - // With no highConcurrencyStatus, response is 200 with no body + // With no highConcurrencyStatus, response is 200 without body Assert.assertEquals(HttpServletResponse.SC_OK, rc2); canReturn.countDown(); @@ -159,7 +179,7 @@ public void testHighConcurrencyStatusNotSet() throws Exception { @Test - public void testGetSetProperties() throws Exception { + public void testGetSetProperties() { SemaphoreValve valve = new SemaphoreValve(); // Defaults @@ -182,8 +202,8 @@ public void testGetSetProperties() throws Exception { valve.setInterruptible(true); Assert.assertTrue(valve.getInterruptible()); - valve.setHighConcurrencyStatus(429); - Assert.assertEquals(429, valve.getHighConcurrencyStatus()); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_TOO_MANY_REQUESTS); + Assert.assertEquals(HttpServletResponse.SC_TOO_MANY_REQUESTS, valve.getHighConcurrencyStatus()); } @@ -193,8 +213,8 @@ public void testFairSemaphore() throws Exception { Context ctx = getProgrammaticRootContext(); - Tomcat.addServlet(ctx, "ok", new OkServlet()); - ctx.addServletMappingDecoded("/", "ok"); + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); SemaphoreValve valve = new SemaphoreValve(); valve.setConcurrency(5); @@ -203,30 +223,178 @@ public void testFairSemaphore() throws Exception { tomcat.start(); + Assert.assertNotNull(valve.semaphore); + Assert.assertTrue(valve.semaphore.isFair()); + Assert.assertEquals(5, valve.semaphore.availablePermits()); + ByteChunk res = new ByteChunk(); - res.setCharset(StandardCharsets.UTF_8); int rc = getUrl("http://localhost:" + getPort(), res, null); Assert.assertEquals(HttpServletResponse.SC_OK, rc); - Assert.assertEquals("OK", res.toString()); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); } + @Test + public void testBlockingWaitsForPermit() throws Exception { + Tomcat tomcat = getTomcatInstance(); - private static final class OkServlet extends HttpServlet { + Context ctx = getProgrammaticRootContext(); - private static final long serialVersionUID = 1L; + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { - resp.setContentType("text/plain"); - resp.getWriter().print("OK"); - } + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + AtomicReference firstError = new AtomicReference<>(); + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + } catch (IOException e) { + firstError.set(e); + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + AtomicInteger secondRc = new AtomicInteger(); + AtomicReference secondError = new AtomicReference<>(); + Thread secondThread = new Thread(() -> { + try { + secondRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); + } catch (IOException e) { + secondError.set(e); + } + }); + secondThread.start(); + + // Give the second request time to arrive and block on the semaphore + Thread.sleep(500); + + Assert.assertTrue("Second request should be blocked waiting for permit", secondThread.isAlive()); + + canReturn.countDown(); + firstThread.join(10000); + Assert.assertNull(firstError.get()); + + secondThread.join(10000); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertNull(secondError.get()); + Assert.assertEquals(HttpServletResponse.SC_OK, secondRc.get()); + } + + @Test + public void testControlConcurrencyBypass() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/slow", "slow"); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/bypass", "hello"); + + SemaphoreValve valve = new SemaphoreValve() { + @Override + public boolean controlConcurrency(Request request, Response response) { + return !request.getDecodedRequestURI().equals("/bypass"); + } + }; + valve.setConcurrency(1); + valve.setBlock(false); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + } catch (IOException e) { + // Ignored + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Request to /bypass should succeed despite all permits being held, + // because controlConcurrency() returns false for this path + int bypassRc = getUrl("http://localhost:" + getPort() + "/bypass", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_OK, bypassRc); + + int deniedRc = getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, deniedRc); + + canReturn.countDown(); + firstThread.join(10000); } + @Test + public void testInterruptibleDenied() throws Exception { + SemaphoreValve semaphoreValve = new SemaphoreValve(); + semaphoreValve.setConcurrency(1); + semaphoreValve.setBlock(true); + semaphoreValve.setInterruptible(true); + semaphoreValve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + + semaphoreValve.semaphore = new Semaphore(1, false); + + AtomicBoolean nextInvoked = new AtomicBoolean(false); + semaphoreValve.setNext(new ValveBase() { + @Override + public void invoke(Request request, Response response) { + nextInvoked.set(true); + } + }); + + MockResponse response = new MockResponse(); + + semaphoreValve.semaphore.acquire(); + + // On a new thread, valve will block on semaphore.acquire() because the permit is already held. + CountDownLatch invokeStarted = new CountDownLatch(1); + Thread blocked = new Thread(() -> { + invokeStarted.countDown(); + try { + semaphoreValve.invoke(null, response); + } catch (Throwable t) { + // Ignored + } + }); + blocked.start(); + + Assert.assertTrue(invokeStarted.await(10, TimeUnit.SECONDS)); + Thread.sleep(200); + + blocked.interrupt(); + blocked.join(10000); + Assert.assertFalse(blocked.isAlive()); + + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, response.getStatus()); + + Assert.assertFalse("Next valve should not be invoked when permit denied", nextInvoked.get()); + + Assert.assertEquals(0, semaphoreValve.semaphore.availablePermits()); + + semaphoreValve.semaphore.release(); + } private static final class SlowServlet extends HttpServlet { + @Serial private static final long serialVersionUID = 1L; private final CountDownLatch insideServlet; private final CountDownLatch canReturn; @@ -238,10 +406,10 @@ private SlowServlet(CountDownLatch insideServlet, CountDownLatch canReturn) { @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) - throws ServletException, IOException { + throws IOException { insideServlet.countDown(); try { - canReturn.await(30, TimeUnit.SECONDS); + Assert.assertTrue(canReturn.await(30, TimeUnit.SECONDS)); } catch (InterruptedException e) { // Ignore } @@ -249,4 +417,24 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) resp.getWriter().print("OK"); } } + + public static class MockResponse extends Response { + + public MockResponse() { + super(null); + } + + private int status = HttpServletResponse.SC_OK; + + @Override + public void sendError(int status) throws IOException { + this.status = status; + } + + @Override + public int getStatus() { + return status; + } + } + }