diff --git a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java index 146b4010af4b3..c9054bd59b975 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportMultiSearchAction.java @@ -44,6 +44,8 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -193,7 +195,7 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It if (responseCounter.decrementAndGet() == 0) { assert requests.isEmpty(); finish(); - } else { + } else if (isCancelled(request.request.getParentTask()) == false) { if (thread == Thread.currentThread()) { // we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread threadPool.generic() @@ -220,6 +222,14 @@ private long buildTookInMillis() { }); } + private boolean isCancelled(TaskId taskId) { + if (taskId.isSet()) { + CancellableTask task = taskManager.getCancellableTask(taskId.getId()); + return task != null && task.isCancelled(); + } + return false; + } + /** * Slots a search request * diff --git a/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java b/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java index 48970e2b96add..45980e7137ce4 100644 --- a/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/TransportMultiSearchActionTests.java @@ -49,7 +49,9 @@ import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskListener; import org.opensearch.tasks.TaskManager; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; @@ -62,7 +64,9 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -289,4 +293,118 @@ public void testDefaultMaxConcurrentSearches() { assertThat(result, equalTo(1)); } + public void testCancellation() { + // Initialize dependencies of TransportMultiSearchAction + Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build(); + ActionFilters actionFilters = mock(ActionFilters.class); + when(actionFilters.filters()).thenReturn(new ActionFilter[0]); + ThreadPool threadPool = new ThreadPool(settings); + TransportService transportService = new TransportService( + Settings.EMPTY, + mock(Transport.class), + threadPool, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()), + null, + Collections.emptySet(), + NoopTracer.INSTANCE + ) { + @Override + public TaskManager getTaskManager() { + return taskManager; + } + }; + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build()); + + // Keep track of the number of concurrent searches started by multi search api, + // and if there are more searches than is allowed create an error and remember that. + int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time. + AtomicInteger counter = new AtomicInteger(); + AtomicReference errorHolder = new AtomicReference<>(); + // randomize whether or not requests are executed asynchronously + ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC); + final Set requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>())); + CountDownLatch countDownLatch = new CountDownLatch(1); + CancellableTask[] parentTask = new CancellableTask[1]; + NodeClient client = new NodeClient(settings, threadPool) { + @Override + public void search(final SearchRequest request, final ActionListener listener) { + if (parentTask[0] != null && parentTask[0].isCancelled()) { + fail("Should not execute search after parent task is cancelled"); + } + try { + countDownLatch.await(10, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + requests.add(request); + executorService.execute(() -> { + counter.decrementAndGet(); + listener.onResponse( + new SearchResponse( + InternalSearchResponse.empty(), + null, + 0, + 0, + 0, + 0L, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ) + ); + }); + } + + @Override + public String getLocalNodeId() { + return "local_node_id"; + } + }; + + TransportMultiSearchAction action = new TransportMultiSearchAction( + threadPool, + actionFilters, + transportService, + clusterService, + 10, + System::nanoTime, + client + ); + + // Execute the multi search api and fail if we find an error after executing: + try { + /* + * Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number + * of requests in a single batch was large + */ + int numSearchRequests = scaledRandomIntBetween(1024, 8192); + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches); + for (int i = 0; i < numSearchRequests; i++) { + multiSearchRequest.add(new SearchRequest()); + } + MultiSearchResponse[] responses = new MultiSearchResponse[1]; + Exception[] exceptions = new Exception[1]; + parentTask[0] = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<>() { + @Override + public void onResponse(Task task, MultiSearchResponse items) { + responses[0] = items; + } + + @Override + public void onFailure(Task task, Exception e) { + exceptions[0] = e; + } + }); + parentTask[0].cancel("Giving up"); + countDownLatch.countDown(); + + assertNull(responses[0]); + assertNull(exceptions[0]); + } finally { + assertTrue(OpenSearchTestCase.terminate(threadPool)); + } + } }