From 79a55aeadac600f8f2e32ed17b9832854bf851e7 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Thu, 7 Oct 2021 11:08:23 +0200 Subject: Refactor async completion logic for read and write path Ensure that failure from request read, write or async context correctly terminates request processing. --- .../server/jetty/ErrorResponseContentCreator.java | 7 +- .../http/server/jetty/HttpRequestDispatch.java | 123 ++++++------ .../server/jetty/JDiscFilterInvokerFilter.java | 2 +- .../jdisc/http/server/jetty/JDiscHttpServlet.java | 2 +- .../server/jetty/ServletOutputStreamWriter.java | 62 ++---- .../http/server/jetty/ServletRequestReader.java | 220 +++++++++------------ .../server/jetty/ServletResponseController.java | 193 ++++++++---------- .../jetty/ErrorResponseContentCreatorTest.java | 3 +- 8 files changed, 258 insertions(+), 354 deletions(-) (limited to 'container-core') diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreator.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreator.java index cd21dccde0e..e33f8fd178d 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreator.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreator.java @@ -5,7 +5,6 @@ import org.eclipse.jetty.util.ByteArrayISO8859Writer; import org.eclipse.jetty.util.StringUtil; import java.io.IOException; -import java.util.Optional; /** * Creates HTML body having the status code, error message and request uri. @@ -14,12 +13,12 @@ import java.util.Optional; * * @author bjorncs */ -public class ErrorResponseContentCreator { +class ErrorResponseContentCreator { private final ByteArrayISO8859Writer writer = new ByteArrayISO8859Writer(2048); - public byte[] createErrorContent(String requestUri, int statusCode, Optional message) { - String sanitizedString = message.map(StringUtil::sanitizeXmlString).orElse(""); + byte[] createErrorContent(String requestUri, int statusCode, String message) { + String sanitizedString = message != null ? StringUtil.sanitizeXmlString(message) : ""; String statusCodeString = Integer.toString(statusCode); writer.resetWriter(); try { diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java index 512d78d4537..e81d060a938 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/HttpRequestDispatch.java @@ -22,7 +22,8 @@ import org.eclipse.jetty.server.Request; import org.eclipse.jetty.util.Callback; import javax.servlet.AsyncContext; -import javax.servlet.ServletInputStream; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; @@ -32,9 +33,6 @@ import java.util.Arrays; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.BiConsumer; -import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; @@ -53,16 +51,13 @@ class HttpRequestDispatch { private final static String CHARSET_ANNOTATION = ";charset="; private final JDiscContext jDiscContext; - private final AsyncContext async; private final Request jettyRequest; private final ServletResponseController servletResponseController; private final RequestHandler requestHandler; private final RequestMetricReporter metricReporter; - private final BiConsumer completeRequestCallback; - private final AtomicBoolean completeRequestCalled = new AtomicBoolean(false); - public HttpRequestDispatch(JDiscContext jDiscContext, + HttpRequestDispatch(JDiscContext jDiscContext, AccessLogEntry accessLogEntry, Context metricContext, HttpServletRequest servletRequest, @@ -79,45 +74,71 @@ class HttpRequestDispatch { metricReporter, jDiscContext.developerMode()); shutdownConnectionGracefullyIfThresholdReached(jettyRequest); - this.async = servletRequest.startAsync(); - async.setTimeout(0); metricReporter.uriLength(jettyRequest.getOriginalURI().length()); - completeRequestCallback = this::handleCompleteRequestCallback; } - public void dispatch() throws IOException { + void dispatchRequest() { + CompletableFuture requestCompletion = startServletAsyncExecution(); ServletRequestReader servletRequestReader; try { servletRequestReader = handleRequest(); - } catch (Throwable throwable) { - servletResponseController.trySendError(throwable); - servletResponseController.finishedFuture().whenComplete((result, exception) -> - completeRequestCallback.accept(null, throwable)); + } catch (Throwable t) { + servletResponseController.finishedFuture() + .whenComplete((__, ___) -> requestCompletion.completeExceptionally(t)); + servletResponseController.fail(t); return; } - try { - onError(servletRequestReader.finishedFuture, servletResponseController::trySendError); - onError(servletResponseController.finishedFuture(), servletRequestReader::onError); - CompletableFuture.allOf(servletRequestReader.finishedFuture, servletResponseController.finishedFuture()) - .whenComplete(completeRequestCallback); - } catch (Throwable throwable) { - log.log(Level.WARNING, "Failed registering finished listeners.", throwable); - } + servletRequestReader.finishedFuture().whenComplete((__, t) -> { + if (t != null) servletResponseController.fail(t); + }); + servletResponseController.finishedFuture().whenComplete((__, t) -> { + if (t != null) servletRequestReader.fail(t); + }); + CompletableFuture.allOf(servletRequestReader.finishedFuture(), servletResponseController.finishedFuture()) + .whenComplete((r, t) -> { + if (t != null) requestCompletion.completeExceptionally(t); + else requestCompletion.complete(null); + }); } + ContentChannel dispatchFilterRequest(Response response) { + try { - private void handleCompleteRequestCallback(Void result, Throwable error) - { - boolean alreadyCalled = completeRequestCalled.getAndSet(true); - if (alreadyCalled) { - AssertionError e = new AssertionError("completeRequest called more than once"); - log.log(Level.WARNING, "Assertion failed.", e); - throw e; + CompletableFuture requestCompletion = startServletAsyncExecution(); + jettyRequest.getInputStream().close(); + ContentChannel responseContentChannel = servletResponseController.responseHandler().handleResponse(response); + servletResponseController.finishedFuture() + .whenComplete((r, t) -> { + if (t != null) requestCompletion.completeExceptionally(t); + else requestCompletion.complete(null); + }); + return responseContentChannel; + } catch (IOException e) { + throw throwUnchecked(e); } + } - boolean reportedError = false; + private CompletableFuture startServletAsyncExecution() { + CompletableFuture requestCompletion = new CompletableFuture<>(); + AsyncContext asyncCtx = jettyRequest.startAsync(); + asyncCtx.setTimeout(0); + asyncCtx.addListener(new AsyncListener() { + @Override public void onStartAsync(AsyncEvent event) {} + @Override public void onComplete(AsyncEvent event) { requestCompletion.complete(null); } + @Override public void onTimeout(AsyncEvent event) { + requestCompletion.completeExceptionally(new TimeoutException("Timeout from AsyncContext")); + } + @Override public void onError(AsyncEvent event) { + requestCompletion.completeExceptionally(event.getThrowable()); + } + }); + requestCompletion.whenComplete((__, t) -> onRequestFinished(asyncCtx, t)); + return requestCompletion; + } + private void onRequestFinished(AsyncContext asyncCtx, Throwable error) { + boolean reportedError = false; if (error != null) { if (isErrorOfType(error, EofException.class, IOException.class)) { log.log(Level.FINE, @@ -138,7 +159,7 @@ class HttpRequestDispatch { } try { - async.complete(); + asyncCtx.complete(); log.finest(() -> "Request completed successfully: " + jettyRequest.getRequestURI()); } catch (Throwable throwable) { Level level = reportedError ? Level.FINE: Level.WARNING; @@ -190,47 +211,17 @@ class HttpRequestDispatch { private ServletRequestReader handleRequest() throws IOException { HttpRequest jdiscRequest = HttpRequestFactory.newJDiscRequest(jDiscContext.container, jettyRequest); ContentChannel requestContentChannel; - try (ResourceReference ref = References.fromResource(jdiscRequest)) { HttpRequestFactory.copyHeaders(jettyRequest, jdiscRequest); - requestContentChannel = requestHandler.handleRequest(jdiscRequest, servletResponseController.responseHandler); + requestContentChannel = requestHandler.handleRequest(jdiscRequest, servletResponseController.responseHandler()); } - - //TODO If the below method throws requestContentChannel will not be close and there is a reference leak - ServletInputStream servletInputStream = jettyRequest.getInputStream(); - - ServletRequestReader servletRequestReader = new ServletRequestReader(servletInputStream, - requestContentChannel, - jDiscContext.janitor, - metricReporter); - //TODO If the below method throws servletRequestReader will not complete and // requestContentChannel will not be closed and there is a reference leak // Ditto for the servletInputStream - servletInputStream.setReadListener(servletRequestReader); - return servletRequestReader; + return new ServletRequestReader( + jettyRequest.getInputStream(), requestContentChannel, jDiscContext.janitor, metricReporter); } - private static void onError(CompletableFuture future, Consumer errorHandler) { - future.whenComplete((result, exception) -> { - if (exception != null) { - errorHandler.accept(exception); - } - }); - } - - ContentChannel handleRequestFilterResponse(Response response) { - try { - jettyRequest.getInputStream().close(); - ContentChannel responseContentChannel = servletResponseController.responseHandler.handleResponse(response); - servletResponseController.finishedFuture().whenComplete(completeRequestCallback); - return responseContentChannel; - } catch (IOException e) { - throw throwUnchecked(e); - } - } - - private static RequestHandler newRequestHandler(JDiscContext context, AccessLogEntry accessLogEntry, HttpServletRequest servletRequest) { diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java index 2904d79ad41..304585c6176 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscFilterInvokerFilter.java @@ -126,7 +126,7 @@ class JDiscFilterInvokerFilter implements Filter { throw new RuntimeException("Can't return response from filter asynchronously"); HttpRequestDispatch requestDispatch = createRequestDispatch(httpRequest, httpResponse); - return requestDispatch.handleRequestFilterResponse(jdiscResponse); + return requestDispatch.dispatchFilterRequest(jdiscResponse); }; } diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscHttpServlet.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscHttpServlet.java index 7e1445ffa4f..5451097b717 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscHttpServlet.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JDiscHttpServlet.java @@ -110,7 +110,7 @@ class JDiscHttpServlet extends HttpServlet { try { switch (request.getDispatcherType()) { case REQUEST: - new HttpRequestDispatch(context, accessLogEntry, getMetricContext(request), request, response).dispatch(); + new HttpRequestDispatch(context, accessLogEntry, getMetricContext(request), request, response).dispatchRequest(); break; default: if (log.isLoggable(Level.INFO)) { diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletOutputStreamWriter.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletOutputStreamWriter.java index 696fd2d51ad..72e3f3255a3 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletOutputStreamWriter.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletOutputStreamWriter.java @@ -10,7 +10,6 @@ import java.nio.ByteBuffer; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import java.util.logging.Level; @@ -22,7 +21,7 @@ import static com.yahoo.jdisc.http.server.jetty.CompletionHandlerUtils.NOOP_COMP * @author Tony Vaagenes * @author bjorncs */ -public class ServletOutputStreamWriter { +class ServletOutputStreamWriter { /** Rules: * 1) Don't modify the output stream without isReady returning true (write/flush/close). * Multiple modification calls without interleaving isReady calls are not allowed. @@ -66,31 +65,16 @@ public class ServletOutputStreamWriter { * * The future might complete in the servlet framework thread, user thread or executor thread. */ - final CompletableFuture finishedFuture = new CompletableFuture<>(); + private final CompletableFuture finishedFuture = new CompletableFuture<>(); - public ServletOutputStreamWriter(ServletOutputStream outputStream, Janitor janitor, RequestMetricReporter metricReporter) { + ServletOutputStreamWriter(ServletOutputStream outputStream, Janitor janitor, RequestMetricReporter metricReporter) { this.outputStream = outputStream; this.janitor = janitor; this.metricReporter = metricReporter; } - public void sendErrorContentAndCloseAsync(ByteBuffer errorContent) { - synchronized (monitor) { - // Assert that no content has been written as it is too late to write error response if the response is committed. - assertStateIs(state, State.NOT_STARTED); - queueErrorContent_holdingLock(errorContent); - state = State.WAITING_FOR_WRITE_POSSIBLE_CALLBACK; - outputStream.setWriteListener(writeListener); - } - } - - private void queueErrorContent_holdingLock(ByteBuffer errorContent) { - responseContentQueue.addLast(new ResponseContentPart(errorContent, NOOP_COMPLETION_HANDLER)); - responseContentQueue.addLast(new ResponseContentPart(CLOSE_STREAM_BUFFER, NOOP_COMPLETION_HANDLER)); - } - - public void writeBuffer(ByteBuffer buf, CompletionHandler handler) { + void writeBuffer(ByteBuffer buf, CompletionHandler handler) { boolean thisThreadShouldWrite = false; synchronized (monitor) { @@ -121,13 +105,13 @@ public class ServletOutputStreamWriter { } } - public void close(CompletionHandler handler) { - writeBuffer(CLOSE_STREAM_BUFFER, handler); - } + void fail(Throwable t) { setFinished(t); } - public void close() { - close(NOOP_COMPLETION_HANDLER); - } + void close(CompletionHandler handler) { writeBuffer(CLOSE_STREAM_BUFFER, handler); } + + void close() { close(NOOP_COMPLETION_HANDLER); } + + CompletableFuture finishedFuture() { return finishedFuture; } private void writeBuffersInQueueToOutputStream() { boolean lastOperationWasFlush = false; @@ -165,29 +149,28 @@ public class ServletOutputStreamWriter { if (contentPart.buf == CLOSE_STREAM_BUFFER) { callCompletionHandlerWhenDone(contentPart.handler, outputStream::close); - setFinished(Optional.empty()); + setFinished(null); return; } else { writeBufferToOutputStream(contentPart); } - } catch (Throwable e) { - setFinished(Optional.of(e)); + } catch (Throwable t) { + setFinished(t); return; } } } - private void setFinished(Optional e) { + private void setFinished(Throwable t) { synchronized (monitor) { state = State.FINISHED_OR_ERROR; if (!responseContentQueue.isEmpty()) { - failAllParts_holdingLock(e.orElse(new IllegalStateException("ContentChannel closed."))); + failAllParts_holdingLock(t != null ? t : new IllegalStateException("ContentChannel closed.")); } } - assert !Thread.holdsLock(monitor); - if (e.isPresent()) { - finishedFuture.completeExceptionally(e.get()); + if (t != null) { + finishedFuture.completeExceptionally(t); } else { finishedFuture.complete(null); } @@ -255,13 +238,9 @@ public class ServletOutputStreamWriter { } } - public void fail(Throwable t) { - setFinished(Optional.of(t)); - } - private final WriteListener writeListener = new WriteListener() { @Override - public void onWritePossible() throws IOException { + public void onWritePossible() { synchronized (monitor) { if (state == State.FINISHED_OR_ERROR) { return; @@ -274,10 +253,7 @@ public class ServletOutputStreamWriter { writeBuffersInQueueToOutputStream(); } - @Override - public void onError(Throwable t) { - setFinished(Optional.of(t)); - } + @Override public void onError(Throwable t) { setFinished(t); } }; private static class ResponseContentPart { diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java index e2bf5711e15..666132087af 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.http.server.jetty; -import com.google.common.base.Preconditions; import com.yahoo.jdisc.handler.CompletionHandler; import com.yahoo.jdisc.handler.ContentChannel; @@ -9,6 +8,7 @@ import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.logging.Level; import java.util.logging.Logger; @@ -25,8 +25,12 @@ import java.util.logging.Logger; * error reporting might be async. * Since we have tests that first reports errors and then closes the response content, * it's important that errors are delivered synchronously. + * + * @author Tony Vaagenes + * @author Oyvind Bakksjo + * @author bjorncs */ -class ServletRequestReader implements ReadListener { +class ServletRequestReader { private enum State { READING, ALL_DATA_READ, REQUEST_CONTENT_CLOSED @@ -38,12 +42,12 @@ class ServletRequestReader implements ReadListener { private final Object monitor = new Object(); - private final ServletInputStream servletInputStream; + private final ServletInputStream in; private final ContentChannel requestContentChannel; - private final Janitor janitor; private final RequestMetricReporter metricReporter; + private Throwable errorDuringRead; private int bytesRead; /** @@ -87,82 +91,91 @@ class ServletRequestReader implements ReadListener { * If calls to those methods does not close the request content channel immediately, * there is some outstanding completion callback that will later come in and complete the request. */ - final CompletableFuture finishedFuture = new CompletableFuture<>(); + private final CompletableFuture finishedFuture = new CompletableFuture<>(); - public ServletRequestReader( - ServletInputStream servletInputStream, + ServletRequestReader( + ServletInputStream in, ContentChannel requestContentChannel, Janitor janitor, RequestMetricReporter metricReporter) { - - Preconditions.checkNotNull(servletInputStream); - Preconditions.checkNotNull(requestContentChannel); - Preconditions.checkNotNull(janitor); - Preconditions.checkNotNull(metricReporter); - - this.servletInputStream = servletInputStream; - this.requestContentChannel = requestContentChannel; - this.janitor = janitor; - this.metricReporter = metricReporter; + this.in = Objects.requireNonNull(in); + this.requestContentChannel = Objects.requireNonNull(requestContentChannel); + this.janitor = Objects.requireNonNull(janitor); + this.metricReporter = Objects.requireNonNull(metricReporter); + in.setReadListener(new Listener()); } - @Override - public void onDataAvailable() throws IOException { - while (servletInputStream.isReady()) { - final byte[] buffer = new byte[BUFFER_SIZE_BYTES]; - int numBytesRead; + CompletableFuture finishedFuture() { return finishedFuture; } - synchronized (monitor) { - numBytesRead = servletInputStream.read(buffer); - if (numBytesRead < 0) { - // End of stream; there should be no more data available, ever. - return; - } - if (state != State.READING) { - //We have a failure, so no point in giving the buffer to the user. - assert finishedFuture.isCompletedExceptionally(); - return; + private class Listener implements ReadListener { + + @Override + public void onDataAvailable() throws IOException { + while (in.isReady()) { + final byte[] buffer = new byte[BUFFER_SIZE_BYTES]; + int numBytesRead; + + synchronized (monitor) { + numBytesRead = in.read(buffer); + if (numBytesRead < 0) { + // End of stream; there should be no more data available, ever. + return; + } + if (state != State.READING) { + //We have a failure, so no point in giving the buffer to the user. + assert finishedFuture.isCompletedExceptionally(); + return; + } + //wait for both + // - requestContentChannel.write to finish + // - the write completion handler to be called + numberOfOutstandingUserCalls += 2; + bytesRead += numBytesRead; } - //wait for both - // - requestContentChannel.write to finish - // - the write completion handler to be called - numberOfOutstandingUserCalls += 2; - bytesRead += numBytesRead; - } - try { - requestContentChannel.write(ByteBuffer.wrap(buffer, 0, numBytesRead), writeCompletionHandler); - metricReporter.successfulRead(numBytesRead); - } - catch (Throwable t) { - finishedFuture.completeExceptionally(t); - } - finally { - //decrease due to this method completing. - decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); + try { + requestContentChannel.write(ByteBuffer.wrap(buffer, 0, numBytesRead), new CompletionHandler() { + @Override + public void completed() { + decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); + } + @Override + public void failed(final Throwable t) { + decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); + finishedFuture.completeExceptionally(t); + } + }); + metricReporter.successfulRead(numBytesRead); + } catch (Throwable t) { + finishedFuture.completeExceptionally(t); + } finally { + //decrease due to this method completing. + decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); + } } } + + @Override public void onError(final Throwable t) { fail(t); } + @Override public void onAllDataRead() { doneReading(null); } + } + + void fail(Throwable t) { + doneReading(t); + finishedFuture.completeExceptionally(t); } private void decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally() { boolean shouldCloseRequestContentChannel; - synchronized (monitor) { assertStateNotEquals(state, State.REQUEST_CONTENT_CLOSED); - - numberOfOutstandingUserCalls -= 1; - - shouldCloseRequestContentChannel = numberOfOutstandingUserCalls == 0 && - (finishedFuture.isDone() || state == State.ALL_DATA_READ); - + shouldCloseRequestContentChannel = numberOfOutstandingUserCalls == 0 && state == State.ALL_DATA_READ; if (shouldCloseRequestContentChannel) { state = State.REQUEST_CONTENT_CLOSED; } } - if (shouldCloseRequestContentChannel) { - janitor.scheduleTask(this::closeCompletionHandler_noThrow); + janitor.scheduleTask(this::closeRequestContentChannel); } } @@ -178,22 +191,14 @@ class ServletRequestReader implements ReadListener { } } - @Override - public void onAllDataRead() { - doneReading(); - } - - private void doneReading() { - final boolean shouldCloseRequestContentChannel; - + private void doneReading(Throwable t) { + boolean shouldCloseRequestContentChannel; int bytesRead; - synchronized (monitor) { - if (state != State.READING) { - return; - } + synchronized (monitor) { + errorDuringRead = t; + if (state != State.READING) return; state = State.ALL_DATA_READ; - shouldCloseRequestContentChannel = numberOfOutstandingUserCalls == 0; if (shouldCloseRequestContentChannel) { state = State.REQUEST_CONTENT_CLOSED; @@ -202,69 +207,32 @@ class ServletRequestReader implements ReadListener { } if (shouldCloseRequestContentChannel) { - closeCompletionHandler_noThrow(); + closeRequestContentChannel(); } - metricReporter.contentSize(bytesRead); } - private void closeCompletionHandler_noThrow() { - //Cannot complete finishedFuture directly in completed(), as any exceptions after this fact will be ignored. - // E.g. - // close(CompletionHandler completionHandler) { - // completionHandler.completed(); - // throw new RuntimeException - // } - - CompletableFuture completedCalledFuture = new CompletableFuture<>(); - - CompletionHandler closeCompletionHandler = new CompletionHandler() { - @Override - public void completed() { - completedCalledFuture.complete(null); - } - - @Override - public void failed(final Throwable t) { - finishedFuture.completeExceptionally(t); - } - }; - + private void closeRequestContentChannel() { + Throwable readError; + synchronized (monitor) { readError = this.errorDuringRead; } try { - requestContentChannel.close(closeCompletionHandler); - //if close did not cause an exception, - // is it safe to pipe the result of the completionHandlerInvokedFuture into finishedFuture - completedCalledFuture.whenComplete(this::setFinishedFuture); - } catch (final Throwable t) { + if (readError != null) requestContentChannel.onError(readError); + //Cannot complete finishedFuture directly in completed(), as any exceptions after this fact will be ignored. + // E.g. + // close(CompletionHandler completionHandler) { + // completionHandler.completed(); + // throw new RuntimeException + // } + CompletableFuture completedCalledFuture = new CompletableFuture<>(); + requestContentChannel.close(new CompletionHandler() { + @Override public void completed() { completedCalledFuture.complete(null); } + @Override public void failed(Throwable t) { finishedFuture.completeExceptionally(t); } + }); + // Propagate successful completion as close did not throw an exception + completedCalledFuture.whenComplete((__, ___) -> finishedFuture.complete(null)); + } catch (Throwable t) { finishedFuture.completeExceptionally(t); } } - private void setFinishedFuture(Void result, Throwable throwable) { - if (throwable != null) { - finishedFuture.completeExceptionally(throwable); - } else { - finishedFuture.complete(null); - } - } - - @Override - public void onError(final Throwable t) { - finishedFuture.completeExceptionally(t); - requestContentChannel.onError(t); - doneReading(); - } - - private final CompletionHandler writeCompletionHandler = new CompletionHandler() { - @Override - public void completed() { - decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); - } - - @Override - public void failed(final Throwable t) { - finishedFuture.completeExceptionally(t); - decreaseOutstandingUserCallsAndCloseRequestContentChannelConditionally(); - } - }; } diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java index d61a3745653..2b5b7a3420f 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java @@ -30,9 +30,9 @@ import static com.yahoo.jdisc.http.server.jetty.CompletionHandlerUtils.NOOP_COMP * @author Tony Vaagenes * @author bjorncs */ -public class ServletResponseController { +class ServletResponseController { - private static Logger log = Logger.getLogger(ServletResponseController.class.getName()); + private static final Logger log = Logger.getLogger(ServletResponseController.class.getName()); /** * The servlet spec does not require (Http)ServletResponse nor ServletOutputStream to be thread-safe. Therefore, @@ -49,12 +49,12 @@ public class ServletResponseController { private final ErrorResponseContentCreator errorResponseContentCreator = new ErrorResponseContentCreator(); //all calls to the servletOutputStreamWriter must hold the monitor first to ensure visibility of servletResponse changes. - private final ServletOutputStreamWriter servletOutputStreamWriter; + private final ServletOutputStreamWriter out; // GuardedBy("monitor") private boolean responseCommitted = false; - public ServletResponseController( + ServletResponseController( HttpServletRequest servletRequest, HttpServletResponse servletResponse, Janitor janitor, @@ -64,10 +64,61 @@ public class ServletResponseController { this.servletRequest = servletRequest; this.servletResponse = servletResponse; this.developerMode = developerMode; - this.servletOutputStreamWriter = - new ServletOutputStreamWriter(servletResponse.getOutputStream(), janitor, metricReporter); + this.out = new ServletOutputStreamWriter(servletResponse.getOutputStream(), janitor, metricReporter); } + void fail(Throwable t) { + try { + trySendError(t); + } catch (Throwable suppressed) { + t.addSuppressed(suppressed); + } finally { + out.close(); + } + } + + /** + * When this future completes there will be no more calls against the servlet output stream or servlet response. + * The framework is still allowed to invoke us though. + * + * The future might complete in the servlet framework thread, user thread or executor thread. + */ + CompletableFuture finishedFuture() { return out.finishedFuture(); } + + ResponseHandler responseHandler() { return responseHandler; } + + private void trySendError(Throwable t) { + synchronized (monitor) { + if (!responseCommitted) { + responseCommitted = true; + servletResponse.setHeader(HttpHeaders.Names.EXPIRES, null); + servletResponse.setHeader(HttpHeaders.Names.LAST_MODIFIED, null); + servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, null); + servletResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, null); + servletResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, null); + String reasonPhrase = getReasonPhrase(t, developerMode); + int statusCode = getStatusCode(t); + setStatus(servletResponse, statusCode, reasonPhrase); + // If we are allowed to have a body + if (statusCode != HttpServletResponse.SC_NO_CONTENT && + statusCode != HttpServletResponse.SC_NOT_MODIFIED && + statusCode != HttpServletResponse.SC_PARTIAL_CONTENT && + statusCode >= HttpServletResponse.SC_OK) { + servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, "must-revalidate,no-cache,no-store"); + servletResponse.setContentType(MimeTypes.Type.TEXT_HTML_8859_1.toString()); + byte[] errorContent = errorResponseContentCreator + .createErrorContent(servletRequest.getRequestURI(), statusCode, reasonPhrase); + servletResponse.setContentLength(errorContent.length); + out.writeBuffer(ByteBuffer.wrap(errorContent), NOOP_COMPLETION_HANDLER); + } else { + servletResponse.setContentLength(0); + } + } else { + RuntimeException exceptionWithStackTrace = new RuntimeException(t); + log.log(Level.FINE, "Response already committed, can't change response code", exceptionWithStackTrace); + } + } + } private static int getStatusCode(Throwable t) { if (t instanceof BindingNotFoundException) { @@ -96,75 +147,6 @@ public class ServletResponseController { } } - - public void trySendError(Throwable t) { - final boolean responseWasCommitted; - try { - synchronized (monitor) { - String reasonPhrase = getReasonPhrase(t, developerMode); - int statusCode = getStatusCode(t); - responseWasCommitted = responseCommitted; - if (!responseCommitted) { - responseCommitted = true; - sendErrorAsync(statusCode, reasonPhrase); - } - } - } catch (Throwable e) { - servletOutputStreamWriter.fail(t); - return; - } - - //Must be evaluated after state transition for test purposes(See ConformanceTestException) - //Done outside the monitor since it causes a callback in tests. - if (responseWasCommitted) { - RuntimeException exceptionWithStackTrace = new RuntimeException(t); - log.log(Level.FINE, "Response already committed, can't change response code", exceptionWithStackTrace); - // TODO: should always have failed here, but that breaks test assumptions. Doing soft close instead. - //assert !Thread.holdsLock(monitor); - //servletOutputStreamWriter.fail(t); - servletOutputStreamWriter.close(); - } - - } - - /** - * Async version of {@link org.eclipse.jetty.server.Response#sendError(int, String)}. - */ - private void sendErrorAsync(int statusCode, String reasonPhrase) { - servletResponse.setHeader(HttpHeaders.Names.EXPIRES, null); - servletResponse.setHeader(HttpHeaders.Names.LAST_MODIFIED, null); - servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, null); - servletResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, null); - servletResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, null); - setStatus(servletResponse, statusCode, Optional.of(reasonPhrase)); - - // If we are allowed to have a body - if (statusCode != HttpServletResponse.SC_NO_CONTENT && - statusCode != HttpServletResponse.SC_NOT_MODIFIED && - statusCode != HttpServletResponse.SC_PARTIAL_CONTENT && - statusCode >= HttpServletResponse.SC_OK) { - servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, "must-revalidate,no-cache,no-store"); - servletResponse.setContentType(MimeTypes.Type.TEXT_HTML_8859_1.toString()); - byte[] errorContent = errorResponseContentCreator - .createErrorContent(servletRequest.getRequestURI(), statusCode, Optional.ofNullable(reasonPhrase)); - servletResponse.setContentLength(errorContent.length); - servletOutputStreamWriter.sendErrorContentAndCloseAsync(ByteBuffer.wrap(errorContent)); - } else { - servletResponse.setContentLength(0); - servletOutputStreamWriter.close(); - } - } - - /** - * When this future completes there will be no more calls against the servlet output stream or servlet response. - * The framework is still allowed to invoke us though. - * - * The future might complete in the servlet framework thread, user thread or executor thread. - */ - public CompletableFuture finishedFuture() { - return servletOutputStreamWriter.finishedFuture; - } - private void setResponse(Response jdiscResponse) { synchronized (monitor) { servletRequest.setAttribute(HttpResponseStatisticsCollector.requestTypeAttribute, jdiscResponse.getRequestType()); @@ -176,57 +158,46 @@ public class ServletResponseController { //TODO: should throw an exception here, but this breaks unit tests. //The failures will now instead happen when writing buffers. - servletOutputStreamWriter.close(); + out.close(); return; } - setStatus_holdingLock(jdiscResponse, servletResponse); - setHeaders_holdingLock(jdiscResponse, servletResponse); - } - } - - private static void setHeaders_holdingLock(Response jdiscResponse, HttpServletResponse servletResponse) { - for (final Map.Entry entry : jdiscResponse.headers().entries()) { - servletResponse.addHeader(entry.getKey(), entry.getValue()); - } - - if (servletResponse.getContentType() == null) { - servletResponse.setContentType("text/plain;charset=utf-8"); - } - } - - private static void setStatus_holdingLock(Response jdiscResponse, HttpServletResponse servletResponse) { - if (jdiscResponse instanceof HttpResponse) { - setStatus(servletResponse, jdiscResponse.getStatus(), Optional.ofNullable(((HttpResponse) jdiscResponse).getMessage())); - } else { - setStatus(servletResponse, jdiscResponse.getStatus(), getErrorMessage(jdiscResponse)); + if (jdiscResponse instanceof HttpResponse) { + setStatus(servletResponse, jdiscResponse.getStatus(), ((HttpResponse) jdiscResponse).getMessage()); + } else { + String message = Optional.ofNullable(jdiscResponse.getError()) + .flatMap(error -> Optional.ofNullable(error.getMessage())) + .orElse(null); + setStatus(servletResponse, jdiscResponse.getStatus(), message); + } + for (final Map.Entry entry : jdiscResponse.headers().entries()) { + servletResponse.addHeader(entry.getKey(), entry.getValue()); + } + if (servletResponse.getContentType() == null) { + servletResponse.setContentType("text/plain;charset=utf-8"); + } } } @SuppressWarnings("deprecation") - private static void setStatus(HttpServletResponse response, int statusCode, Optional reasonPhrase) { - if (reasonPhrase.isPresent()) { + private static void setStatus(HttpServletResponse response, int statusCode, String reasonPhrase) { + if (reasonPhrase != null) { // Sets the status line: a status code along with a custom message. // Using a custom status message is deprecated in the Servlet API. No alternative exist. - response.setStatus(statusCode, reasonPhrase.get()); // DEPRECATED + response.setStatus(statusCode, reasonPhrase); // DEPRECATED } else { response.setStatus(statusCode); } } - private static Optional getErrorMessage(Response jdiscResponse) { - return Optional.ofNullable(jdiscResponse.getError()).flatMap( - error -> Optional.ofNullable(error.getMessage())); - } - - private void commitResponse() { + private void ensureCommitted() { synchronized (monitor) { responseCommitted = true; } } - public final ResponseHandler responseHandler = new ResponseHandler() { + private final ResponseHandler responseHandler = new ResponseHandler() { @Override public ContentChannel handleResponse(Response response) { setResponse(response); @@ -234,17 +205,17 @@ public class ServletResponseController { } }; - public final ContentChannel responseContentChannel = new ContentChannel() { + private final ContentChannel responseContentChannel = new ContentChannel() { @Override public void write(ByteBuffer buf, CompletionHandler handler) { - commitResponse(); - servletOutputStreamWriter.writeBuffer(buf, handlerOrNoopHandler(handler)); + ensureCommitted(); + out.writeBuffer(buf, handlerOrNoopHandler(handler)); } @Override public void close(CompletionHandler handler) { - commitResponse(); - servletOutputStreamWriter.close(handlerOrNoopHandler(handler)); + ensureCommitted(); + out.close(handlerOrNoopHandler(handler)); } private CompletionHandler handlerOrNoopHandler(CompletionHandler handler) { diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreatorTest.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreatorTest.java index d66f22801f7..3fa8e154826 100644 --- a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreatorTest.java +++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/ErrorResponseContentCreatorTest.java @@ -6,7 +6,6 @@ import org.junit.Test; import javax.servlet.http.HttpServletResponse; import java.nio.charset.StandardCharsets; -import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -36,7 +35,7 @@ public class ErrorResponseContentCreatorTest { byte[] rawContent = c.createErrorContent( "http://foo.bar", HttpServletResponse.SC_OK, - Optional.of("My custom error message")); + "My custom error message"); String actualHtml = new String(rawContent, StandardCharsets.ISO_8859_1); assertEquals(expectedHtml, actualHtml); } -- cgit v1.2.3 From 47b73921208a52d8f356be4ff87ed0f679791827 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Thu, 7 Oct 2021 12:07:52 +0200 Subject: Ensure writer is closed while holding lock Handler can inject its response content if lock is not held between write and close. --- .../server/jetty/ServletResponseController.java | 66 +++++++++++----------- 1 file changed, 33 insertions(+), 33 deletions(-) (limited to 'container-core') diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java index 2b5b7a3420f..fd6084d2384 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletResponseController.java @@ -68,12 +68,14 @@ class ServletResponseController { } void fail(Throwable t) { - try { - trySendError(t); - } catch (Throwable suppressed) { - t.addSuppressed(suppressed); - } finally { - out.close(); + synchronized (monitor) { + try { + trySendError(t); + } catch (Throwable suppressed) { + t.addSuppressed(suppressed); + } finally { + out.close(); + } } } @@ -88,35 +90,33 @@ class ServletResponseController { ResponseHandler responseHandler() { return responseHandler; } private void trySendError(Throwable t) { - synchronized (monitor) { - if (!responseCommitted) { - responseCommitted = true; - servletResponse.setHeader(HttpHeaders.Names.EXPIRES, null); - servletResponse.setHeader(HttpHeaders.Names.LAST_MODIFIED, null); - servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, null); - servletResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, null); - servletResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, null); - String reasonPhrase = getReasonPhrase(t, developerMode); - int statusCode = getStatusCode(t); - setStatus(servletResponse, statusCode, reasonPhrase); - // If we are allowed to have a body - if (statusCode != HttpServletResponse.SC_NO_CONTENT && - statusCode != HttpServletResponse.SC_NOT_MODIFIED && - statusCode != HttpServletResponse.SC_PARTIAL_CONTENT && - statusCode >= HttpServletResponse.SC_OK) { - servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, "must-revalidate,no-cache,no-store"); - servletResponse.setContentType(MimeTypes.Type.TEXT_HTML_8859_1.toString()); - byte[] errorContent = errorResponseContentCreator - .createErrorContent(servletRequest.getRequestURI(), statusCode, reasonPhrase); - servletResponse.setContentLength(errorContent.length); - out.writeBuffer(ByteBuffer.wrap(errorContent), NOOP_COMPLETION_HANDLER); - } else { - servletResponse.setContentLength(0); - } + if (!responseCommitted) { + responseCommitted = true; + servletResponse.setHeader(HttpHeaders.Names.EXPIRES, null); + servletResponse.setHeader(HttpHeaders.Names.LAST_MODIFIED, null); + servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, null); + servletResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, null); + servletResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, null); + String reasonPhrase = getReasonPhrase(t, developerMode); + int statusCode = getStatusCode(t); + setStatus(servletResponse, statusCode, reasonPhrase); + // If we are allowed to have a body + if (statusCode != HttpServletResponse.SC_NO_CONTENT && + statusCode != HttpServletResponse.SC_NOT_MODIFIED && + statusCode != HttpServletResponse.SC_PARTIAL_CONTENT && + statusCode >= HttpServletResponse.SC_OK) { + servletResponse.setHeader(HttpHeaders.Names.CACHE_CONTROL, "must-revalidate,no-cache,no-store"); + servletResponse.setContentType(MimeTypes.Type.TEXT_HTML_8859_1.toString()); + byte[] errorContent = errorResponseContentCreator + .createErrorContent(servletRequest.getRequestURI(), statusCode, reasonPhrase); + servletResponse.setContentLength(errorContent.length); + out.writeBuffer(ByteBuffer.wrap(errorContent), NOOP_COMPLETION_HANDLER); } else { - RuntimeException exceptionWithStackTrace = new RuntimeException(t); - log.log(Level.FINE, "Response already committed, can't change response code", exceptionWithStackTrace); + servletResponse.setContentLength(0); } + } else { + RuntimeException exceptionWithStackTrace = new RuntimeException(t); + log.log(Level.FINE, "Response already committed, can't change response code", exceptionWithStackTrace); } } -- cgit v1.2.3