diff --git a/processing/src/main/java/org/apache/druid/java/util/http/client/NettyHttpClient.java b/processing/src/main/java/org/apache/druid/java/util/http/client/NettyHttpClient.java index 3ab3719180fd..76a50a711273 100644 --- a/processing/src/main/java/org/apache/druid/java/util/http/client/NettyHttpClient.java +++ b/processing/src/main/java/org/apache/druid/java/util/http/client/NettyHttpClient.java @@ -219,21 +219,34 @@ public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) log.debug("[%s] Got response: %s", requestDesc, httpResponse.getStatus()); } - HttpResponseHandler.TrafficCop trafficCop = resumeChunkNum -> { - synchronized (watermarkLock) { - resumeWatermark = Math.max(resumeWatermark, resumeChunkNum); - - if (suspendWatermark >= 0 && resumeWatermark >= suspendWatermark) { - suspendWatermark = -1; - channel.setReadable(true); - long backPressureDuration = System.nanoTime() - backPressureStartTimeNs; - log.debug("[%s] Resumed reads from channel (chunkNum = %,d).", requestDesc, resumeChunkNum); - return backPressureDuration; + HttpResponseHandler.TrafficCop trafficCop = new HttpResponseHandler.TrafficCop() + { + @Override + public long resume(long resumeChunkNum) + { + synchronized (watermarkLock) { + resumeWatermark = Math.max(resumeWatermark, resumeChunkNum); + + if (suspendWatermark >= 0 && resumeWatermark >= suspendWatermark) { + suspendWatermark = -1; + channel.setReadable(true); + long backPressureDuration = System.nanoTime() - backPressureStartTimeNs; + log.debug("[%s] Resumed reads from channel (chunkNum = %,d).", requestDesc, resumeChunkNum); + return backPressureDuration; + } } + + return 0; //If we didn't resume, don't know if backpressure was happening } - return 0; //If we didn't resume, don't know if backpressure was happening + @Override + public void abort() + { + log.debug("[%s] Aborted connection at caller's request.", requestDesc); + channel.close(); + } }; + response = handler.handleResponse(httpResponse, trafficCop); if (response.isFinished()) { retVal.set((Final) response.getObj()); diff --git a/processing/src/main/java/org/apache/druid/java/util/http/client/response/HttpResponseHandler.java b/processing/src/main/java/org/apache/druid/java/util/http/client/response/HttpResponseHandler.java index 43bca7ace1d1..10229176295e 100644 --- a/processing/src/main/java/org/apache/druid/java/util/http/client/response/HttpResponseHandler.java +++ b/processing/src/main/java/org/apache/druid/java/util/http/client/response/HttpResponseHandler.java @@ -98,5 +98,13 @@ interface TrafficCop * @return time that backpressure was applied (channel was closed for reads) */ long resume(long chunkNum); + + /** + * Closes the underlying connection, abandoning any remaining response. + * + * Intended for callers that decide they no longer need the rest of the response (for example, because the + * consumer of the resulting stream has been closed early) + */ + void abort(); } } diff --git a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java index 3c746e410701..336bee51012e 100644 --- a/server/src/main/java/org/apache/druid/client/DirectDruidClient.java +++ b/server/src/main/java/org/apache/druid/client/DirectDruidClient.java @@ -181,6 +181,9 @@ public Sequence run(final QueryPlus queryPlus, final ResponseContext conte private final AtomicLong channelSuspendedTime = new AtomicLong(0); private final BlockingQueue queue = new LinkedBlockingQueue<>(); private final AtomicBoolean done = new AtomicBoolean(false); + // Set when the consumer closes the result stream early (see the SequenceInputStream.close() override in + // handleResponse). Once set, incoming chunks are dropped rather than buffered. + private final AtomicBoolean discard = new AtomicBoolean(false); private final AtomicBoolean nodeMetricsEmitted = new AtomicBoolean(false); private final AtomicReference fail = new AtomicReference<>(); private final AtomicReference trafficCopRef = new AtomicReference<>(); @@ -202,6 +205,12 @@ private QueryMetrics> acquireResponseMetrics() */ private boolean enqueue(ChannelBuffer buffer, long chunkNum) throws InterruptedException { + // If the consumer has abandoned the response (see the SequenceInputStream.close() override below), drop the + // chunk instead of buffering it, and keep reads flowing (continueReading = true) so we never suspend the + // channel while it is being wound down. + if (discard.get()) { + return true; + } // Increment queuedByteCount before queueing the object, so queuedByteCount is at least as high as // the actual number of queued bytes at any particular time. final InputStreamHolder holder = InputStreamHolder.fromChannelBuffer(buffer, chunkNum); @@ -282,6 +291,12 @@ public int read() throws IOException @Override public boolean hasMoreElements() { + // If the consumer abandoned the stream (close() ran), report end-of-stream. After discard is set + // enqueue() drops every chunk, so a further read would otherwise block in dequeue() until the + // query timeout and then throw a misleading QueryTimeoutException. + if (discard.get()) { + return false; + } if (fail.get() != null) { throw new RE(fail.get()); } @@ -310,7 +325,33 @@ public InputStream nextElement() } } } - ), + ) + { + /** + * Closing this stream means the caller no longer needs the response. The default + * {@link SequenceInputStream#close()} would drain the entire remaining response off the wire first + * we want to avoid. Instead, abandon the response and force-close the connection + */ + @Override + public void close() + { + final TrafficCop trafficCop; + synchronized (done) { + if (done.get()) { + return; + } + // Stop buffering further chunks (see enqueue()) and drop anything already buffered so the + // underlying Netty ChannelBuffers can be released. + discard.set(true); + queue.clear(); + trafficCop = trafficCopRef.get(); + } + if (trafficCop == null) { + return; + } + trafficCop.abort(); + } + }, continueReading ); } diff --git a/server/src/test/java/org/apache/druid/client/DirectDruidClientAbortHttpTest.java b/server/src/test/java/org/apache/druid/client/DirectDruidClientAbortHttpTest.java new file mode 100644 index 000000000000..051c058f959d --- /dev/null +++ b/server/src/test/java/org/apache/druid/client/DirectDruidClientAbortHttpTest.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; +import org.apache.druid.common.utils.SocketUtil; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Yielder; +import org.apache.druid.java.util.common.guava.Yielders; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.lifecycle.Lifecycle; +import org.apache.druid.java.util.http.client.HttpClientConfig; +import org.apache.druid.java.util.http.client.HttpClientInit; +import org.apache.druid.query.BaseQuery; +import org.apache.druid.query.Druids; +import org.apache.druid.query.QueryPlus; +import org.apache.druid.query.QueryRunnerFactoryConglomerate; +import org.apache.druid.query.QueryRunnerTestHelper; +import org.apache.druid.query.Result; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.server.QueryStackTests; +import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.eclipse.jetty.ee8.servlet.ServletContextHandler; +import org.eclipse.jetty.ee8.servlet.ServletHolder; +import org.eclipse.jetty.server.Server; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +public class DirectDruidClientAbortHttpTest +{ + @Test + public void testEarlyStreamClose() throws Exception + { + final ObjectMapper mapper = TestHelper.makeJsonMapper(); + + final AtomicBoolean serverSawDisconnect = new AtomicBoolean(false); + final CountDownLatch responseSent = new CountDownLatch(1); + final CountDownLatch connectionTerminated = new CountDownLatch(1); + final CountDownLatch terminationDetected = new CountDownLatch(1); + + final int port = SocketUtil.findOpenPort(0); + final Server server = new Server(port); + final ServletContextHandler handler = new ServletContextHandler(); + handler.addServlet( + new ServletHolder(new HttpServlet() + { + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) + { + final ServletOutputStream out; + try { + out = resp.getOutputStream(); + resp.setStatus(HttpServletResponse.SC_OK); + out.write(StringUtils.toUtf8("[{\"timestamp\":\"2014-01-01T01:02:03Z\", \"result\": 42.0}]")); + out.flush(); + } + catch (Exception e) { + throw DruidException.defensive(e, "Encountered some issue while sending the real bytes"); + } + + responseSent.countDown(); + + try { + connectionTerminated.await(); + final long endTime = System.currentTimeMillis() + 100; + while (System.currentTimeMillis() < endTime) { + out.write(0); + out.flush(); + } + } + catch (IOException e) { + serverSawDisconnect.set(true); + terminationDetected.countDown(); + } + catch (Exception e) { + throw DruidException.defensive( + e, + "Encountered some issue while awaiting for the connection to be terminated" + ); + } + } + }), + "/*" + ); + server.setHandler(handler); + + final Lifecycle lifecycle = new Lifecycle(); + final ScheduledExecutorService queryCancellationExecutor = + Execs.scheduledSingleThreaded("DirectDruidClientAbortHttpTest-cancel-%d"); + final Closer closer = Closer.create(); + try { + server.start(); + lifecycle.start(); + + final QueryRunnerFactoryConglomerate conglomerate = + QueryStackTests.createQueryRunnerFactoryConglomerate(closer); + + final DirectDruidClient directDruidClient = new DirectDruidClient( + conglomerate, + QueryRunnerTestHelper.NOOP_QUERYWATCHER, + mapper, + HttpClientInit.createClient( + HttpClientConfig.builder().withNumConnections(1).build(), + lifecycle + ), + "http", + "localhost:" + port, + new NoopServiceEmitter(), + queryCancellationExecutor + ); + + final Map queryContext = ImmutableMap.of( + DirectDruidClient.QUERY_FAIL_TIME, System.currentTimeMillis() + 60_000L, + BaseQuery.QUERY_ID, "abort-test" + ); + + final Sequence results = directDruidClient.run( + QueryPlus.wrap(Druids.newTimeBoundaryQueryBuilder().dataSource("test").context(queryContext).build()), + DirectDruidClient.makeResponseContextForQuery() + ); + + responseSent.await(); + + final AtomicInteger resultCount = new AtomicInteger(0); + final Yielder yielder = Yielders.each(results); + try { + Assertions.assertFalse(yielder.isDone(), "expected at least one result before stopping"); + final Result result = (Result) yielder.get(); + Assertions.assertEquals(DateTimes.of("2014-01-01T01:02:03Z"), result.getTimestamp()); + resultCount.incrementAndGet(); + } + finally { + connectionTerminated.countDown(); + yielder.close(); + } + + Assertions.assertEquals(1, resultCount.get(), "expected exactly one result before stopping"); + if (!terminationDetected.await(5, TimeUnit.SECONDS)) { + Assertions.fail("Test did not complete in 5 seconds!?"); + } + Assertions.assertTrue(serverSawDisconnect.get(), "server should have marked the connection as disconnected"); + } + finally { + queryCancellationExecutor.shutdownNow(); + closer.close(); + lifecycle.stop(); + server.stop(); + } + } +} diff --git a/server/src/test/java/org/apache/druid/client/TestHttpClient.java b/server/src/test/java/org/apache/druid/client/TestHttpClient.java index ef4da264aadb..4116bc5e16fb 100644 --- a/server/src/test/java/org/apache/druid/client/TestHttpClient.java +++ b/server/src/test/java/org/apache/druid/client/TestHttpClient.java @@ -62,7 +62,20 @@ */ public class TestHttpClient implements HttpClient { - private static final TrafficCop NOOP_TRAFFIC_COP = checkNum -> 0L; + public static final TrafficCop NOOP_TRAFFIC_COP = new TrafficCop() + { + @Override + public long resume(long chunkNum) + { + return 0; + } + + @Override + public void abort() + { + + } + }; private static final int RESPONSE_CTX_HEADER_LEN_LIMIT = 7 * 1024; private final Map servers = new HashMap<>(); diff --git a/server/src/test/java/org/apache/druid/rpc/MockServiceClient.java b/server/src/test/java/org/apache/druid/rpc/MockServiceClient.java index 021db219d963..4dacf0e62ee8 100644 --- a/server/src/test/java/org/apache/druid/rpc/MockServiceClient.java +++ b/server/src/test/java/org/apache/druid/rpc/MockServiceClient.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.client.TestHttpClient; import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.http.client.response.ClientResponse; import org.apache.druid.java.util.http.client.response.HttpResponseHandler; @@ -60,7 +61,7 @@ public ListenableFuture asyncRequest( if (expectation.response.isValue()) { final ClientResponse response = - handler.done(handler.handleResponse(expectation.response.valueOrThrow(), chunkNum -> 0)); + handler.done(handler.handleResponse(expectation.response.valueOrThrow(), TestHttpClient.NOOP_TRAFFIC_COP)); return Futures.immediateFuture(response.getObj()); } else { return Futures.immediateFailedFuture(expectation.response.error()); diff --git a/server/src/test/java/org/apache/druid/server/coordinator/simulate/TestSegmentLoadingHttpClient.java b/server/src/test/java/org/apache/druid/server/coordinator/simulate/TestSegmentLoadingHttpClient.java index ae911c4d8ae1..eda96d28b0b5 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/simulate/TestSegmentLoadingHttpClient.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/simulate/TestSegmentLoadingHttpClient.java @@ -25,6 +25,7 @@ import com.google.common.util.concurrent.ListeningScheduledExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.client.TestHttpClient; import org.apache.druid.java.util.http.client.HttpClient; import org.apache.druid.java.util.http.client.Request; import org.apache.druid.java.util.http.client.response.HttpResponseHandler; @@ -52,7 +53,6 @@ public class TestSegmentLoadingHttpClient implements HttpClient { - private static final HttpResponseHandler.TrafficCop NOOP_TRAFFIC_COP = checkNum -> 0L; private static final DataSegmentChangeCallback NOOP_CALLBACK = () -> { }; @@ -107,7 +107,7 @@ private Final processRequest( final HttpResponse failureResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND); failureResponse.setContent(ChannelBuffers.EMPTY_BUFFER); - handler.handleResponse(failureResponse, NOOP_TRAFFIC_COP); + handler.handleResponse(failureResponse, TestHttpClient.NOOP_TRAFFIC_COP); return (Final) new ByteArrayInputStream(new byte[0]); } @@ -122,7 +122,7 @@ private Final processRequest( final HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); response.setContent(ChannelBuffers.EMPTY_BUFFER); - handler.handleResponse(response, NOOP_TRAFFIC_COP); + handler.handleResponse(response, TestHttpClient.NOOP_TRAFFIC_COP); return (Final) new ByteArrayInputStream(serializedContent); } catch (Exception e) { @@ -157,7 +157,7 @@ private ListenableFuture getCapabilities(HttpRespon // Set response content and status final HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); response.setContent(ChannelBuffers.EMPTY_BUFFER); - handler.handleResponse(response, NOOP_TRAFFIC_COP); + handler.handleResponse(response, TestHttpClient.NOOP_TRAFFIC_COP); // Serialize SettableFuture future = SettableFuture.create();