From acaddce3017b8a0a2f91f516a189ad5c144247cd Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Mon, 22 Nov 2021 13:32:21 +0100 Subject: Delay registration of read listener until failure wiring is complete Handle exceptions from getInputStream() and setReadListener() equally to exceptions from listener's onError(). By delaying registration the completion of finishedFuture will trigger an error response immediately. --- .../http/server/jetty/HttpRequestDispatch.java | 8 ++---- .../http/server/jetty/ServletRequestReader.java | 32 ++++++++++++++++++---- 2 files changed, 29 insertions(+), 11 deletions(-) (limited to 'container-core/src/main/java/com/yahoo/jdisc') diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index 779a5f65673..aedbb3afb69 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -100,6 +100,8 @@ class HttpRequestDispatch { if (t != null) requestCompletion.completeExceptionally(t); else requestCompletion.complete(null); }); + // Start the reader after wiring of "finished futures" are complete + servletRequestReader.start(); } ContentChannel dispatchFilterRequest(Response response) { @@ -217,11 +219,7 @@ class HttpRequestDispatch { HttpRequestFactory.copyHeaders(jettyRequest, jdiscRequest); requestContentChannel = requestHandler.handleRequest(jdiscRequest, servletResponseController.responseHandler()); } - //TODO If the below method throws servletRequestReader will not complete and - // requestContentChannel will not be closed and there is a reference leak - // Ditto for the servletInputStream - return new ServletRequestReader( - jettyRequest.getInputStream(), requestContentChannel, jDiscContext.janitor, metricReporter); + return new ServletRequestReader(jettyRequest, requestContentChannel, jDiscContext.janitor, metricReporter); } private static RequestHandler newRequestHandler(JDiscContext context, diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java index 1def9ccaab1..43050a53f58 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java @@ -6,6 +6,7 @@ import com.yahoo.jdisc.handler.ContentChannel; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Objects; @@ -33,7 +34,7 @@ import java.util.logging.Logger; class ServletRequestReader { private enum State { - READING, ALL_DATA_READ, REQUEST_CONTENT_CLOSED + NOT_STARTED, READING, ALL_DATA_READ, REQUEST_CONTENT_CLOSED } private static final Logger log = Logger.getLogger(ServletRequestReader.class.getName()); @@ -42,11 +43,12 @@ class ServletRequestReader { private final Object monitor = new Object(); - private final ServletInputStream in; + private final HttpServletRequest req; private final ContentChannel requestContentChannel; private final Janitor janitor; private final RequestMetricReporter metricReporter; + private ServletInputStream in; private Throwable errorDuringRead; private int bytesRead; @@ -63,7 +65,7 @@ class ServletRequestReader { * (i.e. when being called from user code, don't call back into user code.) */ // GuardedBy("monitor") - private State state = State.READING; + private State state = State.NOT_STARTED; /** * Number of calls that we're waiting for from user code. @@ -94,15 +96,31 @@ class ServletRequestReader { private final CompletableFuture finishedFuture = new CompletableFuture<>(); ServletRequestReader( - ServletInputStream in, + HttpServletRequest req, ContentChannel requestContentChannel, Janitor janitor, RequestMetricReporter metricReporter) { - this.in = Objects.requireNonNull(in); + this.req = Objects.requireNonNull(req); this.requestContentChannel = Objects.requireNonNull(requestContentChannel); this.janitor = Objects.requireNonNull(janitor); this.metricReporter = Objects.requireNonNull(metricReporter); - in.setReadListener(new Listener()); + } + + /** Register read listener to start reading request data */ + void start() { + try { + ServletInputStream in; + synchronized (monitor) { + if (state != State.NOT_STARTED) throw new IllegalStateException("State=" + state); + in = req.getInputStream(); // may throw + this.in = in; + state = State.READING; + } + // Not holding monitor in case listener is invoked from this thread + in.setReadListener(new Listener()); // may throw + } catch (Throwable t) { + fail(t); + } } CompletableFuture finishedFuture() { return finishedFuture; } @@ -111,6 +129,8 @@ class ServletRequestReader { @Override public void onDataAvailable() throws IOException { + ServletInputStream in; + synchronized (monitor) { in = ServletRequestReader.this.in; } while (in.isReady()) { final byte[] buffer = new byte[BUFFER_SIZE_BYTES]; int numBytesRead; -- cgit v1.2.3