diff --git a/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/build.gradle b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/build.gradle index f5542d48255..1d3cc39c2c4 100644 --- a/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/build.gradle +++ b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/build.gradle @@ -15,6 +15,13 @@ muzzle { extraDependency 'org.apache.tomcat:tomcat-catalina:7.0.4' assertInverse = true } + pass { + name = 'glassfish' + group = 'org.glassfish.main.extras' + module = 'glassfish-embedded-all' + versions = '[4.0, 6.1.0)' // GlassFish 6.1.0+ uses jakarta.* namespace; our advice uses javax.servlet.http.Part + assertInverse = true + } } apply from: "$rootDir/gradle/java.gradle" @@ -22,10 +29,15 @@ apply from: "$rootDir/gradle/java.gradle" dependencies { compileOnly group: 'org.apache.tomcat', name: 'tomcat-catalina', version: '7.0.4' compileOnly group: 'org.apache.tomcat', name: 'tomcat-coyote', version: '7.0.4' + // Servlet 3.1 API needed to reference Part.getSubmittedFileName() in GlassFishMultipartInstrumentation. + // tomcat-catalina:7.0.4 provides only Servlet 3.0 (no getSubmittedFileName); GlassFish 4+ is Servlet 3.1. + compileOnly group: 'javax.servlet', name: 'javax.servlet-api', version: '3.1.0' implementation project(':dd-java-agent:instrumentation:tomcat:tomcat-common') testImplementation group: 'org.apache.tomcat', name: 'tomcat-catalina', version: '7.0.4' testImplementation group: 'org.apache.tomcat', name: 'tomcat-coyote', version: '7.0.4' + testImplementation group: 'javax.servlet', name: 'javax.servlet-api', version: '3.1.0' + testImplementation libs.bundles.mockito } // testing happens in tomcat-5.5 module diff --git a/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelper.java b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelper.java new file mode 100644 index 00000000000..e4ffae15a06 --- /dev/null +++ b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelper.java @@ -0,0 +1,164 @@ +package datadog.trace.instrumentation.tomcat7; + +import datadog.appsec.api.blocking.BlockingContentType; +import datadog.trace.api.Config; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.http.MultipartContentDecoder; +import datadog.trace.bootstrap.blocking.BlockingActionHelper; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; + +public final class GlassFishBlockingHelper { + + public static final int MAX_FILE_CONTENT_COUNT = Config.get().getAppSecMaxFileContentCount(); + public static final int MAX_FILE_CONTENT_BYTES = Config.get().getAppSecMaxFileContentBytes(); + + /** + * Attempts to commit a blocking response via the registered {@link BlockResponseFunction} or via + * the Servlet API fallback, then marks the trace segment as effectively blocked. + * + *

Returns {@code true} if the response was committed (regardless of whether {@link + * datadog.trace.api.internal.TraceSegment#effectivelyBlocked()} succeeded). Returns {@code false} + * if no response could be committed. + */ + public static boolean tryBlock( + RequestContext reqCtx, + HttpServletRequest fallbackReq, + HttpServletResponse fallbackResp, + Flow.Action.RequestBlockingAction rba) { + try { + BlockResponseFunction brf = reqCtx.getBlockResponseFunction(); + if (brf != null) { + brf.tryCommitBlockingResponse(reqCtx.getTraceSegment(), rba); + } else if (!commitBlocking(fallbackReq, fallbackResp, rba)) { + return false; + } + } catch (Exception ignored) { + return false; + } + // Response was committed — mark as blocked on a best-effort basis. + // effectivelyBlocked() can throw if the span is already finished; that must not suppress the + // true return value since the response has already been sent to the client. + try { + reqCtx.getTraceSegment().effectivelyBlocked(); + } catch (Exception ignored) { + } + return true; + } + + /** + * Collects filenames and file contents from the given multipart parts, fires the AppSec IG + * callbacks, and commits a blocking response if the WAF requests one. + * + *

Returns {@code true} if a blocking response was committed (the caller should replace the + * parts collection with an empty list to prevent further processing). + */ + public static boolean processPartsAndBlock( + Collection parts, + RequestContext reqCtx, + HttpServletRequest fallbackReq, + HttpServletResponse fallbackResp, + BiFunction, Flow> filenamesCb, + BiFunction, Flow> contentCb) { + List filenames = null; + List contents = null; + for (Object partObj : parts) { + try { + if (!(partObj instanceof Part)) { + continue; + } + Part part = (Part) partObj; + String filename = part.getSubmittedFileName(); + if (filename == null) { + continue; + } + if (filenamesCb != null && !filename.isEmpty()) { + if (filenames == null) { + filenames = new ArrayList<>(); + } + filenames.add(filename); + } + if (contentCb != null) { + if (contents == null) { + contents = new ArrayList<>(); + } + if (contents.size() < MAX_FILE_CONTENT_COUNT) { + try (InputStream is = part.getInputStream()) { + contents.add( + MultipartContentDecoder.readInputStream( + is, MAX_FILE_CONTENT_BYTES, part.getContentType())); + } catch (Exception ignored) { + contents.add(""); + } + } + } + } catch (Exception ignored) { + } + } + + if (filenames != null && !filenames.isEmpty()) { + Flow flow = filenamesCb.apply(reqCtx, filenames); + Flow.Action action = flow.getAction(); + if (action instanceof Flow.Action.RequestBlockingAction) { + if (tryBlock( + reqCtx, fallbackReq, fallbackResp, (Flow.Action.RequestBlockingAction) action)) { + return true; + } + } + } + + if (contents != null && !contents.isEmpty()) { + Flow contentFlow = contentCb.apply(reqCtx, contents); + Flow.Action contentAction = contentFlow.getAction(); + if (contentAction instanceof Flow.Action.RequestBlockingAction) { + return tryBlock( + reqCtx, fallbackReq, fallbackResp, (Flow.Action.RequestBlockingAction) contentAction); + } + } + + return false; + } + + public static boolean commitBlocking( + HttpServletRequest request, + HttpServletResponse response, + Flow.Action.RequestBlockingAction rba) { + if (response == null) { + return false; + } + try { + if (response.isCommitted()) { + return false; + } + response.reset(); + response.setStatus(BlockingActionHelper.getHttpCode(rba.getStatusCode())); + for (Map.Entry e : rba.getExtraHeaders().entrySet()) { + response.setHeader(e.getKey(), e.getValue()); + } + if (rba.getBlockingContentType() != BlockingContentType.NONE) { + String accept = request != null ? request.getHeader("Accept") : null; + BlockingActionHelper.TemplateType type = + BlockingActionHelper.determineTemplateType(rba.getBlockingContentType(), accept); + byte[] body = BlockingActionHelper.getTemplate(type, rba.getSecurityResponseId()); + if (body != null) { + response.setHeader("Content-Type", BlockingActionHelper.getContentType(type)); + response.setHeader("Content-Length", Integer.toString(body.length)); + response.getOutputStream().write(body); + } + } + response.flushBuffer(); + return true; + } catch (Exception e) { + return false; + } + } +} diff --git a/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishMultipartInstrumentation.java b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishMultipartInstrumentation.java new file mode 100644 index 00000000000..00df4717397 --- /dev/null +++ b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/main/java/datadog/trace/instrumentation/tomcat7/GlassFishMultipartInstrumentation.java @@ -0,0 +1,130 @@ +package datadog.trace.instrumentation.tomcat7; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.api.gateway.Events.EVENTS; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.gateway.CallbackProvider; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import net.bytebuddy.asm.Advice; + +/** + * GlassFish/Payara does not have {@code Request.parseParts()} — instead {@code Request.getParts()} + * delegates to {@code org.apache.catalina.fileupload.Multipart.getParts()}. This instrumentation + * hooks that GlassFish-specific class to report uploaded file names and contents to the AppSec WAF + * via the {@code requestFilesFilenames} and {@code requestFilesContent} IG events. + * + *

Because {@code org.apache.catalina.fileupload.Multipart} does not exist in standard Tomcat, + * this instrumentation is automatically skipped by ByteBuddy on non-GlassFish containers. + * + *

This advice casts each {@code Part} through the {@code javax.servlet.http.Part} interface + * (which {@code org.apache.catalina.fileupload.PartItem} implements) to avoid Java module-system + * access restrictions that prevent reflective invocation of methods on GlassFish-internal classes. + */ +@AutoService(InstrumenterModule.class) +public class GlassFishMultipartInstrumentation extends InstrumenterModule.AppSec + implements Instrumenter.ForSingleType, Instrumenter.HasMethodAdvice { + + public GlassFishMultipartInstrumentation() { + super("tomcat"); + } + + @Override + public String muzzleDirective() { + return "glassfish"; + } + + @Override + public String instrumentedType() { + return "org.apache.catalina.fileupload.Multipart"; + } + + @Override + public String[] helperClassNames() { + return new String[] { + "datadog.trace.instrumentation.tomcat7.GlassFishBlockingHelper", + }; + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + named("getParts").and(takesArguments(0)).and(isPublic()), + getClass().getName() + "$GetPartsAdvice"); + } + + public static class GetPartsAdvice { + + @Advice.OnMethodExit(suppress = Throwable.class, onThrowable = Throwable.class) + static void after( + @Advice.This Object thisMultipart, + @Advice.Return(readOnly = false) Collection parts, + @Advice.Thrown Throwable t) { + if (t != null || parts == null || parts.isEmpty()) { + return; + } + + AgentSpan agentSpan = AgentTracer.activeSpan(); + if (agentSpan == null) { + return; + } + RequestContext reqCtx = agentSpan.getRequestContext(); + if (reqCtx == null || reqCtx.getData(RequestContextSlot.APPSEC) == null) { + return; + } + + CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC); + BiFunction, Flow> filenamesCb = + cbp.getCallback(EVENTS.requestFilesFilenames()); + BiFunction, Flow> contentCb = + cbp.getCallback(EVENTS.requestFilesContent()); + if (filenamesCb == null && contentCb == null) { + return; + } + + // Extract servlet request/response for fallback blocking when no BlockResponseFunction is + // registered (Payara: TomcatServerInstrumentation is muzzled out for Payara's response type). + // setAccessible works here because this code is inlined into Multipart.getParts() — + // the same module as the private field's owner class. + HttpServletRequest fallbackReq = null; + HttpServletResponse fallbackResp = null; + try { + Field f = thisMultipart.getClass().getDeclaredField("request"); + f.setAccessible(true); + Object catReq = f.get(thisMultipart); + if (catReq instanceof HttpServletRequest) { + fallbackReq = (HttpServletRequest) catReq; + } + if (catReq != null) { + Method m = catReq.getClass().getMethod("getResponse"); + Object catResp = m.invoke(catReq); + if (catResp instanceof HttpServletResponse) { + fallbackResp = (HttpServletResponse) catResp; + } + } + } catch (Exception ignored) { + } + + if (GlassFishBlockingHelper.processPartsAndBlock( + parts, reqCtx, fallbackReq, fallbackResp, filenamesCb, contentCb)) { + parts = Collections.emptyList(); + } + } + } +} diff --git a/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/test/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelperTest.java b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/test/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelperTest.java new file mode 100644 index 00000000000..58cbe6eec20 --- /dev/null +++ b/dd-java-agent/instrumentation/tomcat/tomcat-appsec/tomcat-appsec-7.0/src/test/java/datadog/trace/instrumentation/tomcat7/GlassFishBlockingHelperTest.java @@ -0,0 +1,385 @@ +package datadog.trace.instrumentation.tomcat7; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.contains; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import datadog.appsec.api.blocking.BlockingContentType; +import datadog.trace.api.gateway.BlockResponseFunction; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.internal.TraceSegment; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.Part; +import org.junit.jupiter.api.Test; + +class GlassFishBlockingHelperTest { + + // ------- commitBlocking() ------- + + @Test + void commitBlocking_nullResponse_returnsFalse() { + assertFalse(GlassFishBlockingHelper.commitBlocking(null, null, rba(403))); + } + + @Test + void commitBlocking_committedResponse_returnsFalse() { + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(true); + assertFalse(GlassFishBlockingHelper.commitBlocking(null, resp, rba(403))); + } + + @Test + void commitBlocking_blockingContentTypeNone_setsStatusWithoutBody() throws IOException { + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + + assertTrue( + GlassFishBlockingHelper.commitBlocking( + null, resp, new Flow.Action.RequestBlockingAction(403, BlockingContentType.NONE))); + + verify(resp).setStatus(403); + verify(resp).flushBuffer(); + verify(resp, never()).setHeader(eq("Content-Type"), any()); + verify(resp, never()).getOutputStream(); + } + + @Test + void commitBlocking_withJsonAccept_writesJsonBody() throws IOException { + HttpServletRequest req = mock(HttpServletRequest.class); + when(req.getHeader("Accept")).thenReturn("application/json"); + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + + assertTrue(GlassFishBlockingHelper.commitBlocking(req, resp, rba(403))); + + verify(resp).setStatus(403); + verify(resp).setHeader(eq("Content-Type"), contains("json")); + verify(resp).setHeader(eq("Content-Length"), any()); + assertTrue(out.getBytes().length > 0); + verify(resp).flushBuffer(); + } + + @Test + void commitBlocking_withHtmlAccept_writesHtmlBody() throws IOException { + HttpServletRequest req = mock(HttpServletRequest.class); + when(req.getHeader("Accept")).thenReturn("text/html"); + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + + assertTrue(GlassFishBlockingHelper.commitBlocking(req, resp, rba(403))); + + verify(resp).setHeader(eq("Content-Type"), contains("html")); + assertTrue(out.getBytes().length > 0); + } + + @Test + void commitBlocking_nullRequest_defaultsToJsonBody() throws IOException { + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + + assertTrue(GlassFishBlockingHelper.commitBlocking(null, resp, rba(403))); + + verify(resp).setStatus(403); + assertTrue(out.getBytes().length > 0); + } + + @Test + void commitBlocking_ioException_returnsFalse() throws IOException { + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenThrow(new IOException("stream error")); + + assertFalse(GlassFishBlockingHelper.commitBlocking(null, resp, rba(403))); + } + + // ------- tryBlock() ------- + + @Test + void tryBlock_withBrf_commitsViaFunctionAndReturnsTrue() throws Exception { + TraceSegment segment = mock(TraceSegment.class); + BlockResponseFunction brf = mock(BlockResponseFunction.class); + RequestContext reqCtx = mockReqCtx(brf, segment); + + Flow.Action.RequestBlockingAction action = rba(403); + assertTrue(GlassFishBlockingHelper.tryBlock(reqCtx, null, null, action)); + + verify(brf).tryCommitBlockingResponse(segment, action); + verify(segment).effectivelyBlocked(); + } + + @Test + void tryBlock_noBrf_fallbackSucceeds_returnsTrue() throws IOException { + TraceSegment segment = mock(TraceSegment.class); + RequestContext reqCtx = mockReqCtx(null, segment); + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + + assertTrue(GlassFishBlockingHelper.tryBlock(reqCtx, null, resp, rba(403))); + verify(segment).effectivelyBlocked(); + } + + @Test + void tryBlock_noBrf_nullFallbackResponse_returnsFalse() { + RequestContext reqCtx = mock(RequestContext.class); + when(reqCtx.getBlockResponseFunction()).thenReturn(null); + + assertFalse(GlassFishBlockingHelper.tryBlock(reqCtx, null, null, rba(403))); + verify(reqCtx, never()).getTraceSegment(); + } + + @Test + void tryBlock_brfThrows_returnsFalse() throws Exception { + TraceSegment segment = mock(TraceSegment.class); + BlockResponseFunction brf = mock(BlockResponseFunction.class); + RequestContext reqCtx = mockReqCtx(brf, segment); + doThrow(new RuntimeException("commit failed")) + .when(brf) + .tryCommitBlockingResponse(any(), any(Flow.Action.RequestBlockingAction.class)); + + assertFalse(GlassFishBlockingHelper.tryBlock(reqCtx, null, null, rba(403))); + verify(segment, never()).effectivelyBlocked(); + } + + @Test + void tryBlock_effectivelyBlockedThrows_stillReturnsTrue() throws Exception { + TraceSegment segment = mock(TraceSegment.class); + BlockResponseFunction brf = mock(BlockResponseFunction.class); + RequestContext reqCtx = mockReqCtx(brf, segment); + doThrow(new RuntimeException("span already finished")).when(segment).effectivelyBlocked(); + + assertTrue(GlassFishBlockingHelper.tryBlock(reqCtx, null, null, rba(403))); + } + + // ------- processPartsAndBlock() ------- + + @Test + void processPartsAndBlock_formField_skipped() throws Exception { + Part formField = mockPart(null, "text/plain", new byte[0]); + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> filenamesCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(formField), reqCtx, null, null, filenamesCb, null)); + verify(formField).getSubmittedFileName(); + verify(formField, never()).getInputStream(); + } + + @Test + void processPartsAndBlock_emptyFilename_notAddedToFilenames_butContentRead() throws Exception { + byte[] content = "data".getBytes(); + Part filePart = mockPart("", "application/octet-stream", content); + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> filenamesCb = mockPassThroughCb(); + BiFunction, Flow> contentCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, null, filenamesCb, contentCb)); + + verify(filePart).getInputStream(); + verify(filenamesCb, never()).apply(any(), any()); + verify(contentCb).apply(eq(reqCtx), any()); + } + + @Test + void processPartsAndBlock_normalFilename_reportedViaFilenamesCb() throws Exception { + Part filePart = mockPart("file.txt", "text/plain", "hello".getBytes()); + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> filenamesCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, null, filenamesCb, null)); + + verify(filenamesCb).apply(eq(reqCtx), eq(Collections.singletonList("file.txt"))); + } + + @Test + void processPartsAndBlock_contentRead_reportedViaContentCb() throws Exception { + Part filePart = mockPart("file.bin", "application/octet-stream", new byte[] {1, 2, 3}); + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> contentCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, null, null, contentCb)); + + verify(contentCb).apply(eq(reqCtx), any()); + } + + @Test + void processPartsAndBlock_maxFilesLimit_enforced() throws Exception { + int limit = GlassFishBlockingHelper.MAX_FILE_CONTENT_COUNT; + Part[] tooMany = new Part[limit + 1]; + for (int i = 0; i <= limit; i++) { + tooMany[i] = mockPart("f" + i + ".bin", "application/octet-stream", new byte[0]); + } + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> contentCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Arrays.asList(tooMany), reqCtx, null, null, null, contentCb)); + + verify(contentCb).apply(eq(reqCtx), any(List.class)); + verify(tooMany[limit], never()).getInputStream(); + } + + @Test + @SuppressWarnings("unchecked") + void processPartsAndBlock_getInputStreamThrows_emptyStringFallback() throws Exception { + Part filePart = mock(Part.class); + when(filePart.getSubmittedFileName()).thenReturn("bad.bin"); + when(filePart.getInputStream()).thenThrow(new IOException("disk error")); + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> contentCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, null, null, contentCb)); + + verify(contentCb).apply(eq(reqCtx), eq(Collections.singletonList(""))); + } + + @Test + void processPartsAndBlock_filenamesCbBlocks_contentCbNotFired() throws Exception { + Part filePart = mockPart("evil.exe", "application/octet-stream", "content".getBytes()); + TraceSegment segment = mock(TraceSegment.class); + RequestContext reqCtx = mockReqCtx(null, segment); + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + BiFunction, Flow> filenamesCb = mockBlockingCb(403); + BiFunction, Flow> contentCb = mockPassThroughCb(); + + assertTrue( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, resp, filenamesCb, contentCb)); + + verify(contentCb, never()).apply(any(), any()); + } + + @Test + void processPartsAndBlock_contentCbBlocks_returnsTrue() throws Exception { + Part filePart = mockPart("upload.bin", "application/octet-stream", "payload".getBytes()); + TraceSegment segment = mock(TraceSegment.class); + RequestContext reqCtx = mockReqCtx(null, segment); + TestServletOutputStream out = new TestServletOutputStream(); + HttpServletResponse resp = mock(HttpServletResponse.class); + when(resp.isCommitted()).thenReturn(false); + when(resp.getOutputStream()).thenReturn(out); + BiFunction, Flow> filenamesCb = mockPassThroughCb(); + BiFunction, Flow> contentCb = mockBlockingCb(403); + + assertTrue( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList(filePart), reqCtx, null, resp, filenamesCb, contentCb)); + } + + @Test + void processPartsAndBlock_nonPartObject_skipped() { + RequestContext reqCtx = mockReqCtx(null, mock(TraceSegment.class)); + BiFunction, Flow> filenamesCb = mockPassThroughCb(); + + assertFalse( + GlassFishBlockingHelper.processPartsAndBlock( + Collections.singletonList("not-a-part"), reqCtx, null, null, filenamesCb, null)); + + verify(filenamesCb, never()).apply(any(), any()); + } + + // ------- Helpers ------- + + private static Flow.Action.RequestBlockingAction rba(int statusCode) { + return new Flow.Action.RequestBlockingAction(statusCode, BlockingContentType.AUTO); + } + + private static RequestContext mockReqCtx(BlockResponseFunction brf, TraceSegment segment) { + RequestContext reqCtx = mock(RequestContext.class); + when(reqCtx.getBlockResponseFunction()).thenReturn(brf); + when(reqCtx.getTraceSegment()).thenReturn(segment); + return reqCtx; + } + + private static Part mockPart(String submittedFilename, String contentType, byte[] content) + throws Exception { + Part part = mock(Part.class); + when(part.getSubmittedFileName()).thenReturn(submittedFilename); + when(part.getContentType()).thenReturn(contentType); + when(part.getInputStream()).thenAnswer(ignored -> new java.io.ByteArrayInputStream(content)); + return part; + } + + @SuppressWarnings("unchecked") + private static BiFunction, Flow> mockPassThroughCb() { + BiFunction, Flow> cb = mock(BiFunction.class); + Flow flow = mock(Flow.class); + when(flow.getAction()).thenReturn(Flow.Action.Noop.INSTANCE); + when(cb.apply(any(), any())).thenReturn(flow); + return cb; + } + + @SuppressWarnings("unchecked") + private static BiFunction, Flow> mockBlockingCb( + int statusCode) { + BiFunction, Flow> cb = mock(BiFunction.class); + Flow flow = mock(Flow.class); + when(flow.getAction()) + .thenReturn(new Flow.Action.RequestBlockingAction(statusCode, BlockingContentType.AUTO)); + when(cb.apply(any(), any())).thenReturn(flow); + return cb; + } + + private static final class TestServletOutputStream extends ServletOutputStream { + private final ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setWriteListener(WriteListener listener) {} + + @Override + public void write(int b) throws IOException { + buffer.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + buffer.write(b, off, len); + } + + public byte[] getBytes() { + return buffer.toByteArray(); + } + } +}