diff --git a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/Blocker.java b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/Blocker.java index 1ef20424e282..da2da8a6179e 100644 --- a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/Blocker.java +++ b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/Blocker.java @@ -80,6 +80,12 @@ public Throwable fillInStackTrace() { return this; } + + @Override + public String toString() + { + return "ACQUIRED"; + } }; private static final Throwable SUCCEEDED = new Throwable() { @@ -88,6 +94,12 @@ public Throwable fillInStackTrace() { return this; } + + @Override + public String toString() + { + return "SUCCEEDED"; + } }; public interface Runnable extends java.lang.Runnable, AutoCloseable, Invocable @@ -434,5 +446,11 @@ public Runnable runnable() throws IOException _lock.unlock(); } } + + @Override + public String toString() + { + return "%s@%x[c=%s]".formatted(getClass().getSimpleName(), hashCode(), _completed); + } } } diff --git a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/ExceptionUtil.java b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/ExceptionUtil.java index 4ebb076ab555..46d158d20de3 100644 --- a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/ExceptionUtil.java +++ b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/ExceptionUtil.java @@ -13,6 +13,8 @@ package org.eclipse.jetty.util; +import java.io.PrintWriter; +import java.io.StringWriter; import java.lang.reflect.Constructor; import java.util.Arrays; import java.util.List; @@ -466,6 +468,15 @@ public static T get(CompletableFuture completableFuture) } } + public static String toString(Throwable x) + { + if (x == null) + return "null"; + StringWriter sw = new StringWriter(); + x.printStackTrace(new PrintWriter(sw)); + return sw.toString(); + } + private ExceptionUtil() { } diff --git a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/HttpOutput.java b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/HttpOutput.java index baf55e2ac821..d30184de93e6 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/HttpOutput.java +++ b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/HttpOutput.java @@ -263,7 +263,7 @@ private void onWriteComplete(boolean last, Throwable failure) try (AutoLock ignored = _channelState.lock()) { if (LOG.isDebugEnabled()) - state = stateString(); + state = lockedStateString(); // Transition to CLOSED state if we were the last write or we have failed if (last || failure != null) @@ -273,7 +273,7 @@ private void onWriteComplete(boolean last, Throwable failure) _closedCallback = null; if (failure == null) lockedReleaseBuffer(); - wake = updateApiState(failure); + wake = lockedUpdateApiState(failure); } else if (_state == State.CLOSE) { @@ -285,13 +285,13 @@ else if (_state == State.CLOSE) } else { - wake = updateApiState(null); + wake = lockedUpdateApiState(null); } - } - if (LOG.isDebugEnabled()) - LOG.debug("onWriteComplete({},{}) {}->{} c={} cb={} w={}", - last, failure, state, stateString(), BufferUtil.toDetailString(closeContent), closedCallback, wake, failure); + if (LOG.isDebugEnabled()) + LOG.debug("onWriteComplete({},{}) {}->{} c={} cb={} w={}", + last, failure, state, lockedStateString(), BufferUtil.toDetailString(closeContent), closedCallback, wake, failure); + } try { @@ -314,8 +314,10 @@ else if (closeContent != null) } } - private boolean updateApiState(Throwable failure) + private boolean lockedUpdateApiState(Throwable failure) { + assert _channelState.isLockHeldByCurrentThread(); + boolean wake = false; switch (_apiState) { @@ -342,7 +344,7 @@ private boolean updateApiState(Throwable failure) default: if (_state == State.CLOSED) break; - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } return wake; } @@ -465,10 +467,10 @@ public void complete(Callback callback) break; } } - } - if (LOG.isDebugEnabled()) - LOG.debug("complete({}) {} s={} e={}, c={}", callback, stateString(), succeeded, error, BufferUtil.toDetailString(content)); + if (LOG.isDebugEnabled()) + LOG.debug("complete({}) {} s={} e={}, c={}", callback, lockedStateString(), succeeded, error, BufferUtil.toDetailString(content)); + } if (succeeded) { @@ -501,6 +503,7 @@ public void completed(Throwable ignored) @Override public void close() throws IOException { + RetainableByteBuffer aggregate = null; ByteBuffer content = null; Blocker.Callback blocker = null; try (AutoLock ignored = _channelState.lock()) @@ -549,7 +552,16 @@ public void close() throws IOException _apiState = ApiState.BLOCKED; _state = State.CLOSING; blocker = _writeBlocker.callback(); - content = _aggregate != null && _aggregate.hasRemaining() ? _aggregate.getByteBuffer() : BufferUtil.EMPTY_BUFFER; + aggregate = _aggregate; + if (aggregate != null && _aggregate.hasRemaining()) + { + aggregate.retain(); + content = aggregate.getByteBuffer(); + } + else + { + content = BufferUtil.EMPTY_BUFFER; + } break; case BLOCKED: @@ -567,7 +579,16 @@ public void close() throws IOException // Output is idle in async state, so we can do an async close _apiState = ApiState.PENDING; _state = State.CLOSING; - content = _aggregate != null && _aggregate.hasRemaining() ? _aggregate.getByteBuffer() : BufferUtil.EMPTY_BUFFER; + aggregate = _aggregate; + if (aggregate != null && _aggregate.hasRemaining()) + { + aggregate.retain(); + content = aggregate.getByteBuffer(); + } + else + { + content = BufferUtil.EMPTY_BUFFER; + } break; case UNREADY: @@ -580,10 +601,10 @@ public void close() throws IOException } break; } - } - if (LOG.isDebugEnabled()) - LOG.debug("close() {} c={} b={}", stateString(), BufferUtil.toDetailString(content), blocker); + if (LOG.isDebugEnabled()) + LOG.debug("close() {} c={} b={}", lockedStateString(), BufferUtil.toDetailString(content), blocker); + } if (content == null) { @@ -602,7 +623,10 @@ public void close() throws IOException if (blocker == null) { // Do an async close - channelWrite(content, true, new WriteCompleteCB()); + Callback callback = new WriteCompleteCB(); + if (aggregate != null) + callback = Callback.from(callback, aggregate::release); + channelWrite(content, true, callback); } else { @@ -611,6 +635,8 @@ public void close() throws IOException { channelWrite(content, true, blocker); b.block(); + if (aggregate != null) + aggregate.release(); onWriteComplete(true, null); } catch (Throwable t) @@ -701,7 +727,7 @@ public void flush() throws IOException case ASYNC: case PENDING: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: _apiState = ApiState.PENDING; @@ -711,7 +737,7 @@ public void flush() throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } } } @@ -795,7 +821,7 @@ public void write(byte[] b, int off, int len) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -807,7 +833,7 @@ public void write(byte[] b, int off, int len) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; @@ -823,7 +849,7 @@ public void write(byte[] b, int off, int len) throws IOException { if (LOG.isDebugEnabled()) LOG.debug("write(array) {} aggregated !flush {}", - stateString(), _aggregate); + lockedStateString(), _aggregate); return; } @@ -831,11 +857,11 @@ public void write(byte[] b, int off, int len) throws IOException off += filled; len -= filled; } - } - if (LOG.isDebugEnabled()) - LOG.debug("write(array) {} last={} agg={} flush=true async={}, len={} {}", - stateString(), last, aggregate, async, len, _aggregate); + if (LOG.isDebugEnabled()) + LOG.debug("write(array) {} last={} agg={} flush=true async={}, len={} {}", + lockedStateString(), last, aggregate, async, len, _aggregate); + } if (async) { @@ -928,7 +954,7 @@ public void write(ByteBuffer buffer) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -940,7 +966,7 @@ public void write(ByteBuffer buffer) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; } @@ -1010,7 +1036,7 @@ public void write(int b) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -1022,7 +1048,7 @@ public void write(int b) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; @@ -1282,7 +1308,7 @@ private boolean prepareSendContent(int len, Callback callback) } if (_apiState != ApiState.BLOCKING) - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); _apiState = ApiState.PENDING; if (len > 0) _written += len; @@ -1337,13 +1363,13 @@ public void resetBuffer() @Override public void setWriteListener(WriteListener writeListener) { - if (!_servletChannel.getServletRequestState().isAsync()) - throw new IllegalStateException("!ASYNC: " + stateString()); boolean wake; try (AutoLock ignored = _channelState.lock()) { + if (!_servletChannel.getServletRequestState().isAsync()) + throw new IllegalStateException("!ASYNC: " + lockedStateString()); if (_apiState != ApiState.BLOCKING) - throw new IllegalStateException("!OPEN" + stateString()); + throw new IllegalStateException("!OPEN" + lockedStateString()); _apiState = ApiState.READY; _writeListener = writeListener; wake = _servletChannel.getServletRequestState().onWritePossible(); @@ -1422,17 +1448,24 @@ public void writeCallback() } } - private String stateString() + private String lockedStateString() + { + assert _channelState.isLockHeldByCurrentThread(); + return unsafeStateString(); + } + + private String unsafeStateString() { - return String.format("s=%s,api=%s,sc=%b,e=%s", _state, _apiState, _softClose, _onError); + return String.format("s=%s,api=%s,sc=%b,e=%s,wb=%s", _state, _apiState, _softClose, _onError, _writeBlocker); } @Override public String toString() { - try (AutoLock ignored = _channelState.lock()) + try (AutoLock lock = _channelState.tryLock()) { - return String.format("%s@%x{%s}", this.getClass().getSimpleName(), hashCode(), stateString()); + boolean held = lock.isHeldByCurrentThread(); + return String.format("%s@%x{%s%s}", this.getClass().getSimpleName(), hashCode(), held ? "" : "?:", unsafeStateString()); } } diff --git a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletApiRequest.java b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletApiRequest.java index a4c910dbde61..1bc47b205296 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletApiRequest.java +++ b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletApiRequest.java @@ -1200,8 +1200,15 @@ public ServletInputStream getInputStream() throws IOException if (_inputState != ServletContextRequest.INPUT_NONE && _inputState != ServletContextRequest.INPUT_STREAM) throw new IllegalStateException("READER"); _inputState = ServletContextRequest.INPUT_STREAM; - // Try to write a 100 continue, ignoring failure result if it was not necessary. - _servletChannel.getResponse().writeInterim(HttpStatus.CONTINUE_100, HttpFields.EMPTY); + try + { + // Try to write a 100 continue, ignoring failure result if it was not necessary. + _servletChannel.getResponse().writeInterim(HttpStatus.CONTINUE_100, HttpFields.EMPTY); + } + catch (IllegalStateException ise) + { + throw new IOException(ise); + } return getServletRequestInfo().getHttpInput(); } diff --git a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletChannelState.java b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletChannelState.java index 9d007939b125..b470096b394b 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletChannelState.java +++ b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/ServletChannelState.java @@ -228,6 +228,11 @@ AutoLock lock() return _lock.lock(); } + AutoLock tryLock() + { + return _lock.tryLock(); + } + boolean isLockHeldByCurrentThread() { return _lock.isHeldByCurrentThread(); diff --git a/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/BlockingTest.java b/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/BlockingTest.java new file mode 100644 index 000000000000..0a35130b4e1b --- /dev/null +++ b/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/BlockingTest.java @@ -0,0 +1,665 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.ee10.servlet; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.servlet.AsyncContext; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.eclipse.jetty.http.HttpTester; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.util.ExceptionUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.awaitility.Awaitility.await; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.core.Is.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class BlockingTest +{ + private Server server; + private ServerConnector connector; + private ServletContextHandler context; + + @BeforeEach + public void setUp() + { + server = new Server(); + connector = new ServerConnector(server); + server.addConnector(connector); + + context = new ServletContextHandler("/ctx"); + server.setHandler(context); + } + + @AfterEach + public void tearDown() throws Exception + { + server.stop(); + } + + @Test + public void testBlockingReadThenNormalComplete() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.setStatus(200); + resp.setContentType("text/plain"); + resp.getOutputStream().print("OK\r\n"); + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(200)); + assertThat(response.getContent(), containsString("OK")); + + // Async thread should have stopped + boolean await = stopped.await(10, TimeUnit.SECONDS); + if (!await) + { + StackTraceElement[] stackTrace = threadRef.get().getStackTrace(); + for (StackTraceElement stackTraceElement : stackTrace) + { + System.out.println(stackTraceElement); + } + } + assertTrue(await); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingCloseWhileReading() throws Exception + { + AtomicReference threadRef = new AtomicReference<>(); + AtomicReference threadFailure = new AtomicReference<>(); + + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) + { + try + { + AsyncContext asyncContext = req.startAsync(); + ServletOutputStream outputStream = resp.getOutputStream(); + resp.setStatus(200); + resp.setContentType("text/plain"); + + Thread thread = new Thread(() -> + { + try + { + try + { + for (int i = 0; i < 5; i++) + { + int b = req.getInputStream().read(); + assertThat(b, not(is(-1))); + } + outputStream.write("All read.".getBytes(StandardCharsets.UTF_8)); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + + // this read should throw IOException as the client has closed the connection + assertThrows(IOException.class, () -> req.getInputStream().read()); + + try + { + outputStream.close(); + } + catch (IOException e) + { + // can happen + } + finally + { + try + { + asyncContext.complete(); + } + catch (Exception e) + { + // tolerated + } + } + } + catch (Throwable x) + { + threadFailure.set(x); + } + }) + { + @Override + public String toString() + { + return super.toString() + " " + outputStream; + } + }; + threadRef.set(thread); + thread.start(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + }; + ServletContextHandler contextHandler = new ServletContextHandler(); + contextHandler.addServlet(servlet, "/*"); + + server.setHandler(contextHandler); + server.start(); + + String request = "POST /ctx/path/info HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-Type: test/data\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "10\r\n" + + "01234"; + + try (Socket socket = new Socket("localhost", connector.getLocalPort())) + { + socket.getOutputStream().write(request.getBytes(StandardCharsets.ISO_8859_1)); + + // Wait for handler thread to be started and for it to have read all bytes of the request. + await().pollInterval(1, TimeUnit.MICROSECONDS).atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread thread = threadRef.get(); + return thread != null && (thread.getState() == Thread.State.WAITING || thread.getState() == Thread.State.TIMED_WAITING); + }); + } + threadRef.get().join(5000); + if (threadRef.get().isAlive()) + { + System.err.println("Blocked handler thread: " + threadRef.get().toString()); + for (StackTraceElement stackTraceElement : threadRef.get().getStackTrace()) + { + System.err.println("\tat " + stackTraceElement); + } + fail("handler thread should not be alive anymore"); + } + assertThat("handler thread should not be alive anymore", threadRef.get().isAlive(), is(false)); + assertThat("handler thread failed: " + ExceptionUtil.toString(threadFailure.get()), threadFailure.get(), nullValue()); + } + + @Test + public void testNormalCompleteThenBlockingRead() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch completed = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet handler = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + assertTrue(completed.await(10, TimeUnit.SECONDS)); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.setStatus(200); + resp.setContentType("text/plain"); + resp.getOutputStream().print("OK\r\n"); + } + }; + context.addServlet(handler, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(200)); + assertThat(response.getContent(), containsString("OK")); + + completed.countDown(); + await().atMost(5, TimeUnit.SECONDS).until(() -> threadRef.get().getState() == Thread.State.TERMINATED); + + // Async thread should have stopped + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testStartAsyncThenBlockingReadThenTimeout() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch completed = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference threadRef = new AtomicReference<>(); + AtomicReference readException = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + if (req.getDispatcherType() != DispatcherType.ERROR) + { + AsyncContext async = req.startAsync(); + async.setTimeout(100); + + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + assertTrue(completed.await(10, TimeUnit.SECONDS)); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(500)); + assertThat(response.getContent(), containsString("AsyncContext timeout")); + + completed.countDown(); + await().atMost(5, TimeUnit.SECONDS).until(() -> threadRef.get().getState() == Thread.State.TERMINATED); + + // Async thread should have stopped + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingReadThenSendError() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + if (req.getDispatcherType() != DispatcherType.ERROR) + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.sendError(499); + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(5000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(499)); + + // Async thread should have stopped + boolean await = stopped.await(10, TimeUnit.SECONDS); + if (!await) + { + StackTraceElement[] stackTrace = threadRef.get().getStackTrace(); + for (StackTraceElement stackTraceElement : stackTrace) + { + System.err.println(stackTraceElement.toString()); + } + } + assertTrue(await); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingWriteThenNormalComplete() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.setStatus(200); + resp.setContentType("text/plain"); + Thread thread = new Thread(() -> + { + try + { + byte[] data = new byte[16 * 1024]; + Arrays.fill(data, (byte)'X'); + data[data.length - 2] = '\r'; + data[data.length - 1] = '\n'; + OutputStream out = resp.getOutputStream(); + started.countDown(); + while (true) + out.write(data); + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on write + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("GET /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("\r\n"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.ISO_8859_1)); + + // Read the header + List header = new ArrayList<>(); + while (true) + { + String line = in.readLine(); + if (line.length() == 0) + break; + header.add(line); + } + assertThat(header.get(0), containsString("200 OK")); + + // read one line of content + String content = in.readLine(); + assertThat(content, is("4000")); + content = in.readLine(); + assertThat(content, startsWith("XXXXXXXX")); + + // check that writing thread is stopped by end of request handling + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + + // read until last line + String last = null; + while (true) + { + String line = in.readLine(); + if (line == null) + break; + + last = line; + } + + // last line is not empty chunk, ie abnormal completion + assertThat(last, startsWith("XXXXX")); + assertThat(readException.get(), notNullValue()); + } + } +} diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/HttpOutput.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/HttpOutput.java index af7890f70302..1c1b12a1f343 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/HttpOutput.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/HttpOutput.java @@ -263,7 +263,7 @@ private void onWriteComplete(boolean last, Throwable failure) try (AutoLock ignored = _channelState.lock()) { if (LOG.isDebugEnabled()) - state = stateString(); + state = lockedStateString(); // Transition to CLOSED state if we were the last write or we have failed if (last || failure != null) @@ -273,7 +273,7 @@ private void onWriteComplete(boolean last, Throwable failure) _closedCallback = null; if (failure == null) lockedReleaseBuffer(); - wake = updateApiState(failure); + wake = lockedUpdateApiState(failure); } else if (_state == State.CLOSE) { @@ -285,13 +285,13 @@ else if (_state == State.CLOSE) } else { - wake = updateApiState(null); + wake = lockedUpdateApiState(null); } - } - if (LOG.isDebugEnabled()) - LOG.debug("onWriteComplete({},{}) {}->{} c={} cb={} w={}", - last, failure, state, stateString(), BufferUtil.toDetailString(closeContent), closedCallback, wake, failure); + if (LOG.isDebugEnabled()) + LOG.debug("onWriteComplete({},{}) {}->{} c={} cb={} w={}", + last, failure, state, lockedStateString(), BufferUtil.toDetailString(closeContent), closedCallback, wake, failure); + } try { @@ -314,8 +314,10 @@ else if (closeContent != null) } } - private boolean updateApiState(Throwable failure) + private boolean lockedUpdateApiState(Throwable failure) { + assert _channelState.isLockHeldByCurrentThread(); + boolean wake = false; switch (_apiState) { @@ -342,7 +344,7 @@ private boolean updateApiState(Throwable failure) default: if (_state == State.CLOSED) break; - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } return wake; } @@ -374,7 +376,7 @@ public ByteBuffer takeContentAndClose() try (AutoLock l = _channelState.lock()) { if (_state != State.OPEN) - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); ByteBuffer content = _aggregate != null && _aggregate.hasRemaining() ? BufferUtil.copy(_aggregate.getByteBuffer()) : BufferUtil.EMPTY_BUFFER; _state = State.CLOSED; lockedReleaseBuffer(); @@ -478,10 +480,10 @@ public void complete(Callback callback) break; } } - } - if (LOG.isDebugEnabled()) - LOG.debug("complete({}) {} s={} e={}, c={}", callback, stateString(), succeeded, error, BufferUtil.toDetailString(content)); + if (LOG.isDebugEnabled()) + LOG.debug("complete({}) {} s={} e={}, c={}", callback, lockedStateString(), succeeded, error, BufferUtil.toDetailString(content)); + } if (succeeded) { @@ -514,6 +516,7 @@ public void completed(Throwable ignored) @Override public void close() throws IOException { + RetainableByteBuffer aggregate = null; ByteBuffer content = null; Blocker.Callback blocker = null; try (AutoLock ignored = _channelState.lock()) @@ -562,7 +565,16 @@ public void close() throws IOException _apiState = ApiState.BLOCKED; _state = State.CLOSING; blocker = _writeBlocker.callback(); - content = _aggregate != null && _aggregate.hasRemaining() ? _aggregate.getByteBuffer() : BufferUtil.EMPTY_BUFFER; + aggregate = _aggregate; + if (aggregate != null && _aggregate.hasRemaining()) + { + aggregate.retain(); + content = aggregate.getByteBuffer(); + } + else + { + content = BufferUtil.EMPTY_BUFFER; + } break; case BLOCKED: @@ -580,7 +592,16 @@ public void close() throws IOException // Output is idle in async state, so we can do an async close _apiState = ApiState.PENDING; _state = State.CLOSING; - content = _aggregate != null && _aggregate.hasRemaining() ? _aggregate.getByteBuffer() : BufferUtil.EMPTY_BUFFER; + aggregate = _aggregate; + if (aggregate != null && _aggregate.hasRemaining()) + { + aggregate.retain(); + content = aggregate.getByteBuffer(); + } + else + { + content = BufferUtil.EMPTY_BUFFER; + } break; case UNREADY: @@ -593,10 +614,10 @@ public void close() throws IOException } break; } - } - if (LOG.isDebugEnabled()) - LOG.debug("close() {} c={} b={}", stateString(), BufferUtil.toDetailString(content), blocker); + if (LOG.isDebugEnabled()) + LOG.debug("close() {} c={} b={}", lockedStateString(), BufferUtil.toDetailString(content), blocker); + } if (content == null) { @@ -615,7 +636,10 @@ public void close() throws IOException if (blocker == null) { // Do an async close - channelWrite(content, true, new WriteCompleteCB()); + Callback callback = new WriteCompleteCB(); + if (aggregate != null) + callback = Callback.from(callback, aggregate::release); + channelWrite(content, true, callback); } else { @@ -624,6 +648,8 @@ public void close() throws IOException { channelWrite(content, true, blocker); b.block(); + if (aggregate != null) + aggregate.release(); onWriteComplete(true, null); } catch (Throwable t) @@ -714,7 +740,7 @@ public void flush() throws IOException case ASYNC: case PENDING: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: _apiState = ApiState.PENDING; @@ -724,7 +750,7 @@ public void flush() throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } } } @@ -808,7 +834,7 @@ public void write(byte[] b, int off, int len) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -820,7 +846,7 @@ public void write(byte[] b, int off, int len) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; @@ -836,7 +862,7 @@ public void write(byte[] b, int off, int len) throws IOException { if (LOG.isDebugEnabled()) LOG.debug("write(array) {} aggregated !flush {}", - stateString(), _aggregate); + lockedStateString(), _aggregate); return; } @@ -844,11 +870,11 @@ public void write(byte[] b, int off, int len) throws IOException off += filled; len -= filled; } - } - if (LOG.isDebugEnabled()) - LOG.debug("write(array) {} last={} agg={} flush=true async={}, len={} {}", - stateString(), last, aggregate, async, len, _aggregate); + if (LOG.isDebugEnabled()) + LOG.debug("write(array) {} last={} agg={} flush=true async={}, len={} {}", + lockedStateString(), last, aggregate, async, len, _aggregate); + } if (async) { @@ -941,7 +967,7 @@ public void write(ByteBuffer buffer) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -953,7 +979,7 @@ public void write(ByteBuffer buffer) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; } @@ -1023,7 +1049,7 @@ public void write(int b) throws IOException break; case ASYNC: - throw new IllegalStateException("isReady() not called: " + stateString()); + throw new IllegalStateException("isReady() not called: " + lockedStateString()); case READY: async = true; @@ -1035,7 +1061,7 @@ public void write(int b) throws IOException throw new WritePendingException(); default: - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); } _written = written; @@ -1295,7 +1321,7 @@ private boolean prepareSendContent(int len, Callback callback) } if (_apiState != ApiState.BLOCKING) - throw new IllegalStateException(stateString()); + throw new IllegalStateException(lockedStateString()); _apiState = ApiState.PENDING; if (len > 0) _written += len; @@ -1350,13 +1376,13 @@ public void resetBuffer() @Override public void setWriteListener(WriteListener writeListener) { - if (!_servletChannel.getServletRequestState().isAsync()) - throw new IllegalStateException("!ASYNC: " + stateString()); boolean wake; try (AutoLock ignored = _channelState.lock()) { + if (!_servletChannel.getServletRequestState().isAsync()) + throw new IllegalStateException("!ASYNC: " + lockedStateString()); if (_apiState != ApiState.BLOCKING) - throw new IllegalStateException("!OPEN" + stateString()); + throw new IllegalStateException("!OPEN" + lockedStateString()); _apiState = ApiState.READY; _writeListener = writeListener; wake = _servletChannel.getServletRequestState().onWritePossible(); @@ -1435,17 +1461,24 @@ public void writeCallback() } } - private String stateString() + private String lockedStateString() + { + assert _channelState.isLockHeldByCurrentThread(); + return unsafeStateString(); + } + + private String unsafeStateString() { - return String.format("s=%s,api=%s,sc=%b,e=%s", _state, _apiState, _softClose, _onError); + return String.format("s=%s,api=%s,sc=%b,e=%s,wb=%s", _state, _apiState, _softClose, _onError, _writeBlocker); } @Override public String toString() { - try (AutoLock ignored = _channelState.lock()) + try (AutoLock lock = _channelState.tryLock()) { - return String.format("%s@%x{%s}", this.getClass().getSimpleName(), hashCode(), stateString()); + boolean held = lock.isHeldByCurrentThread(); + return String.format("%s@%x{%s%s}", this.getClass().getSimpleName(), hashCode(), held ? "" : "?:", unsafeStateString()); } } diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java index 9b14fb57dd6e..185892c0d872 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java @@ -1209,8 +1209,15 @@ public ServletInputStream getInputStream() throws IOException if (_inputState != ServletContextRequest.INPUT_NONE && _inputState != ServletContextRequest.INPUT_STREAM) throw new IllegalStateException("READER"); _inputState = ServletContextRequest.INPUT_STREAM; - // Try to write a 100 continue, ignoring failure result if it was not necessary. - _servletChannel.getResponse().writeInterim(HttpStatus.CONTINUE_100, HttpFields.EMPTY); + try + { + // Try to write a 100 continue, ignoring failure result if it was not necessary. + _servletChannel.getResponse().writeInterim(HttpStatus.CONTINUE_100, HttpFields.EMPTY); + } + catch (IllegalStateException ise) + { + throw new IOException(ise); + } return getServletRequestInfo().getHttpInput(); } diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java index d9d9bfc74116..f56dd6f0bc54 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java @@ -223,6 +223,11 @@ AutoLock lock() return _lock.lock(); } + AutoLock tryLock() + { + return _lock.tryLock(); + } + boolean isLockHeldByCurrentThread() { return _lock.isHeldByCurrentThread(); diff --git a/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/BlockingTest.java b/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/BlockingTest.java new file mode 100644 index 000000000000..4a64b5c1a9ff --- /dev/null +++ b/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/BlockingTest.java @@ -0,0 +1,665 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.ee11.servlet; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.servlet.AsyncContext; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.eclipse.jetty.http.HttpTester; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.util.ExceptionUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.awaitility.Awaitility.await; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; +import static org.hamcrest.core.Is.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class BlockingTest +{ + private Server server; + private ServerConnector connector; + private ServletContextHandler context; + + @BeforeEach + public void setUp() + { + server = new Server(); + connector = new ServerConnector(server); + server.addConnector(connector); + + context = new ServletContextHandler("/ctx"); + server.setHandler(context); + } + + @AfterEach + public void tearDown() throws Exception + { + server.stop(); + } + + @Test + public void testBlockingReadThenNormalComplete() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.setStatus(200); + resp.setContentType("text/plain"); + resp.getOutputStream().print("OK\r\n"); + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(200)); + assertThat(response.getContent(), containsString("OK")); + + // Async thread should have stopped + boolean await = stopped.await(10, TimeUnit.SECONDS); + if (!await) + { + StackTraceElement[] stackTrace = threadRef.get().getStackTrace(); + for (StackTraceElement stackTraceElement : stackTrace) + { + System.out.println(stackTraceElement); + } + } + assertTrue(await); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingCloseWhileReading() throws Exception + { + AtomicReference threadRef = new AtomicReference<>(); + AtomicReference threadFailure = new AtomicReference<>(); + + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) + { + try + { + AsyncContext asyncContext = req.startAsync(); + ServletOutputStream outputStream = resp.getOutputStream(); + resp.setStatus(200); + resp.setContentType("text/plain"); + + Thread thread = new Thread(() -> + { + try + { + try + { + for (int i = 0; i < 5; i++) + { + int b = req.getInputStream().read(); + assertThat(b, not(is(-1))); + } + outputStream.write("All read.".getBytes(StandardCharsets.UTF_8)); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + + // this read should throw IOException as the client has closed the connection + assertThrows(IOException.class, () -> req.getInputStream().read()); + + try + { + outputStream.close(); + } + catch (IOException e) + { + // can happen + } + finally + { + try + { + asyncContext.complete(); + } + catch (Exception e) + { + // tolerated + } + } + } + catch (Throwable x) + { + threadFailure.set(x); + } + }) + { + @Override + public String toString() + { + return super.toString() + " " + outputStream; + } + }; + threadRef.set(thread); + thread.start(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + }; + ServletContextHandler contextHandler = new ServletContextHandler(); + contextHandler.addServlet(servlet, "/*"); + + server.setHandler(contextHandler); + server.start(); + + String request = "POST /ctx/path/info HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-Type: test/data\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "10\r\n" + + "01234"; + + try (Socket socket = new Socket("localhost", connector.getLocalPort())) + { + socket.getOutputStream().write(request.getBytes(StandardCharsets.ISO_8859_1)); + + // Wait for handler thread to be started and for it to have read all bytes of the request. + await().pollInterval(1, TimeUnit.MICROSECONDS).atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread thread = threadRef.get(); + return thread != null && (thread.getState() == Thread.State.WAITING || thread.getState() == Thread.State.TIMED_WAITING); + }); + } + threadRef.get().join(5000); + if (threadRef.get().isAlive()) + { + System.err.println("Blocked handler thread: " + threadRef.get().toString()); + for (StackTraceElement stackTraceElement : threadRef.get().getStackTrace()) + { + System.err.println("\tat " + stackTraceElement); + } + fail("handler thread should not be alive anymore"); + } + assertThat("handler thread should not be alive anymore", threadRef.get().isAlive(), is(false)); + assertThat("handler thread failed: " + ExceptionUtil.toString(threadFailure.get()), threadFailure.get(), nullValue()); + } + + @Test + public void testNormalCompleteThenBlockingRead() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch completed = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet handler = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + assertTrue(completed.await(10, TimeUnit.SECONDS)); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.setStatus(200); + resp.setContentType("text/plain"); + resp.getOutputStream().print("OK\r\n"); + } + }; + context.addServlet(handler, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(200)); + assertThat(response.getContent(), containsString("OK")); + + completed.countDown(); + await().atMost(5, TimeUnit.SECONDS).until(() -> threadRef.get().getState() == Thread.State.TERMINATED); + + // Async thread should have stopped + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testStartAsyncThenBlockingReadThenTimeout() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch completed = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference threadRef = new AtomicReference<>(); + AtomicReference readException = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + if (req.getDispatcherType() != DispatcherType.ERROR) + { + AsyncContext async = req.startAsync(); + async.setTimeout(100); + + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + assertTrue(completed.await(10, TimeUnit.SECONDS)); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(500)); + assertThat(response.getContent(), containsString("AsyncContext timeout")); + + completed.countDown(); + await().atMost(5, TimeUnit.SECONDS).until(() -> threadRef.get().getState() == Thread.State.TERMINATED); + + // Async thread should have stopped + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingReadThenSendError() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + AtomicReference threadRef = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + if (req.getDispatcherType() != DispatcherType.ERROR) + { + Thread thread = new Thread(() -> + { + try + { + int b = req.getInputStream().read(); + if (b == '1') + { + started.countDown(); + if (req.getInputStream().read() > Integer.MIN_VALUE) + throw new IllegalStateException(); + } + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + threadRef.set(thread); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on second byte + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + + resp.sendError(499); + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("POST /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("Content-Type: test/data\r\n") + .append("Content-Length: 2\r\n") + .append("\r\n") + .append("1"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(5000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + HttpTester.Response response = HttpTester.parseResponse(socket.getInputStream()); + assertThat(response, notNullValue()); + assertThat(response.getStatus(), is(499)); + + // Async thread should have stopped + boolean await = stopped.await(10, TimeUnit.SECONDS); + if (!await) + { + StackTraceElement[] stackTrace = threadRef.get().getStackTrace(); + for (StackTraceElement stackTraceElement : stackTrace) + { + System.err.println(stackTraceElement.toString()); + } + } + assertTrue(await); + assertThat(readException.get(), instanceOf(IOException.class)); + } + } + + @Test + public void testBlockingWriteThenNormalComplete() throws Exception + { + CountDownLatch started = new CountDownLatch(1); + CountDownLatch stopped = new CountDownLatch(1); + AtomicReference readException = new AtomicReference<>(); + HttpServlet servlet = new HttpServlet() + { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + resp.setStatus(200); + resp.setContentType("text/plain"); + Thread thread = new Thread(() -> + { + try + { + byte[] data = new byte[16 * 1024]; + Arrays.fill(data, (byte)'X'); + data[data.length - 2] = '\r'; + data[data.length - 1] = '\n'; + OutputStream out = resp.getOutputStream(); + started.countDown(); + while (true) + out.write(data); + } + catch (Throwable t) + { + readException.set(t); + stopped.countDown(); + } + }); + thread.start(); + + try + { + // wait for thread to start and read first byte + assertTrue(started.await(10, TimeUnit.SECONDS)); + // give it time to block on write + await().atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread.State state = thread.getState(); + return state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING; + }); + } + catch (Throwable e) + { + throw new ServletException(e); + } + } + }; + context.addServlet(servlet, "/*"); + server.start(); + + StringBuilder request = new StringBuilder(); + request.append("GET /ctx/path/info HTTP/1.1\r\n") + .append("Host: localhost\r\n") + .append("\r\n"); + + int port = connector.getLocalPort(); + try (Socket socket = new Socket("localhost", port)) + { + socket.setSoTimeout(10000); + OutputStream out = socket.getOutputStream(); + out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); + + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.ISO_8859_1)); + + // Read the header + List header = new ArrayList<>(); + while (true) + { + String line = in.readLine(); + if (line.length() == 0) + break; + header.add(line); + } + assertThat(header.get(0), containsString("200 OK")); + + // read one line of content + String content = in.readLine(); + assertThat(content, is("4000")); + content = in.readLine(); + assertThat(content, startsWith("XXXXXXXX")); + + // check that writing thread is stopped by end of request handling + assertTrue(stopped.await(10, TimeUnit.SECONDS)); + + // read until last line + String last = null; + while (true) + { + String line = in.readLine(); + if (line == null) + break; + + last = line; + } + + // last line is not empty chunk, ie abnormal completion + assertThat(last, startsWith("XXXXX")); + assertThat(readException.get(), notNullValue()); + } + } +} diff --git a/jetty-ee9/jetty-ee9-nested/src/test/java/org/eclipse/jetty/ee9/nested/BlockingTest.java b/jetty-ee9/jetty-ee9-nested/src/test/java/org/eclipse/jetty/ee9/nested/BlockingTest.java index 11d731614b2c..4dd5c98305a4 100644 --- a/jetty-ee9/jetty-ee9-nested/src/test/java/org/eclipse/jetty/ee9/nested/BlockingTest.java +++ b/jetty-ee9/jetty-ee9-nested/src/test/java/org/eclipse/jetty/ee9/nested/BlockingTest.java @@ -23,7 +23,6 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -36,7 +35,7 @@ import org.eclipse.jetty.http.HttpTester; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; -import org.eclipse.jetty.server.handler.gzip.GzipHandler; +import org.eclipse.jetty.util.ExceptionUtil; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -47,9 +46,12 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.core.Is.is; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class BlockingTest { @@ -156,59 +158,79 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques } @Test - public void testBlockingReadAndBlockingWriteGzipped() throws Exception + public void testBlockingCloseWhileReading() throws Exception { AtomicReference threadRef = new AtomicReference<>(); - CyclicBarrier barrier = new CyclicBarrier(2); + AtomicReference threadFailure = new AtomicReference<>(); - AbstractHandler handler = new AbstractHandler() + Handler handler = new AbstractHandler() { @Override - public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException + public void handle(String target, Request baseRequest, HttpServletRequest req, HttpServletResponse resp) { try { baseRequest.setHandled(true); - final AsyncContext asyncContext = baseRequest.startAsync(); - final ServletOutputStream outputStream = response.getOutputStream(); - final Thread thread = new Thread(() -> + AsyncContext asyncContext = req.startAsync(); + ServletOutputStream outputStream = resp.getOutputStream(); + resp.setStatus(200); + resp.setContentType("text/plain"); + + Thread thread = new Thread(() -> { try { - for (int i = 0; i < 5; i++) + try { - int b = baseRequest.getHttpInput().read(); - assertThat(b, not(is(-1))); + for (int i = 0; i < 5; i++) + { + int b = req.getInputStream().read(); + assertThat(b, not(is(-1))); + } + outputStream.write("All read.".getBytes(StandardCharsets.UTF_8)); } - outputStream.write("All read.".getBytes(StandardCharsets.UTF_8)); - barrier.await(); // notify that all bytes were read - baseRequest.getHttpInput().read(); // this read should throw IOException as the client has closed the connection - throw new AssertionError("should have thrown IOException"); - } - catch (Exception e) - { - //throw new RuntimeException(e); - } - finally - { + catch (IOException e) + { + throw new RuntimeException(e); + } + + // this read should throw IOException as the client has closed the connection + assertThrows(IOException.class, () -> req.getInputStream().read()); + try { outputStream.close(); } - catch (Exception e2) + catch (IOException e) + { + // can happen + } + finally { - //e2.printStackTrace(); + try + { + asyncContext.complete(); + } + catch (Exception e) + { + // tolerated + } } - asyncContext.complete(); } - }); + catch (Throwable x) + { + threadFailure.set(x); + } + }) + { + @Override + public String toString() + { + return super.toString() + " " + outputStream; + } + }; threadRef.set(thread); thread.start(); - barrier.await(); // notify that handler thread has started - - response.setStatus(200); - response.setContentType("text/plain"); - response.getOutputStream().print("OK\r\n"); } catch (Exception e) { @@ -219,35 +241,39 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques ContextHandler contextHandler = new ContextHandler(); contextHandler.setHandler(handler); - GzipHandler gzipHandler = new GzipHandler(); - gzipHandler.setMinGzipSize(1); - gzipHandler.setHandler(contextHandler); - server.setHandler(gzipHandler); + server.setHandler(contextHandler); server.start(); - StringBuilder request = new StringBuilder(); - // partial chunked request - request.append("POST /ctx/path/info HTTP/1.1\r\n") - .append("Host: localhost\r\n") - .append("Accept-Encoding: gzip, *\r\n") - .append("Content-Type: test/data\r\n") - .append("Transfer-Encoding: chunked\r\n") - .append("\r\n") - .append("10\r\n") - .append("01234") - ; + String request = "POST /ctx/path/info HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-Type: test/data\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "10\r\n" + + "01234"; - int port = connector.getLocalPort(); - try (Socket socket = new Socket("localhost", port)) + try (Socket socket = new Socket("localhost", connector.getLocalPort())) { - socket.setSoLinger(true, 0); // send TCP RST upon close instead of FIN - OutputStream out = socket.getOutputStream(); - out.write(request.toString().getBytes(StandardCharsets.ISO_8859_1)); - barrier.await(); // wait for handler thread to be started - barrier.await(); // wait for all bytes of the request to be read + socket.getOutputStream().write(request.getBytes(StandardCharsets.ISO_8859_1)); + + // Wait for handler thread to be started and for it to have read all bytes of the request. + await().pollInterval(1, TimeUnit.MICROSECONDS).atMost(5, TimeUnit.SECONDS).until(() -> + { + Thread thread = threadRef.get(); + return thread != null && (thread.getState() == Thread.State.WAITING || thread.getState() == Thread.State.TIMED_WAITING); + }); } threadRef.get().join(5000); - assertThat("handler thread should not be alive anymore", threadRef.get().isAlive(), is(false)); + if (threadRef.get().isAlive()) + { + System.err.println("Blocked handler thread: " + threadRef.get().toString()); + for (StackTraceElement stackTraceElement : threadRef.get().getStackTrace()) + { + System.err.println("\tat " + stackTraceElement); + } + fail("handler thread should not be alive anymore"); + } + assertThat("handler thread failed: " + ExceptionUtil.toString(threadFailure.get()), threadFailure.get(), nullValue()); } @Test