From c033d567f2feff35385686aa389eba2872af5d5e Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 20 Nov 2023 12:21:08 -0800 Subject: [PATCH 01/22] AwaitsFix AsyncOperatorTests#testFailure Tracked at #102264 --- .../org/elasticsearch/compute/operator/AsyncOperatorTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java index 00b046abdca24..2283083512f13 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java @@ -185,6 +185,7 @@ protected void doClose() { operator.close(); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102264") public void testFailure() throws Exception { DriverContext driverContext = driverContext(); final SequenceLongBlockSourceOperator sourceOperator = new SequenceLongBlockSourceOperator( From eabb21e89922c95412c803cee1e171abe512dbeb Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 20 Nov 2023 17:03:56 -0800 Subject: [PATCH 02/22] Enable YAML tests in ESQL mixed cluster module (#102392) With the YAML tests moved to x-pack core, we now can include them in the ESQL mixed cluster tests. --- .../esql/qa/server/mixed-cluster/build.gradle | 5 ++- .../xpack/esql/qa/mixed/EsqlClientYamlIT.java | 34 +++++++++++++++++++ .../rest-api-spec/test/esql/30_types.yml | 4 +-- 3 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/esql/qa/server/mixed-cluster/src/test/java/org/elasticsearch/xpack/esql/qa/mixed/EsqlClientYamlIT.java diff --git a/x-pack/plugin/esql/qa/server/mixed-cluster/build.gradle b/x-pack/plugin/esql/qa/server/mixed-cluster/build.gradle index 10f993124652d..01955adb3af0c 100644 --- a/x-pack/plugin/esql/qa/server/mixed-cluster/build.gradle +++ b/x-pack/plugin/esql/qa/server/mixed-cluster/build.gradle @@ -16,7 +16,10 @@ dependencies { restResources { restApi { - include '_common', 'bulk', 'indices', 'esql', 'xpack', 'enrich' + include '_common', 'bulk', 'get', 'indices', 'esql', 'xpack', 'enrich', 'cluster' + } + restTests { + includeXpack 'esql' } } diff --git a/x-pack/plugin/esql/qa/server/mixed-cluster/src/test/java/org/elasticsearch/xpack/esql/qa/mixed/EsqlClientYamlIT.java b/x-pack/plugin/esql/qa/server/mixed-cluster/src/test/java/org/elasticsearch/xpack/esql/qa/mixed/EsqlClientYamlIT.java new file mode 100644 index 0000000000000..0965c5506c6a1 --- /dev/null +++ b/x-pack/plugin/esql/qa/server/mixed-cluster/src/test/java/org/elasticsearch/xpack/esql/qa/mixed/EsqlClientYamlIT.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.mixed; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.elasticsearch.xpack.esql.qa.rest.EsqlSpecTestCase; +import org.junit.After; +import org.junit.Before; + +public class EsqlClientYamlIT extends ESClientYamlSuiteTestCase { + + public EsqlClientYamlIT(final ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return createParameters(); + } + + @Before + @After + public void assertRequestBreakerEmpty() throws Exception { + EsqlSpecTestCase.assertRequestBreakerEmpty(); + } +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/30_types.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/30_types.yml index 406ae169872a2..7e80ee8dd6904 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/30_types.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/30_types.yml @@ -1,8 +1,8 @@ --- setup: - skip: - version: " - 8.10.99" - reason: "ESQL is available in 8.11+" + version: " - 8.11.99" + reason: "more field loading added in 8.12+" features: warnings --- From f264323f5d4a34c0f2ed9e8e1034fe6d07b6e797 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 21 Nov 2023 17:06:06 +1100 Subject: [PATCH 03/22] Minor tweak for free space reporting (#102276) This PR adds free space size in error message. It also removes the special handling of negative size since the JDK bug is fixed. --- server/src/main/java/org/elasticsearch/env/Environment.java | 6 +----- .../org/elasticsearch/blobcache/shared/SharedBytes.java | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/env/Environment.java b/server/src/main/java/org/elasticsearch/env/Environment.java index 44cf74c4339a6..2f738eb1412a5 100644 --- a/server/src/main/java/org/elasticsearch/env/Environment.java +++ b/server/src/main/java/org/elasticsearch/env/Environment.java @@ -326,11 +326,7 @@ public static FileStore getFileStore(final Path path) throws IOException { public static long getUsableSpace(Path path) throws IOException { long freeSpaceInBytes = Environment.getFileStore(path).getUsableSpace(); - - /* See: https://bugs.openjdk.java.net/browse/JDK-8162520 */ - if (freeSpaceInBytes < 0) { - freeSpaceInBytes = Long.MAX_VALUE; - } + assert freeSpaceInBytes >= 0; return freeSpaceInBytes; } diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java index 04347aaf6bff2..530cbbe6c6184 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBytes.java @@ -134,7 +134,9 @@ public static Path findCacheSnapshotCacheFilePath(NodeEnvironment environment, l if (usableSpace > fileSize) { return p; } else { - throw new IOException("Not enough free space for cache file of size [" + fileSize + "] in path [" + path + "]"); + throw new IOException( + "Not enough free space [" + usableSpace + "] for cache file of size [" + fileSize + "] in path [" + path + "]" + ); } } From 9a2c2a4f0741c598ab8e834d69267ece15443591 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 07:49:58 +0000 Subject: [PATCH 04/22] Unmute #102010 (#102383) This has only been observed to fail on a single PR branch, and although the changes in that PR seem unrelated we don't have enough to go on nor can we reproduce it on latest `main`. Closes #102010 (at least temporarily) --- .../java/org/elasticsearch/upgrades/CcrRollingUpgradeIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/qa/rolling-upgrade-multi-cluster/src/test/java/org/elasticsearch/upgrades/CcrRollingUpgradeIT.java b/x-pack/qa/rolling-upgrade-multi-cluster/src/test/java/org/elasticsearch/upgrades/CcrRollingUpgradeIT.java index 0b7ab1fe5980d..92c751ca28948 100644 --- a/x-pack/qa/rolling-upgrade-multi-cluster/src/test/java/org/elasticsearch/upgrades/CcrRollingUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade-multi-cluster/src/test/java/org/elasticsearch/upgrades/CcrRollingUpgradeIT.java @@ -239,7 +239,6 @@ public void testCannotFollowLeaderInUpgradedCluster() throws Exception { } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102010") public void testBiDirectionalIndexFollowing() throws Exception { logger.info("clusterName={}, upgradeState={}", clusterName, upgradeState); From 4779a28d4e1efdc2c0b0d879365d3e53cd7a44dc Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 08:09:46 +0000 Subject: [PATCH 05/22] Fix ConcurrentSeqNoVersioningIT aborts (#102380) Today we treat an abort (timeout) of the linearizability checker as a linearizability failure, but really it is inconclusive and should not itself fail the test. Closes #102255 --- .../ConcurrentSeqNoVersioningIT.java | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/versioning/ConcurrentSeqNoVersioningIT.java b/server/src/internalClusterTest/java/org/elasticsearch/versioning/ConcurrentSeqNoVersioningIT.java index 0f9b63f2757cf..3fca413ea5c41 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/versioning/ConcurrentSeqNoVersioningIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/versioning/ConcurrentSeqNoVersioningIT.java @@ -118,7 +118,6 @@ public class ConcurrentSeqNoVersioningIT extends AbstractDisruptionTestCase { // multiple threads doing CAS updates. // Wait up to 1 minute (+10s in thread to ensure it does not time out) for threads to complete previous round before initiating next // round. - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102255") public void testSeqNoCASLinearizability() { final int disruptTimeSeconds = scaledRandomIntBetween(1, 8); @@ -434,24 +433,31 @@ public void assertLinearizable() { history ); LinearizabilityChecker.SequentialSpec spec = new CASSequentialSpec(initialVersion); - boolean linearizable = false; + Boolean linearizable = null; try { linearizable = LinearizabilityChecker.isLinearizable(spec, history, missingResponseGenerator()); } catch (LinearizabilityCheckAborted e) { - logger.warn("linearizability check check was aborted", e); + logger.warn("linearizability check was aborted, assuming linearizable", e); } finally { try { - if (linearizable) { + if (Boolean.TRUE.equals(linearizable)) { // ensure that we can serialize all histories. writeHistory(new OutputStreamStreamOutput(OutputStream.nullOutputStream()), history); } else { - logger.error("Linearizability check failed. Spec: {}, initial version: {}", spec, initialVersion); + final var outcome = linearizable == null ? "inconclusive" : "unlinearizable"; + + logger.error( + "Linearizability check did not succeed. Spec: {}, initial version: {}, outcome: {}", + spec, + initialVersion, + outcome + ); // we dump base64 encoded data, since the nature of this test is that it does not reproduce even with same seed. try ( var chunkedLoggingStream = ChunkedLoggingStream.create( logger, Level.ERROR, - "raw unlinearizable history in partition " + id, + "raw " + outcome + " history in partition " + id, ReferenceDocs.LOGGING // any old docs link will do ); var output = new OutputStreamStreamOutput(chunkedLoggingStream) @@ -462,20 +468,20 @@ public void assertLinearizable() { var chunkedLoggingStream = ChunkedLoggingStream.create( logger, Level.ERROR, - "visualisation of unlinearizable history in partition " + id, + "visualisation of " + outcome + " history in partition " + id, ReferenceDocs.LOGGING // any old docs link will do ); var writer = new OutputStreamWriter(chunkedLoggingStream, StandardCharsets.UTF_8) ) { LinearizabilityChecker.writeVisualisation(spec, history, missingResponseGenerator(), writer); } + assertNull("Must not be unlinearizable", linearizable); } } catch (IOException e) { logger.error("failure writing out history", e); fail(e); } } - assertTrue("Must be linearizable", linearizable); } } From bc7cd6aeffa5c1c6d8816b61f1ae2025335d6616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 21 Nov 2023 09:17:15 +0100 Subject: [PATCH 06/22] [Transform] Pass source query to FieldCapabilitiesAction (as index_filter) for better performance (#102379) --- docs/changelog/102379.yaml | 6 ++++++ .../integration/TransformPivotRestIT.java | 21 +++++++++++++++++-- .../transform/transforms/pivot/Pivot.java | 10 ++++++++- .../transforms/pivot/SchemaUtil.java | 6 ++++++ .../AggregationSchemaAndResultTests.java | 21 +++++++++++++++++-- .../transforms/pivot/SchemaUtilTests.java | 7 +++++++ 6 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 docs/changelog/102379.yaml diff --git a/docs/changelog/102379.yaml b/docs/changelog/102379.yaml new file mode 100644 index 0000000000000..0773b137779a5 --- /dev/null +++ b/docs/changelog/102379.yaml @@ -0,0 +1,6 @@ +pr: 102379 +summary: Pass source query to `_field_caps` (as `index_filter`) when deducing destination index mappings for better + performance +area: Transform +type: enhancement +issues: [] diff --git a/x-pack/plugin/transform/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/transform/integration/TransformPivotRestIT.java b/x-pack/plugin/transform/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/transform/integration/TransformPivotRestIT.java index ffafa133989a6..925e6d5381770 100644 --- a/x-pack/plugin/transform/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/transform/integration/TransformPivotRestIT.java +++ b/x-pack/plugin/transform/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/transform/integration/TransformPivotRestIT.java @@ -1123,8 +1123,24 @@ public void testContinuousDateHistogramPivot() throws Exception { assertEquals(11, totalStars, 0); } - @SuppressWarnings("unchecked") public void testPreviewTransform() throws Exception { + testPreviewTransform(""); + } + + public void testPreviewTransformWithQuery() throws Exception { + testPreviewTransform(""" + , + "query": { + "range": { + "timestamp": { + "gte": 123456789 + } + } + }"""); + } + + @SuppressWarnings("unchecked") + private void testPreviewTransform(String queryJson) throws Exception { setupDataAccessRole(DATA_ACCESS_ROLE, REVIEWS_INDEX_NAME); final Request createPreviewRequest = createRequestWithAuth( "POST", @@ -1136,6 +1152,7 @@ public void testPreviewTransform() throws Exception { { "source": { "index": "%s" + %s }, "pivot": { "group_by": { @@ -1159,7 +1176,7 @@ public void testPreviewTransform() throws Exception { } } } - }""", REVIEWS_INDEX_NAME); + }""", REVIEWS_INDEX_NAME, queryJson); createPreviewRequest.setJsonEntity(config); diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/Pivot.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/Pivot.java index a6d4b06151a0f..3edc0b281fa41 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/Pivot.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/Pivot.java @@ -97,7 +97,15 @@ public void deduceMappings( listener.onResponse(emptyMap()); return; } - SchemaUtil.deduceMappings(client, headers, config, sourceConfig.getIndex(), sourceConfig.getRuntimeMappings(), listener); + SchemaUtil.deduceMappings( + client, + headers, + config, + sourceConfig.getIndex(), + sourceConfig.getQueryConfig().getQuery(), + sourceConfig.getRuntimeMappings(), + listener + ); } /** diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtil.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtil.java index 14259bffdb43d..5cacee644fe3c 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtil.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtil.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.xpack.core.ClientHelper; @@ -89,6 +90,7 @@ public static Object dropFloatingPointComponentIfTypeRequiresIt(String type, dou * @param client Client from which to make requests against the cluster * @param config The PivotConfig for which to deduce destination mapping * @param sourceIndex Source index that contains the data to pivot + * @param sourceQuery Source index query to apply * @param runtimeMappings Source runtime mappings * @param listener Listener to alert on success or failure. */ @@ -97,6 +99,7 @@ public static void deduceMappings( final Map headers, final PivotConfig config, final String[] sourceIndex, + final QueryBuilder sourceQuery, final Map runtimeMappings, final ActionListener> listener ) { @@ -145,6 +148,7 @@ public static void deduceMappings( client, headers, sourceIndex, + sourceQuery, allFieldNames.values().stream().filter(Objects::nonNull).toArray(String[]::new), runtimeMappings, ActionListener.wrap( @@ -248,6 +252,7 @@ static void getSourceFieldMappings( Client client, Map headers, String[] index, + QueryBuilder query, String[] fields, Map runtimeMappings, ActionListener> listener @@ -257,6 +262,7 @@ static void getSourceFieldMappings( return; } FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest().indices(index) + .indexFilter(query) .fields(fields) .runtimeFields(runtimeMappings) .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java index a2dda2a1603f1..9221dd36271f7 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/AggregationSchemaAndResultTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -147,7 +148,15 @@ public void testBasic() throws InterruptedException { .count(); this.>assertAsync( - listener -> SchemaUtil.deduceMappings(client, emptyMap(), pivotConfig, new String[] { "source-index" }, emptyMap(), listener), + listener -> SchemaUtil.deduceMappings( + client, + emptyMap(), + pivotConfig, + new String[] { "source-index" }, + QueryBuilders.matchAllQuery(), + emptyMap(), + listener + ), mappings -> { assertEquals("Mappings were: " + mappings, numGroupsWithoutScripts + 15, mappings.size()); assertEquals("long", mappings.get("max_rating")); @@ -219,7 +228,15 @@ public void testNested() throws InterruptedException { .count(); this.>assertAsync( - listener -> SchemaUtil.deduceMappings(client, emptyMap(), pivotConfig, new String[] { "source-index" }, emptyMap(), listener), + listener -> SchemaUtil.deduceMappings( + client, + emptyMap(), + pivotConfig, + new String[] { "source-index" }, + QueryBuilders.matchAllQuery(), + emptyMap(), + listener + ), mappings -> { assertEquals(numGroupsWithoutScripts + 12, mappings.size()); assertEquals("long", mappings.get("filter_1")); diff --git a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java index 778ca4bf7767d..881d578cb4536 100644 --- a/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java +++ b/x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/SchemaUtilTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.common.Strings; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; import org.elasticsearch.threadpool.ThreadPool; @@ -104,6 +105,7 @@ public void testGetSourceFieldMappings() throws InterruptedException { client, emptyMap(), new String[] { "index-1", "index-2" }, + QueryBuilders.matchAllQuery(), null, emptyMap(), listener @@ -120,6 +122,7 @@ public void testGetSourceFieldMappings() throws InterruptedException { client, emptyMap(), new String[] { "index-1", "index-2" }, + QueryBuilders.matchAllQuery(), new String[] {}, emptyMap(), listener @@ -136,6 +139,7 @@ public void testGetSourceFieldMappings() throws InterruptedException { client, emptyMap(), null, + QueryBuilders.matchAllQuery(), new String[] { "field-1", "field-2" }, emptyMap(), listener @@ -152,6 +156,7 @@ public void testGetSourceFieldMappings() throws InterruptedException { client, emptyMap(), new String[] {}, + QueryBuilders.matchAllQuery(), new String[] { "field-1", "field-2" }, emptyMap(), listener @@ -168,6 +173,7 @@ public void testGetSourceFieldMappings() throws InterruptedException { client, emptyMap(), new String[] { "index-1", "index-2" }, + QueryBuilders.matchAllQuery(), new String[] { "field-1", "field-2" }, emptyMap(), listener @@ -196,6 +202,7 @@ public void testGetSourceFieldMappingsWithRuntimeMappings() throws InterruptedEx client, emptyMap(), new String[] { "index-1", "index-2" }, + QueryBuilders.matchAllQuery(), new String[] { "field-1", "field-2" }, runtimeMappings, listener From d9865bcdb44c69ae194ca75a490579d7b317125b Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Tue, 21 Nov 2023 10:21:02 +0100 Subject: [PATCH 07/22] Update build cache setup to use Gradle Enterprise connector (#102180) * Update build cache setup to use Gradle Enterprise connector * Cleanup outdated artifactory setup * Remove duplicate configuration on CI * Fix bwc builds --- .ci/init.gradle | 83 +++++-------------- .../groovy/elasticsearch.build-scan.gradle | 6 -- .../gradle/internal/BwcSetupExtension.java | 5 ++ 3 files changed, 27 insertions(+), 67 deletions(-) diff --git a/.ci/init.gradle b/.ci/init.gradle index 4b2cbd1907ca0..f708fd411f5a6 100644 --- a/.ci/init.gradle +++ b/.ci/init.gradle @@ -10,8 +10,6 @@ initscript { } } -boolean USE_ARTIFACTORY = false - if (System.getenv('VAULT_ADDR') == null) { // When trying to reproduce errors outside of CI, it can be useful to allow this to just return rather than blowing up if (System.getenv('CI') == null) { @@ -50,75 +48,38 @@ final Vault vault = new Vault( .engineVersion(1) .token(vaultToken) .build() -) - .withRetries(5, 1000) +).withRetries(5, 1000) -if (USE_ARTIFACTORY) { - final Map artifactoryCredentials = vault.logical() - .read("${vaultPathPrefix}/artifactory.elstc.co") - .getData() - logger.info("Using elastic artifactory repos") - Closure configCache = { - return { - name "artifactory-gradle-release" - url "https://artifactory.elstc.co/artifactory/gradle-release" - credentials { - username artifactoryCredentials.get("username") - password artifactoryCredentials.get("token") - } - } - } - settingsEvaluated { settings -> - settings.pluginManagement { - repositories { - maven configCache() - } - } - } - projectsLoaded { - allprojects { - buildscript { - repositories { - maven configCache() - } - } - repositories { - maven configCache() - } - } - } -} - gradle.settingsEvaluated { settings -> settings.pluginManager.withPlugin("com.gradle.enterprise") { - settings.gradleEnterprise { - server = 'https://gradle-enterprise.elastic.co' - } + configureGradleEnterprise(settings) } } +void configureGradleEnterprise(def settings) { + settings.gradleEnterprise { + server = 'https://gradle-enterprise.elastic.co' + buildScan.publishAlways() + } -final String buildCacheUrl = System.getProperty('org.elasticsearch.build.cache.url') -final boolean buildCachePush = Boolean.valueOf(System.getProperty('org.elasticsearch.build.cache.push', 'false')) + def isCI = System.getenv("CI") == "true" + settings.buildCache { + local { + // Disable the local build cache in CI since we use ephemeral workers and it incurs an IO penalty + enabled = isCI == false + } + remote(settings.gradleEnterprise.buildCache) { + if (isCI) { + final boolean buildCachePush = Boolean.valueOf(System.getProperty('org.elasticsearch.build.cache.push', 'false')) + final Map buildCacheCredentials = System.getenv("GRADLE_BUILD_CACHE_USERNAME") ? [:] : vault.logical() + .read("${vaultPathPrefix}/gradle-build-cache") + .getData() + def username = System.getenv("GRADLE_BUILD_CACHE_USERNAME") ?: buildCacheCredentials.get("username") + def password = System.getenv("GRADLE_BUILD_CACHE_PASSWORD") ?: buildCacheCredentials.get("password") -if (buildCacheUrl) { - final Map buildCacheCredentials = System.getenv("GRADLE_BUILD_CACHE_USERNAME") ? [:] : vault.logical() - .read("${vaultPathPrefix}/gradle-build-cache") - .getData() - gradle.settingsEvaluated { settings -> - settings.buildCache { - local { - // Disable the local build cache in CI since we use ephemeral workers and it incurs an IO penalty - enabled = false - } - remote(HttpBuildCache) { - url = buildCacheUrl push = buildCachePush - credentials { - username = System.getenv("GRADLE_BUILD_CACHE_USERNAME") ?: buildCacheCredentials.get("username") - password = System.getenv("GRADLE_BUILD_CACHE_PASSWORD") ?: buildCacheCredentials.get("password") - } + usernameAndPassword(username, password) } } } diff --git a/build-tools-internal/src/main/groovy/elasticsearch.build-scan.gradle b/build-tools-internal/src/main/groovy/elasticsearch.build-scan.gradle index 2bb00faae38be..f1bd3017ced68 100644 --- a/build-tools-internal/src/main/groovy/elasticsearch.build-scan.gradle +++ b/build-tools-internal/src/main/groovy/elasticsearch.build-scan.gradle @@ -15,12 +15,6 @@ buildScan { URL jenkinsUrl = System.getenv('JENKINS_URL') ? new URL(System.getenv('JENKINS_URL')) : null String buildKiteUrl = System.getenv('BUILDKITE_BUILD_URL') ? System.getenv('BUILDKITE_BUILD_URL') : null - // Automatically publish scans from Elasticsearch CI - if (jenkinsUrl?.host?.endsWith('elastic.co') || jenkinsUrl?.host?.endsWith('elastic.dev') || System.getenv('BUILDKITE') == 'true') { - publishAlways() - buildScan.server = 'https://gradle-enterprise.elastic.co' - } - background { tag OS.current().name() tag Architecture.current().name() diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java index d71c893cdd20f..fcb9a3528483e 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcSetupExtension.java @@ -83,6 +83,11 @@ private TaskProvider createRunBwcGradleTask( return getJavaHome(Integer.parseInt(minimumCompilerVersion)); })); + // temporally workaround for reworked gradle enterprise setup + // removed when PR https://github.com/elastic/elasticsearch/pull/102180 backported + // to all BWC branches + loggedExec.getEnvironment().put("BUILDKITE", "false"); + if (BuildParams.isCi() && OS.current() != OS.WINDOWS) { // TODO: Disabled for now until we can figure out why files are getting corrupted // loggedExec.getEnvironment().put("GRADLE_RO_DEP_CACHE", System.getProperty("user.home") + "/gradle_ro_cache"); From 8b13e1d07fe599ad7ee512eaad3116f0c6868d47 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 09:46:24 +0000 Subject: [PATCH 08/22] Fix testSnapshotWithStuckNode (#102398) This test sometimes relies on repository cleanup to remove all but the `index.latest` and `index-N` blobs, but in fact repo cleanup will leave behind the `index-(N-1)` blob too. This commit relaxes the test to account for this, but then strengthens it to assert that the blobs left in the repo are exactly the ones we expect. Closes #101573 --- .../DedicatedClusterSnapshotRestoreIT.java | 27 +++++++++++++++---- .../AbstractSnapshotIntegTestCase.java | 7 ++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/DedicatedClusterSnapshotRestoreIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/DedicatedClusterSnapshotRestoreIT.java index 4b02e26815524..1d6ed5d5a177c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/DedicatedClusterSnapshotRestoreIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/DedicatedClusterSnapshotRestoreIT.java @@ -92,6 +92,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFutureThrows; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.empty; @@ -104,6 +105,7 @@ import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.startsWith; @ClusterScope(scope = Scope.TEST, numDataNodes = 0) public class DedicatedClusterSnapshotRestoreIT extends AbstractSnapshotIntegTestCase { @@ -199,13 +201,28 @@ public void testSnapshotWithStuckNode() throws Exception { () -> clusterAdmin().prepareGetSnapshots("test-repo").setSnapshots("test-snap").execute().actionGet() ); - logger.info("--> Go through a loop of creating and deleting a snapshot to trigger repository cleanup"); + logger.info("--> trigger repository cleanup"); clusterAdmin().prepareCleanupRepository("test-repo").get(); - // Expect two files to remain in the repository: - // (1) index-(N+1) - // (2) index-latest - assertFileCount(repo, 2); + // Expect two or three files to remain in the repository: + // (1) index-latest + // (2) index-(N+1) + // (3) index-N (maybe: a fully successful deletion removes this, but cleanup does not, see #100718) + + final var blobPaths = getAllFilesInDirectoryAndDescendants(repo); + final var blobPathsString = blobPaths.toString(); + assertTrue(blobPathsString, blobPaths.remove(repo.resolve(BlobStoreRepository.INDEX_LATEST_BLOB))); + assertThat(blobPathsString, blobPaths, anyOf(hasSize(1), hasSize(2))); + final var repoGenerations = blobPaths.stream().mapToLong(blobPath -> { + final var blobName = repo.relativize(blobPath).toString(); + assertThat(blobPathsString, blobName, startsWith(BlobStoreRepository.INDEX_FILE_PREFIX)); + return Long.parseLong(blobName.substring(BlobStoreRepository.INDEX_FILE_PREFIX.length())); + }).toArray(); + + if (repoGenerations.length == 2) { + assertEquals(blobPathsString, 1, Math.abs(repoGenerations[0] - repoGenerations[1])); + } + logger.info("--> done"); } diff --git a/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java index 3d1526943c055..6c1b24538cd9f 100644 --- a/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java @@ -196,9 +196,14 @@ public static long getFailureCount(String repository) { } public static void assertFileCount(Path dir, int expectedCount) throws IOException { + final List found = getAllFilesInDirectoryAndDescendants(dir); + assertEquals("Unexpected file count, found: [" + found + "].", expectedCount, found.size()); + } + + protected static List getAllFilesInDirectoryAndDescendants(Path dir) throws IOException { final List found = new ArrayList<>(); forEachFileRecursively(dir, ((path, basicFileAttributes) -> found.add(path))); - assertEquals("Unexpected file count, found: [" + found + "].", expectedCount, found.size()); + return found; } protected void stopNode(final String node) throws IOException { From 0d4091f78275fb1540ec4e714bdde17c0b5b0242 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 10:33:19 +0000 Subject: [PATCH 09/22] Simplify callers of IndexShard#openEngineAndRecoverFromTranslog (#102344) Reduces the size & complexity of the changes proposed in #96774; relates #96767, #96607 --- .../index/shard/StoreRecovery.java | 143 ++++++++++-------- .../index/shard/IndexShardTests.java | 6 +- 2 files changed, 85 insertions(+), 64 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java b/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java index ded3ffa4ebcc0..bc5a4b02116a7 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java +++ b/server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java @@ -31,12 +31,12 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.engine.Engine; -import org.elasticsearch.index.engine.EngineException; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.index.snapshots.IndexShardRestoreFailedException; @@ -48,6 +48,7 @@ import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -308,7 +309,7 @@ void recoverFromRepository(final IndexShard indexShard, Repository repository, A RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType(); assert recoveryType == RecoverySource.Type.SNAPSHOT : "expected snapshot recovery type: " + recoveryType; SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) indexShard.recoveryState().getRecoverySource(); - restore(indexShard, repository, recoverySource, recoveryListener(indexShard, listener)); + restore(indexShard, repository, recoverySource, recoveryListener(indexShard, listener).map(ignored -> true)); } else { listener.onResponse(false); } @@ -408,15 +409,19 @@ private ActionListener recoveryListener(IndexShard indexShard, ActionLi * Recovers the state of the shard from the store. */ private void internalRecoverFromStore(IndexShard indexShard, ActionListener outerListener) { - indexShard.preRecovery(outerListener.delegateFailureAndWrap((listener, ignored) -> { - final RecoveryState recoveryState = indexShard.recoveryState(); - final boolean indexShouldExists = recoveryState.getRecoverySource().getType() != RecoverySource.Type.EMPTY_STORE; - indexShard.prepareForIndexRecovery(); - SegmentInfos si = null; - final Store store = indexShard.store(); - store.incRef(); - boolean triggeredPostRecovery = false; - try { + final List releasables = new ArrayList<>(1); + SubscribableListener + + .newForked(indexShard::preRecovery) + + .andThen((l, ignored) -> { + final RecoveryState recoveryState = indexShard.recoveryState(); + final boolean indexShouldExists = recoveryState.getRecoverySource().getType() != RecoverySource.Type.EMPTY_STORE; + indexShard.prepareForIndexRecovery(); + SegmentInfos si = null; + final Store store = indexShard.store(); + store.incRef(); + releasables.add(store::decRef); try { store.failIfCorrupted(); try { @@ -480,16 +485,16 @@ private void internalRecoverFromStore(IndexShard indexShard, ActionListener { + if (e instanceof IndexShardRecoveryException) { + l.onFailure(e); + } else { + l.onFailure(new IndexShardRecoveryException(shardId, "failed to recover from gateway", e)); } - } - })); + }), () -> Releasables.close(releasables))); } private static void writeEmptyRetentionLeasesFile(IndexShard indexShard) throws IOException { @@ -513,31 +518,24 @@ private void restore( IndexShard indexShard, Repository repository, SnapshotRecoverySource restoreSource, - ActionListener outerListener + ActionListener outerListener ) { logger.debug("restoring from {} ...", indexShard.recoveryState().getRecoverySource()); - indexShard.preRecovery(outerListener.delegateFailure((listener, ignored) -> { - final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog(); - if (restoreSource == null) { - listener.onFailure(new IndexShardRestoreFailedException(shardId, "empty restore source")); - return; - } - if (logger.isTraceEnabled()) { - logger.trace("[{}] restoring shard [{}]", restoreSource.snapshot(), shardId); - } - final ActionListener restoreListener = ActionListener.wrap(v -> { - indexShard.getIndexEventListener().afterFilesRestoredFromRepository(indexShard); - final Store store = indexShard.store(); - bootstrap(indexShard, store); - assert indexShard.shardRouting.primary() : "only primary shards can recover from store"; - writeEmptyRetentionLeasesFile(indexShard); - indexShard.openEngineAndRecoverFromTranslog(); - indexShard.getEngine().fillSeqNoGaps(indexShard.getPendingPrimaryTerm()); - indexShard.finalizeRecovery(); - indexShard.postRecovery("restore done", listener.map(voidValue -> true)); - }, e -> listener.onFailure(new IndexShardRestoreFailedException(shardId, "restore failed", e))); - try { + record ShardAndIndexIds(IndexId indexId, ShardId shardId) {} + + SubscribableListener + + .newForked(indexShard::preRecovery) + + .andThen((shardAndIndexIdsListener, ignored) -> { + final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog(); + if (restoreSource == null) { + throw new IndexShardRestoreFailedException(shardId, "empty restore source"); + } + if (logger.isTraceEnabled()) { + logger.trace("[{}] restoring shard [{}]", restoreSource.snapshot(), shardId); + } translogState.totalOperations(0); translogState.totalOperationsOnStart(0); indexShard.prepareForIndexRecovery(); @@ -548,37 +546,56 @@ private void restore( } else { snapshotShardId = new ShardId(indexId.getName(), IndexMetadata.INDEX_UUID_NA_VALUE, shardId.id()); } - final ListenableFuture indexIdListener = new ListenableFuture<>(); - // If the index UUID was not found in the recovery source we will have to load RepositoryData and resolve it by index name if (indexId.getId().equals(IndexMetadata.INDEX_UUID_NA_VALUE)) { - // BwC path, running against an old version master that did not add the IndexId to the recovery source + // BwC path, running against an old version master that did not add the IndexId to the recovery source. If the index + // UUID was not found in the recovery source we will have to load RepositoryData and resolve it by index name repository.getRepositoryData( // TODO no need to fork back to GENERIC if using cached repo data, see #101445 EsExecutors.DIRECT_EXECUTOR_SERVICE, new ThreadedActionListener<>( indexShard.getThreadPool().generic(), - indexIdListener.map(repositoryData -> repositoryData.resolveIndexId(indexId.getName())) + shardAndIndexIdsListener.map( + repositoryData -> new ShardAndIndexIds(repositoryData.resolveIndexId(indexId.getName()), snapshotShardId) + ) ) ); } else { - indexIdListener.onResponse(indexId); + shardAndIndexIdsListener.onResponse(new ShardAndIndexIds(indexId, snapshotShardId)); } + }) + + .andThen((restoreListener, shardAndIndexId) -> { assert indexShard.getEngineOrNull() == null; - indexIdListener.addListener(restoreListener.delegateFailureAndWrap((l, idx) -> { - assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC, ThreadPool.Names.SNAPSHOT); - repository.restoreShard( - indexShard.store(), - restoreSource.snapshot().getSnapshotId(), - idx, - snapshotShardId, - indexShard.recoveryState(), - l - ); - })); - } catch (Exception e) { - restoreListener.onFailure(e); - } - })); + assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC, ThreadPool.Names.SNAPSHOT); + repository.restoreShard( + indexShard.store(), + restoreSource.snapshot().getSnapshotId(), + shardAndIndexId.indexId(), + shardAndIndexId.shardId(), + indexShard.recoveryState(), + restoreListener + ); + }) + + .andThen((l, ignored) -> { + indexShard.getIndexEventListener().afterFilesRestoredFromRepository(indexShard); + final Store store = indexShard.store(); + bootstrap(indexShard, store); + assert indexShard.shardRouting.primary() : "only primary shards can recover from store"; + writeEmptyRetentionLeasesFile(indexShard); + indexShard.openEngineAndRecoverFromTranslog(); + indexShard.getEngine().fillSeqNoGaps(indexShard.getPendingPrimaryTerm()); + indexShard.finalizeRecovery(); + indexShard.postRecovery("restore done", l); + }) + + .addListener(outerListener.delegateResponse((l, e) -> { + if (e instanceof IndexShardRestoreFailedException) { + l.onFailure(e); + } else { + l.onFailure(new IndexShardRestoreFailedException(shardId, "restore failed", e)); + } + })); } public static void bootstrap(final IndexShard indexShard, final Store store) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index 3d1337af4edfb..69cf587793e55 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -3536,7 +3536,11 @@ public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IO IndexShardRecoveryException.class, () -> newStartedShard(p -> corruptedShard, true) ); - assertThat(indexShardRecoveryException.getMessage(), equalTo("failed recovery")); + assertThat(indexShardRecoveryException.getMessage(), equalTo("failed to recover from gateway")); + assertThat( + asInstanceOf(RecoveryFailedException.class, indexShardRecoveryException.getCause()).getMessage(), + containsString("Recovery failed") + ); appender.assertAllExpectationsMatched(); } finally { From 269439e3460e46df9254929b76c5201e42ad3450 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 10:36:43 +0000 Subject: [PATCH 10/22] AwaitsFix for #102405 --- .../xpack/autoscaling/existence/FrozenExistenceDeciderIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/existence/FrozenExistenceDeciderIT.java b/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/existence/FrozenExistenceDeciderIT.java index 60c6732b45000..497734fd5ac28 100644 --- a/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/existence/FrozenExistenceDeciderIT.java +++ b/x-pack/plugin/autoscaling/src/internalClusterTest/java/org/elasticsearch/xpack/autoscaling/existence/FrozenExistenceDeciderIT.java @@ -82,6 +82,7 @@ protected Collection> nodePlugins() { ); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102405") public void testZeroToOne() throws Exception { internalCluster().startMasterOnlyNode(); setupRepoAndPolicy(); From 4f4974fae82776cb5bd4d870a1fddffbd97c4d3e Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Tue, 21 Nov 2023 12:08:34 +0100 Subject: [PATCH 11/22] Ref count fetch phase results (#102324) Making the fetch result ref counted as the second to last step to making `SearchHit` ref counted. * Fix bug in existing ref counting * Make search results collector ref counted * Misc. small adjustments to counting * Closing/decrementing a bunch of things in tests --- .../search/AbstractSearchAsyncAction.java | 9 +- .../search/ArraySearchPhaseResults.java | 32 + .../search/CanMatchPreFilterSearchPhase.java | 20 + .../action/search/CountedCollector.java | 7 +- .../action/search/DfsQueryPhase.java | 2 +- .../action/search/FetchSearchPhase.java | 1 + .../search/QueryPhaseResultConsumer.java | 10 +- .../SearchDfsQueryThenFetchAsyncAction.java | 3 +- .../action/search/SearchPhaseResults.java | 3 +- .../SearchQueryThenFetchAsyncAction.java | 3 - ...SearchScrollQueryThenFetchAsyncAction.java | 18 +- .../search/DefaultSearchContext.java | 1 + .../elasticsearch/search/SearchService.java | 6 +- .../search/fetch/FetchSearchResult.java | 25 + .../search/fetch/QueryFetchSearchResult.java | 12 +- .../search/internal/SubSearchContext.java | 1 + .../AbstractSearchAsyncActionTests.java | 68 +- .../action/search/CountedCollectorTests.java | 122 +- .../action/search/DfsQueryPhaseTests.java | 122 +- .../action/search/FetchSearchPhaseTests.java | 999 +++++++------- .../action/search/MockSearchPhaseContext.java | 12 +- .../search/QueryPhaseResultConsumerTests.java | 38 +- .../action/search/SearchAsyncActionTests.java | 415 +++--- .../search/SearchPhaseControllerTests.java | 1193 +++++++++-------- .../SearchQueryThenFetchAsyncActionTests.java | 102 +- .../SearchProfileResultsBuilderTests.java | 57 +- 26 files changed, 1785 insertions(+), 1496 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index b56cb0ca5926c..426553769e8a1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -160,13 +160,16 @@ abstract class AbstractSearchAsyncAction exten this.executor = executor; this.request = request; this.task = task; - this.listener = ActionListener.runAfter(listener, this::releaseContext); + this.listener = ActionListener.runAfter(listener, () -> Releasables.close(releasables)); this.nodeIdToConnection = nodeIdToConnection; this.concreteIndexBoosts = concreteIndexBoosts; this.clusterStateVersion = clusterState.version(); this.minTransportVersion = clusterState.getMinTransportVersion(); this.aliasFilter = aliasFilter; this.results = resultConsumer; + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + addReleasable(resultConsumer::decRef); this.clusters = clusters; } @@ -189,10 +192,6 @@ public void addReleasable(Releasable releasable) { releasables.add(releasable); } - public void releaseContext() { - Releasables.close(releasables); - } - /** * Builds how long it took to execute the search. */ diff --git a/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java b/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java index 9f61042320f3e..b4fd0107f731f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java +++ b/server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java @@ -9,7 +9,10 @@ package org.elasticsearch.action.search; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.AbstractRefCounted; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.transport.LeakTracker; import java.util.stream.Stream; @@ -19,6 +22,8 @@ class ArraySearchPhaseResults extends SearchPhaseResults { final AtomicArray results; + private final RefCounted refCounted = LeakTracker.wrap(AbstractRefCounted.of(this::doClose)); + ArraySearchPhaseResults(int size) { super(size); this.results = new AtomicArray<>(size); @@ -32,9 +37,16 @@ Stream getSuccessfulResults() { void consumeResult(Result result, Runnable next) { assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set"; results.set(result.getShardIndex(), result); + result.incRef(); next.run(); } + protected void doClose() { + for (Result result : getAtomicArray().asList()) { + result.decRef(); + } + } + boolean hasResult(int shardIndex) { return results.get(shardIndex) != null; } @@ -43,4 +55,24 @@ boolean hasResult(int shardIndex) { AtomicArray getAtomicArray() { return results; } + + @Override + public void incRef() { + refCounted.incRef(); + } + + @Override + public boolean tryIncRef() { + return refCounted.tryIncRef(); + } + + @Override + public boolean decRef() { + return refCounted.decRef(); + } + + @Override + public boolean hasReferences() { + return refCounted.hasReferences(); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index 6e553f254ee8b..db17136ae29af 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -480,6 +480,26 @@ synchronized FixedBitSet getPossibleMatches() { Stream getSuccessfulResults() { return Stream.empty(); } + + @Override + public void incRef() { + + } + + @Override + public boolean tryIncRef() { + return false; + } + + @Override + public boolean decRef() { + return false; + } + + @Override + public boolean hasReferences() { + return false; + } } private GroupShardsIterator getIterator( diff --git a/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java b/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java index 34b33770efd55..d5605b280f385 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java +++ b/server/src/main/java/org/elasticsearch/action/search/CountedCollector.java @@ -25,6 +25,7 @@ final class CountedCollector { CountedCollector(ArraySearchPhaseResults resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) { this.resultConsumer = resultConsumer; + resultConsumer.incRef(); this.counter = new CountDown(expectedOps); this.onFinish = onFinish; this.context = context; @@ -37,7 +38,11 @@ final class CountedCollector { void countDown() { assert counter.isCountedDown() == false : "more operations executed than specified"; if (counter.countDown()) { - onFinish.run(); + try { + onFinish.run(); + } finally { + resultConsumer.decRef(); + } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 128281ead4046..ce2c86be4b4e6 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -66,7 +66,7 @@ final class DfsQueryPhase extends SearchPhase { // register the release of the query consumer to free up the circuit breaker memory // at the end of the search - context.addReleasable(queryResult); + context.addReleasable(queryResult::decRef); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java index afffde13cf641..e8d3ded154f55 100644 --- a/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java @@ -67,6 +67,7 @@ final class FetchSearchPhase extends SearchPhase { ); } this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards()); + context.addReleasable(fetchResults::decRef); this.queryResults = resultConsumer.getAtomicArray(); this.aggregatedDfs = aggregatedDfs; this.nextPhaseFactory = nextPhaseFactory; diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 73061298d8f7e..fdd543fc8758f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -51,7 +51,7 @@ * needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it * exceeds the maximum memory allowed in this breaker. */ -public class QueryPhaseResultConsumer extends ArraySearchPhaseResults implements Releasable { +public class QueryPhaseResultConsumer extends ArraySearchPhaseResults { private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class); private final Executor executor; @@ -104,8 +104,12 @@ public QueryPhaseResultConsumer( } @Override - public void close() { - pendingMerges.close(); + protected void doClose() { + try { + super.doClose(); + } finally { + pendingMerges.close(); + } } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index c7ad250892160..2fcb792f821c9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -64,6 +64,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction clusters ); this.queryPhaseResultConsumer = queryPhaseResultConsumer; + addReleasable(queryPhaseResultConsumer::decRef); this.progressListener = task.getProgressListener(); // don't build the SearchShard list (can be expensive) if the SearchProgressListener won't use it if (progressListener != SearchProgressListener.NOOP) { @@ -90,7 +91,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults res final List dfsSearchResults = results.getAtomicArray().asList(); final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults); final List mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults); - + queryPhaseResultConsumer.incRef(); return new DfsQueryPhase( dfsSearchResults, aggregatedDfs, diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java index edabbc86b4b31..11b8e0a0792a3 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseResults.java @@ -9,6 +9,7 @@ package org.elasticsearch.action.search; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.RefCounted; import org.elasticsearch.search.SearchPhaseResult; import java.util.stream.Stream; @@ -16,7 +17,7 @@ /** * This class acts as a basic result collection that can be extended to do on-the-fly reduction or result processing */ -abstract class SearchPhaseResults { +abstract class SearchPhaseResults implements RefCounted { private final int numShards; SearchPhaseResults(int numShards) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 2dfd46182266c..8cf4ee9b75f76 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -74,9 +74,6 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction hits = null)); + public FetchSearchResult() {} public FetchSearchResult(ShardSearchContextId id, SearchShardTarget shardTarget) { @@ -90,4 +95,24 @@ public int counterGetAndIncrement() { public ProfileResult profileResult() { return profileResult; } + + @Override + public void incRef() { + refCounted.incRef(); + } + + @Override + public boolean tryIncRef() { + return refCounted.tryIncRef(); + } + + @Override + public boolean decRef() { + return refCounted.decRef(); + } + + @Override + public boolean hasReferences() { + return refCounted.hasReferences(); + } } diff --git a/server/src/main/java/org/elasticsearch/search/fetch/QueryFetchSearchResult.java b/server/src/main/java/org/elasticsearch/search/fetch/QueryFetchSearchResult.java index 193f8c04664bf..bb838c29ff54c 100644 --- a/server/src/main/java/org/elasticsearch/search/fetch/QueryFetchSearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/fetch/QueryFetchSearchResult.java @@ -26,17 +26,21 @@ public final class QueryFetchSearchResult extends SearchPhaseResult { private final FetchSearchResult fetchResult; private final RefCounted refCounted; + public static QueryFetchSearchResult of(QuerySearchResult queryResult, FetchSearchResult fetchResult) { + // We're acquiring a copy, we should incRef it + queryResult.incRef(); + fetchResult.incRef(); + return new QueryFetchSearchResult(queryResult, fetchResult); + } + public QueryFetchSearchResult(StreamInput in) throws IOException { // These get a ref count of 1 when we create them, so we don't need to incRef here this(new QuerySearchResult(in), new FetchSearchResult(in)); } - public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) { + private QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) { this.queryResult = queryResult; this.fetchResult = fetchResult; - // We're acquiring a copy, we should incRef it - this.queryResult.incRef(); - this.fetchResult.incRef(); refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> { queryResult.decRef(); fetchResult.decRef(); diff --git a/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java index 78a218bb3cd1b..8567677aca30a 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java @@ -58,6 +58,7 @@ public SubSearchContext(SearchContext context) { super(context); context.addReleasable(this); this.fetchSearchResult = new FetchSearchResult(); + addReleasable(fetchSearchResult::decRef); this.querySearchResult = new QuerySearchResult(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index 5cbb7ab228c0c..7f5b5f7716f3e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -134,47 +134,45 @@ public void testTookWithRealClock() { private void runTestTook(final boolean controlled) { final AtomicLong expected = new AtomicLong(); - AbstractSearchAsyncAction action = createAction( - new SearchRequest(), - new ArraySearchPhaseResults<>(10), - null, - controlled, - expected - ); - final long actual = action.buildTookInMillis(); - if (controlled) { - // with a controlled clock, we can assert the exact took time - assertThat(actual, equalTo(TimeUnit.NANOSECONDS.toMillis(expected.get()))); - } else { - // with a real clock, the best we can say is that it took as long as we spun for - assertThat(actual, greaterThanOrEqualTo(TimeUnit.NANOSECONDS.toMillis(expected.get()))); + var result = new ArraySearchPhaseResults<>(10); + try { + AbstractSearchAsyncAction action = createAction(new SearchRequest(), result, null, controlled, expected); + final long actual = action.buildTookInMillis(); + if (controlled) { + // with a controlled clock, we can assert the exact took time + assertThat(actual, equalTo(TimeUnit.NANOSECONDS.toMillis(expected.get()))); + } else { + // with a real clock, the best we can say is that it took as long as we spun for + assertThat(actual, greaterThanOrEqualTo(TimeUnit.NANOSECONDS.toMillis(expected.get()))); + } + } finally { + result.decRef(); } } public void testBuildShardSearchTransportRequest() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(randomBoolean()); final AtomicLong expected = new AtomicLong(); - AbstractSearchAsyncAction action = createAction( - searchRequest, - new ArraySearchPhaseResults<>(10), - null, - false, - expected - ); - String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); - SearchShardIterator iterator = new SearchShardIterator( - clusterAlias, - new ShardId(new Index("name", "foo"), 1), - Collections.emptyList(), - new OriginalIndices(new String[] { "name", "name1" }, IndicesOptions.strictExpand()) - ); - ShardSearchRequest shardSearchTransportRequest = action.buildShardSearchRequest(iterator, 10); - assertEquals(IndicesOptions.strictExpand(), shardSearchTransportRequest.indicesOptions()); - assertArrayEquals(new String[] { "name", "name1" }, shardSearchTransportRequest.indices()); - assertEquals(new MatchAllQueryBuilder(), shardSearchTransportRequest.getAliasFilter().getQueryBuilder()); - assertEquals(2.0f, shardSearchTransportRequest.indexBoost(), 0.0f); - assertArrayEquals(new String[] { "name", "name1" }, shardSearchTransportRequest.indices()); - assertEquals(clusterAlias, shardSearchTransportRequest.getClusterAlias()); + var result = new ArraySearchPhaseResults<>(10); + try { + AbstractSearchAsyncAction action = createAction(searchRequest, result, null, false, expected); + String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + SearchShardIterator iterator = new SearchShardIterator( + clusterAlias, + new ShardId(new Index("name", "foo"), 1), + Collections.emptyList(), + new OriginalIndices(new String[] { "name", "name1" }, IndicesOptions.strictExpand()) + ); + ShardSearchRequest shardSearchTransportRequest = action.buildShardSearchRequest(iterator, 10); + assertEquals(IndicesOptions.strictExpand(), shardSearchTransportRequest.indicesOptions()); + assertArrayEquals(new String[] { "name", "name1" }, shardSearchTransportRequest.indices()); + assertEquals(new MatchAllQueryBuilder(), shardSearchTransportRequest.getAliasFilter().getQueryBuilder()); + assertEquals(2.0f, shardSearchTransportRequest.indexBoost(), 0.0f); + assertArrayEquals(new String[] { "name", "name1" }, shardSearchTransportRequest.indices()); + assertEquals(clusterAlias, shardSearchTransportRequest.getClusterAlias()); + } finally { + result.decRef(); + } } public void testSendSearchResponseDisallowPartialFailures() { diff --git a/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java b/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java index 659d4de8552c3..838e13d6026c7 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CountedCollectorTests.java @@ -24,73 +24,77 @@ public class CountedCollectorTests extends ESTestCase { public void testCollect() throws InterruptedException { ArraySearchPhaseResults consumer = new ArraySearchPhaseResults<>(randomIntBetween(1, 100)); - List state = new ArrayList<>(); - int numResultsExpected = randomIntBetween(1, consumer.getAtomicArray().length()); - MockSearchPhaseContext context = new MockSearchPhaseContext(consumer.getAtomicArray().length()); - CountDownLatch latch = new CountDownLatch(1); - boolean maybeFork = randomBoolean(); - Executor executor = (runnable) -> { - if (randomBoolean() && maybeFork) { - new Thread(runnable).start(); + try { + List state = new ArrayList<>(); + int numResultsExpected = randomIntBetween(1, consumer.getAtomicArray().length()); + MockSearchPhaseContext context = new MockSearchPhaseContext(consumer.getAtomicArray().length()); + CountDownLatch latch = new CountDownLatch(1); + boolean maybeFork = randomBoolean(); + Executor executor = (runnable) -> { + if (randomBoolean() && maybeFork) { + new Thread(runnable).start(); - } else { - runnable.run(); - } - }; - CountedCollector collector = new CountedCollector<>(consumer, numResultsExpected, latch::countDown, context); - for (int i = 0; i < numResultsExpected; i++) { - int shardID = i; - switch (randomIntBetween(0, 2)) { - case 0 -> { - state.add(0); - executor.execute(() -> collector.countDown()); + } else { + runnable.run(); } - case 1 -> { - state.add(1); - executor.execute(() -> { - DfsSearchResult dfsSearchResult = new DfsSearchResult( - new ShardSearchContextId(UUIDs.randomBase64UUID(), shardID), - null, - null + }; + CountedCollector collector = new CountedCollector<>(consumer, numResultsExpected, latch::countDown, context); + for (int i = 0; i < numResultsExpected; i++) { + int shardID = i; + switch (randomIntBetween(0, 2)) { + case 0 -> { + state.add(0); + executor.execute(() -> collector.countDown()); + } + case 1 -> { + state.add(1); + executor.execute(() -> { + DfsSearchResult dfsSearchResult = new DfsSearchResult( + new ShardSearchContextId(UUIDs.randomBase64UUID(), shardID), + null, + null + ); + dfsSearchResult.setShardIndex(shardID); + dfsSearchResult.setSearchShardTarget(new SearchShardTarget("foo", new ShardId("bar", "baz", shardID), null)); + collector.onResult(dfsSearchResult); + }); + } + case 2 -> { + state.add(2); + executor.execute( + () -> collector.onFailure( + shardID, + new SearchShardTarget("foo", new ShardId("bar", "baz", shardID), null), + new RuntimeException("boom") + ) ); - dfsSearchResult.setShardIndex(shardID); - dfsSearchResult.setSearchShardTarget(new SearchShardTarget("foo", new ShardId("bar", "baz", shardID), null)); - collector.onResult(dfsSearchResult); - }); - } - case 2 -> { - state.add(2); - executor.execute( - () -> collector.onFailure( - shardID, - new SearchShardTarget("foo", new ShardId("bar", "baz", shardID), null), - new RuntimeException("boom") - ) - ); + } + default -> fail("unknown state"); } - default -> fail("unknown state"); } - } - latch.await(); - assertEquals(numResultsExpected, state.size()); - AtomicArray results = consumer.getAtomicArray(); - for (int i = 0; i < numResultsExpected; i++) { - switch (state.get(i)) { - case 0 -> assertNull(results.get(i)); - case 1 -> { - assertNotNull(results.get(i)); - assertEquals(i, results.get(i).getContextId().getId()); + latch.await(); + assertEquals(numResultsExpected, state.size()); + AtomicArray results = consumer.getAtomicArray(); + for (int i = 0; i < numResultsExpected; i++) { + switch (state.get(i)) { + case 0 -> assertNull(results.get(i)); + case 1 -> { + assertNotNull(results.get(i)); + assertEquals(i, results.get(i).getContextId().getId()); + } + case 2 -> { + final int shardId = i; + assertEquals(1, context.failures.stream().filter(f -> f.shardId() == shardId).count()); + } + default -> fail("unknown state"); } - case 2 -> { - final int shardId = i; - assertEquals(1, context.failures.stream().filter(f -> f.shardId() == shardId).count()); - } - default -> fail("unknown state"); } - } - for (int i = numResultsExpected; i < results.length(); i++) { - assertNull("index: " + i, results.get(i)); + for (int i = numResultsExpected; i < results.length(); i++) { + assertNull("index: " + i, results.get(i)); + } + } finally { + consumer.decRef(); } } } diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index b896ae3d3f025..21c1e9b0470b5 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -44,6 +44,10 @@ import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; + public class DfsQueryPhaseTests extends ESTestCase { private static DfsSearchResult newSearchResult(int shardIndex, ShardSearchContextId contextId, SearchShardTarget target) { @@ -130,26 +134,30 @@ public void sendExecuteQuery( results.length(), exc -> {} ); - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(response.results); - } - }, mockSearchPhaseContext); - assertEquals("dfs_query", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - assertNotNull(responseRef.get()); - assertNotNull(responseRef.get().get(0)); - assertNull(responseRef.get().get(0).fetchResult()); - assertEquals(1, responseRef.get().get(0).queryResult().topDocs().topDocs.totalHits.value); - assertEquals(42, responseRef.get().get(0).queryResult().topDocs().topDocs.scoreDocs[0].doc); - assertNotNull(responseRef.get().get(1)); - assertNull(responseRef.get().get(1).fetchResult()); - assertEquals(1, responseRef.get().get(1).queryResult().topDocs().topDocs.totalHits.value); - assertEquals(84, responseRef.get().get(1).queryResult().topDocs().topDocs.scoreDocs[0].doc); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); - assertEquals(2, mockSearchPhaseContext.numSuccess.get()); + try { + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { + @Override + public void run() throws IOException { + responseRef.set(response.results); + } + }, mockSearchPhaseContext); + assertEquals("dfs_query", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertNotNull(responseRef.get()); + assertNotNull(responseRef.get().get(0)); + assertNull(responseRef.get().get(0).fetchResult()); + assertEquals(1, responseRef.get().get(0).queryResult().topDocs().topDocs.totalHits.value); + assertEquals(42, responseRef.get().get(0).queryResult().topDocs().topDocs.scoreDocs[0].doc); + assertNotNull(responseRef.get().get(1)); + assertNull(responseRef.get().get(1).fetchResult()); + assertEquals(1, responseRef.get().get(1).queryResult().topDocs().topDocs.totalHits.value); + assertEquals(84, responseRef.get().get(1).queryResult().topDocs().topDocs.scoreDocs[0].doc); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + assertEquals(2, mockSearchPhaseContext.numSuccess.get()); + } finally { + consumer.decRef(); + } } public void testDfsWith1ShardFailed() throws IOException { @@ -212,28 +220,32 @@ public void sendExecuteQuery( results.length(), exc -> {} ); - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(response.results); - } - }, mockSearchPhaseContext); - assertEquals("dfs_query", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - assertNotNull(responseRef.get()); - assertNotNull(responseRef.get().get(0)); - assertNull(responseRef.get().get(0).fetchResult()); - assertEquals(1, responseRef.get().get(0).queryResult().topDocs().topDocs.totalHits.value); - assertEquals(42, responseRef.get().get(0).queryResult().topDocs().topDocs.scoreDocs[0].doc); - assertNull(responseRef.get().get(1)); + try { + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { + @Override + public void run() throws IOException { + responseRef.set(response.results); + } + }, mockSearchPhaseContext); + assertEquals("dfs_query", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + assertNotNull(responseRef.get()); + assertNotNull(responseRef.get().get(0)); + assertNull(responseRef.get().get(0).fetchResult()); + assertEquals(1, responseRef.get().get(0).queryResult().topDocs().topDocs.totalHits.value); + assertEquals(42, responseRef.get().get(0).queryResult().topDocs().topDocs.scoreDocs[0].doc); + assertNull(responseRef.get().get(1)); - assertEquals(1, mockSearchPhaseContext.numSuccess.get()); - assertEquals(1, mockSearchPhaseContext.failures.size()); - assertTrue(mockSearchPhaseContext.failures.get(0).getCause() instanceof MockDirectoryWrapper.FakeIOException); - assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(new ShardSearchContextId("", 2L))); - assertNull(responseRef.get().get(1)); + assertEquals(1, mockSearchPhaseContext.numSuccess.get()); + assertEquals(1, mockSearchPhaseContext.failures.size()); + assertTrue(mockSearchPhaseContext.failures.get(0).getCause() instanceof MockDirectoryWrapper.FakeIOException); + assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(new ShardSearchContextId("", 2L))); + assertNull(responseRef.get().get(1)); + } finally { + consumer.decRef(); + } } public void testFailPhaseOnException() throws IOException { @@ -278,7 +290,7 @@ public void sendExecuteQuery( queryResult.decRef(); } } else if (request.contextId().getId() == 2) { - throw new UncheckedIOException(new MockDirectoryWrapper.FakeIOException()); + listener.onFailure(new UncheckedIOException(new MockDirectoryWrapper.FakeIOException())); } else { fail("no such request ID: " + request.contextId()); } @@ -296,15 +308,21 @@ public void sendExecuteQuery( results.length(), exc -> {} ); - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { - @Override - public void run() throws IOException { - responseRef.set(response.results); - } - }, mockSearchPhaseContext); - assertEquals("dfs_query", phase.getName()); - expectThrows(UncheckedIOException.class, phase::run); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts + try { + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, null, consumer, (response) -> new SearchPhase("test") { + @Override + public void run() throws IOException { + responseRef.set(response.results); + } + }, mockSearchPhaseContext); + assertEquals("dfs_query", phase.getName()); + phase.run(); + assertThat(mockSearchPhaseContext.failures, hasSize(1)); + assertThat(mockSearchPhaseContext.failures.get(0).getCause(), instanceOf(UncheckedIOException.class)); + assertThat(mockSearchPhaseContext.releasedSearchContexts, hasSize(1)); // phase execution will clean up on the contexts + } finally { + consumer.decRef(); + } } public void testRewriteShardSearchRequestWithRank() { @@ -317,7 +335,7 @@ public void testRewriteShardSearchRequestWithRank() { ); MockSearchPhaseContext mspc = new MockSearchPhaseContext(2); mspc.searchTransport = new SearchTransportService(null, null, null); - DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, null, null, mspc); + DfsQueryPhase dqp = new DfsQueryPhase(null, null, dkrs, mock(QueryPhaseResultConsumer.class), null, mspc); QueryBuilder bm25 = new TermQueryBuilder("field", "term"); SearchSourceBuilder ssb = new SearchSourceBuilder().query(bm25) diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 3fa5c6fc4283a..3d66c4bc2793f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -38,7 +38,9 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.nullValue; public class FetchSearchPhaseTests extends ESTestCase { @@ -56,59 +58,71 @@ public void testShortcutQueryAndFetchOptimization() { 1, exc -> {} ); - boolean hasHits = randomBoolean(); - boolean profiled = hasHits && randomBoolean(); - final int numHits; - if (hasHits) { - QuerySearchResult queryResult = new QuerySearchResult(); - queryResult.setSearchShardTarget(new SearchShardTarget("node0", new ShardId("index", "index", 0), null)); - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), - 1.0F - ), - new DocValueFormat[0] - ); - addProfiling(profiled, queryResult); - queryResult.size(1); - FetchSearchResult fetchResult = new FetchSearchResult(); - fetchResult.setSearchShardTarget(queryResult.getSearchShardTarget()); - SearchHits hits = new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - fetchResult.shardResult(hits, fetchProfile(profiled)); - QueryFetchSearchResult fetchSearchResult = new QueryFetchSearchResult(queryResult, fetchResult); - try { - fetchSearchResult.setShardIndex(0); - results.consumeResult(fetchSearchResult, () -> {}); - } finally { - fetchSearchResult.decRef(); + try { + boolean hasHits = randomBoolean(); + boolean profiled = hasHits && randomBoolean(); + final int numHits; + if (hasHits) { + QuerySearchResult queryResult = new QuerySearchResult(); + queryResult.setSearchShardTarget(new SearchShardTarget("node0", new ShardId("index", "index", 0), null)); + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), + 1.0F + ), + new DocValueFormat[0] + ); + addProfiling(profiled, queryResult); + queryResult.size(1); + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + fetchResult.setSearchShardTarget(queryResult.getSearchShardTarget()); + SearchHits hits = new SearchHits( + new SearchHit[] { new SearchHit(42) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + QueryFetchSearchResult fetchSearchResult = QueryFetchSearchResult.of(queryResult, fetchResult); + try { + fetchSearchResult.setShardIndex(0); + results.consumeResult(fetchSearchResult, () -> {}); + } finally { + fetchSearchResult.decRef(); + } + numHits = 1; + } finally { + fetchResult.decRef(); + } + } else { + numHits = 0; } - numHits = 1; - } else { - numHits = 0; - } - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { - @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + } } + ); + assertEquals("fetch", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(numHits, searchResponse.getHits().getTotalHits().value); + if (numHits != 0) { + assertEquals(42, searchResponse.getHits().getAt(0).docId()); } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); - assertNotNull(searchResponse); - assertEquals(numHits, searchResponse.getHits().getTotalHits().value); - if (numHits != 0) { - assertEquals(42, searchResponse.getHits().getAt(0).docId()); + assertProfiles(profiled, 1, searchResponse); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + } finally { + results.decRef(); } - assertProfiles(profiled, 1, searchResponse); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); } private void assertProfiles(boolean profiled, int totalShards, SearchResponse searchResponse) { @@ -134,93 +148,109 @@ public void testFetchTwoDocument() { 2, exc -> {} ); - int resultSetSize = randomIntBetween(2, 10); - boolean profiled = randomBoolean(); - - ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); - SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); - SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); - QuerySearchResult queryResult = new QuerySearchResult(ctx1, shard1Target, null); try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); // the size of the result set - queryResult.setShardIndex(0); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); + int resultSetSize = randomIntBetween(2, 10); + boolean profiled = randomBoolean(); - } finally { - queryResult.decRef(); - } + ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + QuerySearchResult queryResult = new QuerySearchResult(ctx1, shard1Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); // the size of the result set + queryResult.setShardIndex(0); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); - final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); - try { - queryResult = new QuerySearchResult(ctx2, shard2Target, null); - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); - queryResult.setShardIndex(1); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } + } finally { + queryResult.decRef(); + } - mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { - @Override - public void sendExecuteFetch( - Transport.Connection connection, - ShardFetchSearchRequest request, - SearchTask task, - SearchActionListener listener - ) { - FetchSearchResult fetchResult = new FetchSearchResult(); - SearchHits hits; - if (request.contextId().equals(ctx2)) { - fetchResult.setSearchShardTarget(shard2Target); - hits = new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F); - } else { - assertEquals(ctx1, request.contextId()); - fetchResult.setSearchShardTarget(shard1Target); - hits = new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - } - fetchResult.shardResult(hits, fetchProfile(profiled)); - listener.onResponse(fetchResult); + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); + try { + queryResult = new QuerySearchResult(ctx2, shard2Target, null); + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); + queryResult.setShardIndex(1); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); } - }; - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { + + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + SearchTask task, + SearchActionListener listener + ) { + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + SearchHits hits; + if (request.contextId().equals(ctx2)) { + fetchResult.setSearchShardTarget(shard2Target); + hits = new SearchHits( + new SearchHit[] { new SearchHit(84) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 2.0F + ); + } else { + assertEquals(ctx1, request.contextId()); + fetchResult.setSearchShardTarget(shard1Target); + hits = new SearchHits( + new SearchHit[] { new SearchHit(42) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + } + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } } - } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); - assertNotNull(searchResponse); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - assertEquals(84, searchResponse.getHits().getAt(0).docId()); - assertEquals(42, searchResponse.getHits().getAt(1).docId()); - assertEquals(0, searchResponse.getFailedShards()); - assertEquals(2, searchResponse.getSuccessfulShards()); - assertProfiles(profiled, 2, searchResponse); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + }; + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + } + } + ); + assertEquals("fetch", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertEquals(84, searchResponse.getHits().getAt(0).docId()); + assertEquals(42, searchResponse.getHits().getAt(1).docId()); + assertEquals(0, searchResponse.getFailedShards()); + assertEquals(2, searchResponse.getSuccessfulShards()); + assertProfiles(profiled, 2, searchResponse); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + } finally { + results.decRef(); + } } public void testFailFetchOneDoc() { @@ -235,108 +265,116 @@ public void testFailFetchOneDoc() { 2, exc -> {} ); - int resultSetSize = randomIntBetween(2, 10); - boolean profiled = randomBoolean(); - - final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); - SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); - QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); // the size of the result set - queryResult.setShardIndex(0); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } + int resultSetSize = randomIntBetween(2, 10); + boolean profiled = randomBoolean(); - SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); - queryResult = new QuerySearchResult(new ShardSearchContextId("", 321), shard2Target, null); - try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); - queryResult.setShardIndex(1); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } + final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); // the size of the result set + queryResult.setShardIndex(0); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); + } - mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { - @Override - public void sendExecuteFetch( - Transport.Connection connection, - ShardFetchSearchRequest request, - SearchTask task, - SearchActionListener listener - ) { - if (request.contextId().getId() == 321) { - FetchSearchResult fetchResult = new FetchSearchResult(); - fetchResult.setSearchShardTarget(shard1Target); - SearchHits hits = new SearchHits( - new SearchHit[] { new SearchHit(84) }, - new TotalHits(1, TotalHits.Relation.EQUAL_TO), + SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + queryResult = new QuerySearchResult(new ShardSearchContextId("", 321), shard2Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), 2.0F - ); - fetchResult.shardResult(hits, fetchProfile(profiled)); - listener.onResponse(fetchResult); - } else { - listener.onFailure(new MockDirectoryWrapper.FakeIOException()); - } + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); + queryResult.setShardIndex(1); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); } - }; - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { + + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + SearchTask task, + SearchActionListener listener + ) { + if (request.contextId().getId() == 321) { + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + fetchResult.setSearchShardTarget(shard1Target); + SearchHits hits = new SearchHits( + new SearchHit[] { new SearchHit(84) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 2.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } + } else { + listener.onFailure(new MockDirectoryWrapper.FakeIOException()); + } + } + }; + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + } } - } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); - assertNotNull(searchResponse); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - assertEquals(84, searchResponse.getHits().getAt(0).docId()); - assertEquals(1, searchResponse.getFailedShards()); - assertEquals(1, searchResponse.getSuccessfulShards()); - assertEquals(1, searchResponse.getShardFailures().length); - assertTrue(searchResponse.getShardFailures()[0].getCause() instanceof MockDirectoryWrapper.FakeIOException); - assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); - if (profiled) { - /* - * Shard 2 failed to fetch but still searched so it will have - * profiling information for the search on both shards but only - * for the fetch on the successful shard. - */ - assertThat(searchResponse.getProfileResults().values().size(), equalTo(2)); - assertThat(searchResponse.getProfileResults().get(shard1Target.toString()).getFetchPhase(), nullValue()); - assertThat( - searchResponse.getProfileResults().get(shard2Target.toString()).getFetchPhase().getTime(), - equalTo(FETCH_PROFILE_TIME) ); - } else { - assertThat(searchResponse.getProfileResults(), equalTo(Map.of())); + assertEquals("fetch", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertEquals(84, searchResponse.getHits().getAt(0).docId()); + assertEquals(1, searchResponse.getFailedShards()); + assertEquals(1, searchResponse.getSuccessfulShards()); + assertEquals(1, searchResponse.getShardFailures().length); + assertTrue(searchResponse.getShardFailures()[0].getCause() instanceof MockDirectoryWrapper.FakeIOException); + assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); + if (profiled) { + /* + * Shard 2 failed to fetch but still searched so it will have + * profiling information for the search on both shards but only + * for the fetch on the successful shard. + */ + assertThat(searchResponse.getProfileResults().values().size(), equalTo(2)); + assertThat(searchResponse.getProfileResults().get(shard1Target.toString()).getFetchPhase(), nullValue()); + assertThat( + searchResponse.getProfileResults().get(shard2Target.toString()).getFetchPhase().getTime(), + equalTo(FETCH_PROFILE_TIME) + ); + } else { + assertThat(searchResponse.getProfileResults(), equalTo(Map.of())); + } + assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx)); + } finally { + results.decRef(); } - assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx)); } public void testFetchDocsConcurrently() throws InterruptedException { @@ -355,95 +393,103 @@ public void testFetchDocsConcurrently() throws InterruptedException { numHits, exc -> {} ); - SearchShardTarget[] shardTargets = new SearchShardTarget[numHits]; - for (int i = 0; i < numHits; i++) { - shardTargets[i] = new SearchShardTarget("node1", new ShardId("test", "na", i), null); - QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i), shardTargets[i], null); - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(i + 1, i) }), - i - ), - new DocValueFormat[0] - ); - try { - queryResult.size(resultSetSize); // the size of the result set - queryResult.setShardIndex(i); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } - } - mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { - @Override - public void sendExecuteFetch( - Transport.Connection connection, - ShardFetchSearchRequest request, - SearchTask task, - SearchActionListener listener - ) { - new Thread(() -> { - FetchSearchResult fetchResult = new FetchSearchResult(); - fetchResult.setSearchShardTarget(shardTargets[(int) request.contextId().getId()]); - SearchHits hits = new SearchHits( - new SearchHit[] { new SearchHit((int) (request.contextId().getId() + 1)) }, - new TotalHits(1, TotalHits.Relation.EQUAL_TO), - 100F + try { + SearchShardTarget[] shardTargets = new SearchShardTarget[numHits]; + for (int i = 0; i < numHits; i++) { + shardTargets[i] = new SearchShardTarget("node1", new ShardId("test", "na", i), null); + QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i), shardTargets[i], null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(i + 1, i) }), + i + ), + new DocValueFormat[0] ); - fetchResult.shardResult(hits, fetchProfile(profiled)); - listener.onResponse(fetchResult); - }).start(); + queryResult.size(resultSetSize); // the size of the result set + queryResult.setShardIndex(i); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); + } } - }; - CountDownLatch latch = new CountDownLatch(1); - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); - latch.countDown(); + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + SearchTask task, + SearchActionListener listener + ) { + new Thread(() -> { + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + fetchResult.setSearchShardTarget(shardTargets[(int) request.contextId().getId()]); + SearchHits hits = new SearchHits( + new SearchHit[] { new SearchHit((int) (request.contextId().getId() + 1)) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 100F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } + }).start(); } + }; + CountDownLatch latch = new CountDownLatch(1); + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + latch.countDown(); + } + } + ); + assertEquals("fetch", phase.getName()); + phase.run(); + latch.await(); + mockSearchPhaseContext.assertNoFailure(); + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(numHits, searchResponse.getHits().getTotalHits().value); + assertEquals(Math.min(numHits, resultSetSize), searchResponse.getHits().getHits().length); + SearchHit[] hits = searchResponse.getHits().getHits(); + for (int i = 0; i < hits.length; i++) { + assertNotNull(hits[i]); + assertEquals("index: " + i, numHits - i, hits[i].docId()); + assertEquals("index: " + i, numHits - 1 - i, (int) hits[i].getScore()); } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - latch.await(); - mockSearchPhaseContext.assertNoFailure(); - SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); - assertNotNull(searchResponse); - assertEquals(numHits, searchResponse.getHits().getTotalHits().value); - assertEquals(Math.min(numHits, resultSetSize), searchResponse.getHits().getHits().length); - SearchHit[] hits = searchResponse.getHits().getHits(); - for (int i = 0; i < hits.length; i++) { - assertNotNull(hits[i]); - assertEquals("index: " + i, numHits - i, hits[i].docId()); - assertEquals("index: " + i, numHits - 1 - i, (int) hits[i].getScore()); - } - assertEquals(0, searchResponse.getFailedShards()); - assertEquals(numHits, searchResponse.getSuccessfulShards()); - if (profiled) { - assertThat(searchResponse.getProfileResults().values().size(), equalTo(numHits)); - int count = 0; - for (SearchProfileShardResult profileShardResult : searchResponse.getProfileResults().values()) { - if (profileShardResult.getFetchPhase() != null) { - count++; - assertThat(profileShardResult.getFetchPhase().getTime(), equalTo(FETCH_PROFILE_TIME)); + assertEquals(0, searchResponse.getFailedShards()); + assertEquals(numHits, searchResponse.getSuccessfulShards()); + if (profiled) { + assertThat(searchResponse.getProfileResults().values().size(), equalTo(numHits)); + int count = 0; + for (SearchProfileShardResult profileShardResult : searchResponse.getProfileResults().values()) { + if (profileShardResult.getFetchPhase() != null) { + count++; + assertThat(profileShardResult.getFetchPhase().getTime(), equalTo(FETCH_PROFILE_TIME)); + } } + assertThat(count, equalTo(Math.min(numHits, resultSetSize))); + } else { + assertThat(searchResponse.getProfileResults(), equalTo(Map.of())); } - assertThat(count, equalTo(Math.min(numHits, resultSetSize))); - } else { - assertThat(searchResponse.getProfileResults(), equalTo(Map.of())); + int sizeReleasedContexts = Math.max(0, numHits - resultSetSize); // all non fetched results will be freed + assertEquals( + mockSearchPhaseContext.releasedSearchContexts.toString(), + sizeReleasedContexts, + mockSearchPhaseContext.releasedSearchContexts.size() + ); + } finally { + results.decRef(); } - int sizeReleasedContexts = Math.max(0, numHits - resultSetSize); // all non fetched results will be freed - assertEquals( - mockSearchPhaseContext.releasedSearchContexts.toString(), - sizeReleasedContexts, - mockSearchPhaseContext.releasedSearchContexts.size() - ); } public void testExceptionFailsPhase() { @@ -458,87 +504,103 @@ public void testExceptionFailsPhase() { 2, exc -> {} ); - int resultSetSize = randomIntBetween(2, 10); - boolean profiled = randomBoolean(); - - SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); - SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); - QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123), shard1Target, null); - try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); // the size of the result set - queryResult.setShardIndex(0); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } try { - queryResult = new QuerySearchResult(new ShardSearchContextId("", 321), shard2Target, null); - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); - queryResult.setShardIndex(1); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } + int resultSetSize = randomIntBetween(2, 10); + boolean profiled = randomBoolean(); - AtomicInteger numFetches = new AtomicInteger(0); - mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { - @Override - public void sendExecuteFetch( - Transport.Connection connection, - ShardFetchSearchRequest request, - SearchTask task, - SearchActionListener listener - ) { - FetchSearchResult fetchResult = new FetchSearchResult(); - if (numFetches.incrementAndGet() == 1) { - throw new RuntimeException("BOOM"); - } - SearchHits hits; - if (request.contextId().getId() == 321) { - fetchResult.setSearchShardTarget(shard2Target); - hits = new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F); - } else { - fetchResult.setSearchShardTarget(shard1Target); - assertEquals(request, 123); - hits = new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F); - } - fetchResult.shardResult(hits, fetchProfile(profiled)); - listener.onResponse(fetchResult); + SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123), shard1Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); // the size of the result set + queryResult.setShardIndex(0); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); + } + queryResult = new QuerySearchResult(new ShardSearchContextId("", 321), shard2Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); + queryResult.setShardIndex(1); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); } - }; - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { + + AtomicInteger numFetches = new AtomicInteger(0); + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + SearchTask task, + SearchActionListener listener + ) { + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + if (numFetches.incrementAndGet() == 1) { + listener.onFailure(new RuntimeException("BOOM")); + return; + } + SearchHits hits; + if (request.contextId().getId() == 321) { + fetchResult.setSearchShardTarget(shard2Target); + hits = new SearchHits( + new SearchHit[] { new SearchHit(84) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 2.0F + ); + } else { + fetchResult.setSearchShardTarget(shard1Target); + assertEquals(request, 123); + hits = new SearchHits( + new SearchHit[] { new SearchHit(42) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 1.0F + ); + } + fetchResult.shardResult(hits, fetchProfile(profiled)); + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } } - } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - assertNotNull(mockSearchPhaseContext.phaseFailure.get()); - assertEquals(mockSearchPhaseContext.phaseFailure.get().getMessage(), "BOOM"); - assertNull(mockSearchPhaseContext.searchResponse.get()); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); + }; + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + } + } + ); + assertEquals("fetch", phase.getName()); + phase.run(); + assertNotNull(mockSearchPhaseContext.searchResponse.get()); + assertThat(mockSearchPhaseContext.searchResponse.get().getShardFailures(), arrayWithSize(1)); + assertThat(mockSearchPhaseContext.releasedSearchContexts, hasSize(1)); + } finally { + results.decRef(); + } } public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up @@ -553,100 +615,109 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s 2, exc -> {} ); - int resultSetSize = 1; - boolean profiled = randomBoolean(); - - final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); - SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); - QuerySearchResult queryResult = new QuerySearchResult(ctx1, shard1Target, null); try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); // the size of the result set - queryResult.setShardIndex(0); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } - final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); - SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); - queryResult = new QuerySearchResult(ctx2, shard2Target, null); - try { - queryResult.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), - 2.0F - ), - new DocValueFormat[0] - ); - queryResult.size(resultSetSize); - queryResult.setShardIndex(1); - addProfiling(profiled, queryResult); - results.consumeResult(queryResult, () -> {}); - } finally { - queryResult.decRef(); - } + int resultSetSize = 1; + boolean profiled = randomBoolean(); - mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { - @Override - public void sendExecuteFetch( - Transport.Connection connection, - ShardFetchSearchRequest request, - SearchTask task, - SearchActionListener listener - ) { - FetchSearchResult fetchResult = new FetchSearchResult(); - if (request.contextId().getId() == 321) { - fetchResult.setSearchShardTarget(shard1Target); - SearchHits hits = new SearchHits( - new SearchHit[] { new SearchHit(84) }, - new TotalHits(1, TotalHits.Relation.EQUAL_TO), + final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); + SearchShardTarget shard1Target = new SearchShardTarget("node1", new ShardId("test", "na", 0), null); + QuerySearchResult queryResult = new QuerySearchResult(ctx1, shard1Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 2.0F - ); - fetchResult.shardResult(hits, fetchProfile(profiled)); - } else { - fail("requestID 123 should not be fetched but was"); - } - listener.onResponse(fetchResult); + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); // the size of the result set + queryResult.setShardIndex(0); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); } - }; - FetchSearchPhase phase = new FetchSearchPhase( - results, - null, - mockSearchPhaseContext, - (searchResponse, scrollId) -> new SearchPhase("test") { + final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); + SearchShardTarget shard2Target = new SearchShardTarget("node2", new ShardId("test", "na", 1), null); + queryResult = new QuerySearchResult(ctx2, shard2Target, null); + try { + queryResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), + 2.0F + ), + new DocValueFormat[0] + ); + queryResult.size(resultSetSize); + queryResult.setShardIndex(1); + addProfiling(profiled, queryResult); + results.consumeResult(queryResult, () -> {}); + } finally { + queryResult.decRef(); + } + + mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) { @Override - public void run() { - mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + public void sendExecuteFetch( + Transport.Connection connection, + ShardFetchSearchRequest request, + SearchTask task, + SearchActionListener listener + ) { + FetchSearchResult fetchResult = new FetchSearchResult(); + try { + if (request.contextId().getId() == 321) { + fetchResult.setSearchShardTarget(shard1Target); + SearchHits hits = new SearchHits( + new SearchHit[] { new SearchHit(84) }, + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + 2.0F + ); + fetchResult.shardResult(hits, fetchProfile(profiled)); + } else { + fail("requestID 123 should not be fetched but was"); + } + listener.onResponse(fetchResult); + } finally { + fetchResult.decRef(); + } + } + }; + FetchSearchPhase phase = new FetchSearchPhase( + results, + null, + mockSearchPhaseContext, + (searchResponse, scrollId) -> new SearchPhase("test") { + @Override + public void run() { + mockSearchPhaseContext.sendSearchResponse(searchResponse, null); + } } - } - ); - assertEquals("fetch", phase.getName()); - phase.run(); - mockSearchPhaseContext.assertNoFailure(); - SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); - assertNotNull(searchResponse); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - assertEquals(1, searchResponse.getHits().getHits().length); - assertEquals(84, searchResponse.getHits().getAt(0).docId()); - assertEquals(0, searchResponse.getFailedShards()); - assertEquals(2, searchResponse.getSuccessfulShards()); - if (profiled) { - assertThat(searchResponse.getProfileResults().size(), equalTo(2)); - assertThat(searchResponse.getProfileResults().get(shard1Target.toString()).getFetchPhase(), nullValue()); - assertThat( - searchResponse.getProfileResults().get(shard2Target.toString()).getFetchPhase().getTime(), - equalTo(FETCH_PROFILE_TIME) ); + assertEquals("fetch", phase.getName()); + phase.run(); + mockSearchPhaseContext.assertNoFailure(); + SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); + assertNotNull(searchResponse); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + assertEquals(1, searchResponse.getHits().getHits().length); + assertEquals(84, searchResponse.getHits().getAt(0).docId()); + assertEquals(0, searchResponse.getFailedShards()); + assertEquals(2, searchResponse.getSuccessfulShards()); + if (profiled) { + assertThat(searchResponse.getProfileResults().size(), equalTo(2)); + assertThat(searchResponse.getProfileResults().get(shard1Target.toString()).getFetchPhase(), nullValue()); + assertThat( + searchResponse.getProfileResults().get(shard2Target.toString()).getFetchPhase().getTime(), + equalTo(FETCH_PROFILE_TIME) + ); + } + assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); + assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx1)); + } finally { + results.decRef(); } - assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); - assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx1)); + } private void addProfiling(boolean profiled, QuerySearchResult queryResult) { diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index dcddecb88323d..71156517b0306 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.internal.InternalSearchResponse; @@ -43,6 +44,8 @@ public final class MockSearchPhaseContext implements SearchPhaseContext { final SearchRequest searchRequest = new SearchRequest(); final AtomicReference searchResponse = new AtomicReference<>(); + private final List releasables = new ArrayList<>(); + public MockSearchPhaseContext(int numShards) { this.numShards = numShards; numSuccess = new AtomicInteger(numShards); @@ -137,12 +140,17 @@ public void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) { @Override public void addReleasable(Releasable releasable) { - // Noop + releasables.add(releasable); } @Override public void execute(Runnable command) { - command.run(); + try { + command.run(); + } finally { + Releasables.close(releasables); + releasables.clear(); + } } @Override diff --git a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java index f18d69c442b4b..6035950ca4635 100644 --- a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java @@ -116,26 +116,30 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { return curr; }) ); + try { + + CountDownLatch partialReduceLatch = new CountDownLatch(10); + + for (int i = 0; i < 10; i++) { + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(i); + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } - CountDownLatch partialReduceLatch = new CountDownLatch(10); + assertEquals(10, searchProgressListener.onQueryResult.get()); + assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS)); + assertNull(onPartialMergeFailure.get()); + assertEquals(8, searchProgressListener.onPartialReduce.get()); - for (int i = 0; i < 10; i++) { - SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i), null); - QuerySearchResult querySearchResult = new QuerySearchResult(); - TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]); - querySearchResult.setSearchShardTarget(searchShardTarget); - querySearchResult.setShardIndex(i); - queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + queryPhaseResultConsumer.reduce(); + assertEquals(1, searchProgressListener.onFinalReduce.get()); + } finally { + queryPhaseResultConsumer.decRef(); } - - assertEquals(10, searchProgressListener.onQueryResult.get()); - assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS)); - assertNull(onPartialMergeFailure.get()); - assertEquals(8, searchProgressListener.onPartialReduce.get()); - - queryPhaseResultConsumer.reduce(); - assertEquals(1, searchProgressListener.onFinalReduce.get()); } private static class ThrowingSearchProgressListener extends SearchProgressListener { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 0c2670348f9f6..430e66c116744 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -177,7 +177,6 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { } CountDownLatch latch = new CountDownLatch(1); AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false); - ActionListener responseListener = ActionTestUtils.assertNoFailureListener(response -> {}); DiscoveryNode primaryNode = DiscoveryNodeUtils.create("node_1"); // for the sake of this test we place the replica on the same node. ie. this is not a mistake since we limit per node now DiscoveryNode replicaNode = DiscoveryNodeUtils.create("node_1"); @@ -199,71 +198,80 @@ public void testLimitConcurrentShardRequests() throws InterruptedException { Map aliasFilters = Collections.singletonMap("_na_", AliasFilter.EMPTY); CountDownLatch awaitInitialRequests = new CountDownLatch(1); AtomicInteger numRequests = new AtomicInteger(0); - AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction( - "test", - logger, - transportService, - (cluster, node) -> { - assert cluster == null : "cluster was not null: " + cluster; - return lookup.get(node); - }, - aliasFilters, - Collections.emptyMap(), - null, - request, - responseListener, - shardsIter, - new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), - ClusterState.EMPTY_STATE, - null, - new ArraySearchPhaseResults<>(shardsIter.size()), - request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY - ) { - - @Override - protected void executePhaseOnShard( - SearchShardIterator shardIt, - SearchShardTarget shard, - SearchActionListener listener + var results = new ArraySearchPhaseResults(shardsIter.size()); + try { + AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction<>( + "test", + logger, + transportService, + (cluster, node) -> { + assert cluster == null : "cluster was not null: " + cluster; + return lookup.get(node); + }, + aliasFilters, + Collections.emptyMap(), + null, + request, + ActionTestUtils.assertNoFailureListener(response -> {}), + shardsIter, + new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), + ClusterState.EMPTY_STATE, + null, + results, + request.getMaxConcurrentShardRequests(), + SearchResponse.Clusters.EMPTY ) { - seenShard.computeIfAbsent(shard.getShardId(), (i) -> { - numRequests.incrementAndGet(); // only count this once per shard copy - return Boolean.TRUE; - }); - new Thread(() -> { - safeAwait(awaitInitialRequests); - Transport.Connection connection = getConnection(null, shard.getNodeId()); - TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( - new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), - connection.getNode() - ); - if (shardFailures[shard.getShardId().id()]) { - listener.onFailure(new RuntimeException()); - } else { - listener.onResponse(testSearchPhaseResult); - } - }).start(); - } + @Override + protected void executePhaseOnShard( + SearchShardIterator shardIt, + SearchShardTarget shard, + SearchActionListener listener + ) { + seenShard.computeIfAbsent(shard.getShardId(), (i) -> { + numRequests.incrementAndGet(); // only count this once per shard copy + return Boolean.TRUE; + }); + + new Thread(() -> { + safeAwait(awaitInitialRequests); + Transport.Connection connection = getConnection(null, shard.getNodeId()); + TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( + new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), + connection.getNode() + ); + try { + if (shardFailures[shard.getShardId().id()]) { + listener.onFailure(new RuntimeException()); + } else { + listener.onResponse(testSearchPhaseResult); + } + } finally { + testSearchPhaseResult.decRef(); + } + }).start(); + } - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - assertTrue(searchPhaseDidRun.compareAndSet(false, true)); - latch.countDown(); - } - }; - } - }; - asyncAction.start(); - assertEquals(numConcurrent, numRequests.get()); - awaitInitialRequests.countDown(); - latch.await(); - assertTrue(searchPhaseDidRun.get()); - assertEquals(numShards, numRequests.get()); + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + assertTrue(searchPhaseDidRun.compareAndSet(false, true)); + latch.countDown(); + } + }; + } + }; + asyncAction.start(); + assertEquals(numConcurrent, numRequests.get()); + awaitInitialRequests.countDown(); + latch.await(); + assertTrue(searchPhaseDidRun.get()); + assertEquals(numShards, numRequests.get()); + } finally { + results.decRef(); + } } public void testFanOutAndCollect() throws InterruptedException { @@ -304,82 +312,87 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors())); final CountDownLatch latch = new CountDownLatch(1); final AtomicBoolean latchTriggered = new AtomicBoolean(); - AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction( - "test", - logger, - transportService, - (cluster, node) -> { - assert cluster == null : "cluster was not null: " + cluster; - return lookup.get(node); - }, - aliasFilters, - Collections.emptyMap(), - executor, - request, - responseListener, - shardsIter, - new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), - ClusterState.EMPTY_STATE, - null, - new ArraySearchPhaseResults<>(shardsIter.size()), - request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY - ) { - TestSearchResponse response = new TestSearchResponse(); - - @Override - protected void executePhaseOnShard( - SearchShardIterator shardIt, - SearchShardTarget shard, - SearchActionListener listener + var results = new ArraySearchPhaseResults(shardsIter.size()); + try { + AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction<>( + "test", + logger, + transportService, + (cluster, node) -> { + assert cluster == null : "cluster was not null: " + cluster; + return lookup.get(node); + }, + aliasFilters, + Collections.emptyMap(), + executor, + request, + responseListener, + shardsIter, + new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), + ClusterState.EMPTY_STATE, + null, + results, + request.getMaxConcurrentShardRequests(), + SearchResponse.Clusters.EMPTY ) { - assertTrue("shard: " + shard.getShardId() + " has been queried twice", response.queried.add(shard.getShardId())); - Transport.Connection connection = getConnection(null, shard.getNodeId()); - TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( - new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), - connection.getNode() - ); - Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet()); - ids.add(testSearchPhaseResult.getContextId()); - if (randomBoolean()) { - listener.onResponse(testSearchPhaseResult); - } else { - new Thread(() -> listener.onResponse(testSearchPhaseResult)).start(); + final TestSearchResponse response = new TestSearchResponse(); + + @Override + protected void executePhaseOnShard( + SearchShardIterator shardIt, + SearchShardTarget shard, + SearchActionListener listener + ) { + assertTrue("shard: " + shard.getShardId() + " has been queried twice", response.queried.add(shard.getShardId())); + Transport.Connection connection = getConnection(null, shard.getNodeId()); + TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( + new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), + connection.getNode() + ); + Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet()); + ids.add(testSearchPhaseResult.getContextId()); + if (randomBoolean()) { + listener.onResponse(testSearchPhaseResult); + } else { + new Thread(() -> listener.onResponse(testSearchPhaseResult)).start(); + } } - } - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - for (int i = 0; i < results.getNumShards(); i++) { - TestSearchPhaseResult result = results.getAtomicArray().get(i); - assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId()); - sendReleaseSearchContext(result.getContextId(), new MockConnection(result.node), OriginalIndices.NONE); + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + for (int i = 0; i < results.getNumShards(); i++) { + TestSearchPhaseResult result = results.getAtomicArray().get(i); + assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId()); + sendReleaseSearchContext(result.getContextId(), new MockConnection(result.node), OriginalIndices.NONE); + } + responseListener.onResponse(response); + if (latchTriggered.compareAndSet(false, true) == false) { + throw new AssertionError("latch triggered twice"); + } + latch.countDown(); } - responseListener.onResponse(response); - if (latchTriggered.compareAndSet(false, true) == false) { - throw new AssertionError("latch triggered twice"); - } - latch.countDown(); - } - }; + }; + } + }; + asyncAction.start(); + latch.await(); + assertNotNull(response.get()); + assertFalse(nodeToContextMap.isEmpty()); + assertTrue(nodeToContextMap.toString(), nodeToContextMap.containsKey(primaryNode) || nodeToContextMap.containsKey(replicaNode)); + assertEquals(shardsIter.size(), numFreedContext.get()); + if (nodeToContextMap.containsKey(primaryNode)) { + assertTrue(nodeToContextMap.get(primaryNode).toString(), nodeToContextMap.get(primaryNode).isEmpty()); + } else { + assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty()); } - }; - asyncAction.start(); - latch.await(); - assertNotNull(response.get()); - assertFalse(nodeToContextMap.isEmpty()); - assertTrue(nodeToContextMap.toString(), nodeToContextMap.containsKey(primaryNode) || nodeToContextMap.containsKey(replicaNode)); - assertEquals(shardsIter.size(), numFreedContext.get()); - if (nodeToContextMap.containsKey(primaryNode)) { - assertTrue(nodeToContextMap.get(primaryNode).toString(), nodeToContextMap.get(primaryNode).isEmpty()); - } else { - assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty()); + final List runnables = executor.shutdownNow(); + assertThat(runnables, equalTo(Collections.emptyList())); + } finally { + results.decRef(); } - final List runnables = executor.shutdownNow(); - assertThat(runnables, equalTo(Collections.emptyList())); } public void testFanOutAndFail() throws InterruptedException { @@ -424,7 +437,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI lookup.put(replicaNode.getId(), new MockConnection(replicaNode)); Map aliasFilters = Collections.singletonMap("_na_", AliasFilter.EMPTY); ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors())); - AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction( + AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction<>( "test", logger, transportService, @@ -445,7 +458,7 @@ public void sendFreeContext(Transport.Connection connection, ShardSearchContextI request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY ) { - TestSearchResponse response = new TestSearchResponse(); + final TestSearchResponse response = new TestSearchResponse(); @Override protected void executePhaseOnShard( @@ -506,7 +519,6 @@ public void testAllowPartialResults() throws InterruptedException { request.setMaxConcurrentShardRequests(numConcurrent); int numShards = randomIntBetween(5, 10); AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false); - ActionListener responseListener = ActionTestUtils.assertNoFailureListener(response -> {}); DiscoveryNode primaryNode = DiscoveryNodeUtils.create("node_1"); // for the sake of this test we place the replica on the same node. ie. this is not a mistake since we limit per node now DiscoveryNode replicaNode = DiscoveryNodeUtils.create("node_1"); @@ -530,69 +542,78 @@ public void testAllowPartialResults() throws InterruptedException { Map aliasFilters = Collections.singletonMap("_na_", AliasFilter.EMPTY); AtomicInteger numRequests = new AtomicInteger(0); AtomicInteger numFailReplicas = new AtomicInteger(0); - AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction<>( - "test", - logger, - transportService, - (cluster, node) -> { - assert cluster == null : "cluster was not null: " + cluster; - return lookup.get(node); - }, - aliasFilters, - Collections.emptyMap(), - null, - request, - responseListener, - shardsIter, - new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), - ClusterState.EMPTY_STATE, - null, - new ArraySearchPhaseResults<>(shardsIter.size()), - request.getMaxConcurrentShardRequests(), - SearchResponse.Clusters.EMPTY - ) { - - @Override - protected void executePhaseOnShard( - SearchShardIterator shardIt, - SearchShardTarget shard, - SearchActionListener listener + var results = new ArraySearchPhaseResults(shardsIter.size()); + try { + AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction<>( + "test", + logger, + transportService, + (cluster, node) -> { + assert cluster == null : "cluster was not null: " + cluster; + return lookup.get(node); + }, + aliasFilters, + Collections.emptyMap(), + null, + request, + ActionTestUtils.assertNoFailureListener(response -> {}), + shardsIter, + new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0), + ClusterState.EMPTY_STATE, + null, + results, + request.getMaxConcurrentShardRequests(), + SearchResponse.Clusters.EMPTY ) { - seenShard.computeIfAbsent(shard.getShardId(), (i) -> { - numRequests.incrementAndGet(); // only count this once per shard copy - return Boolean.TRUE; - }); - new Thread(() -> { - Transport.Connection connection = getConnection(null, shard.getNodeId()); - TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( - new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), - connection.getNode() - ); - if (shardIt.remaining() > 0) { - numFailReplicas.incrementAndGet(); - listener.onFailure(new RuntimeException()); - } else { - listener.onResponse(testSearchPhaseResult); - } - }).start(); - } - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - assertTrue(searchPhaseDidRun.compareAndSet(false, true)); - latch.countDown(); - } - }; - } - }; - asyncAction.start(); - latch.await(); - assertTrue(searchPhaseDidRun.get()); - assertEquals(numShards, numRequests.get()); - assertThat(numFailReplicas.get(), greaterThanOrEqualTo(1)); + @Override + protected void executePhaseOnShard( + SearchShardIterator shardIt, + SearchShardTarget shard, + SearchActionListener listener + ) { + seenShard.computeIfAbsent(shard.getShardId(), (i) -> { + numRequests.incrementAndGet(); // only count this once per shard copy + return Boolean.TRUE; + }); + new Thread(() -> { + Transport.Connection connection = getConnection(null, shard.getNodeId()); + TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult( + new ShardSearchContextId(UUIDs.randomBase64UUID(), contextIdGenerator.incrementAndGet()), + connection.getNode() + ); + try { + if (shardIt.remaining() > 0) { + numFailReplicas.incrementAndGet(); + listener.onFailure(new RuntimeException()); + } else { + listener.onResponse(testSearchPhaseResult); + } + } finally { + testSearchPhaseResult.decRef(); + } + }).start(); + } + + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + assertTrue(searchPhaseDidRun.compareAndSet(false, true)); + latch.countDown(); + } + }; + } + }; + asyncAction.start(); + latch.await(); + assertTrue(searchPhaseDidRun.get()); + assertEquals(numShards, numRequests.get()); + assertThat(numFailReplicas.get(), greaterThanOrEqualTo(1)); + } finally { + results.decRef(); + } } public void testSkipUnavailableSearchShards() throws InterruptedException { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 9172b541a8236..0dcb6abe3a86e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -293,57 +293,61 @@ public void testMerge() { reducedQueryPhase.suggest(), profile ); - InternalSearchResponse mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); - if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { - assertNull(mergedResponse.hits.getTotalHits()); - } else { - assertThat(mergedResponse.hits.getTotalHits().value, equalTo(0L)); - assertEquals(mergedResponse.hits.getTotalHits().relation, Relation.EQUAL_TO); - } - for (SearchHit hit : mergedResponse.hits().getHits()) { - SearchPhaseResult searchPhaseResult = fetchResults.get(hit.getShard().getShardId().id()); - assertSame(searchPhaseResult.getSearchShardTarget(), hit.getShard()); - } - int suggestSize = 0; - for (Suggest.Suggestion s : reducedQueryPhase.suggest()) { - suggestSize += s.getEntries().stream().mapToInt(e -> e.getOptions().size()).sum(); - } - assertThat(suggestSize, lessThanOrEqualTo(maxSuggestSize)); - assertThat( - mergedResponse.hits().getHits().length, - equalTo(reducedQueryPhase.sortedTopDocs().scoreDocs().length - suggestSize) - ); - Suggest suggestResult = mergedResponse.suggest(); - for (Suggest.Suggestion suggestion : reducedQueryPhase.suggest()) { - assertThat(suggestion, instanceOf(CompletionSuggestion.class)); - if (suggestion.getEntries().get(0).getOptions().size() > 0) { - CompletionSuggestion suggestionResult = suggestResult.getSuggestion(suggestion.getName()); - assertNotNull(suggestionResult); - List options = suggestionResult.getEntries().get(0).getOptions(); - assertThat(options.size(), equalTo(suggestion.getEntries().get(0).getOptions().size())); - for (CompletionSuggestion.Entry.Option option : options) { - assertNotNull(option.getHit()); - SearchPhaseResult searchPhaseResult = fetchResults.get(option.getHit().getShard().getShardId().id()); - assertSame(searchPhaseResult.getSearchShardTarget(), option.getHit().getShard()); - } + try { + InternalSearchResponse mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); + if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + assertNull(mergedResponse.hits.getTotalHits()); + } else { + assertThat(mergedResponse.hits.getTotalHits().value, equalTo(0L)); + assertEquals(mergedResponse.hits.getTotalHits().relation, Relation.EQUAL_TO); } - } - if (profile) { - assertThat(mergedResponse.profile().entrySet(), hasSize(nShards)); - assertThat( - // All shards should have a query profile - mergedResponse.profile().toString(), - mergedResponse.profile().values().stream().filter(r -> r.getQueryProfileResults() != null).count(), - equalTo((long) nShards) - ); + for (SearchHit hit : mergedResponse.hits().getHits()) { + SearchPhaseResult searchPhaseResult = fetchResults.get(hit.getShard().getShardId().id()); + assertSame(searchPhaseResult.getSearchShardTarget(), hit.getShard()); + } + int suggestSize = 0; + for (Suggest.Suggestion s : reducedQueryPhase.suggest()) { + suggestSize += s.getEntries().stream().mapToInt(e -> e.getOptions().size()).sum(); + } + assertThat(suggestSize, lessThanOrEqualTo(maxSuggestSize)); assertThat( - // Some or all shards should have a fetch profile - mergedResponse.profile().toString(), - mergedResponse.profile().values().stream().filter(r -> r.getFetchPhase() != null).count(), - both(greaterThan(0L)).and(lessThanOrEqualTo((long) nShards)) + mergedResponse.hits().getHits().length, + equalTo(reducedQueryPhase.sortedTopDocs().scoreDocs().length - suggestSize) ); - } else { - assertThat(mergedResponse.profile(), is(anEmptyMap())); + Suggest suggestResult = mergedResponse.suggest(); + for (Suggest.Suggestion suggestion : reducedQueryPhase.suggest()) { + assertThat(suggestion, instanceOf(CompletionSuggestion.class)); + if (suggestion.getEntries().get(0).getOptions().size() > 0) { + CompletionSuggestion suggestionResult = suggestResult.getSuggestion(suggestion.getName()); + assertNotNull(suggestionResult); + List options = suggestionResult.getEntries().get(0).getOptions(); + assertThat(options.size(), equalTo(suggestion.getEntries().get(0).getOptions().size())); + for (CompletionSuggestion.Entry.Option option : options) { + assertNotNull(option.getHit()); + SearchPhaseResult searchPhaseResult = fetchResults.get(option.getHit().getShard().getShardId().id()); + assertSame(searchPhaseResult.getSearchShardTarget(), option.getHit().getShard()); + } + } + } + if (profile) { + assertThat(mergedResponse.profile().entrySet(), hasSize(nShards)); + assertThat( + // All shards should have a query profile + mergedResponse.profile().toString(), + mergedResponse.profile().values().stream().filter(r -> r.getQueryProfileResults() != null).count(), + equalTo((long) nShards) + ); + assertThat( + // Some or all shards should have a fetch profile + mergedResponse.profile().toString(), + mergedResponse.profile().values().stream().filter(r -> r.getFetchPhase() != null).count(), + both(greaterThan(0L)).and(lessThanOrEqualTo((long) nShards)) + ); + } else { + assertThat(mergedResponse.profile(), is(anEmptyMap())); + } + } finally { + fetchResults.asList().forEach(TransportMessage::decRef); } } finally { queryResults.asList().forEach(TransportMessage::decRef); @@ -407,22 +411,27 @@ protected boolean lessThan(RankDoc a, RankDoc b) { reducedQueryPhase.suggest(), false ); - InternalSearchResponse mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); - if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { - assertNull(mergedResponse.hits.getTotalHits()); - } else { - assertThat(mergedResponse.hits.getTotalHits().value, equalTo(0L)); - assertEquals(mergedResponse.hits.getTotalHits().relation, Relation.EQUAL_TO); - } - int rank = 1; - for (SearchHit hit : mergedResponse.hits().getHits()) { - SearchPhaseResult searchPhaseResult = fetchResults.get(hit.getShard().getShardId().id()); - assertSame(searchPhaseResult.getSearchShardTarget(), hit.getShard()); - assertEquals(rank++, hit.getRank()); + try { + InternalSearchResponse mergedResponse = SearchPhaseController.merge(false, reducedQueryPhase, fetchResults); + if (trackTotalHits == SearchContext.TRACK_TOTAL_HITS_DISABLED) { + assertNull(mergedResponse.hits.getTotalHits()); + } else { + assertThat(mergedResponse.hits.getTotalHits().value, equalTo(0L)); + assertEquals(mergedResponse.hits.getTotalHits().relation, Relation.EQUAL_TO); + } + int rank = 1; + for (SearchHit hit : mergedResponse.hits().getHits()) { + SearchPhaseResult searchPhaseResult = fetchResults.get(hit.getShard().getShardId().id()); + assertSame(searchPhaseResult.getSearchShardTarget(), hit.getShard()); + assertEquals(rank++, hit.getRank()); + } + assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.sortedTopDocs().scoreDocs().length)); + assertThat(mergedResponse.profile(), is(anEmptyMap())); + } finally { + fetchResults.asList().forEach(TransportMessage::decRef); } - assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.sortedTopDocs().scoreDocs().length)); - assertThat(mergedResponse.profile(), is(anEmptyMap())); } finally { + queryResults.asList().forEach(TransportMessage::decRef); } } @@ -609,109 +618,113 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { 3 + numEmptyResponses, exc -> {} ); - if (numEmptyResponses == 0) { - assertEquals(0, reductions.size()); - } - if (numEmptyResponses > 0) { - QuerySearchResult empty = QuerySearchResult.nullInstance(); - int shardId = 2 + numEmptyResponses; - empty.setShardIndex(2 + numEmptyResponses); - empty.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null)); - consumer.consumeResult(empty, latch::countDown); - numEmptyResponses--; - } - - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", 0), - new SearchShardTarget("node", new ShardId("a", "b", 0), null), - null - ); try { - result.topDocs( - new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), - new DocValueFormat[0] + if (numEmptyResponses == 0) { + assertEquals(0, reductions.size()); + } + if (numEmptyResponses > 0) { + QuerySearchResult empty = QuerySearchResult.nullInstance(); + int shardId = 2 + numEmptyResponses; + empty.setShardIndex(2 + numEmptyResponses); + empty.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null)); + consumer.consumeResult(empty, latch::countDown); + numEmptyResponses--; + } + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", 0), + new SearchShardTarget("node", new ShardId("a", "b", 0), null), + null ); - InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 1.0D, DocValueFormat.RAW, emptyMap()))); - result.aggregations(aggs); - result.setShardIndex(0); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); - } - result = new QuerySearchResult( - new ShardSearchContextId("", 1), - new SearchShardTarget("node", new ShardId("a", "b", 0), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), - new DocValueFormat[0] + try { + result.topDocs( + new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 1.0D, DocValueFormat.RAW, emptyMap()))); + result.aggregations(aggs); + result.setShardIndex(0); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } + result = new QuerySearchResult( + new ShardSearchContextId("", 1), + new SearchShardTarget("node", new ShardId("a", "b", 0), null), + null ); - InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 3.0D, DocValueFormat.RAW, emptyMap()))); - result.aggregations(aggs); - result.setShardIndex(2); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); - } - result = new QuerySearchResult( - new ShardSearchContextId("", 1), - new SearchShardTarget("node", new ShardId("a", "b", 0), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), - new DocValueFormat[0] + try { + result.topDocs( + new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 3.0D, DocValueFormat.RAW, emptyMap()))); + result.aggregations(aggs); + result.setShardIndex(2); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } + result = new QuerySearchResult( + new ShardSearchContextId("", 1), + new SearchShardTarget("node", new ShardId("a", "b", 0), null), + null ); - InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 2.0D, DocValueFormat.RAW, emptyMap()))); - result.aggregations(aggs); - result.setShardIndex(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); - } - while (numEmptyResponses > 0) { - result = QuerySearchResult.nullInstance(); try { - int shardId = 2 + numEmptyResponses; - result.setShardIndex(shardId); - result.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null)); + result.topDocs( + new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from(singletonList(new Max("test", 2.0D, DocValueFormat.RAW, emptyMap()))); + result.aggregations(aggs); + result.setShardIndex(1); consumer.consumeResult(result, latch::countDown); } finally { result.decRef(); } - numEmptyResponses--; - } - latch.await(); - final int numTotalReducePhases; - if (numShards > bufferSize) { - if (bufferSize == 2) { - assertEquals(1, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); - assertEquals(1, reductions.size()); - assertEquals(false, reductions.get(0)); - numTotalReducePhases = 2; + while (numEmptyResponses > 0) { + result = QuerySearchResult.nullInstance(); + try { + int shardId = 2 + numEmptyResponses; + result.setShardIndex(shardId); + result.setSearchShardTarget(new SearchShardTarget("node", new ShardId("a", "b", shardId), null)); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } + numEmptyResponses--; + } + latch.await(); + final int numTotalReducePhases; + if (numShards > bufferSize) { + if (bufferSize == 2) { + assertEquals(1, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); + assertEquals(1, reductions.size()); + assertEquals(false, reductions.get(0)); + numTotalReducePhases = 2; + } else { + assertEquals(0, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); + assertEquals(0, reductions.size()); + numTotalReducePhases = 1; + } } else { - assertEquals(0, ((QueryPhaseResultConsumer) consumer).getNumReducePhases()); assertEquals(0, reductions.size()); numTotalReducePhases = 1; } - } else { - assertEquals(0, reductions.size()); - numTotalReducePhases = 1; - } - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertEquals(numTotalReducePhases, reduce.numReducePhases()); - assertEquals(numTotalReducePhases, reductions.size()); - assertAggReduction(request); - Max max = (Max) reduce.aggregations().asList().get(0); - assertEquals(3.0D, max.value(), 0.0D); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertEquals(numTotalReducePhases, reduce.numReducePhases()); + assertEquals(numTotalReducePhases, reductions.size()); + assertAggReduction(request); + Max max = (Max) reduce.aggregations().asList().get(0); + assertEquals(3.0D, max.value(), 0.0D); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); + } } public void testConsumerConcurrently() throws Exception { @@ -730,58 +743,62 @@ public void testConsumerConcurrently() throws Exception { expectedNumResults, exc -> {} ); - AtomicInteger max = new AtomicInteger(); - Thread[] threads = new Thread[expectedNumResults]; - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - int id = i; - threads[i] = new Thread(() -> { - int number = randomIntBetween(1, 1000); - max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", id), - new SearchShardTarget("node", new ShardId("a", "b", id), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), - number - ), - new DocValueFormat[0] - ); - InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + try { + AtomicInteger max = new AtomicInteger(); + Thread[] threads = new Thread[expectedNumResults]; + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + int id = i; + threads[i] = new Thread(() -> { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", id), + new SearchShardTarget("node", new ShardId("a", "b", id), null), + null ); - result.aggregations(aggs); - result.setShardIndex(id); - result.size(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); - } + try { + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), + number + ), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from( + Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + ); + result.aggregations(aggs); + result.setShardIndex(id); + result.size(1); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } - }); - threads[i].start(); - } - for (int i = 0; i < expectedNumResults; i++) { - threads[i].join(); - } - latch.await(); + }); + threads[i].start(); + } + for (int i = 0; i < expectedNumResults; i++) { + threads[i].join(); + } + latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - Max internalMax = (Max) reduce.aggregations().asList().get(0); - assertEquals(max.get(), internalMax.value(), 0.0D); - assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(max.get(), reduce.maxScore(), 0.0f); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + Max internalMax = (Max) reduce.aggregations().asList().get(0); + assertEquals(max.get(), internalMax.value(), 0.0D); + assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(max.get(), reduce.maxScore(), 0.0f); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); + } } public void testConsumerOnlyAggs() throws Exception { @@ -799,45 +816,49 @@ public void testConsumerOnlyAggs() throws Exception { expectedNumResults, exc -> {} ); - AtomicInteger max = new AtomicInteger(); - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - int number = randomIntBetween(1, 1000); - max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number), - new DocValueFormat[0] - ); - InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + try { + AtomicInteger max = new AtomicInteger(); + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null ); - result.aggregations(aggs); - result.setShardIndex(i); - result.size(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + result.topDocs( + new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from( + Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + ); + result.aggregations(aggs); + result.setShardIndex(i); + result.size(1); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } } - } - latch.await(); + latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - Max internalMax = (Max) reduce.aggregations().asList().get(0); - assertEquals(max.get(), internalMax.value(), 0.0D); - assertEquals(0, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(max.get(), reduce.maxScore(), 0.0f); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + Max internalMax = (Max) reduce.aggregations().asList().get(0); + assertEquals(max.get(), internalMax.value(), 0.0D); + assertEquals(0, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(max.get(), reduce.maxScore(), 0.0f); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); + } } public void testConsumerOnlyHits() throws Exception { @@ -857,42 +878,46 @@ public void testConsumerOnlyHits() throws Exception { expectedNumResults, exc -> {} ); - AtomicInteger max = new AtomicInteger(); - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - int number = randomIntBetween(1, 1000); - max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), - number - ), - new DocValueFormat[0] + try { + AtomicInteger max = new AtomicInteger(); + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null ); - result.setShardIndex(i); - result.size(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), + number + ), + new DocValueFormat[0] + ); + result.setShardIndex(i); + result.size(1); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(max.get(), reduce.maxScore(), 0.0f); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); } - latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(max.get(), reduce.maxScore(), 0.0f); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); } private void assertAggReduction(SearchRequest searchRequest) { @@ -920,43 +945,47 @@ public void testReduceTopNWithFromOffset() throws Exception { 4, exc -> {} ); - int score = 100; - CountDownLatch latch = new CountDownLatch(4); - for (int i = 0; i < 4; i++) { - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - ScoreDoc[] docs = new ScoreDoc[3]; - for (int j = 0; j < docs.length; j++) { - docs[j] = new ScoreDoc(0, score--); - } - result.topDocs( - new TopDocsAndMaxScore(new TopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), docs), docs[0].score), - new DocValueFormat[0] + try { + int score = 100; + CountDownLatch latch = new CountDownLatch(4); + for (int i = 0; i < 4; i++) { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null ); - result.setShardIndex(i); - result.size(5); - result.from(5); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + ScoreDoc[] docs = new ScoreDoc[3]; + for (int j = 0; j < docs.length; j++) { + docs[j] = new ScoreDoc(0, score--); + } + result.topDocs( + new TopDocsAndMaxScore(new TopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), docs), docs[0].score), + new DocValueFormat[0] + ); + result.setShardIndex(i); + result.size(5); + result.from(5); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } } + latch.await(); + // 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5 + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + ScoreDoc[] scoreDocs = reduce.sortedTopDocs().scoreDocs(); + assertEquals(5, scoreDocs.length); + assertEquals(100.f, reduce.maxScore(), 0.0f); + assertEquals(12, reduce.totalHits().value); + assertEquals(95.0f, scoreDocs[0].score, 0.0f); + assertEquals(94.0f, scoreDocs[1].score, 0.0f); + assertEquals(93.0f, scoreDocs[2].score, 0.0f); + assertEquals(92.0f, scoreDocs[3].score, 0.0f); + assertEquals(91.0f, scoreDocs[4].score, 0.0f); + } finally { + consumer.decRef(); } - latch.await(); - // 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5 - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - ScoreDoc[] scoreDocs = reduce.sortedTopDocs().scoreDocs(); - assertEquals(5, scoreDocs.length); - assertEquals(100.f, reduce.maxScore(), 0.0f); - assertEquals(12, reduce.totalHits().value); - assertEquals(95.0f, scoreDocs[0].score, 0.0f); - assertEquals(94.0f, scoreDocs[1].score, 0.0f); - assertEquals(93.0f, scoreDocs[2].score, 0.0f); - assertEquals(92.0f, scoreDocs[3].score, 0.0f); - assertEquals(91.0f, scoreDocs[4].score, 0.0f); } public void testConsumerSortByField() throws Exception { @@ -974,41 +1003,45 @@ public void testConsumerSortByField() throws Exception { expectedNumResults, exc -> {} ); - AtomicInteger max = new AtomicInteger(); - SortField[] sortFields = { new SortField("field", SortField.Type.INT, true) }; - DocValueFormat[] docValueFormats = { DocValueFormat.RAW }; - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - int number = randomIntBetween(1, 1000); - max.updateAndGet(prev -> Math.max(prev, number)); - FieldDoc[] fieldDocs = { new FieldDoc(0, Float.NaN, new Object[] { number }) }; - TopDocs topDocs = new TopFieldDocs(new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); - result.setShardIndex(i); - result.size(size); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + AtomicInteger max = new AtomicInteger(); + SortField[] sortFields = { new SortField("field", SortField.Type.INT, true) }; + DocValueFormat[] docValueFormats = { DocValueFormat.RAW }; + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + FieldDoc[] fieldDocs = { new FieldDoc(0, Float.NaN, new Object[] { number }) }; + TopDocs topDocs = new TopFieldDocs(new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null + ); + try { + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); + result.setShardIndex(i); + result.size(size); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs().scoreDocs().length); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertEquals(max.get(), ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[0]).fields[0]); + assertTrue(reduce.sortedTopDocs().isSortedByField()); + assertEquals(1, reduce.sortedTopDocs().sortFields().length); + assertEquals("field", reduce.sortedTopDocs().sortFields()[0].getField()); + assertEquals(SortField.Type.INT, reduce.sortedTopDocs().sortFields()[0].getType()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); } - latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs().scoreDocs().length); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertEquals(max.get(), ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[0]).fields[0]); - assertTrue(reduce.sortedTopDocs().isSortedByField()); - assertEquals(1, reduce.sortedTopDocs().sortFields().length); - assertEquals("field", reduce.sortedTopDocs().sortFields()[0].getField()); - assertEquals(SortField.Type.INT, reduce.sortedTopDocs().sortFields()[0].getType()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); } public void testConsumerFieldCollapsing() throws Exception { @@ -1026,45 +1059,49 @@ public void testConsumerFieldCollapsing() throws Exception { expectedNumResults, exc -> {} ); - SortField[] sortFields = { new SortField("field", SortField.Type.STRING) }; - BytesRef a = new BytesRef("a"); - BytesRef b = new BytesRef("b"); - BytesRef c = new BytesRef("c"); - Object[] collapseValues = new Object[] { a, b, c }; - DocValueFormat[] docValueFormats = { DocValueFormat.RAW }; - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - Object[] values = { randomFrom(collapseValues) }; - FieldDoc[] fieldDocs = { new FieldDoc(0, Float.NaN, values) }; - TopDocs topDocs = new TopFieldGroups("field", new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields, values); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); - result.setShardIndex(i); - result.size(size); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + SortField[] sortFields = { new SortField("field", SortField.Type.STRING) }; + BytesRef a = new BytesRef("a"); + BytesRef b = new BytesRef("b"); + BytesRef c = new BytesRef("c"); + Object[] collapseValues = new Object[] { a, b, c }; + DocValueFormat[] docValueFormats = { DocValueFormat.RAW }; + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + Object[] values = { randomFrom(collapseValues) }; + FieldDoc[] fieldDocs = { new FieldDoc(0, Float.NaN, values) }; + TopDocs topDocs = new TopFieldGroups("field", new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields, values); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null + ); + try { + result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats); + result.setShardIndex(i); + result.size(size); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + assertEquals(3, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertEquals(a, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[0]).fields[0]); + assertEquals(b, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[1]).fields[0]); + assertEquals(c, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[2]).fields[0]); + assertTrue(reduce.sortedTopDocs().isSortedByField()); + assertEquals(1, reduce.sortedTopDocs().sortFields().length); + assertEquals("field", reduce.sortedTopDocs().sortFields()[0].getField()); + assertEquals(SortField.Type.STRING, reduce.sortedTopDocs().sortFields()[0].getType()); + assertEquals("field", reduce.sortedTopDocs().collapseField()); + assertArrayEquals(collapseValues, reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); } - latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - assertEquals(3, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertEquals(a, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[0]).fields[0]); - assertEquals(b, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[1]).fields[0]); - assertEquals(c, ((FieldDoc) reduce.sortedTopDocs().scoreDocs()[2]).fields[0]); - assertTrue(reduce.sortedTopDocs().isSortedByField()); - assertEquals(1, reduce.sortedTopDocs().sortFields().length); - assertEquals("field", reduce.sortedTopDocs().sortFields()[0].getField()); - assertEquals(SortField.Type.STRING, reduce.sortedTopDocs().sortFields()[0].getType()); - assertEquals("field", reduce.sortedTopDocs().collapseField()); - assertArrayEquals(collapseValues, reduce.sortedTopDocs().collapseValues()); } public void testConsumerSuggestions() throws Exception { @@ -1081,102 +1118,106 @@ public void testConsumerSuggestions() throws Exception { expectedNumResults, exc -> {} ); - int maxScoreTerm = -1; - int maxScorePhrase = -1; - int maxScoreCompletion = -1; - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", i), - new SearchShardTarget("node", new ShardId("a", "b", i), null), - null - ); - try { - List>> suggestions = - new ArrayList<>(); - { - TermSuggestion termSuggestion = new TermSuggestion("term", 1, SortBy.SCORE); - TermSuggestion.Entry entry = new TermSuggestion.Entry(new Text("entry"), 0, 10); - int numOptions = randomIntBetween(1, 10); - for (int j = 0; j < numOptions; j++) { - int score = numOptions - j; - maxScoreTerm = Math.max(maxScoreTerm, score); - entry.addOption(new TermSuggestion.Entry.Option(new Text("option"), randomInt(), score)); + try { + int maxScoreTerm = -1; + int maxScorePhrase = -1; + int maxScoreCompletion = -1; + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", i), + new SearchShardTarget("node", new ShardId("a", "b", i), null), + null + ); + try { + List>> suggestions = + new ArrayList<>(); + { + TermSuggestion termSuggestion = new TermSuggestion("term", 1, SortBy.SCORE); + TermSuggestion.Entry entry = new TermSuggestion.Entry(new Text("entry"), 0, 10); + int numOptions = randomIntBetween(1, 10); + for (int j = 0; j < numOptions; j++) { + int score = numOptions - j; + maxScoreTerm = Math.max(maxScoreTerm, score); + entry.addOption(new TermSuggestion.Entry.Option(new Text("option"), randomInt(), score)); + } + termSuggestion.addTerm(entry); + suggestions.add(termSuggestion); } - termSuggestion.addTerm(entry); - suggestions.add(termSuggestion); - } - { - PhraseSuggestion phraseSuggestion = new PhraseSuggestion("phrase", 1); - PhraseSuggestion.Entry entry = new PhraseSuggestion.Entry(new Text("entry"), 0, 10); - int numOptions = randomIntBetween(1, 10); - for (int j = 0; j < numOptions; j++) { - int score = numOptions - j; - maxScorePhrase = Math.max(maxScorePhrase, score); - entry.addOption(new PhraseSuggestion.Entry.Option(new Text("option"), new Text("option"), score)); + { + PhraseSuggestion phraseSuggestion = new PhraseSuggestion("phrase", 1); + PhraseSuggestion.Entry entry = new PhraseSuggestion.Entry(new Text("entry"), 0, 10); + int numOptions = randomIntBetween(1, 10); + for (int j = 0; j < numOptions; j++) { + int score = numOptions - j; + maxScorePhrase = Math.max(maxScorePhrase, score); + entry.addOption(new PhraseSuggestion.Entry.Option(new Text("option"), new Text("option"), score)); + } + phraseSuggestion.addTerm(entry); + suggestions.add(phraseSuggestion); } - phraseSuggestion.addTerm(entry); - suggestions.add(phraseSuggestion); - } - { - CompletionSuggestion completionSuggestion = new CompletionSuggestion("completion", 1, false); - CompletionSuggestion.Entry entry = new CompletionSuggestion.Entry(new Text("entry"), 0, 10); - int numOptions = randomIntBetween(1, 10); - for (int j = 0; j < numOptions; j++) { - int score = numOptions - j; - maxScoreCompletion = Math.max(maxScoreCompletion, score); - CompletionSuggestion.Entry.Option option = new CompletionSuggestion.Entry.Option( - j, - new Text("option"), - score, - Collections.emptyMap() - ); - entry.addOption(option); + { + CompletionSuggestion completionSuggestion = new CompletionSuggestion("completion", 1, false); + CompletionSuggestion.Entry entry = new CompletionSuggestion.Entry(new Text("entry"), 0, 10); + int numOptions = randomIntBetween(1, 10); + for (int j = 0; j < numOptions; j++) { + int score = numOptions - j; + maxScoreCompletion = Math.max(maxScoreCompletion, score); + CompletionSuggestion.Entry.Option option = new CompletionSuggestion.Entry.Option( + j, + new Text("option"), + score, + Collections.emptyMap() + ); + entry.addOption(option); + } + completionSuggestion.addTerm(entry); + suggestions.add(completionSuggestion); } - completionSuggestion.addTerm(entry); - suggestions.add(completionSuggestion); + result.suggest(new Suggest(suggestions)); + result.topDocs(new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN), new DocValueFormat[0]); + result.setShardIndex(i); + result.size(0); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); } - result.suggest(new Suggest(suggestions)); - result.topDocs(new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN), new DocValueFormat[0]); - result.setShardIndex(i); - result.size(0); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertEquals(3, reduce.suggest().size()); + { + TermSuggestion term = reduce.suggest().getSuggestion("term"); + assertEquals(1, term.getEntries().size()); + assertEquals(1, term.getEntries().get(0).getOptions().size()); + assertEquals(maxScoreTerm, term.getEntries().get(0).getOptions().get(0).getScore(), 0f); + } + { + PhraseSuggestion phrase = reduce.suggest().getSuggestion("phrase"); + assertEquals(1, phrase.getEntries().size()); + assertEquals(1, phrase.getEntries().get(0).getOptions().size()); + assertEquals(maxScorePhrase, phrase.getEntries().get(0).getOptions().get(0).getScore(), 0f); + } + { + CompletionSuggestion completion = reduce.suggest().getSuggestion("completion"); + assertEquals(1, completion.getSize()); + assertEquals(1, completion.getOptions().size()); + CompletionSuggestion.Entry.Option option = completion.getOptions().get(0); + assertEquals(maxScoreCompletion, option.getScore(), 0f); + } + assertAggReduction(request); + assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(maxScoreCompletion, reduce.sortedTopDocs().scoreDocs()[0].score, 0f); + assertEquals(0, reduce.sortedTopDocs().scoreDocs()[0].doc); + assertNotEquals(-1, reduce.sortedTopDocs().scoreDocs()[0].shardIndex); + assertEquals(0, reduce.totalHits().value); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); + } finally { + consumer.decRef(); } - latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertEquals(3, reduce.suggest().size()); - { - TermSuggestion term = reduce.suggest().getSuggestion("term"); - assertEquals(1, term.getEntries().size()); - assertEquals(1, term.getEntries().get(0).getOptions().size()); - assertEquals(maxScoreTerm, term.getEntries().get(0).getOptions().get(0).getScore(), 0f); - } - { - PhraseSuggestion phrase = reduce.suggest().getSuggestion("phrase"); - assertEquals(1, phrase.getEntries().size()); - assertEquals(1, phrase.getEntries().get(0).getOptions().size()); - assertEquals(maxScorePhrase, phrase.getEntries().get(0).getOptions().get(0).getScore(), 0f); - } - { - CompletionSuggestion completion = reduce.suggest().getSuggestion("completion"); - assertEquals(1, completion.getSize()); - assertEquals(1, completion.getOptions().size()); - CompletionSuggestion.Entry.Option option = completion.getOptions().get(0); - assertEquals(maxScoreCompletion, option.getScore(), 0f); - } - assertAggReduction(request); - assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(maxScoreCompletion, reduce.sortedTopDocs().scoreDocs()[0].score, 0f); - assertEquals(0, reduce.sortedTopDocs().scoreDocs()[0].doc); - assertNotEquals(-1, reduce.sortedTopDocs().scoreDocs()[0].shardIndex); - assertEquals(0, reduce.totalHits().value); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); } public void testProgressListener() throws Exception { @@ -1224,63 +1265,67 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna expectedNumResults, exc -> {} ); - AtomicInteger max = new AtomicInteger(); - Thread[] threads = new Thread[expectedNumResults]; - CountDownLatch latch = new CountDownLatch(expectedNumResults); - for (int i = 0; i < expectedNumResults; i++) { - int id = i; - threads[i] = new Thread(() -> { - int number = randomIntBetween(1, 1000); - max.updateAndGet(prev -> Math.max(prev, number)); - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId("", id), - new SearchShardTarget("node", new ShardId("a", "b", id), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), - number - ), - new DocValueFormat[0] - ); - InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + try { + AtomicInteger max = new AtomicInteger(); + Thread[] threads = new Thread[expectedNumResults]; + CountDownLatch latch = new CountDownLatch(expectedNumResults); + for (int i = 0; i < expectedNumResults; i++) { + int id = i; + threads[i] = new Thread(() -> { + int number = randomIntBetween(1, 1000); + max.updateAndGet(prev -> Math.max(prev, number)); + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("", id), + new SearchShardTarget("node", new ShardId("a", "b", id), null), + null ); - result.aggregations(aggs); - result.setShardIndex(id); - result.size(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); - } - }); - threads[i].start(); - } - for (int i = 0; i < expectedNumResults; i++) { - threads[i].join(); - } - latch.await(); - SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); - assertAggReduction(request); - Max internalMax = (Max) reduce.aggregations().asList().get(0); - assertEquals(max.get(), internalMax.value(), 0.0D); - assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); - assertEquals(max.get(), reduce.maxScore(), 0.0f); - assertEquals(expectedNumResults, reduce.totalHits().value); - assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); - assertFalse(reduce.sortedTopDocs().isSortedByField()); - assertNull(reduce.sortedTopDocs().sortFields()); - assertNull(reduce.sortedTopDocs().collapseField()); - assertNull(reduce.sortedTopDocs().collapseValues()); + try { + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, number) }), + number + ), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from( + Collections.singletonList(new Max("test", (double) number, DocValueFormat.RAW, Collections.emptyMap())) + ); + result.aggregations(aggs); + result.setShardIndex(id); + result.size(1); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } + }); + threads[i].start(); + } + for (int i = 0; i < expectedNumResults; i++) { + threads[i].join(); + } + latch.await(); + SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce(); + assertAggReduction(request); + Max internalMax = (Max) reduce.aggregations().asList().get(0); + assertEquals(max.get(), internalMax.value(), 0.0D); + assertEquals(1, reduce.sortedTopDocs().scoreDocs().length); + assertEquals(max.get(), reduce.maxScore(), 0.0f); + assertEquals(expectedNumResults, reduce.totalHits().value); + assertEquals(max.get(), reduce.sortedTopDocs().scoreDocs()[0].score, 0.0f); + assertFalse(reduce.sortedTopDocs().isSortedByField()); + assertNull(reduce.sortedTopDocs().sortFields()); + assertNull(reduce.sortedTopDocs().collapseField()); + assertNull(reduce.sortedTopDocs().collapseValues()); - assertEquals(reduce.aggregations(), finalAggsListener.get()); - assertEquals(reduce.totalHits(), totalHitsListener.get()); + assertEquals(reduce.aggregations(), finalAggsListener.get()); + assertEquals(reduce.totalHits(), totalHitsListener.get()); - assertEquals(expectedNumResults, numQueryResultListener.get()); - assertEquals(0, numQueryFailureListener.get()); - assertEquals(numReduceListener.get(), reduce.numReducePhases()); + assertEquals(expectedNumResults, numQueryResultListener.get()); + assertEquals(0, numQueryFailureListener.get()); + assertEquals(numReduceListener.get(), reduce.numReducePhases()); + } finally { + consumer.decRef(); + } } } @@ -1311,55 +1356,58 @@ private void testReduceCase(int numShards, int bufferSize, boolean shouldFail) t numShards, exc -> hasConsumedFailure.set(true) ); - CountDownLatch latch = new CountDownLatch(numShards); - Thread[] threads = new Thread[numShards]; - for (int i = 0; i < numShards; i++) { - final int index = i; - threads[index] = new Thread(() -> { - QuerySearchResult result = new QuerySearchResult( - new ShardSearchContextId(UUIDs.randomBase64UUID(), index), - new SearchShardTarget("node", new ShardId("a", "b", index), null), - null - ); - try { - result.topDocs( - new TopDocsAndMaxScore( - new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), - Float.NaN - ), - new DocValueFormat[0] - ); - InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new Max("test", 0d, DocValueFormat.RAW, Collections.emptyMap())) + try { + CountDownLatch latch = new CountDownLatch(numShards); + Thread[] threads = new Thread[numShards]; + for (int i = 0; i < numShards; i++) { + final int index = i; + threads[index] = new Thread(() -> { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId(UUIDs.randomBase64UUID(), index), + new SearchShardTarget("node", new ShardId("a", "b", index), null), + null ); - result.aggregations(aggs); - result.setShardIndex(index); - result.size(1); - consumer.consumeResult(result, latch::countDown); - } finally { - result.decRef(); + try { + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), + Float.NaN + ), + new DocValueFormat[0] + ); + InternalAggregations aggs = InternalAggregations.from( + Collections.singletonList(new Max("test", 0d, DocValueFormat.RAW, Collections.emptyMap())) + ); + result.aggregations(aggs); + result.setShardIndex(index); + result.size(1); + consumer.consumeResult(result, latch::countDown); + } finally { + result.decRef(); + } + }); + threads[index].start(); + } + for (int i = 0; i < numShards; i++) { + threads[i].join(); + } + latch.await(); + if (shouldFail) { + if (shouldFailPartial == false) { + circuitBreaker.shouldBreak.set(true); + } else { + circuitBreaker.shouldBreak.set(false); } - }); - threads[index].start(); - } - for (int i = 0; i < numShards; i++) { - threads[i].join(); - } - latch.await(); - if (shouldFail) { - if (shouldFailPartial == false) { - circuitBreaker.shouldBreak.set(true); - } else { + CircuitBreakingException exc = expectThrows(CircuitBreakingException.class, () -> consumer.reduce()); + assertEquals(shouldFailPartial, hasConsumedFailure.get()); + assertThat(exc.getMessage(), containsString("")); circuitBreaker.shouldBreak.set(false); + } else { + consumer.reduce(); } - CircuitBreakingException exc = expectThrows(CircuitBreakingException.class, () -> consumer.reduce()); - assertEquals(shouldFailPartial, hasConsumedFailure.get()); - assertThat(exc.getMessage(), containsString("")); - circuitBreaker.shouldBreak.set(false); - } else { - consumer.reduce(); + } finally { + consumer.decRef(); } - consumer.close(); assertThat(circuitBreaker.allocated, equalTo(0L)); } @@ -1371,17 +1419,16 @@ public void testFailConsumeAggs() throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); AtomicBoolean hasConsumedFailure = new AtomicBoolean(); - try ( - QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults( - fixedExecutor, - new NoopCircuitBreaker(CircuitBreaker.REQUEST), - () -> false, - SearchProgressListener.NOOP, - request, - expectedNumResults, - exc -> hasConsumedFailure.set(true) - ) - ) { + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults( + fixedExecutor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + () -> false, + SearchProgressListener.NOOP, + request, + expectedNumResults, + exc -> hasConsumedFailure.set(true) + ); + try { for (int i = 0; i < expectedNumResults; i++) { final int index = i; QuerySearchResult result = new QuerySearchResult( @@ -1406,6 +1453,8 @@ public void testFailConsumeAggs() throws Exception { } } assertNull(consumer.reduce().aggregations()); + } finally { + consumer.decRef(); } } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index 7270326933dea..3097376de7a41 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -203,59 +203,63 @@ public void sendExecuteQuery( shardsIter.size(), exc -> {} ); - SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( - logger, - searchTransportService, - (clusterAlias, node) -> lookup.get(node), - Collections.singletonMap("_na_", AliasFilter.EMPTY), - Collections.emptyMap(), - EsExecutors.DIRECT_EXECUTOR_SERVICE, - resultConsumer, - searchRequest, - null, - shardsIter, - timeProvider, - new ClusterState.Builder(new ClusterName("test")).build(), - task, - SearchResponse.Clusters.EMPTY - ) { - @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { - return new SearchPhase("test") { - @Override - public void run() { - latch.countDown(); - } - }; - } - }; - action.start(); - latch.await(); - assertThat(successfulOps.get(), equalTo(numShards)); - if (withScroll) { - assertFalse(canReturnNullResponse.get()); - assertThat(numWithTopDocs.get(), equalTo(0)); - } else { - assertTrue(canReturnNullResponse.get()); - if (withCollapse) { + try { + SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction( + logger, + searchTransportService, + (clusterAlias, node) -> lookup.get(node), + Collections.singletonMap("_na_", AliasFilter.EMPTY), + Collections.emptyMap(), + EsExecutors.DIRECT_EXECUTOR_SERVICE, + resultConsumer, + searchRequest, + null, + shardsIter, + timeProvider, + new ClusterState.Builder(new ClusterName("test")).build(), + task, + SearchResponse.Clusters.EMPTY + ) { + @Override + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { + return new SearchPhase("test") { + @Override + public void run() { + latch.countDown(); + } + }; + } + }; + action.start(); + latch.await(); + assertThat(successfulOps.get(), equalTo(numShards)); + if (withScroll) { + assertFalse(canReturnNullResponse.get()); assertThat(numWithTopDocs.get(), equalTo(0)); } else { - assertThat(numWithTopDocs.get(), greaterThanOrEqualTo(1)); + assertTrue(canReturnNullResponse.get()); + if (withCollapse) { + assertThat(numWithTopDocs.get(), equalTo(0)); + } else { + assertThat(numWithTopDocs.get(), greaterThanOrEqualTo(1)); + } } + SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); + assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); + if (withScroll) { + assertThat(phase.totalHits().value, equalTo((long) numShards)); + assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + } else { + assertThat(phase.totalHits().value, equalTo(2L)); + assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); + } + assertThat(phase.sortedTopDocs().scoreDocs().length, equalTo(1)); + assertThat(phase.sortedTopDocs().scoreDocs()[0], instanceOf(FieldDoc.class)); + assertThat(((FieldDoc) phase.sortedTopDocs().scoreDocs()[0]).fields.length, equalTo(1)); + assertThat(((FieldDoc) phase.sortedTopDocs().scoreDocs()[0]).fields[0], equalTo(0)); + } finally { + resultConsumer.decRef(); } - SearchPhaseController.ReducedQueryPhase phase = action.results.reduce(); - assertThat(phase.numReducePhases(), greaterThanOrEqualTo(1)); - if (withScroll) { - assertThat(phase.totalHits().value, equalTo((long) numShards)); - assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - } else { - assertThat(phase.totalHits().value, equalTo(2L)); - assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)); - } - assertThat(phase.sortedTopDocs().scoreDocs().length, equalTo(1)); - assertThat(phase.sortedTopDocs().scoreDocs()[0], instanceOf(FieldDoc.class)); - assertThat(((FieldDoc) phase.sortedTopDocs().scoreDocs()[0]).fields.length, equalTo(1)); - assertThat(((FieldDoc) phase.sortedTopDocs().scoreDocs()[0]).fields[0], equalTo(0)); } @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101932") diff --git a/server/src/test/java/org/elasticsearch/search/profile/SearchProfileResultsBuilderTests.java b/server/src/test/java/org/elasticsearch/search/profile/SearchProfileResultsBuilderTests.java index 9d801f0303386..66879e5a90a3f 100644 --- a/server/src/test/java/org/elasticsearch/search/profile/SearchProfileResultsBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/profile/SearchProfileResultsBuilderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportMessage; import java.util.List; import java.util.Map; @@ -30,24 +31,32 @@ public void testFetchWithoutQuery() { randomValueOtherThanMany(searchPhase::containsKey, SearchProfileResultsBuilderTests::randomTarget), null ); - Exception e = expectThrows(IllegalStateException.class, () -> builder(searchPhase).build(List.of(fetchPhase))); - assertThat( - e.getMessage(), - matchesPattern( - "Profile returned fetch phase information for .+ but didn't return query phase information\\. Query phase keys were .+" - ) - ); + try { + Exception e = expectThrows(IllegalStateException.class, () -> builder(searchPhase).build(List.of(fetchPhase))); + assertThat( + e.getMessage(), + matchesPattern( + "Profile returned fetch phase information for .+ but didn't return query phase information\\. Query phase keys were .+" + ) + ); + } finally { + fetchPhase.decRef(); + } } public void testQueryWithoutAnyFetch() { Map searchPhase = randomSearchPhaseResults(between(1, 2)); FetchSearchResult fetchPhase = fetchResult(searchPhase.keySet().iterator().next(), null); - SearchProfileResults result = builder(searchPhase).build(List.of(fetchPhase)); - assertThat( - result.getShardResults().values().stream().filter(r -> r.getQueryPhase() != null).count(), - equalTo((long) searchPhase.size()) - ); - assertThat(result.getShardResults().values().stream().filter(r -> r.getFetchPhase() != null).count(), equalTo(0L)); + try { + SearchProfileResults result = builder(searchPhase).build(List.of(fetchPhase)); + assertThat( + result.getShardResults().values().stream().filter(r -> r.getQueryPhase() != null).count(), + equalTo((long) searchPhase.size()) + ); + assertThat(result.getShardResults().values().stream().filter(r -> r.getFetchPhase() != null).count(), equalTo(0L)); + } finally { + fetchPhase.decRef(); + } } public void testQueryAndFetch() { @@ -56,15 +65,19 @@ public void testQueryAndFetch() { .stream() .map(e -> fetchResult(e.getKey(), new ProfileResult("fetch", "", Map.of(), Map.of(), 1, List.of()))) .collect(toList()); - SearchProfileResults result = builder(searchPhase).build(fetchPhase); - assertThat( - result.getShardResults().values().stream().filter(r -> r.getQueryPhase() != null).count(), - equalTo((long) searchPhase.size()) - ); - assertThat( - result.getShardResults().values().stream().filter(r -> r.getFetchPhase() != null).count(), - equalTo((long) searchPhase.size()) - ); + try { + SearchProfileResults result = builder(searchPhase).build(fetchPhase); + assertThat( + result.getShardResults().values().stream().filter(r -> r.getQueryPhase() != null).count(), + equalTo((long) searchPhase.size()) + ); + assertThat( + result.getShardResults().values().stream().filter(r -> r.getFetchPhase() != null).count(), + equalTo((long) searchPhase.size()) + ); + } finally { + fetchPhase.forEach(TransportMessage::decRef); + } } private static Map randomSearchPhaseResults(int size) { From 6af01d05678829fd44253b0a506de3a37b29ed02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Istv=C3=A1n=20Zolt=C3=A1n=20Szab=C3=B3?= Date: Tue, 21 Nov 2023 12:21:34 +0100 Subject: [PATCH 12/22] [DOCS] Adds Search Labs links to the ES landing page (#102401) * [DOCS] Adds Search Labs links to the ES landing page. * [DOCS] Addresses feedback. --- docs/reference/landing-page.asciidoc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/reference/landing-page.asciidoc b/docs/reference/landing-page.asciidoc index a53a5770fe030..5d955acb0c6c3 100644 --- a/docs/reference/landing-page.asciidoc +++ b/docs/reference/landing-page.asciidoc @@ -215,6 +215,12 @@
  • Plugins and integrations
  • +
  • + Search Labs +
  • +
  • + Notebook examples +
  • From eb443417eab689ac833eec00f628ba2d59a449b3 Mon Sep 17 00:00:00 2001 From: Venkata Krishnan <110184468+pernelkanic@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:05:19 +0530 Subject: [PATCH 13/22] Update scripted-metric-aggregation.asciidoc (#101899) --- .../aggregations/metrics/scripted-metric-aggregation.asciidoc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc index 2bedcd4698b42..d7d837b2f8364 100644 --- a/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/scripted-metric-aggregation.asciidoc @@ -127,8 +127,7 @@ init_script:: Executed prior to any collection of documents. Allows the ag + In the above example, the `init_script` creates an array `transactions` in the `state` object. -map_script:: Executed once per document collected. This is a required script. If no combine_script is specified, the resulting state - needs to be stored in the `state` object. +map_script:: Executed once per document collected. This is a required script. + In the above example, the `map_script` checks the value of the type field. If the value is 'sale' the value of the amount field is added to the transactions array. If the value of the type field is not 'sale' the negated value of the amount field is added From bd10775b025ea5ae2062388aaec263080f0695b9 Mon Sep 17 00:00:00 2001 From: Albert Zaharovits Date: Tue, 21 Nov 2023 14:11:08 +0200 Subject: [PATCH 14/22] Grant API Key API with JWTs (#101904) Introduces support for JWTs to the grant API Key API. Callers can now pass-in a JWT in the request, like: POST /_security/api_key/grant { "grant_type": "access_token", "access_token" : "some.signed.JWT", "client_authentication": { // optional "scheme": "SharedSecret", "value": "ES-Client-Authentication header value after scheme" } } The JWT will be authenticated by a backing JWT realm and a new API Key will be returned for the authenticated user. --- docs/changelog/101904.yaml | 5 + .../rest-api/security/grant-api-keys.asciidoc | 43 +++- .../authentication/jwt-realm.asciidoc | 6 +- .../org/elasticsearch/TransportVersions.java | 1 + x-pack/plugin/core/build.gradle | 23 +- .../licenses/nimbus-jose-jwt-LICENSE.txt | 0 .../licenses/nimbus-jose-jwt-NOTICE.txt | 0 .../core/src/main/java/module-info.java | 1 + .../xpack/core/security/action/Grant.java | 67 ++++- .../authc/jwt/JwtAuthenticationToken.java | 12 +- .../security/authc/jwt/JwtRealmSettings.java | 2 + .../core}/security/authc/jwt/JwtUtil.java | 3 +- .../src/javaRestTest/resources/jwk/README.md | 2 +- .../xpack/security/authc/JwtRealmAuthIT.java | 3 +- .../authc/apikey/ApiKeySingleNodeTests.java | 117 +++++++++ .../authc/jwt/JwtRealmSingleNodeTests.java | 232 ++++++++++++++---- .../profile/SecurityDomainIntegTests.java | 7 +- .../security/src/main/java/module-info.java | 1 - .../security/action/TransportGrantAction.java | 7 - .../security/authc/jwt/JwkSetLoader.java | 1 + .../security/authc/jwt/JwkValidateUtil.java | 1 + .../security/authc/jwt/JwtAuthenticator.java | 1 + .../xpack/security/authc/jwt/JwtRealm.java | 5 +- .../authc/jwt/JwtSignatureValidator.java | 3 +- .../oidc/OpenIdConnectAuthenticator.java | 2 +- .../action/apikey/RestGrantApiKeyAction.java | 27 +- .../TransportGrantApiKeyActionTests.java | 80 +++++- .../authc/jwt/JwtAuthenticatorTests.java | 1 + .../xpack/security/authc/jwt/JwtIssuer.java | 1 + .../authc/jwt/JwtRealmAuthenticateTests.java | 1 + .../authc/jwt/JwtRealmGenerateTests.java | 3 +- .../security/authc/jwt/JwtRealmInspector.java | 1 + .../security/authc/jwt/JwtRealmTestCase.java | 1 + .../xpack/security/authc/jwt/JwtTestCase.java | 2 +- .../authc/jwt/JwtTokenExtractionTests.java | 3 +- .../security/authc/jwt/JwtUtilTests.java | 1 + .../apikey/RestGrantApiKeyActionTests.java | 70 ++++++ .../xpack/security/authc/apikey/README.md | 21 ++ .../authc/apikey/rsa-private-jwkset.json | 23 ++ .../authc/apikey/rsa-public-jwkset.json | 16 ++ .../apikey/serialized-signed-RS256-jwt.txt | 1 + .../rest-api-spec/test/api_key/12_grant.yml | 36 +++ .../security/authc/jwt/JwtWithOidcAuthIT.java | 6 +- 43 files changed, 736 insertions(+), 103 deletions(-) create mode 100644 docs/changelog/101904.yaml rename x-pack/plugin/{security => core}/licenses/nimbus-jose-jwt-LICENSE.txt (100%) rename x-pack/plugin/{security => core}/licenses/nimbus-jose-jwt-NOTICE.txt (100%) rename x-pack/plugin/{security/src/main/java/org/elasticsearch/xpack => core/src/main/java/org/elasticsearch/xpack/core}/security/authc/jwt/JwtAuthenticationToken.java (91%) rename x-pack/plugin/{security/src/main/java/org/elasticsearch/xpack => core/src/main/java/org/elasticsearch/xpack/core}/security/authc/jwt/JwtUtil.java (99%) create mode 100644 x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyActionTests.java create mode 100644 x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/README.md create mode 100644 x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-private-jwkset.json create mode 100644 x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-public-jwkset.json create mode 100644 x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/serialized-signed-RS256-jwt.txt diff --git a/docs/changelog/101904.yaml b/docs/changelog/101904.yaml new file mode 100644 index 0000000000000..cad422cc52e15 --- /dev/null +++ b/docs/changelog/101904.yaml @@ -0,0 +1,5 @@ +pr: 101904 +summary: Allow granting API keys with JWT as the access_token +area: Security +type: feature +issues: [] diff --git a/docs/reference/rest-api/security/grant-api-keys.asciidoc b/docs/reference/rest-api/security/grant-api-keys.asciidoc index ad16f602d32c2..8feb6c3cd5f52 100644 --- a/docs/reference/rest-api/security/grant-api-keys.asciidoc +++ b/docs/reference/rest-api/security/grant-api-keys.asciidoc @@ -15,7 +15,7 @@ Creates an API key on behalf of another user. [[security-api-grant-api-key-prereqs]] ==== {api-prereq-title} -* To use this API, you must have the `grant_api_key` cluster privilege. +* To use this API, you must have the `grant_api_key` or the `manage_api_key` cluster privilege. [[security-api-grant-api-key-desc]] ==== {api-description-title} @@ -23,10 +23,13 @@ Creates an API key on behalf of another user. This API is similar to <>, however it creates the API key for a user that is different than the user that runs the API. -The caller must have authentication credentials (either an access token, -or a username and password) for the user on whose behalf the API key will be -created. It is not possible to use this API to create an API key without that -user's credentials. +The caller must have authentication credentials for the user on whose behalf +the API key will be created. It is not possible to use this API to create an +API key without that user's credentials. +The supported user authentication credentials types are: + * username and password + * <> + * <> The user, for whom the authentication credentials is provided, can optionally <> (impersonate) another user. @@ -55,8 +58,11 @@ The following parameters can be specified in the body of a POST request: `access_token`:: (Required*, string) -The user's access token. If you specify the `access_token` grant type, this -parameter is required. It is not valid with other grant types. +The user's <>, or JWT. Both <> and +<> JWT token types are supported, and they depend on the underlying JWT realm configuration. +The created API key will have a point in time snapshot of permissions of the user authenticated with this token +(or even more restricted permissions, see the `role_descriptors` parameter). +If you specify the `access_token` grant type, this parameter is required. It is not valid with other grant types. `api_key`:: (Required, object) @@ -83,15 +89,32 @@ It supports nested data structure. Within the `metadata` object, keys beginning with `_` are reserved for system usage. +`client_authentication`:: +(Optional, object) When using the `access_token` grant type, and when supplying a +JWT, this specifies the client authentication for <> that +need it (i.e. what's normally specified by the `ES-Client-Authentication` request header). + +`scheme`::: +(Required, string) The scheme (case-sensitive) as it's supplied in the +`ES-Client-Authentication` request header. Currently, the only supported +value is <>. + +`value`::: +(Required, string) The value that follows the scheme for the client credentials +as it's supplied in the `ES-Client-Authentication` request header. For example, +if the request header would be `ES-Client-Authentication: SharedSecret myShar3dS3cret` +if the client were to authenticate directly with a JWT, then `value` here should +be `myShar3dS3cret`. + `grant_type`:: (Required, string) The type of grant. Supported grant types are: `access_token`,`password`. `access_token`::: (Required*, string) -In this type of grant, you must supply an access token that was created by the -{es} token service. For more information, see -<> and <>. +In this type of grant, you must supply either an access token, that was created by the +{es} token service (see <> and <>), +or a <> (either a JWT `access_token` or a JWT `id_token`). `password`::: In this type of grant, you must supply the user ID and password for which you diff --git a/docs/reference/security/authentication/jwt-realm.asciidoc b/docs/reference/security/authentication/jwt-realm.asciidoc index 142c93286c2e9..68e20380449a5 100644 --- a/docs/reference/security/authentication/jwt-realm.asciidoc +++ b/docs/reference/security/authentication/jwt-realm.asciidoc @@ -123,8 +123,9 @@ Instructs the realm to treat and validate incoming JWTs as ID Tokens (`id_token` Specifies the client authentication type as `shared_secret`, which means that the client is authenticated using an HTTP request header that must match a pre-configured secret value. The client must provide this shared secret with -every request in the `ES-Client-Authentication` header. The header value must be a -case-sensitive match to the realm's `client_authentication.shared_secret`. +every request in the `ES-Client-Authentication` header and using the +`SharedSecret` scheme. The header value must be a case-sensitive match +to the realm's `client_authentication.shared_secret`. `allowed_issuer`:: Sets a verifiable identifier for your JWT issuer. This value is typically a @@ -519,6 +520,7 @@ After mapping the roles, you can make an <> to {es} using a JWT and include the `ES-Client-Authentication` header: +[[jwt-auth-shared-secret-scheme-example]] [source,sh] ---- curl -s -X GET -H "Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJhdWQiOlsiZXMwMSIsImVzMDIiLCJlczAzIl0sInN1YiI6InVzZXIyIiwiaXNzIjoibXktaXNzdWVyIiwiZXhwIjo0MDcwOTA4ODAwLCJpYXQiOjk0NjY4NDgwMCwiZW1haWwiOiJ1c2VyMkBzb21ldGhpbmcuZXhhbXBsZS5jb20ifQ.UgO_9w--EoRyUKcWM5xh9SimTfMzl1aVu6ZBsRWhxQA" -H "ES-Client-Authentication: sharedsecret test-secret" https://localhost:9200/_security/_authenticate diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5ad1d43c0d4f8..bde73ec5b801f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -173,6 +173,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_OPENAI_ADDED = def(8_542_00_0); public static final TransportVersion SHUTDOWN_MIGRATION_STATUS_INCLUDE_COUNTS = def(8_543_00_0); public static final TransportVersion TRANSFORM_GET_CHECKPOINT_QUERY_AND_CLUSTER_ADDED = def(8_544_00_0); + public static final TransportVersion GRANT_API_KEY_CLIENT_AUTHENTICATION_ADDED = def(8_545_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/build.gradle b/x-pack/plugin/core/build.gradle index 8368f6e43dfcf..b4ee3a99a754f 100644 --- a/x-pack/plugin/core/build.gradle +++ b/x-pack/plugin/core/build.gradle @@ -48,6 +48,7 @@ dependencies { // security deps api 'com.unboundid:unboundid-ldapsdk:6.0.3' + api "com.nimbusds:nimbus-jose-jwt:9.23" implementation project(":x-pack:plugin:core:template-resources") @@ -131,7 +132,27 @@ tasks.named("thirdPartyAudit").configure { //commons-logging provided dependencies 'javax.servlet.ServletContextEvent', 'javax.servlet.ServletContextListener', - 'javax.jms.Message' + 'javax.jms.Message', + // Optional dependency of nimbus-jose-jwt for handling Ed25519 signatures and ECDH with X25519 (RFC 8037) + 'com.google.crypto.tink.subtle.Ed25519Sign', + 'com.google.crypto.tink.subtle.Ed25519Sign$KeyPair', + 'com.google.crypto.tink.subtle.Ed25519Verify', + 'com.google.crypto.tink.subtle.X25519', + 'com.google.crypto.tink.subtle.XChaCha20Poly1305', + // optional dependencies for nimbus-jose-jwt + 'org.bouncycastle.asn1.pkcs.PrivateKeyInfo', + 'org.bouncycastle.asn1.x509.AlgorithmIdentifier', + 'org.bouncycastle.asn1.x509.SubjectPublicKeyInfo', + 'org.bouncycastle.cert.X509CertificateHolder', + 'org.bouncycastle.cert.jcajce.JcaX509CertificateHolder', + 'org.bouncycastle.crypto.InvalidCipherTextException', + 'org.bouncycastle.crypto.engines.AESEngine', + 'org.bouncycastle.crypto.modes.GCMBlockCipher', + 'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider', + 'org.bouncycastle.jce.provider.BouncyCastleProvider', + 'org.bouncycastle.openssl.PEMKeyPair', + 'org.bouncycastle.openssl.PEMParser', + 'org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter' ) } diff --git a/x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt b/x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt similarity index 100% rename from x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt rename to x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt diff --git a/x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt b/x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt similarity index 100% rename from x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt rename to x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index c4c978f656d21..f77b2eba19e6d 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -22,6 +22,7 @@ requires unboundid.ldapsdk; requires org.elasticsearch.tdigest; requires org.elasticsearch.xcore.templates; + requires com.nimbusds.jose.jwt; exports org.elasticsearch.index.engine.frozen; exports org.elasticsearch.license; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java index f24ddbd86c937..41f1f50b6f7f0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java @@ -7,13 +7,17 @@ package org.elasticsearch.xpack.core.security.action; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.support.BearerToken; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; @@ -33,6 +37,24 @@ public class Grant implements Writeable { private SecureString password; private SecureString accessToken; private String runAsUsername; + private ClientAuthentication clientAuthentication; + + public record ClientAuthentication(String scheme, SecureString value) implements Writeable { + + public ClientAuthentication(SecureString value) { + this(JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME, value); + } + + ClientAuthentication(StreamInput in) throws IOException { + this(in.readString(), in.readSecureString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(scheme); + out.writeSecureString(value); + } + } public Grant() {} @@ -46,6 +68,11 @@ public Grant(StreamInput in) throws IOException { } else { this.runAsUsername = null; } + if (in.getTransportVersion().onOrAfter(TransportVersions.GRANT_API_KEY_CLIENT_AUTHENTICATION_ADDED)) { + this.clientAuthentication = in.readOptionalWriteable(ClientAuthentication::new); + } else { + this.clientAuthentication = null; + } } public void writeTo(StreamOutput out) throws IOException { @@ -56,6 +83,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { out.writeOptionalString(runAsUsername); } + if (out.getTransportVersion().onOrAfter(TransportVersions.GRANT_API_KEY_CLIENT_AUTHENTICATION_ADDED)) { + out.writeOptionalWriteable(clientAuthentication); + } } public String getType() { @@ -78,6 +108,10 @@ public String getRunAsUsername() { return runAsUsername; } + public ClientAuthentication getClientAuthentication() { + return clientAuthentication; + } + public void setType(String type) { this.type = type; } @@ -98,12 +132,31 @@ public void setRunAsUsername(String runAsUsername) { this.runAsUsername = runAsUsername; } + public void setClientAuthentication(ClientAuthentication clientAuthentication) { + this.clientAuthentication = clientAuthentication; + } + public AuthenticationToken getAuthenticationToken() { assert validate(null) == null : "grant is invalid"; return switch (type) { case PASSWORD_GRANT_TYPE -> new UsernamePasswordToken(username, password); - case ACCESS_TOKEN_GRANT_TYPE -> new BearerToken(accessToken); - default -> null; + case ACCESS_TOKEN_GRANT_TYPE -> { + SecureString clientAuthentication = this.clientAuthentication != null ? this.clientAuthentication.value() : null; + AuthenticationToken token = JwtAuthenticationToken.tryParseJwt(accessToken, clientAuthentication); + if (token != null) { + yield token; + } + if (clientAuthentication != null) { + clientAuthentication.close(); + throw new ElasticsearchSecurityException( + "[client_authentication] not supported with the supplied access_token type", + RestStatus.BAD_REQUEST + ); + } + // here we effectively assume it's an ES access token (from the {@code TokenService}) + yield new BearerToken(accessToken); + } + default -> throw new ElasticsearchSecurityException("the grant type [{}] is not supported", type); }; } @@ -114,10 +167,20 @@ public ActionRequestValidationException validate(ActionRequestValidationExceptio validationException = validateRequiredField("username", username, validationException); validationException = validateRequiredField("password", password, validationException); validationException = validateUnsupportedField("access_token", accessToken, validationException); + if (clientAuthentication != null) { + return addValidationError("[client_authentication] is not supported for grant_type [" + type + "]", validationException); + } } else if (type.equals(ACCESS_TOKEN_GRANT_TYPE)) { validationException = validateRequiredField("access_token", accessToken, validationException); validationException = validateUnsupportedField("username", username, validationException); validationException = validateUnsupportedField("password", password, validationException); + if (clientAuthentication != null + && JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME.equals(clientAuthentication.scheme.trim()) == false) { + return addValidationError( + "[client_authentication.scheme] must be set to [" + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + "]", + validationException + ); + } } else { validationException = addValidationError("grant_type [" + type + "] is not supported", validationException); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java similarity index 91% rename from x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java index 9ca0ddb42e663..698938ee2f78f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.security.authc.jwt; +package org.elasticsearch.xpack.core.security.authc.jwt; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; @@ -29,11 +29,19 @@ public class JwtAuthenticationToken implements AuthenticationToken { @Nullable private final SecureString clientAuthenticationSharedSecret; + public static JwtAuthenticationToken tryParseJwt(SecureString userCredentials, @Nullable SecureString clientCredentials) { + SignedJWT signedJWT = JwtUtil.parseSignedJWT(userCredentials); + if (signedJWT == null) { + return null; + } + return new JwtAuthenticationToken(signedJWT, JwtUtil.sha256(userCredentials), clientCredentials); + } + /** * Store a mandatory JWT and optional Shared Secret. * @param signedJWT The JWT parsed from the end-user credentials * @param userCredentialsHash The hash of the end-user credentials is used to compute the key for user cache at the realm level. - * See also {@link JwtRealm#authenticate}. + * See also {@code JwtRealm#authenticate}. * @param clientAuthenticationSharedSecret URL-safe Shared Secret for Client authentication. Required by some JWT realms. */ public JwtAuthenticationToken( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java index 9a4fdae51e81b..1903dd5146f69 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java @@ -33,6 +33,8 @@ */ public class JwtRealmSettings { + public static final String HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME = "SharedSecret"; + private JwtRealmSettings() {} public static final String TYPE = "jwt"; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java similarity index 99% rename from x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java index 928ecd7fa265d..d70b76f8bc574 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.security.authc.jwt; +package org.elasticsearch.xpack.core.security.authc.jwt; import com.nimbusds.jose.JWSObject; import com.nimbusds.jose.jwk.JWK; @@ -47,7 +47,6 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.InputStream; diff --git a/x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/resources/jwk/README.md b/x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/resources/jwk/README.md index a36129e2a13ed..daf62f92bb0fa 100644 --- a/x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/resources/jwk/README.md +++ b/x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/resources/jwk/README.md @@ -6,7 +6,7 @@ These files are created by running the tests in `JwtRealmGenerateTests`. Those tests generate the yaml settings, the keystore settings and the JWK Sets -for each sample reaml +for each sample realm. Copy the output from the test output into the applicable file (you may wish to run it through `jq` first in order to make it more readable). diff --git a/x-pack/plugin/security/qa/smoke-test-all-realms/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/JwtRealmAuthIT.java b/x-pack/plugin/security/qa/smoke-test-all-realms/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/JwtRealmAuthIT.java index fb761079dc8e6..0c0218c51bacc 100644 --- a/x-pack/plugin/security/qa/smoke-test-all-realms/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/JwtRealmAuthIT.java +++ b/x-pack/plugin/security/qa/smoke-test-all-realms/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/JwtRealmAuthIT.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.xpack.core.security.authc.Authentication; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.security.authc.jwt.JwtRealm; import java.io.IOException; @@ -34,7 +35,7 @@ public void testAuthenticationUsingJwtRealm() throws IOException { final RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder() .addHeader( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_CLIENT_SECRET + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_CLIENT_SECRET ) .addHeader(JwtRealm.HEADER_END_USER_AUTHENTICATION, JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME + " " + HEADER_JWT); diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java index dcc98a8e6df7d..638eb129e89a5 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/apikey/ApiKeySingleNodeTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.QueryBuilders; @@ -64,8 +65,10 @@ import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.user.AuthenticateAction; import org.elasticsearch.xpack.core.security.action.user.AuthenticateRequest; +import org.elasticsearch.xpack.core.security.action.user.AuthenticateResponse; import org.elasticsearch.xpack.core.security.action.user.PutUserAction; import org.elasticsearch.xpack.core.security.action.user.PutUserRequest; +import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField; import org.elasticsearch.xpack.core.security.authc.support.Hasher; import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; @@ -99,6 +102,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class ApiKeySingleNodeTests extends SecuritySingleNodeTestCase { @@ -108,6 +112,19 @@ protected Settings nodeSettings() { Settings.Builder builder = Settings.builder().put(super.nodeSettings()); builder.put(XPackSettings.API_KEY_SERVICE_ENABLED_SETTING.getKey(), true); builder.put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true); + builder.put("xpack.security.authc.realms.jwt.jwt1.order", 2) + .put("xpack.security.authc.realms.jwt.jwt1.allowed_audiences", "https://audience.example.com/") + .put("xpack.security.authc.realms.jwt.jwt1.allowed_issuer", "https://issuer.example.com/") + .put( + "xpack.security.authc.realms.jwt.jwt1.pkc_jwkset_path", + getDataPath("/org/elasticsearch/xpack/security/authc/apikey/rsa-public-jwkset.json") + ) + .put("xpack.security.authc.realms.jwt.jwt1.client_authentication.type", "NONE") + .put("xpack.security.authc.realms.jwt.jwt1.claims.name", "name") + .put("xpack.security.authc.realms.jwt.jwt1.claims.dn", "dn") + .put("xpack.security.authc.realms.jwt.jwt1.claims.groups", "roles") + .put("xpack.security.authc.realms.jwt.jwt1.claims.principal", "sub") + .put("xpack.security.authc.realms.jwt.jwt1.claims.mail", "mail"); return builder.build(); } @@ -402,6 +419,106 @@ public void testGrantApiKeyForUserWithRunAs() throws IOException { ); } + public void testGrantAPIKeyFromTokens() throws IOException { + final String jwtToken; + try (var in = getDataInputStream("/org/elasticsearch/xpack/security/authc/apikey/serialized-signed-RS256-jwt.txt")) { + jwtToken = new String(in.readAllBytes(), StandardCharsets.UTF_8); + } + getSecurityClient().putRole(new RoleDescriptor("user1_role", new String[] { "manage_token" }, null, null)); + String role_mapping_rules = """ + { + "enabled": true, + "roles": "user1_role", + "rules": { + "all": [ + { + "field": { + "realm.name": "jwt1" + } + }, + { + "field": { + "username": "user1" + } + } + ] + } + } + """; + getSecurityClient().putRoleMapping( + "user1_role_mapping", + XContentHelper.convertToMap(XContentType.JSON.xContent(), role_mapping_rules, true) + ); + // grant API Key for regular ES access tokens (itself created from JWT credentials) + { + // get ES access token from JWT + final TestSecurityClient.OAuth2Token oAuth2Token = getSecurityClient( + RequestOptions.DEFAULT.toBuilder().addHeader("Authorization", "Bearer " + jwtToken).build() + ).createTokenWithClientCredentialsGrant(); + String apiKeyName = randomAlphaOfLength(8); + GrantApiKeyRequest grantApiKeyRequest = new GrantApiKeyRequest(); + grantApiKeyRequest.getGrant().setType("access_token"); + grantApiKeyRequest.getGrant().setAccessToken(new SecureString(oAuth2Token.accessToken().toCharArray())); + grantApiKeyRequest.setRefreshPolicy(randomFrom(IMMEDIATE, WAIT_UNTIL)); + grantApiKeyRequest.getApiKeyRequest().setName(apiKeyName); + CreateApiKeyResponse createApiKeyResponse = client().execute(GrantApiKeyAction.INSTANCE, grantApiKeyRequest).actionGet(); + // use the API Key to check it's legit + assertThat(createApiKeyResponse.getName(), is(apiKeyName)); + assertThat(createApiKeyResponse.getId(), notNullValue()); + assertThat(createApiKeyResponse.getKey(), notNullValue()); + final String apiKeyId = createApiKeyResponse.getId(); + final String base64ApiKeyKeyValue = Base64.getEncoder() + .encodeToString((apiKeyId + ":" + createApiKeyResponse.getKey()).getBytes(StandardCharsets.UTF_8)); + AuthenticateResponse authenticateResponse = client().filterWithHeader( + Collections.singletonMap("Authorization", "ApiKey " + base64ApiKeyKeyValue) + ).execute(AuthenticateAction.INSTANCE, AuthenticateRequest.INSTANCE).actionGet(); + assertThat(authenticateResponse.authentication().getEffectiveSubject().getUser().principal(), is("user1")); + assertThat(authenticateResponse.authentication().getAuthenticationType(), is(Authentication.AuthenticationType.API_KEY)); + // BUT client_authentication is not supported with the ES access token + { + GrantApiKeyRequest wrongGrantApiKeyRequest = new GrantApiKeyRequest(); + wrongGrantApiKeyRequest.setRefreshPolicy(randomFrom(IMMEDIATE, WAIT_UNTIL, NONE)); + wrongGrantApiKeyRequest.getApiKeyRequest().setName(randomAlphaOfLength(8)); + wrongGrantApiKeyRequest.getGrant().setType("access_token"); + wrongGrantApiKeyRequest.getGrant().setAccessToken(new SecureString(oAuth2Token.accessToken().toCharArray())); + wrongGrantApiKeyRequest.getGrant() + .setClientAuthentication(new Grant.ClientAuthentication(new SecureString("whatever".toCharArray()))); + ElasticsearchSecurityException e = expectThrows( + ElasticsearchSecurityException.class, + () -> client().execute(GrantApiKeyAction.INSTANCE, wrongGrantApiKeyRequest).actionGet() + ); + assertThat(e.getMessage(), containsString("[client_authentication] not supported with the supplied access_token type")); + } + } + // grant API Key for JWT token + { + String apiKeyName = randomAlphaOfLength(8); + GrantApiKeyRequest grantApiKeyRequest = new GrantApiKeyRequest(); + grantApiKeyRequest.getGrant().setType("access_token"); + grantApiKeyRequest.getGrant().setAccessToken(new SecureString(jwtToken.toCharArray())); + grantApiKeyRequest.setRefreshPolicy(randomFrom(IMMEDIATE, WAIT_UNTIL)); + grantApiKeyRequest.getApiKeyRequest().setName(apiKeyName); + // client authentication is ignored for JWTs that don't require it + if (randomBoolean()) { + grantApiKeyRequest.getGrant() + .setClientAuthentication(new Grant.ClientAuthentication(new SecureString("whatever".toCharArray()))); + } + CreateApiKeyResponse createApiKeyResponse = client().execute(GrantApiKeyAction.INSTANCE, grantApiKeyRequest).actionGet(); + // use the API Key to check it's legit + assertThat(createApiKeyResponse.getName(), is(apiKeyName)); + assertThat(createApiKeyResponse.getId(), notNullValue()); + assertThat(createApiKeyResponse.getKey(), notNullValue()); + final String apiKeyId = createApiKeyResponse.getId(); + final String base64ApiKeyKeyValue = Base64.getEncoder() + .encodeToString((apiKeyId + ":" + createApiKeyResponse.getKey()).getBytes(StandardCharsets.UTF_8)); + AuthenticateResponse authenticateResponse = client().filterWithHeader( + Collections.singletonMap("Authorization", "ApiKey " + base64ApiKeyKeyValue) + ).execute(AuthenticateAction.INSTANCE, AuthenticateRequest.INSTANCE).actionGet(); + assertThat(authenticateResponse.authentication().getEffectiveSubject().getUser().principal(), is("user1")); + assertThat(authenticateResponse.authentication().getAuthenticationType(), is(Authentication.AuthenticationType.API_KEY)); + } + } + public void testInvalidateApiKeyWillRecordTimestamp() { CreateApiKeyRequest createApiKeyRequest = new CreateApiKeyRequest( randomAlphaOfLengthBetween(3, 8), diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java index d84b93fa6f638..c9b43afd4322d 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java @@ -16,12 +16,14 @@ import com.nimbusds.jwt.SignedJWT; import org.apache.http.HttpEntity; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.Response; import org.elasticsearch.client.ResponseException; import org.elasticsearch.client.RestClient; import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentHelper; @@ -34,7 +36,17 @@ import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.security.action.Grant; +import org.elasticsearch.xpack.core.security.action.apikey.CreateApiKeyResponse; +import org.elasticsearch.xpack.core.security.action.apikey.GrantApiKeyAction; +import org.elasticsearch.xpack.core.security.action.apikey.GrantApiKeyRequest; +import org.elasticsearch.xpack.core.security.action.user.AuthenticateAction; +import org.elasticsearch.xpack.core.security.action.user.AuthenticateRequest; +import org.elasticsearch.xpack.core.security.action.user.AuthenticateResponse; +import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Realm; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.security.LocalStateSecurity; import org.elasticsearch.xpack.security.Security; import org.elasticsearch.xpack.security.authc.Realms; @@ -44,17 +56,22 @@ import java.text.ParseException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.Base64; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.WAIT_UNTIL; import static org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings.CLIENT_AUTH_SHARED_SECRET_ROTATION_GRACE_PERIOD; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class JwtRealmSingleNodeTests extends SecuritySingleNodeTestCase { @@ -134,6 +151,76 @@ protected boolean addMockHttpTransport() { return false; } + @TestLogging(value = "org.elasticsearch.xpack.security.authc.jwt:DEBUG", reason = "failures can be very difficult to troubleshoot") + public void testGrantApiKeyForJWT() throws Exception { + final JWTClaimsSet.Builder jwtClaims = new JWTClaimsSet.Builder(); + final String subject; + final String sharedSecret; + // id_token or access_token + if (randomBoolean()) { + subject = "me"; + // JWT "id_token" valid for jwt0 + jwtClaims.audience("es-01") + .issuer("my-issuer-01") + .subject(subject) + .claim("groups", "admin") + .issueTime(Date.from(Instant.now())) + .expirationTime(Date.from(Instant.now().plusSeconds(600))) + .build(); + sharedSecret = jwt0SharedSecret; + } else { + subject = "me@example.com"; + // JWT "access_token" valid for jwt2 + jwtClaims.audience("es-03") + .issuer("my-issuer-03") + .subject("user-03") + .claim("groups", "admin") + .claim("email", subject) + .issueTime(Date.from(Instant.now())) + .expirationTime(Date.from(Instant.now().plusSeconds(300))); + sharedSecret = jwt2SharedSecret; + } + { + // JWT is valid but the client authentication is NOT + GrantApiKeyRequest grantApiKeyRequest = getGrantApiKeyForJWT(getSignedJWT(jwtClaims.build()), randomFrom("WRONG", null)); + ElasticsearchSecurityException e = expectThrows( + ElasticsearchSecurityException.class, + () -> client().execute(GrantApiKeyAction.INSTANCE, grantApiKeyRequest).actionGet() + ); + assertThat(e.getMessage(), containsString("unable to authenticate user")); + } + { + // both JWT and client authentication are valid + GrantApiKeyRequest grantApiKeyRequest = getGrantApiKeyForJWT(getSignedJWT(jwtClaims.build()), sharedSecret); + CreateApiKeyResponse createApiKeyResponse = client().execute(GrantApiKeyAction.INSTANCE, grantApiKeyRequest).actionGet(); + assertThat(createApiKeyResponse.getId(), notNullValue()); + assertThat(createApiKeyResponse.getKey(), notNullValue()); + assertThat(createApiKeyResponse.getName(), is(grantApiKeyRequest.getApiKeyRequest().getName())); + final String base64ApiKeyKeyValue = Base64.getEncoder() + .encodeToString((createApiKeyResponse.getId() + ":" + createApiKeyResponse.getKey()).getBytes(StandardCharsets.UTF_8)); + AuthenticateResponse authenticateResponse = client().filterWithHeader(Map.of("Authorization", "ApiKey " + base64ApiKeyKeyValue)) + .execute(AuthenticateAction.INSTANCE, AuthenticateRequest.INSTANCE) + .actionGet(); + assertThat(authenticateResponse.authentication().getEffectiveSubject().getUser().principal(), is(subject)); + assertThat(authenticateResponse.authentication().getAuthenticationType(), is(Authentication.AuthenticationType.API_KEY)); + } + { + // client authentication is valid but the JWT is not + final SignedJWT wrongJWT; + if (randomBoolean()) { + wrongJWT = getSignedJWT(jwtClaims.build(), ("wrong key that's longer than 256 bits").getBytes(StandardCharsets.UTF_8)); + } else { + wrongJWT = getSignedJWT(jwtClaims.audience("wrong audience claim value").build()); + } + GrantApiKeyRequest grantApiKeyRequest = getGrantApiKeyForJWT(wrongJWT, sharedSecret); + ElasticsearchSecurityException e = expectThrows( + ElasticsearchSecurityException.class, + () -> client().execute(GrantApiKeyAction.INSTANCE, grantApiKeyRequest).actionGet() + ); + assertThat(e.getMessage(), containsString("unable to authenticate user")); + } + } + @SuppressWarnings("unchecked") public void testInvalidJWTDoesNotFallbackToAnonymousAccess() throws Exception { // anonymous access works when no valid Bearer @@ -327,73 +414,101 @@ public void testClientSecretRotation() throws Exception { 200, client.performRequest(getRequest(getSignedJWT(jwt2Claims.build()), jwt2SharedSecret)).getStatusLine().getStatusCode() ); - // update the secret in the secure settings - final MockSecureSettings newSecureSettings = new MockSecureSettings(); - newSecureSettings.setString( - "xpack.security.authc.realms.jwt." + realm0.name() + ".client_authentication.shared_secret", - "realm0updatedSecret" - ); - newSecureSettings.setString( - "xpack.security.authc.realms.jwt." + realm1.name() + ".client_authentication.shared_secret", - "realm1updatedSecret" - ); - newSecureSettings.setString( - "xpack.security.authc.realms.jwt." + realm2.name() + ".client_authentication.shared_secret", - "realm2updatedSecret" - ); - // reload settings final PluginsService plugins = getInstanceFromNode(PluginsService.class); final LocalStateSecurity localStateSecurity = plugins.filterPlugins(LocalStateSecurity.class).findFirst().get(); - for (Plugin p : localStateSecurity.plugins()) { - if (p instanceof Security securityPlugin) { - Settings.Builder newSettingsBuilder = Settings.builder().setSecureSettings(newSecureSettings); - securityPlugin.reload(newSettingsBuilder.build()); + // update the secret in the secure settings + try { + final MockSecureSettings newSecureSettings = new MockSecureSettings(); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm0.name() + ".client_authentication.shared_secret", + "realm0updatedSecret" + ); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm1.name() + ".client_authentication.shared_secret", + "realm1updatedSecret" + ); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm2.name() + ".client_authentication.shared_secret", + "realm2updatedSecret" + ); + // reload settings + for (Plugin p : localStateSecurity.plugins()) { + if (p instanceof Security securityPlugin) { + Settings.Builder newSettingsBuilder = Settings.builder().setSecureSettings(newSecureSettings); + securityPlugin.reload(newSettingsBuilder.build()); + } + } + // ensure the old value still works for realm 0 (default grace period) + assertEquals( + 200, + client.performRequest(getRequest(getSignedJWT(jwt0Claims.build()), jwt0SharedSecret)).getStatusLine().getStatusCode() + ); + assertEquals( + 200, + client.performRequest(getRequest(getSignedJWT(jwt0Claims.build()), "realm0updatedSecret")).getStatusLine().getStatusCode() + ); + // ensure the old value still works for realm 1 (explicit grace period) + assertEquals( + 200, + client.performRequest(getRequest(getSignedJWT(jwt1Claims.build()), jwt1SharedSecret)).getStatusLine().getStatusCode() + ); + assertEquals( + 200, + client.performRequest(getRequest(getSignedJWT(jwt1Claims.build()), "realm1updatedSecret")).getStatusLine().getStatusCode() + ); + // ensure the old value does not work for realm 2 (no grace period) + ResponseException exception = expectThrows( + ResponseException.class, + () -> client.performRequest(getRequest(getSignedJWT(jwt2Claims.build()), jwt2SharedSecret)).getStatusLine().getStatusCode() + ); + assertEquals(401, exception.getResponse().getStatusLine().getStatusCode()); + assertEquals( + 200, + client.performRequest(getRequest(getSignedJWT(jwt2Claims.build()), "realm2updatedSecret")).getStatusLine().getStatusCode() + ); + } finally { + // update them back to their original values + final MockSecureSettings newSecureSettings = new MockSecureSettings(); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm0.name() + ".client_authentication.shared_secret", + jwt0SharedSecret + ); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm1.name() + ".client_authentication.shared_secret", + jwt1SharedSecret + ); + newSecureSettings.setString( + "xpack.security.authc.realms.jwt." + realm2.name() + ".client_authentication.shared_secret", + jwt2SharedSecret + ); + // reload settings + for (Plugin p : localStateSecurity.plugins()) { + if (p instanceof Security securityPlugin) { + Settings.Builder newSettingsBuilder = Settings.builder().setSecureSettings(newSecureSettings); + securityPlugin.reload(newSettingsBuilder.build()); + } } } - // ensure the old value still works for realm 0 (default grace period) - assertEquals( - 200, - client.performRequest(getRequest(getSignedJWT(jwt0Claims.build()), jwt0SharedSecret)).getStatusLine().getStatusCode() - ); - assertEquals( - 200, - client.performRequest(getRequest(getSignedJWT(jwt0Claims.build()), "realm0updatedSecret")).getStatusLine().getStatusCode() - ); - // ensure the old value still works for realm 1 (explicit grace period) - assertEquals( - 200, - client.performRequest(getRequest(getSignedJWT(jwt1Claims.build()), jwt1SharedSecret)).getStatusLine().getStatusCode() - ); - assertEquals( - 200, - client.performRequest(getRequest(getSignedJWT(jwt1Claims.build()), "realm1updatedSecret")).getStatusLine().getStatusCode() - ); - // ensure the old value does not work for realm 2 (no grace period) - ResponseException exception = expectThrows( - ResponseException.class, - () -> client.performRequest(getRequest(getSignedJWT(jwt2Claims.build()), jwt2SharedSecret)).getStatusLine().getStatusCode() - ); - assertEquals(401, exception.getResponse().getStatusLine().getStatusCode()); - assertEquals( - 200, - client.performRequest(getRequest(getSignedJWT(jwt2Claims.build()), "realm2updatedSecret")).getStatusLine().getStatusCode() - ); } - private SignedJWT getSignedJWT(JWTClaimsSet claimsSet) throws Exception { + private SignedJWT getSignedJWT(JWTClaimsSet claimsSet, byte[] hmacKeyBytes) throws Exception { JWSHeader jwtHeader = new JWSHeader.Builder(JWSAlgorithm.HS256).build(); - OctetSequenceKey.Builder jwt0signer = new OctetSequenceKey.Builder(jwtHmacKey.getBytes(StandardCharsets.UTF_8)); + OctetSequenceKey.Builder jwt0signer = new OctetSequenceKey.Builder(hmacKeyBytes); jwt0signer.algorithm(JWSAlgorithm.HS256); SignedJWT jwt = new SignedJWT(jwtHeader, claimsSet); jwt.sign(new MACSigner(jwt0signer.build())); return jwt; } - private Request getRequest(SignedJWT jwt, String shardSecret) { + private SignedJWT getSignedJWT(JWTClaimsSet claimsSet) throws Exception { + return getSignedJWT(claimsSet, jwtHmacKey.getBytes(StandardCharsets.UTF_8)); + } + + private Request getRequest(SignedJWT jwt, String sharedSecret) { Request request = new Request("GET", "/_security/_authenticate"); RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder(); options.addHeader("Authorization", "Bearer " + jwt.serialize()); - options.addHeader("ES-Client-Authentication", "SharedSecret " + shardSecret); + options.addHeader("ES-Client-Authentication", "SharedSecret " + sharedSecret); request.setOptions(options); return request; } @@ -446,9 +561,22 @@ private ThreadContext prepareThreadContext(SignedJWT signedJWT, String clientSec if (clientSecret != null) { threadContext.putHeader( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + clientSecret + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + clientSecret ); } return threadContext; } + + private static GrantApiKeyRequest getGrantApiKeyForJWT(SignedJWT signedJWT, String sharedSecret) { + GrantApiKeyRequest grantApiKeyRequest = new GrantApiKeyRequest(); + grantApiKeyRequest.getGrant().setType("access_token"); + grantApiKeyRequest.getGrant().setAccessToken(new SecureString(signedJWT.serialize().toCharArray())); + if (sharedSecret != null) { + grantApiKeyRequest.getGrant() + .setClientAuthentication(new Grant.ClientAuthentication("SharedSecret", new SecureString(sharedSecret.toCharArray()))); + } + grantApiKeyRequest.setRefreshPolicy(randomFrom(IMMEDIATE, WAIT_UNTIL)); + grantApiKeyRequest.getApiKeyRequest().setName(randomAlphaOfLength(8)); + return grantApiKeyRequest; + } } diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/SecurityDomainIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/SecurityDomainIntegTests.java index 0892c6f88873f..8d025972fdeeb 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/SecurityDomainIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/SecurityDomainIntegTests.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.RefreshTokenAction; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.ExpressionParser; import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; import org.elasticsearch.xpack.security.authc.jwt.JwtRealm; @@ -200,7 +201,7 @@ public void testTokenRefreshUnderSameUsernameInDomain() throws IOException { var refreshTokenResponse = client().filterWithHeader( Map.of( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_1, + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_1, JwtRealm.HEADER_END_USER_AUTHENTICATION, JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME + " " + HEADER_JWT_REALM_1 ) @@ -211,7 +212,7 @@ public void testTokenRefreshUnderSameUsernameInDomain() throws IOException { createTokenResponse = client().filterWithHeader( Map.of( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_1, + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_1, JwtRealm.HEADER_END_USER_AUTHENTICATION, JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME + " " + HEADER_JWT_REALM_1 ) @@ -292,7 +293,7 @@ public void testTokenRefreshFailsForUsernameOutsideDomain() throws IOException { () -> client().filterWithHeader( Map.of( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_2, + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + HEADER_SECRET_JWT_REALM_2, JwtRealm.HEADER_END_USER_AUTHENTICATION, JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME + " " + HEADER_JWT_REALM_2 ) diff --git a/x-pack/plugin/security/src/main/java/module-info.java b/x-pack/plugin/security/src/main/java/module-info.java index 75fa7ccf09e0f..316f640b65476 100644 --- a/x-pack/plugin/security/src/main/java/module-info.java +++ b/x-pack/plugin/security/src/main/java/module-info.java @@ -34,7 +34,6 @@ requires org.opensaml.saml; requires org.opensaml.saml.impl; requires org.opensaml.security.impl; - requires org.opensaml.security; requires org.opensaml.xmlsec.impl; requires org.opensaml.xmlsec; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java index 7b65c01f7691e..881d1340ebc3f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.security.action; -import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse; @@ -56,12 +55,6 @@ public final void doExecute(Task task, Request request, ActionListener try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { final AuthenticationToken authenticationToken = request.getGrant().getAuthenticationToken(); assert authenticationToken != null : "authentication token must not be null"; - if (authenticationToken == null) { - listener.onFailure( - new ElasticsearchSecurityException("the grant type [{}] is not supported", request.getGrant().getType()) - ); - return; - } final String runAsUsername = request.getGrant().getRunAsUsername(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java index 063cc85ea0187..0266fc7488e29 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.IOException; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java index 89391f91a2731..cc07b7dfa8381 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import java.nio.charset.StandardCharsets; import java.security.PublicKey; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java index e122ecf4eb1ab..9c1deff9ed891 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.security.authc.RealmConfig; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index eb2517f8e54e4..d8b0575c54d36 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -30,7 +30,9 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.CachingRealm; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper; @@ -64,7 +66,6 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable { public static final String HEADER_END_USER_AUTHENTICATION = "Authorization"; public static final String HEADER_CLIENT_AUTHENTICATION = "ES-Client-Authentication"; public static final String HEADER_END_USER_AUTHENTICATION_SCHEME = "Bearer"; - public static final String HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME = "SharedSecret"; private final Cache jwtCache; private final CacheIteratorHelper jwtCacheHelper; @@ -193,7 +194,7 @@ public AuthenticationToken token(final ThreadContext threadContext) { final SecureString clientCredentials = JwtUtil.getHeaderValue( threadContext, JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME, + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME, true ); return new JwtAuthenticationToken(signedJWT, JwtUtil.sha256(userCredentials), clientCredentials); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java index b1ee1b77998ec..e183ee7d73ac2 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java @@ -35,13 +35,14 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.util.Arrays; import java.util.List; import java.util.stream.Stream; -import static org.elasticsearch.xpack.security.authc.jwt.JwtUtil.toStringRedactSignature; +import static org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil.toStringRedactSignature; public interface JwtSignatureValidator extends Releasable { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java index 0f34850b861b7..e637bda19d886 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java @@ -91,9 +91,9 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; -import org.elasticsearch.xpack.security.authc.jwt.JwtUtil; import java.io.IOException; import java.net.URI; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyAction.java index 46d2fa4605f9d..d07c5529e3ca1 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyAction.java @@ -20,9 +20,11 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.security.action.Grant; import org.elasticsearch.xpack.core.security.action.apikey.CreateApiKeyRequestBuilder; import org.elasticsearch.xpack.core.security.action.apikey.CreateApiKeyResponse; import org.elasticsearch.xpack.core.security.action.apikey.GrantApiKeyAction; @@ -44,6 +46,18 @@ @ServerlessScope(Scope.INTERNAL) public final class RestGrantApiKeyAction extends ApiKeyBaseRestHandler implements RestRequestFilter { + private static final ConstructingObjectParser CLIENT_AUTHENTICATION_PARSER = + new ConstructingObjectParser<>("client_authentication", a -> new Grant.ClientAuthentication((String) a[0], (SecureString) a[1])); + static { + CLIENT_AUTHENTICATION_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("scheme")); + CLIENT_AUTHENTICATION_PARSER.declareField( + ConstructingObjectParser.constructorArg(), + RestGrantApiKeyAction::getSecureString, + new ParseField("value"), + ObjectParser.ValueType.STRING + ); + } + static final ObjectParser PARSER = new ObjectParser<>("grant_api_key_request", GrantApiKeyRequest::new); static { PARSER.declareString((req, str) -> req.getGrant().setType(str), new ParseField("grant_type")); @@ -61,6 +75,11 @@ public final class RestGrantApiKeyAction extends ApiKeyBaseRestHandler implement ObjectParser.ValueType.STRING ); PARSER.declareString((req, str) -> req.getGrant().setRunAsUsername(str), new ParseField("run_as")); + PARSER.declareObject( + (req, clientAuthentication) -> req.getGrant().setClientAuthentication(clientAuthentication), + CLIENT_AUTHENTICATION_PARSER, + new ParseField("client_authentication") + ); PARSER.declareObject( (req, api) -> req.setApiKeyRequest(api), (parser, ignore) -> CreateApiKeyRequestBuilder.parse(parser), @@ -88,11 +107,15 @@ public String getName() { return "xpack_security_grant_api_key"; } + public static GrantApiKeyRequest fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + @Override protected RestChannelConsumer innerPrepareRequest(final RestRequest request, final NodeClient client) throws IOException { String refresh = request.param("refresh"); try (XContentParser parser = request.contentParser()) { - final GrantApiKeyRequest grantRequest = PARSER.parse(parser, null); + final GrantApiKeyRequest grantRequest = fromXContent(parser); if (refresh != null) { grantRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.parse(refresh)); } else { @@ -115,7 +138,7 @@ protected RestChannelConsumer innerPrepareRequest(final RestRequest request, fin } } - private static final Set FILTERED_FIELDS = Set.of("password", "access_token"); + private static final Set FILTERED_FIELDS = Set.of("password", "access_token", "client_authentication.value"); @Override public Set getFilteredFields() { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/apikey/TransportGrantApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/apikey/TransportGrantApiKeyActionTests.java index c8c996f37ebfc..5b077c615f9eb 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/apikey/TransportGrantApiKeyActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/apikey/TransportGrantApiKeyActionTests.java @@ -10,8 +10,10 @@ import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.rest.RestStatus; @@ -19,6 +21,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.security.action.Grant; import org.elasticsearch.xpack.core.security.action.apikey.CreateApiKeyRequest; import org.elasticsearch.xpack.core.security.action.apikey.CreateApiKeyResponse; import org.elasticsearch.xpack.core.security.action.apikey.GrantApiKeyAction; @@ -70,7 +73,6 @@ public class TransportGrantApiKeyActionTests extends ESTestCase { private ApiKeyUserRoleDescriptorResolver resolver; private AuthenticationService authenticationService; private ThreadPool threadPool; - private TransportService transportService; private AuthorizationService authorizationService; @Before @@ -87,7 +89,7 @@ public void setupMocks() throws Exception { action = new TransportGrantApiKeyAction( transportService, - mock(ActionFilters.class), + new ActionFilters(Set.of()), threadContext, authenticationService, authorizationService, @@ -136,12 +138,70 @@ public void testGrantApiKeyWithUsernamePassword() { setupApiKeyServiceWithRoleResolution(authentication, request, response); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); assertThat(future.actionGet(), sameInstance(response)); verify(authorizationService, never()).authorize(any(), any(), any(), anyActionListener()); } + public void testClientAuthenticationForNonJWTFails() { + final GrantApiKeyRequest request = mockRequest(); + request.getGrant().setType("access_token"); + request.getGrant().setAccessToken(new SecureString("obviously a non JWT token".toCharArray())); + // only JWT tokens support client authentication + request.getGrant().setClientAuthentication(new Grant.ClientAuthentication(new SecureString("whatever".toCharArray()))); + + final PlainActionFuture future = new PlainActionFuture<>(); + action.execute(null, request, future); + + final ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, future::actionGet); + assertThat(exception, throwableWithMessage("[client_authentication] not supported with the supplied access_token type")); + + verifyNoMoreInteractions(authenticationService); + verifyNoMoreInteractions(authorizationService); + verifyNoMoreInteractions(apiKeyService); + verifyNoMoreInteractions(resolver); + } + + public void testClientAuthenticationWithUsernamePasswordFails() { + final GrantApiKeyRequest request = mockRequest(); + request.getGrant().setType("password"); + request.getGrant().setUsername(randomAlphaOfLengthBetween(4, 12)); + request.getGrant().setPassword(new SecureString(randomAlphaOfLengthBetween(8, 24).toCharArray())); + // username & password does not support client authentication + request.getGrant().setClientAuthentication(new Grant.ClientAuthentication(new SecureString("whatever".toCharArray()))); + + final PlainActionFuture future = new PlainActionFuture<>(); + action.execute(null, request, future); + + final ActionRequestValidationException exception = expectThrows(ActionRequestValidationException.class, future::actionGet); + assertThat(exception.getMessage(), containsString("[client_authentication] is not supported for grant_type [password]")); + + verifyNoMoreInteractions(authenticationService); + verifyNoMoreInteractions(authorizationService); + verifyNoMoreInteractions(apiKeyService); + verifyNoMoreInteractions(resolver); + } + + public void testUnsupportedClientAuthenticationScheme() { + final GrantApiKeyRequest request = mockRequest(); + request.getGrant().setType("access_token"); + request.getGrant().setAccessToken(new SecureString("some token".toCharArray())); + request.getGrant() + .setClientAuthentication(new Grant.ClientAuthentication("wrong scheme", new SecureString("whatever".toCharArray()))); + + final PlainActionFuture future = new PlainActionFuture<>(); + action.execute(null, request, future); + + final ActionRequestValidationException exception = expectThrows(ActionRequestValidationException.class, future::actionGet); + assertThat(exception.getMessage(), containsString("[client_authentication.scheme] must be set to [SharedSecret]")); + + verifyNoMoreInteractions(authenticationService); + verifyNoMoreInteractions(authorizationService); + verifyNoMoreInteractions(apiKeyService); + verifyNoMoreInteractions(resolver); + } + public void testGrantApiKeyWithAccessToken() { final String username = randomAlphaOfLengthBetween(4, 12); final Authentication authentication = buildAuthentication(username); @@ -173,7 +233,7 @@ public void testGrantApiKeyWithAccessToken() { setupApiKeyServiceWithRoleResolution(authentication, request, response); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); assertThat(future.actionGet(), sameInstance(response)); verify(authorizationService, never()).authorize(any(), any(), any(), anyActionListener()); @@ -227,7 +287,7 @@ public void testGrantApiKeyWithInvalidatedCredentials() { setupApiKeyServiceWithRoleResolution(authentication, request, response); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); final ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, future::actionGet); assertThat(exception, throwableWithMessage("authentication failed for testing")); @@ -285,7 +345,7 @@ public void testGrantWithRunAs() { .authorize(eq(authentication), eq(AuthenticateAction.NAME), any(AuthenticateRequest.class), anyActionListener()); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); assertThat(future.actionGet(), sameInstance(response)); verify(authorizationService).authorize( @@ -343,7 +403,7 @@ public void testGrantWithRunAsFailureDueToAuthorization() { .authorize(eq(authentication), eq(AuthenticateAction.NAME), any(AuthenticateRequest.class), anyActionListener()); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); assertThat(expectThrows(ElasticsearchSecurityException.class, future::actionGet), sameInstance(e)); verify(authorizationService).authorize( @@ -376,7 +436,7 @@ public void testGrantFailureDueToUnsupportedRunAs() { .authenticate(eq(GrantApiKeyAction.NAME), same(request), any(AuthenticationToken.class), anyActionListener()); final PlainActionFuture future = new PlainActionFuture<>(); - action.doExecute(null, request, future); + action.execute(null, request, future); final ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, future::actionGet); assertThat(e.getMessage(), containsString("the provided grant credentials do not support run-as")); @@ -402,7 +462,9 @@ private CreateApiKeyResponse mockResponse(GrantApiKeyRequest request) { private GrantApiKeyRequest mockRequest() { final String keyName = randomAlphaOfLengthBetween(6, 32); final GrantApiKeyRequest request = new GrantApiKeyRequest(); - request.setApiKeyRequest(new CreateApiKeyRequest(keyName, List.of(), null)); + CreateApiKeyRequest createApiKeyRequest = new CreateApiKeyRequest(keyName, List.of(), null); + createApiKeyRequest.setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())); + request.setApiKeyRequest(createApiKeyRequest); return request; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java index 4d16732104237..dd1a984a0dcb5 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.junit.Before; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java index 789ac04c40622..3d4d9eae6acd0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java @@ -14,6 +14,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.user.User; import java.io.Closeable; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index 4f7b82a16e8f1..bf6c64242701b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java index 2f77923c6c50f..7a0e138305b83 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.user.User; @@ -428,7 +429,7 @@ private void printArtifacts( + (Strings.hasText(clientSecret) ? JwtRealm.HEADER_CLIENT_AUTHENTICATION + ": " - + JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + clientSecret + "\n" diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java index 7697849179acf..40a613a0907c8 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.ClaimSetting; import java.net.URI; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index ffc1fec1f5788..1bc49cb628464 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings.ClientAuthenticationType; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java index cfb153d233e9d..f244544460ebf 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java @@ -623,7 +623,7 @@ public ThreadContext createThreadContext(final CharSequence jwt, final CharSeque if (sharedSecret != null) { requestThreadContext.putHeader( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + sharedSecret + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + sharedSecret ); } return requestThreadContext; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTokenExtractionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTokenExtractionTests.java index 7cfac9978081b..8662561aca1ae 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTokenExtractionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTokenExtractionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.Realm; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.support.BearerToken; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.Authenticator; @@ -69,7 +70,7 @@ public void testRealmLetsThroughInvalidJWTs() { if (randomBoolean()) { threadContext.putHeader( JwtRealm.HEADER_CLIENT_AUTHENTICATION, - JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + "some shared secret" + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + "some shared secret" ); } AuthenticationToken authenticationToken = realmsAuthenticator.extractCredentials(context); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java index 6fab33b4d6adf..7d90dffd7517c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyActionTests.java new file mode 100644 index 0000000000000..e6744544a34da --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/apikey/RestGrantApiKeyActionTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.rest.action.apikey; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.security.action.apikey.GrantApiKeyRequest; + +import static org.hamcrest.Matchers.is; + +public class RestGrantApiKeyActionTests extends ESTestCase { + + public void testParseXContentForGrantApiKeyRequest() throws Exception { + final String grantType = randomAlphaOfLength(8); + final String username = randomAlphaOfLength(8); + final String password = randomAlphaOfLength(8); + final String accessToken = randomAlphaOfLength(8); + final String clientAuthenticationScheme = randomAlphaOfLength(8); + final String clientAuthenticationValue = randomAlphaOfLength(8); + final String apiKeyName = randomAlphaOfLength(8); + final String apiKeyExpiration = randomTimeValue(); + final String runAs = randomAlphaOfLength(8); + try ( + XContentParser content = createParser( + XContentFactory.jsonBuilder() + .startObject() + .field("grant_type", grantType) + .field("username", username) + .field("password", password) + .field("access_token", accessToken) + .startObject("client_authentication") + .field("scheme", clientAuthenticationScheme) + .field("value", clientAuthenticationValue) + .endObject() + .startObject("api_key") + .field("name", apiKeyName) + .field("expiration", apiKeyExpiration) + .endObject() + .field("run_as", runAs) + .endObject() + ) + ) { + GrantApiKeyRequest grantApiKeyRequest = RestGrantApiKeyAction.fromXContent(content); + assertThat(grantApiKeyRequest.getGrant().getType(), is(grantType)); + assertThat(grantApiKeyRequest.getGrant().getUsername(), is(username)); + assertThat(grantApiKeyRequest.getGrant().getPassword(), is(new SecureString(password.toCharArray()))); + assertThat(grantApiKeyRequest.getGrant().getAccessToken(), is(new SecureString(accessToken.toCharArray()))); + assertThat(grantApiKeyRequest.getGrant().getClientAuthentication().scheme(), is(clientAuthenticationScheme)); + assertThat( + grantApiKeyRequest.getGrant().getClientAuthentication().value(), + is(new SecureString(clientAuthenticationValue.toCharArray())) + ); + assertThat(grantApiKeyRequest.getGrant().getRunAsUsername(), is(runAs)); + assertThat(grantApiKeyRequest.getApiKeyRequest().getName(), is(apiKeyName)); + assertThat( + grantApiKeyRequest.getApiKeyRequest().getExpiration(), + is(TimeValue.parseTimeValue(apiKeyExpiration, "api_key.expiration")) + ); + } + } + +} diff --git a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/README.md b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/README.md new file mode 100644 index 0000000000000..39b62a1bdfb00 --- /dev/null +++ b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/README.md @@ -0,0 +1,21 @@ +# Explanation of files in this directory + +`rsa-private-jwkset.json`, `rsa-public-jwkset.json` +----------------------------------------------------------------------- + +These files are created by running the tests in `JwtRealmGenerateTests`. + +Those tests generate the yaml settings, the keystore settings and the JWK Sets +for each sample realm. + +Copy the output from the test output into the applicable file (you may wish to +run it through `jq` first in order to make it more readable). + +------- + +If additional keys are needed (e.g. to add more algorithms / key sizes) we can +either extend the existing JWKSets with another set of keys (that is, modify the +existing method in `JwtRealmGenerateTests` so that it creates more keys in the +same JWKSet, then re-run and replace the files on disk) or create new files ( +that is, add another test to `JwtRealmGenerateTests` that creates the relevant +realm and/or JWKSet, then re-run and add the new files to disk). diff --git a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-private-jwkset.json b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-private-jwkset.json new file mode 100644 index 0000000000000..e11e83ea95812 --- /dev/null +++ b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-private-jwkset.json @@ -0,0 +1,23 @@ +{ + "keys": [ + { + "p": "2HaFCsfMmm56qeSlBEnQqSLpUM2S7pyqeGnGdrR9unoBrK_jvJHiW-pI0SWN70iTykPCCiXKTP_NWB4tPV-dB-jtYfwGaQKlKHBdi_5ZgZjcFt3nL-rVSdBIRYjIDqg-zWDyXeT-A98fUKJs9RJlktGjCI6EKVH9pubS-NPG1ts", + "kty": "RSA", + "q": "vF0nR0kH3TKwHZheU0U4ewbnG0rdgN9Lcx5dm2uYMKO77ifC8zWlufm-wTB-SfiA5mJIYGd-_kEdU0scOorJ2RZzxzWQ06CXD1gnWUwhSH-3yqALy5ip8qgtEJO2dLXHX-qfmdhQLvuXyPYg6dD7hF8SZ3A8Dixvk95T8HmmmfM", + "d": "A6XYotHcFGNPfQh5o6B1NDv9wMVb4wdnJd-xdMSgZcmtuiZgXIfXG3L1HROCbSGpfdyP7rjfTHjdcfjJroRpmeTVE-d5pQGmOS8px2UiLfCpjpaZXcNJlNcJeuJFTpcNKOfu5miusbuBcNulu355UpjUJ8uU75qBffjVwQSqKcTQzHn7dvRaChCNwtxPzsFM71HkPk3gfwMJPSeaJeswBSxo7OnCaOJQLiDuaD5EJEgeG7sNd4pwWeq_BC4wNRa8FVv9o18PjVIIn_lFG6eSVKLere4i8bV5qhYJS7bC-Z5mCNIJVanX5-iHyEPiB8L89_kXmLBY5wUoVG0h-HTJRQ", + "e": "AQAB", + "use": "sig", + "kid": "test-rsa-key", + "key_ops": [ + "verify", + "sign" + ], + "qi": "gfz4sx8cdErhh1NxVrHyHFm1APae_2qVits-HLEeIrDsXtqU0KKI68JWflrVD1PYMG7wm2rQkNVL66hKgLjF7GciFboYjDbYp0ezKwxHEHMaK6g9Ts7lm1Ukcu-ujSNgM6H5-LyeChchiwIegUxL_PfiuTxCFlvZND3t9g9TpuM", + "dp": "E5OMbrApIeJR96F1BxFB7Ln6jdM5sZi7lg5C46NpJ383PY3es8Qhy5xBn6Cc2IIg048TMmW-iP4tbQW1o7JM-lUnetAXKFIT67dVzn5fS-guJ2dELEI5reZHUvqO1TyECYD2CmXWTzVTmLBH2FYkl4WcD_8Lls0SepCvjc9hUTc", + "dq": "QPV7GzlTTfQyCjLGrL7daIgL4TyjQNNAyNV7AKNNW9DLeakasRcaLRW0tBkOJGJfyZOxVBW9FN_NxjDL7mB4lbYJfXS6mlDyZ2dGQqRfggoRjv48sxzV1woqaGIYdQ1PUYOvQLX5iQpY4QQIe7oHUzIaPbPV8ile3Ua5-d9qFgM", + "n": "n0XN-JSI02G0rJfoI9Upj_rhmdudJTre9b1evE34kPunSICvJFy4E3Q-Fkc0z2hQa3UigA1Og1qMramkH70p1nWBk0gRIy8cn9_CWuPJQ9pkf5mpFGhEPi_PNVcvY9WpAr6lJ1OWq0Tu1g8b_qu3L5wvLdV8iQeFSIpxON2USNQj-HmoA37ZQcIzWTWZO6dcjTeIRSUEW6V2OVF9UkkS6qXqNGDQoqUXKPO1VElY5mQQTY0X71aUI_B0_gtIYO0iKjNzyYCKTeNZuX4WZbamqAwBApujT7hvZByLHFVbWipGZ2Hl82avUM_yrMr6oLI6UUtzXmvk1pzfn7WPsSjU4Q" + } + ] +} + + diff --git a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-public-jwkset.json b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-public-jwkset.json new file mode 100644 index 0000000000000..5e6a88f403690 --- /dev/null +++ b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/rsa-public-jwkset.json @@ -0,0 +1,16 @@ +{ + "keys": [ + { + "kty": "RSA", + "e": "AQAB", + "use": "sig", + "kid": "test-rsa-key", + "key_ops": [ + "verify", + "sign" + ], + "n": "n0XN-JSI02G0rJfoI9Upj_rhmdudJTre9b1evE34kPunSICvJFy4E3Q-Fkc0z2hQa3UigA1Og1qMramkH70p1nWBk0gRIy8cn9_CWuPJQ9pkf5mpFGhEPi_PNVcvY9WpAr6lJ1OWq0Tu1g8b_qu3L5wvLdV8iQeFSIpxON2USNQj-HmoA37ZQcIzWTWZO6dcjTeIRSUEW6V2OVF9UkkS6qXqNGDQoqUXKPO1VElY5mQQTY0X71aUI_B0_gtIYO0iKjNzyYCKTeNZuX4WZbamqAwBApujT7hvZByLHFVbWipGZ2Hl82avUM_yrMr6oLI6UUtzXmvk1pzfn7WPsSjU4Q" + } + ] +} + diff --git a/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/serialized-signed-RS256-jwt.txt b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/serialized-signed-RS256-jwt.txt new file mode 100644 index 0000000000000..b247f40dd1667 --- /dev/null +++ b/x-pack/plugin/security/src/test/resources/org/elasticsearch/xpack/security/authc/apikey/serialized-signed-RS256-jwt.txt @@ -0,0 +1 @@ +eyJraWQiOiJ0ZXN0LXJzYS1rZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwczpcL1wvaXNzdWVyLmV4YW1wbGUuY29tXC8iLCJhdWQiOiJodHRwczpcL1wvYXVkaWVuY2UuZXhhbXBsZS5jb21cLyIsInN1YiI6InVzZXIxIiwiZXhwIjo0MDcwOTA4ODAwLCJpYXQiOjk0NjY4NDgwMH0.IcJZXIEyd98T198_K4YOBE_4yJDbnNYugituAf_-M7nNI_rGAwD7uecK85xMco8mr0TSlyQWpbazHeOP4dh9jln27_Llf-D4xZeykESrlhM3zkMwUbDML2reM96NoTN42c_Cj5V9pZCEmcbk1BumnkmDD-RCTx4b_cB8CjiR4ODxXFpVnoJB-PdGFt7rImjkO0yacuUF09XOR-uUxH09WkqtmqoCnp-geSqNZbVzb2Kt1bTq66B0Wfiz6sG_cpM-NdhJ-JUZMO_oCJ9mfyje9fH5F1x8LA063qLVABRvQSEWP3t4wIRAnqS3Hj0sDqjfNBdcBgSCBY0_G8NmHw4toA \ No newline at end of file diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/api_key/12_grant.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/api_key/12_grant.yml index cf158ad73a13e..9b2c0de0cfed1 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/api_key/12_grant.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/api_key/12_grant.yml @@ -192,6 +192,7 @@ teardown: - transform_and_set: { login_creds: "#base64EncodeCredentials(id,api_key)" } - match: { encoded: $login_creds } + # verify the granted API Key - do: headers: Authorization: ApiKey ${login_creds} @@ -210,6 +211,41 @@ teardown: - match: { _nodes.failed: 0 } +--- +"Test grant api key with non-JWT token and client authentication fails": + - do: + security.get_token: + body: + grant_type: "password" + username: "api_key_grant_target_user" + password: "x-pack-test-password-2" + + - match: { type: "Bearer" } + - is_true: access_token + - set: { access_token: token } + + - do: + headers: + Authorization: "Basic YXBpX2tleV9ncmFudGVyOngtcGFjay10ZXN0LXBhc3N3b3Jk" # api_key_granter + catch: bad_request + security.grant_api_key: + body: > + { + "api_key": { + "name": "wrong-api-key" + }, + "grant_type": "access_token", + "access_token": "$token", + "client_authentication": { + "scheme": "SharedSecret", + "value": "whatever" + } + } + + - match: { "error.type": "security_exception" } + - match: + "error.reason": "[client_authentication] not supported with the supplied access_token type" + --- "Test grant api key forbidden": - do: diff --git a/x-pack/qa/oidc-op-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtWithOidcAuthIT.java b/x-pack/qa/oidc-op-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtWithOidcAuthIT.java index 47e4b02d63648..2d3fc611758b0 100644 --- a/x-pack/qa/oidc-op-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtWithOidcAuthIT.java +++ b/x-pack/qa/oidc-op-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtWithOidcAuthIT.java @@ -22,6 +22,7 @@ import org.elasticsearch.rest.RestUtils; import org.elasticsearch.test.TestMatchers; import org.elasticsearch.test.TestSecurityClient; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.security.authc.oidc.C2IdOpTestCase; import org.hamcrest.Matchers; @@ -151,7 +152,10 @@ private Map authenticateWithJwtAndSharedSecret(String idJwt, Str final Map authenticateResponse = super.callAuthenticateApiUsingBearerToken( idJwt, RequestOptions.DEFAULT.toBuilder() - .addHeader(JwtRealm.HEADER_CLIENT_AUTHENTICATION, JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + sharedSecret) + .addHeader( + JwtRealm.HEADER_CLIENT_AUTHENTICATION, + JwtRealmSettings.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + sharedSecret + ) .build() ); return authenticateResponse; From b35bba34b168b3ed260d0a76fc34d6e9b1a8fbad Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 21 Nov 2023 12:12:07 +0000 Subject: [PATCH 15/22] Tighten up exception handling in validateJoinRequest (#102370) This method is pretty much the classic fan-out pattern, except that it might also throw an exception if validation fails before sending the first request. This commit reimplements it using more modern tooling, feeds all exceptions to the supplied listener, and removes the exception-mangling coming from the use of `ListenableActionFuture`. --- .../cluster/coordination/Coordinator.java | 89 ++++++++----------- 1 file changed, 35 insertions(+), 54 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java index c3c757bb335e4..3da890b37ade8 100644 --- a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java @@ -13,7 +13,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ChannelActionListener; -import org.elasticsearch.action.support.ListenableActionFuture; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.client.internal.Client; @@ -634,21 +634,11 @@ private void handleJoinRequest(JoinRequest joinRequest, ActionListener joi transportService.connectToNode(joinRequest.getSourceNode(), new ActionListener<>() { @Override public void onResponse(Releasable response) { - boolean retainConnection = false; - try { - validateJoinRequest( - joinRequest, - ActionListener.runBefore(joinListener, () -> Releasables.close(response)) - .delegateFailure((l, ignored) -> processJoinRequest(joinRequest, l)) - ); - retainConnection = true; - } catch (Exception e) { - joinListener.onFailure(e); - } finally { - if (retainConnection == false) { - Releasables.close(response); - } - } + validateJoinRequest( + joinRequest, + ActionListener.runBefore(joinListener, () -> Releasables.close(response)) + .delegateFailure((l, ignored) -> processJoinRequest(joinRequest, l)) + ); } @Override @@ -687,48 +677,39 @@ private void validateJoinRequest(JoinRequest joinRequest, ActionListener v // - if we're already master that it can make sense of the current cluster state. // - we have a healthy PING channel to the node - final ClusterState stateForJoinValidation = getStateForJoinValidationService(); - final ListenableActionFuture validateStateListener = new ListenableActionFuture<>(); - if (stateForJoinValidation != null) { - assert stateForJoinValidation.nodes().isLocalNodeElectedMaster(); - onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation)); - if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) { - // We do this in a couple of places including the cluster update thread. This one here is really just best effort to ensure - // we fail as fast as possible. - NodeJoinExecutor.ensureVersionBarrier( - joinRequest.getSourceNode().getVersion(), - stateForJoinValidation.getNodes().getMinNodeVersion() - ); - } - sendJoinValidate(joinRequest.getSourceNode(), validateStateListener); - } else { - sendJoinPing(joinRequest.getSourceNode(), TransportRequestOptions.Type.STATE, validateStateListener); - } + try (var listeners = new RefCountingListener(validateListener)) { + // The join will be rejected if any of these steps fail, but we wait them all to complete, particularly state validation, since + // the node will retry and we don't want lots of cluster states in flight. - sendJoinPing(joinRequest.getSourceNode(), TransportRequestOptions.Type.PING, new ActionListener<>() { - @Override - public void onResponse(Void ignored) { - validateStateListener.addListener(validateListener); - } + ActionListener.completeWith(listeners.acquire(), () -> { + final ClusterState stateForJoinValidation = getStateForJoinValidationService(); + if (stateForJoinValidation == null) { + return null; + } - @Override - public void onFailure(Exception e) { - // The join will be rejected, but we wait for the state validation to complete as well since the node will retry and we - // don't want lots of cluster states in flight. - validateStateListener.addListener(new ActionListener<>() { - @Override - public void onResponse(Void ignored) { - validateListener.onFailure(e); - } + assert stateForJoinValidation.nodes().isLocalNodeElectedMaster(); + onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation)); + if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) { + // We do this in a couple of places including the cluster update thread. This one here is really just best effort to + // ensure we fail as fast as possible. + NodeJoinExecutor.ensureVersionBarrier( + joinRequest.getSourceNode().getVersion(), + stateForJoinValidation.getNodes().getMinNodeVersion() + ); + } + sendJoinValidate(joinRequest.getSourceNode(), listeners.acquire()); + return null; + }); - @Override - public void onFailure(Exception e2) { - e2.addSuppressed(e); - validateListener.onFailure(e2); - } - }); + if (listeners.isFailing() == false) { + // We may not have sent a state for validation, so just ping both channel types. + sendJoinPing(joinRequest.getSourceNode(), TransportRequestOptions.Type.PING, listeners.acquire()); + sendJoinPing(joinRequest.getSourceNode(), TransportRequestOptions.Type.STATE, listeners.acquire()); } - }); + } catch (Exception e) { + logger.error("unexpected exception in validateJoinRequest", e); + assert false : e; + } } private void sendJoinValidate(DiscoveryNode discoveryNode, ActionListener listener) { From 942553aa9c6fc87317938674df049d883d197eea Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Tue, 21 Nov 2023 14:34:24 +0100 Subject: [PATCH 16/22] Fix flaky test for #102406 (#102407) The mutateInstance just happens to generate the exact same id with this particular random seed, so the test fails. Just checking that the id is the same and generating another one is sufficient to fix this. Fixes https://github.com/elastic/elasticsearch/issues/102406 --- .../xpack/fleet/action/PostSecretResponseTests.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/fleet/src/test/java/org/elasticsearch/xpack/fleet/action/PostSecretResponseTests.java b/x-pack/plugin/fleet/src/test/java/org/elasticsearch/xpack/fleet/action/PostSecretResponseTests.java index 6100bdd30c9cf..c47c2e553a509 100644 --- a/x-pack/plugin/fleet/src/test/java/org/elasticsearch/xpack/fleet/action/PostSecretResponseTests.java +++ b/x-pack/plugin/fleet/src/test/java/org/elasticsearch/xpack/fleet/action/PostSecretResponseTests.java @@ -24,6 +24,7 @@ protected PostSecretResponse createTestInstance() { @Override protected PostSecretResponse mutateInstance(PostSecretResponse instance) { - return new PostSecretResponse(randomAlphaOfLengthBetween(2, 10)); + String id = randomValueOtherThan(instance.id(), () -> randomAlphaOfLengthBetween(2, 10)); + return new PostSecretResponse(id); } } From 8616178831763c221d8634e64cfb7ac4406c59df Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Tue, 21 Nov 2023 15:01:43 +0100 Subject: [PATCH 17/22] Mute two flaky tests for #102339 (#102403) These two fail about 2% of the time, which is actually quite a high failure rate. The issue to fix is https://github.com/elastic/elasticsearch/issues/102339 --- .../resources/rest-api-spec/test/simulate.ingest/10_basic.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/simulate.ingest/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/simulate.ingest/10_basic.yml index 89011750479a0..38aaaa9847efb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/simulate.ingest/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/simulate.ingest/10_basic.yml @@ -118,6 +118,8 @@ setup: - skip: features: headers + version: "all" + reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/102339" - do: headers: @@ -195,6 +197,8 @@ setup: - skip: features: headers + version: "all" + reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/102339" - do: headers: From fb19e49117e3d987c85b9af670adc8a910bd021d Mon Sep 17 00:00:00 2001 From: Kostas Krikellas <131142368+kkrik-es@users.noreply.github.com> Date: Tue, 21 Nov 2023 16:17:09 +0200 Subject: [PATCH 18/22] Update expected test output (#102412) A `having` bucket selector errors out when applied to a doc lacking the expected value in the `buckets_path`. The existing yaml test didn't reflect that but assumes that such docs don't affect results. Adding a refresh in the bulk operation exposes this problem. Fixes #102295 --- .../rest-api-spec/test/aggregations/top_hits.yml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/top_hits.yml b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/top_hits.yml index 4d4848e8aebc3..5cf0265374b08 100644 --- a/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/top_hits.yml +++ b/modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/top_hits.yml @@ -600,10 +600,13 @@ synthetic _source: - do: bulk: index: test2 + refresh: true body: - { index: { } } - { gender: 3 } - do: + # The script can't process a bucket without a salary value for gender '3'. + catch: /path not supported for \[top_salary_hits\]:\ \[_source.salary\]./ search: index: test2 size: 0 @@ -630,13 +633,6 @@ synthetic _source: ts: top_salary_hits[_source.salary] script: "params.ts < 8000" - # Empty bucket for gender '3' affects nothing. - - length: { aggregations.genders.buckets: 1} - - match: { aggregations.genders.buckets.0.top_salary_hits.hits.total.value: 4} - - match: { aggregations.genders.buckets.0.top_salary_hits.hits.hits.0._source.gender: 1} - - match: { aggregations.genders.buckets.0.top_salary_hits.hits.hits.0._source.salary: 4000} - - match: { aggregations.genders.buckets.0.top_salary_hits.hits.hits.0._source.birth_date: 1982} - - do: catch: /path not supported for \[top_salary_hits\]:\ \[_source.nosuchfield\]./ search: From b1a523f8cf6c27ce3c8431452f04858c95fd421a Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Tue, 21 Nov 2023 15:19:32 +0100 Subject: [PATCH 19/22] ESQL: Enhance physical operation provider to extract multiple fields at once (#102408) --- .../planner/EsPhysicalOperationProviders.java | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 1dddee5ed54ea..966db6f02c9ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -63,34 +63,21 @@ public List searchContexts() { @Override public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fieldExtractExec, PhysicalOperation source) { Layout.Builder layout = source.layout.builder(); - var sourceAttr = fieldExtractExec.sourceAttribute(); - - PhysicalOperation op = source; + List readers = searchContexts.stream().map(s -> s.searcher().getIndexReader()).toList(); + List fields = new ArrayList<>(); + int docChannel = source.layout.get(sourceAttr.id()).channel(); for (Attribute attr : fieldExtractExec.attributesToExtract()) { if (attr instanceof FieldAttribute fa && fa.getExactInfo().hasExact()) { attr = fa.exactAttribute(); } layout.append(attr); - Layout previousLayout = op.layout; - DataType dataType = attr.dataType(); String fieldName = attr.name(); List loaders = BlockReaderFactories.loaders(searchContexts, fieldName, EsqlDataTypes.isUnsupported(dataType)); - List readers = searchContexts.stream().map(s -> s.searcher().getIndexReader()).toList(); - - int docChannel = previousLayout.get(sourceAttr.id()).channel(); - - op = op.with( - new ValuesSourceReaderOperator.Factory( - List.of(new ValuesSourceReaderOperator.FieldInfo(fieldName, loaders)), - readers, - docChannel - ), - layout.build() - ); + fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, loaders)); } - return op; + return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build()); } public static Function querySupplier(QueryBuilder queryBuilder) { From 8a22586e2ba648029f8a382f7966ef6339a5de37 Mon Sep 17 00:00:00 2001 From: Craig Taverner Date: Tue, 21 Nov 2023 15:23:26 +0100 Subject: [PATCH 20/22] Mute flaky test for #101876 (#102404) This failed four times in the last week. The issue to fix is https://github.com/elastic/elasticsearch/issues/101876 --- .../java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java index f91a848ed2362..7fdb77909e8e9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java @@ -1078,6 +1078,7 @@ public void testEquivalentDeletesAreDeduplicated() throws Exception { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/101876") public void testMasterFailoverOnFinalizationLoop() throws Exception { internalCluster().startMasterOnlyNodes(3); final String dataNode = internalCluster().startDataOnlyNode(); From 9cd96df179d581567c7e45ea4d8bccc4fb89e083 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Tue, 21 Nov 2023 15:35:49 +0100 Subject: [PATCH 21/22] Add support for index_filter to open pit (#102388) The open point in time API accepts a list of indices and opens a point in time view against those indices. Like we do already for field caps, this commit allows users to provide an index_filter parameter as part of the request body, that will be used to execute the can match phase and exclude the indices that can't possibly match such filter. Closes #99740 --- docs/changelog/102388.yaml | 6 ++ .../search/point-in-time-api.asciidoc | 9 ++- .../rest/yaml/CcsCommonYamlTestSuiteIT.java | 5 +- .../rest-api-spec/api/open_point_in_time.json | 6 +- .../test/search/350_point_in_time.yml | 37 +++++++++-- .../action/search/PointInTimeIT.java | 50 ++++++++++++++ .../org/elasticsearch/TransportVersions.java | 1 + .../action/search/OpenPointInTimeRequest.java | 17 +++++ .../search/RestOpenPointInTimeAction.java | 19 +++++- .../TransportOpenPointInTimeAction.java | 66 ++++++++++++++++++- x-pack/qa/runtime-fields/build.gradle | 1 + 11 files changed, 205 insertions(+), 12 deletions(-) create mode 100644 docs/changelog/102388.yaml diff --git a/docs/changelog/102388.yaml b/docs/changelog/102388.yaml new file mode 100644 index 0000000000000..3e65e46949bda --- /dev/null +++ b/docs/changelog/102388.yaml @@ -0,0 +1,6 @@ +pr: 102388 +summary: Add support for `index_filter` to open pit +area: Search +type: enhancement +issues: + - 99740 diff --git a/docs/reference/search/point-in-time-api.asciidoc b/docs/reference/search/point-in-time-api.asciidoc index 0403f9b04b2d1..2e32324cb44d9 100644 --- a/docs/reference/search/point-in-time-api.asciidoc +++ b/docs/reference/search/point-in-time-api.asciidoc @@ -22,6 +22,13 @@ or alias. To search a <> for an alias, you must have the `read` index privilege for the alias's data streams or indices. +[[point-in-time-api-request-body]] +==== {api-request-body-title} + +`index_filter`:: +(Optional, <> Allows to filter indices if the provided +query rewrites to `match_none` on every shard. + [[point-in-time-api-example]] ==== {api-examples-title} @@ -60,7 +67,7 @@ POST /_search <1> or <> as these parameters are copied from the point in time. <2> Just like regular searches, you can <>, up to the first 10,000 hits. If you +`size` to page through search results>>, up to the first 10,000 hits. If you want to retrieve more hits, use PIT with <>. <3> The `id` parameter tells Elasticsearch to execute the request using contexts from this point in time. diff --git a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java index 5ad525b472b12..7c1514d2d1a6a 100644 --- a/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java +++ b/qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java @@ -389,10 +389,7 @@ private boolean shouldReplaceIndexWithRemote(String apiName) { if (apiName.equals("search") || apiName.equals("msearch") || apiName.equals("async_search.submit")) { final String testCandidateTestPath = testCandidate.getTestPath(); - if (testCandidateTestPath.equals("search/350_point_in_time/basic") - || testCandidateTestPath.equals("search/350_point_in_time/point-in-time with slicing") - || testCandidateTestPath.equals("search/350_point_in_time/msearch") - || testCandidateTestPath.equals("search/350_point_in_time/wildcard") + if (testCandidateTestPath.startsWith("search/350_point_in_time") || testCandidateTestPath.equals("async_search/20-with-poin-in-time/Async search with point in time")) { return false; } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/open_point_in_time.json b/rest-api-spec/src/main/resources/rest-api-spec/api/open_point_in_time.json index a25c3fee32571..bce8dfd794dca 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/open_point_in_time.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/open_point_in_time.json @@ -7,7 +7,8 @@ "stability":"stable", "visibility":"public", "headers":{ - "accept": [ "application/json"] + "accept": [ "application/json"], + "content_type": ["application/json"] }, "url":{ "paths":[ @@ -55,6 +56,9 @@ "description": "Specific the time to live for the point in time", "required": true } + }, + "body":{ + "description":"An index_filter specified with the Query DSL" } } } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/350_point_in_time.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/350_point_in_time.yml index bc3479b705180..7e78450931df5 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/350_point_in_time.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/350_point_in_time.yml @@ -6,19 +6,19 @@ setup: index: index: test id: "1" - body: { id: 1, foo: bar, age: 18 } + body: { id: 1, foo: bar, age: 18, birth: "2022-01-01" } - do: index: index: test id: "42" - body: { id: 42, foo: bar, age: 18 } + body: { id: 42, foo: bar, age: 18, birth: "2022-02-01" } - do: index: index: test id: "172" - body: { id: 172, foo: bar, age: 24 } + body: { id: 172, foo: bar, age: 24, birth: "2022-03-01" } - do: indices.create: @@ -28,7 +28,7 @@ setup: index: index: test2 id: "45" - body: { id: 45, foo: bar, age: 19 } + body: { id: 45, foo: bar, age: 19, birth: "2023-01-01" } - do: indices.refresh: @@ -235,3 +235,32 @@ setup: close_point_in_time: body: id: "$point_in_time_id" + +--- +"point-in-time with index filter": + - skip: + version: " - 8.11.99" + reason: "support for index filter was added in 8.12" + - do: + open_point_in_time: + index: test* + keep_alive: 5m + body: { index_filter: { range: { birth: { gte: "2023-01-01" }}}} + - set: {id: point_in_time_id} + + - do: + search: + body: + size: 1 + pit: + id: "$point_in_time_id" + + - match: {hits.total.value: 1 } + - length: {hits.hits: 1 } + - match: {hits.hits.0._index: test2 } + - match: {hits.hits.0._id: "45" } + + - do: + close_point_in_time: + body: + id: "$point_in_time_id" diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/PointInTimeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/PointInTimeIT.java index bb7658f5011e3..27100e925a0a2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/PointInTimeIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/PointInTimeIT.java @@ -50,6 +50,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFailures; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.Matchers.arrayWithSize; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; @@ -152,6 +153,55 @@ public void testMultipleIndices() { } } + public void testIndexFilter() { + int numDocs = randomIntBetween(1, 9); + for (int i = 1; i <= 3; i++) { + String index = "index-" + i; + createIndex(index); + for (int j = 1; j <= numDocs; j++) { + String id = Integer.toString(j); + client().prepareIndex(index).setId(id).setSource("@timestamp", "2023-0" + i + "-0" + j).get(); + } + } + refresh(); + + { + + OpenPointInTimeRequest request = new OpenPointInTimeRequest("*").keepAlive(TimeValue.timeValueMinutes(2)); + final OpenPointInTimeResponse response = client().execute(OpenPointInTimeAction.INSTANCE, request).actionGet(); + try { + SearchContextId searchContextId = SearchContextId.decode(writableRegistry(), response.getPointInTimeId()); + String[] actualIndices = searchContextId.getActualIndices(); + assertEquals(3, actualIndices.length); + } finally { + closePointInTime(response.getPointInTimeId()); + } + } + { + OpenPointInTimeRequest request = new OpenPointInTimeRequest("*").keepAlive(TimeValue.timeValueMinutes(2)); + request.indexFilter(new RangeQueryBuilder("@timestamp").gte("2023-03-01")); + final OpenPointInTimeResponse response = client().execute(OpenPointInTimeAction.INSTANCE, request).actionGet(); + String pitId = response.getPointInTimeId(); + try { + SearchContextId searchContextId = SearchContextId.decode(writableRegistry(), pitId); + String[] actualIndices = searchContextId.getActualIndices(); + assertEquals(1, actualIndices.length); + assertEquals("index-3", actualIndices[0]); + assertResponse(prepareSearch().setPointInTime(new PointInTimeBuilder(pitId)).setSize(50), resp -> { + assertNoFailures(resp); + assertHitCount(resp, numDocs); + assertNotNull(resp.pointInTimeId()); + assertThat(resp.pointInTimeId(), equalTo(pitId)); + for (SearchHit hit : resp.getHits()) { + assertEquals("index-3", hit.getIndex()); + } + }); + } finally { + closePointInTime(pitId); + } + } + } + public void testRelocation() throws Exception { internalCluster().ensureAtLeastNumDataNodes(4); createIndex("test", Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, between(0, 1)).build()); diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index bde73ec5b801f..535192eeefd47 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -174,6 +174,7 @@ static TransportVersion def(int id) { public static final TransportVersion SHUTDOWN_MIGRATION_STATUS_INCLUDE_COUNTS = def(8_543_00_0); public static final TransportVersion TRANSFORM_GET_CHECKPOINT_QUERY_AND_CLUSTER_ADDED = def(8_544_00_0); public static final TransportVersion GRANT_API_KEY_CLIENT_AUTHENTICATION_ADDED = def(8_545_00_0); + public static final TransportVersion PIT_WITH_INDEX_FILTER = def(8_546_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java index 633e56b97a833..39813a883c428 100644 --- a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; @@ -38,6 +39,8 @@ public final class OpenPointInTimeRequest extends ActionRequest implements Indic @Nullable private String preference; + private QueryBuilder indexFilter; + public static final IndicesOptions DEFAULT_INDICES_OPTIONS = SearchRequest.DEFAULT_INDICES_OPTIONS; public OpenPointInTimeRequest(String... indices) { @@ -54,6 +57,9 @@ public OpenPointInTimeRequest(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_500_020)) { this.maxConcurrentShardRequests = in.readVInt(); } + if (in.getTransportVersion().onOrAfter(TransportVersions.PIT_WITH_INDEX_FILTER)) { + this.indexFilter = in.readOptionalNamedWriteable(QueryBuilder.class); + } } @Override @@ -67,6 +73,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_500_020)) { out.writeVInt(maxConcurrentShardRequests); } + if (out.getTransportVersion().onOrAfter(TransportVersions.PIT_WITH_INDEX_FILTER)) { + out.writeOptionalWriteable(indexFilter); + } } @Override @@ -153,6 +162,14 @@ public void maxConcurrentShardRequests(int maxConcurrentShardRequests) { this.maxConcurrentShardRequests = maxConcurrentShardRequests; } + public void indexFilter(QueryBuilder indexFilter) { + this.indexFilter = indexFilter; + } + + public QueryBuilder indexFilter() { + return indexFilter; + } + @Override public boolean allowsRemoteIndices() { return true; diff --git a/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java index 815deac07dfcd..627fdd88cc308 100644 --- a/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/RestOpenPointInTimeAction.java @@ -17,9 +17,13 @@ import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import java.io.IOException; import java.util.List; +import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery; import static org.elasticsearch.rest.RestRequest.Method.POST; @ServerlessScope(Scope.PUBLIC) @@ -36,7 +40,7 @@ public List routes() { } @Override - public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) { + public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { final String[] indices = Strings.splitStringByCommaToArray(request.param("index")); final OpenPointInTimeRequest openRequest = new OpenPointInTimeRequest(indices); openRequest.indicesOptions(IndicesOptions.fromRequest(request, OpenPointInTimeRequest.DEFAULT_INDICES_OPTIONS)); @@ -50,6 +54,19 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC ); openRequest.maxConcurrentShardRequests(maxConcurrentShardRequests); } + request.withContentOrSourceParamParserOrNull(parser -> { + if (parser != null) { + PARSER.parse(parser, openRequest, null); + } + }); + return channel -> client.execute(OpenPointInTimeAction.INSTANCE, openRequest, new RestToXContentListener<>(channel)); } + + private static final ObjectParser PARSER = new ObjectParser<>("open_point_in_time_request"); + private static final ParseField INDEX_FILTER_FIELD = new ParseField("index_filter"); + + static { + PARSER.declareObject(OpenPointInTimeRequest::indexFilter, (p, c) -> parseTopLevelQuery(p), INDEX_FILTER_FIELD); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index ae3c735e079e9..b247a915ad923 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -8,6 +8,8 @@ package org.elasticsearch.action.search; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.IndicesRequest; @@ -28,6 +30,7 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.internal.ShardSearchContextId; @@ -47,6 +50,9 @@ import java.util.function.BiFunction; public class TransportOpenPointInTimeAction extends HandledTransportAction { + + private static final Logger logger = LogManager.getLogger(TransportOpenPointInTimeAction.class); + public static final String OPEN_SHARD_READER_CONTEXT_NAME = "indices:data/read/open_reader_context"; private final TransportSearchAction transportSearchAction; @@ -93,7 +99,8 @@ protected void doExecute(Task task, OpenPointInTimeRequest request, ActionListen .indicesOptions(request.indicesOptions()) .preference(request.preference()) .routing(request.routing()) - .allowPartialSearchResults(false); + .allowPartialSearchResults(false) + .source(new SearchSourceBuilder().query(request.indexFilter())); searchRequest.setMaxConcurrentShardRequests(request.maxConcurrentShardRequests()); searchRequest.setCcsMinimizeRoundtrips(false); transportSearchAction.executeRequest((SearchTask) task, searchRequest, listener.map(r -> { @@ -125,6 +132,63 @@ public SearchPhase newSearchPhase( boolean preFilter, ThreadPool threadPool, SearchResponse.Clusters clusters + ) { + if (SearchService.canRewriteToMatchNone(searchRequest.source())) { + return new CanMatchPreFilterSearchPhase( + logger, + searchTransportService, + connectionLookup, + aliasFilter, + concreteIndexBoosts, + threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), + searchRequest, + shardIterators, + timeProvider, + task, + false, + searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis), + listener.delegateFailureAndWrap( + (searchResponseActionListener, searchShardIterators) -> openPointInTimePhase( + task, + searchRequest, + executor, + searchShardIterators, + timeProvider, + connectionLookup, + clusterState, + aliasFilter, + concreteIndexBoosts, + clusters + ).start() + ) + ); + } else { + return openPointInTimePhase( + task, + searchRequest, + executor, + shardIterators, + timeProvider, + connectionLookup, + clusterState, + aliasFilter, + concreteIndexBoosts, + clusters + ); + } + } + + SearchPhase openPointInTimePhase( + SearchTask task, + SearchRequest searchRequest, + Executor executor, + GroupShardsIterator shardIterators, + TransportSearchAction.SearchTimeProvider timeProvider, + BiFunction connectionLookup, + ClusterState clusterState, + Map aliasFilter, + Map concreteIndexBoosts, + SearchResponse.Clusters clusters ) { assert searchRequest.getMaxConcurrentShardRequests() == pitRequest.maxConcurrentShardRequests() : searchRequest.getMaxConcurrentShardRequests() + " != " + pitRequest.maxConcurrentShardRequests(); diff --git a/x-pack/qa/runtime-fields/build.gradle b/x-pack/qa/runtime-fields/build.gradle index 1a9e913932eb7..dd7d0abc24b19 100644 --- a/x-pack/qa/runtime-fields/build.gradle +++ b/x-pack/qa/runtime-fields/build.gradle @@ -74,6 +74,7 @@ subprojects { 'search/115_multiple_field_collapsing/two levels fields collapsing', // Field collapsing on a runtime field does not work 'search/111_field_collapsing_with_max_score/*', // Field collapsing on a runtime field does not work 'field_caps/30_index_filter/Field caps with index filter', // We don't support filtering field caps on runtime fields. What should we do? + 'search/350_point_in_time/point-in-time with index filter', // We don't support filtering pit on runtime fields. 'aggregations/filters_bucket/cache busting', // runtime keyword does not support split_queries_on_whitespace 'search/140_pre_filter_search_shards/pre_filter_shard_size with shards that have no hit', //completion suggester does not return options when the context field is a geo_point runtime field From 7345e643ba7c678342f1270f47ac7efa800197d3 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Tue, 21 Nov 2023 09:58:07 -0500 Subject: [PATCH 22/22] [ES|QL] pow function always returns double (#102183) This corrects an earlier mistake in the ES|QL language design. Initially we had thought to have pow return the same type as its inputs, but in practice even for integer inputs this quickly grows out of the representable range, and we returned null much of the time. This also created a lot of edge cases around casting to/from doubles (which the underlying java function uses). The version in this PR follows the java spec, by always casting its inputs to doubles, and returning a double. Doing it this way also allows for a rather significant reduction in lines of code. I removed many of the tests covering pow specific edge cases. This seems reasonable to me as I expect java.lang.math.pow to be well behaved and most of those edge cases were around type testing which no longer applies. At the same time, this simplification allows us to leverage the new scalar function testing framework, which means better null coverage, better type coverage, and much easier extensibility. We do consider this a breaking change, but as the feature is still in tech preview and this is a relatively small surface area, we are not too concerned with disruptions. Resolves #99055 Relates to #100558 --------- Co-authored-by: Elastic Machine --- docs/changelog/102183.yaml | 13 + docs/reference/esql/functions/pow.asciidoc | 59 +-- .../esql/functions/types/pow.asciidoc | 4 +- .../src/main/resources/math.csv-spec | 70 ++- .../src/main/resources/show.csv-spec | 4 +- ...DoubleEvaluator.java => PowEvaluator.java} | 12 +- .../function/scalar/math/PowIntEvaluator.java | 131 ----- .../scalar/math/PowLongEvaluator.java | 131 ----- .../expression/function/scalar/math/Pow.java | 93 +--- .../expression/function/TestCaseSupplier.java | 42 +- .../function/scalar/math/PowTests.java | 455 ++---------------- 11 files changed, 138 insertions(+), 876 deletions(-) create mode 100644 docs/changelog/102183.yaml rename x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/{PowDoubleEvaluator.java => PowEvaluator.java} (89%) delete mode 100644 x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowIntEvaluator.java delete mode 100644 x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowLongEvaluator.java diff --git a/docs/changelog/102183.yaml b/docs/changelog/102183.yaml new file mode 100644 index 0000000000000..3daa1418ba5d0 --- /dev/null +++ b/docs/changelog/102183.yaml @@ -0,0 +1,13 @@ +pr: 102183 +summary: "[ES|QL] pow function always returns double" +area: ES|QL +type: "breaking" +issues: + - 99055 +breaking: + title: "[ES|QL] pow function always returns double" + area: REST API + details: "In ES|QL, the pow function no longer returns the type of its inputs, instead\ + \ always returning a double." + impact: low. Most queries should continue to function with the change. + notable: false diff --git a/docs/reference/esql/functions/pow.asciidoc b/docs/reference/esql/functions/pow.asciidoc index 9f7805bfd3eae..b13151c8cbd76 100644 --- a/docs/reference/esql/functions/pow.asciidoc +++ b/docs/reference/esql/functions/pow.asciidoc @@ -5,7 +5,8 @@ image::esql/functions/signature/pow.svg[Embedded,opts=inline] Returns the value of a base (first argument) raised to the power of an exponent (second argument). -Both arguments must be numeric. +Both arguments must be numeric. The output is always a double. Note that it is still possible to overflow +a double result here; in that case, null will be returned. [source.merge.styled,esql] ---- @@ -16,62 +17,6 @@ include::{esql-specs}/math.csv-spec[tag=powDI] include::{esql-specs}/math.csv-spec[tag=powDI-result] |=== -[discrete] -==== Type rules - -The type of the returned value is determined by the types of the base and exponent. -The following rules are applied to determine the result type: - -* If either of the base or exponent are of a floating point type, the result will be a double -* Otherwise, if either the base or the exponent are 64-bit (long or unsigned long), the result will be a long -* Otherwise, the result will be a 32-bit integer (this covers all other numeric types, including int, short and byte) - -For example, using simple integers as arguments will lead to an integer result: - -[source.merge.styled,esql] ----- -include::{esql-specs}/math.csv-spec[tag=powII] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/math.csv-spec[tag=powII-result] -|=== - -NOTE: The actual power function is performed using double precision values for all cases. -This means that for very large non-floating point values there is a small chance that the -operation can lead to slightly different answers than expected. -However, a more likely outcome of very large non-floating point values is numerical overflow. - -[discrete] -==== Arithmetic errors - -Arithmetic errors and numeric overflow do not result in an error. Instead, the result will be `null` -and a warning for the `ArithmeticException` added. -For example: - -[source.merge.styled,esql] ----- -include::{esql-specs}/math.csv-spec[tag=powULOverrun] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/math.csv-spec[tag=powULOverrun-warning] -|=== -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/math.csv-spec[tag=powULOverrun-result] -|=== - -If it is desired to protect against numerical overruns, use `TO_DOUBLE` on either of the arguments: - -[source.merge.styled,esql] ----- -include::{esql-specs}/math.csv-spec[tag=pow2d] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/math.csv-spec[tag=pow2d-result] -|=== [discrete] ==== Fractional exponents diff --git a/docs/reference/esql/functions/types/pow.asciidoc b/docs/reference/esql/functions/types/pow.asciidoc index 37bddc60c118f..3965732945c50 100644 --- a/docs/reference/esql/functions/types/pow.asciidoc +++ b/docs/reference/esql/functions/types/pow.asciidoc @@ -4,7 +4,7 @@ base | exponent | result double | double | double double | integer | double integer | double | double -integer | integer | integer +integer | integer | double long | double | double -long | integer | long +long | integer | double |=== diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec index 70b416a8a9c02..a6e24e9d45289 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec @@ -200,10 +200,10 @@ height:double | s:double 1.53 | 0.34 ; -powSalarySquared +powSalarySquared#[skip:-8.11.99,reason:return type changed in 8.12] from employees | eval s = pow(to_long(salary) - 75000, 2) + 10000 | keep salary, s | sort salary desc | limit 4; -salary:integer | s:long +salary:integer | s:double 74999 | 10001 74970 | 10900 74572 | 193184 @@ -328,14 +328,14 @@ base:integer | exponent:double | s:double // end::powID-sqrt-result[] ; -powSqrtNeg +powSqrtNeg#[skip:-8.11.99,reason:return type changed in 8.12] // tag::powNeg-sqrt[] ROW base = -4, exponent = 0.5 | EVAL s = POW(base, exponent) // end::powNeg-sqrt[] ; warning:Line 2:12: evaluation of [POW(base, exponent)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 2:12: java.lang.ArithmeticException: invalid result: pow(-4.0, 0.5) +warning:Line 2:12: java.lang.ArithmeticException: invalid result when computing pow // tag::powNeg-sqrt-result[] base:integer | exponent:double | s:double @@ -356,23 +356,19 @@ base:double | exponent:integer | result:double // end::powDI-result[] ; -powIntInt -// tag::powII[] +powIntInt#[skip:-8.11.99,reason:return type changed in 8.12] ROW base = 2, exponent = 2 | EVAL s = POW(base, exponent) -// end::powII[] ; -// tag::powII-result[] -base:integer | exponent:integer | s:integer -2 | 2 | 4 -// end::powII-result[] +base:integer | exponent:integer | s:double +2 | 2 | 4.0 ; -powIntIntPlusInt +powIntIntPlusInt#[skip:-8.11.99,reason:return type changed in 8.12] row s = 1 + pow(2, 2); -s:integer +s:double 5 ; @@ -383,24 +379,24 @@ s:double 5 ; -powIntUL +powIntUL#[skip:-8.11.99,reason:return type changed in 8.12] row x = pow(1, 9223372036854775808); -x:long +x:double 1 ; -powLongUL +powLongUL#[skip:-8.11.99,reason:return type changed in 8.12] row x = to_long(1) | eval x = pow(x, 9223372036854775808); -x:long +x:double 1 ; -powUnsignedLongUL +powUnsignedLongUL#[skip:-8.11.99,reason:return type changed in 8.12] row x = to_ul(1) | eval x = pow(x, 9223372036854775808); -x:long +x:double 1 ; @@ -411,36 +407,28 @@ x:double 1.0 ; -powIntULOverrun +powIntULOverrun#[skip:-8.11.99,reason:return type changed in 8.12] row x = pow(2, 9223372036854775808); warning:Line 1:9: evaluation of [pow(2, 9223372036854775808)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:9: java.lang.ArithmeticException: long overflow +warning:Line 1:9: java.lang.ArithmeticException: invalid result when computing pow -x:long +x:double null ; -powULInt +powULInt#[skip:-8.11.99,reason:return type changed in 8.12] row x = pow(to_unsigned_long(9223372036854775807), 1); -x:long +x:double 9223372036854775807 ; -powULIntOverrun -// tag::powULOverrun[] +powULIntOverrun#[skip:-8.11.99,reason:return type changed in 8.12] ROW x = POW(9223372036854775808, 2) -// end::powULOverrun[] ; -// tag::powULOverrun-warning[] -warning:Line 1:9: evaluation of [POW(9223372036854775808, 2)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:9: java.lang.ArithmeticException: long overflow -// end::powULOverrun-warning[] -// tag::powULOverrun-result[] -x:long -null -// end::powULOverrun-result[] +x:double +8.507059173023462E37 ; powULInt_2d @@ -455,20 +443,18 @@ x:double // end::pow2d-result[] ; -powULLong +powULLong#[skip:-8.11.99,reason:return type changed in 8.12] row x = to_long(10) | eval x = pow(to_unsigned_long(10), x); -x:long +x:double 10000000000 ; -powULLongOverrun +powULLongOverrun#[skip:-8.11.99,reason:return type changed in 8.12] row x = to_long(100) | eval x = pow(to_unsigned_long(10), x); -warning:Line 1:33: evaluation of [pow(to_unsigned_long(10), x)] failed, treating result as null. Only first 20 failures recorded. -warning:Line 1:33: java.lang.ArithmeticException: long overflow -x:long -null +x:double +1.0E100 ; powULDouble diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec index d3abef8021f66..f056e6f2d81bf 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/show.csv-spec @@ -56,7 +56,7 @@ mv_sum |? mv_sum(arg1:?) now |? now() | null |null | null |? | "" | null | false percentile |? percentile(arg1:?, arg2:?) |[arg1, arg2] |[?, ?] |["", ""] |? | "" | [false, false] | false pi |? pi() | null | null | null |? | "" | null | false -pow |"? pow(base:integer|long|double, exponent:integer|double)" |[base, exponent] |["integer|long|double", "integer|double"] |["", ""] |? | "" | [false, false] | false +pow |"? pow(base:integer|unsigned_long|long|double, exponent:integer|unsigned_long|long|double)" |[base, exponent] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["", ""] |? | "" | [false, false] | false replace |"? replace(arg1:?, arg2:?, arg3:?)" | [arg1, arg2, arg3] | [?, ?, ?] |["", "", ""] |? | "" | [false, false, false]| false right |"? right(string:keyword, length:integer)" |[string, length] |["keyword", "integer"] |["", ""] |? | "" | [false, false] | false round |? round(arg1:?, arg2:?) |[arg1, arg2] |[?, ?] |["", ""] |? | "" | [false, false] | false @@ -145,7 +145,7 @@ synopsis:keyword ? now() ? percentile(arg1:?, arg2:?) ? pi() -"? pow(base:integer|long|double, exponent:integer|double)" +"? pow(base:integer|unsigned_long|long|double, exponent:integer|unsigned_long|long|double)" "? replace(arg1:?, arg2:?, arg3:?)" "? right(string:keyword, length:integer)" ? round(arg1:?, arg2:?) diff --git a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowDoubleEvaluator.java b/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowEvaluator.java similarity index 89% rename from x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowDoubleEvaluator.java rename to x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowEvaluator.java index 104385e0e51ef..12fe99b0ab2be 100644 --- a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowDoubleEvaluator.java +++ b/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowEvaluator.java @@ -21,7 +21,7 @@ * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Pow}. * This class is generated. Do not edit it. */ -public final class PowDoubleEvaluator implements EvalOperator.ExpressionEvaluator { +public final class PowEvaluator implements EvalOperator.ExpressionEvaluator { private final Warnings warnings; private final EvalOperator.ExpressionEvaluator base; @@ -30,7 +30,7 @@ public final class PowDoubleEvaluator implements EvalOperator.ExpressionEvaluato private final DriverContext driverContext; - public PowDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator base, + public PowEvaluator(Source source, EvalOperator.ExpressionEvaluator base, EvalOperator.ExpressionEvaluator exponent, DriverContext driverContext) { this.warnings = new Warnings(source); this.base = base; @@ -95,7 +95,7 @@ public DoubleBlock eval(int positionCount, DoubleVector baseVector, DoubleVector @Override public String toString() { - return "PowDoubleEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; + return "PowEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; } @Override @@ -118,13 +118,13 @@ public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory base, } @Override - public PowDoubleEvaluator get(DriverContext context) { - return new PowDoubleEvaluator(source, base.get(context), exponent.get(context), context); + public PowEvaluator get(DriverContext context) { + return new PowEvaluator(source, base.get(context), exponent.get(context), context); } @Override public String toString() { - return "PowDoubleEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; + return "PowEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; } } } diff --git a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowIntEvaluator.java b/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowIntEvaluator.java deleted file mode 100644 index f9d0842a1ab74..0000000000000 --- a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowIntEvaluator.java +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License -// 2.0; you may not use this file except in compliance with the Elastic License -// 2.0. -package org.elasticsearch.xpack.esql.expression.function.scalar.math; - -import java.lang.ArithmeticException; -import java.lang.Override; -import java.lang.String; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.IntBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.EvalOperator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.xpack.esql.expression.function.Warnings; -import org.elasticsearch.xpack.ql.tree.Source; - -/** - * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Pow}. - * This class is generated. Do not edit it. - */ -public final class PowIntEvaluator implements EvalOperator.ExpressionEvaluator { - private final Warnings warnings; - - private final EvalOperator.ExpressionEvaluator base; - - private final EvalOperator.ExpressionEvaluator exponent; - - private final DriverContext driverContext; - - public PowIntEvaluator(Source source, EvalOperator.ExpressionEvaluator base, - EvalOperator.ExpressionEvaluator exponent, DriverContext driverContext) { - this.warnings = new Warnings(source); - this.base = base; - this.exponent = exponent; - this.driverContext = driverContext; - } - - @Override - public Block.Ref eval(Page page) { - try (Block.Ref baseRef = base.eval(page)) { - DoubleBlock baseBlock = (DoubleBlock) baseRef.block(); - try (Block.Ref exponentRef = exponent.eval(page)) { - DoubleBlock exponentBlock = (DoubleBlock) exponentRef.block(); - DoubleVector baseVector = baseBlock.asVector(); - if (baseVector == null) { - return Block.Ref.floating(eval(page.getPositionCount(), baseBlock, exponentBlock)); - } - DoubleVector exponentVector = exponentBlock.asVector(); - if (exponentVector == null) { - return Block.Ref.floating(eval(page.getPositionCount(), baseBlock, exponentBlock)); - } - return Block.Ref.floating(eval(page.getPositionCount(), baseVector, exponentVector)); - } - } - } - - public IntBlock eval(int positionCount, DoubleBlock baseBlock, DoubleBlock exponentBlock) { - try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { - position: for (int p = 0; p < positionCount; p++) { - if (baseBlock.isNull(p) || baseBlock.getValueCount(p) != 1) { - result.appendNull(); - continue position; - } - if (exponentBlock.isNull(p) || exponentBlock.getValueCount(p) != 1) { - result.appendNull(); - continue position; - } - try { - result.appendInt(Pow.processInt(baseBlock.getDouble(baseBlock.getFirstValueIndex(p)), exponentBlock.getDouble(exponentBlock.getFirstValueIndex(p)))); - } catch (ArithmeticException e) { - warnings.registerException(e); - result.appendNull(); - } - } - return result.build(); - } - } - - public IntBlock eval(int positionCount, DoubleVector baseVector, DoubleVector exponentVector) { - try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { - position: for (int p = 0; p < positionCount; p++) { - try { - result.appendInt(Pow.processInt(baseVector.getDouble(p), exponentVector.getDouble(p))); - } catch (ArithmeticException e) { - warnings.registerException(e); - result.appendNull(); - } - } - return result.build(); - } - } - - @Override - public String toString() { - return "PowIntEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; - } - - @Override - public void close() { - Releasables.closeExpectNoException(base, exponent); - } - - static class Factory implements EvalOperator.ExpressionEvaluator.Factory { - private final Source source; - - private final EvalOperator.ExpressionEvaluator.Factory base; - - private final EvalOperator.ExpressionEvaluator.Factory exponent; - - public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory base, - EvalOperator.ExpressionEvaluator.Factory exponent) { - this.source = source; - this.base = base; - this.exponent = exponent; - } - - @Override - public PowIntEvaluator get(DriverContext context) { - return new PowIntEvaluator(source, base.get(context), exponent.get(context), context); - } - - @Override - public String toString() { - return "PowIntEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; - } - } -} diff --git a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowLongEvaluator.java b/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowLongEvaluator.java deleted file mode 100644 index 1aba4fe7f0655..0000000000000 --- a/x-pack/plugin/esql/src/main/java/generated/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowLongEvaluator.java +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -// or more contributor license agreements. Licensed under the Elastic License -// 2.0; you may not use this file except in compliance with the Elastic License -// 2.0. -package org.elasticsearch.xpack.esql.expression.function.scalar.math; - -import java.lang.ArithmeticException; -import java.lang.Override; -import java.lang.String; -import org.elasticsearch.compute.data.Block; -import org.elasticsearch.compute.data.DoubleBlock; -import org.elasticsearch.compute.data.DoubleVector; -import org.elasticsearch.compute.data.LongBlock; -import org.elasticsearch.compute.data.Page; -import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.compute.operator.EvalOperator; -import org.elasticsearch.core.Releasables; -import org.elasticsearch.xpack.esql.expression.function.Warnings; -import org.elasticsearch.xpack.ql.tree.Source; - -/** - * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Pow}. - * This class is generated. Do not edit it. - */ -public final class PowLongEvaluator implements EvalOperator.ExpressionEvaluator { - private final Warnings warnings; - - private final EvalOperator.ExpressionEvaluator base; - - private final EvalOperator.ExpressionEvaluator exponent; - - private final DriverContext driverContext; - - public PowLongEvaluator(Source source, EvalOperator.ExpressionEvaluator base, - EvalOperator.ExpressionEvaluator exponent, DriverContext driverContext) { - this.warnings = new Warnings(source); - this.base = base; - this.exponent = exponent; - this.driverContext = driverContext; - } - - @Override - public Block.Ref eval(Page page) { - try (Block.Ref baseRef = base.eval(page)) { - DoubleBlock baseBlock = (DoubleBlock) baseRef.block(); - try (Block.Ref exponentRef = exponent.eval(page)) { - DoubleBlock exponentBlock = (DoubleBlock) exponentRef.block(); - DoubleVector baseVector = baseBlock.asVector(); - if (baseVector == null) { - return Block.Ref.floating(eval(page.getPositionCount(), baseBlock, exponentBlock)); - } - DoubleVector exponentVector = exponentBlock.asVector(); - if (exponentVector == null) { - return Block.Ref.floating(eval(page.getPositionCount(), baseBlock, exponentBlock)); - } - return Block.Ref.floating(eval(page.getPositionCount(), baseVector, exponentVector)); - } - } - } - - public LongBlock eval(int positionCount, DoubleBlock baseBlock, DoubleBlock exponentBlock) { - try(LongBlock.Builder result = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { - position: for (int p = 0; p < positionCount; p++) { - if (baseBlock.isNull(p) || baseBlock.getValueCount(p) != 1) { - result.appendNull(); - continue position; - } - if (exponentBlock.isNull(p) || exponentBlock.getValueCount(p) != 1) { - result.appendNull(); - continue position; - } - try { - result.appendLong(Pow.processLong(baseBlock.getDouble(baseBlock.getFirstValueIndex(p)), exponentBlock.getDouble(exponentBlock.getFirstValueIndex(p)))); - } catch (ArithmeticException e) { - warnings.registerException(e); - result.appendNull(); - } - } - return result.build(); - } - } - - public LongBlock eval(int positionCount, DoubleVector baseVector, DoubleVector exponentVector) { - try(LongBlock.Builder result = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { - position: for (int p = 0; p < positionCount; p++) { - try { - result.appendLong(Pow.processLong(baseVector.getDouble(p), exponentVector.getDouble(p))); - } catch (ArithmeticException e) { - warnings.registerException(e); - result.appendNull(); - } - } - return result.build(); - } - } - - @Override - public String toString() { - return "PowLongEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; - } - - @Override - public void close() { - Releasables.closeExpectNoException(base, exponent); - } - - static class Factory implements EvalOperator.ExpressionEvaluator.Factory { - private final Source source; - - private final EvalOperator.ExpressionEvaluator.Factory base; - - private final EvalOperator.ExpressionEvaluator.Factory exponent; - - public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory base, - EvalOperator.ExpressionEvaluator.Factory exponent) { - this.source = source; - this.base = base; - this.exponent = exponent; - } - - @Override - public PowLongEvaluator get(DriverContext context) { - return new PowLongEvaluator(source, base.get(context), exponent.get(context), context); - } - - @Override - public String toString() { - return "PowLongEvaluator[" + "base=" + base + ", exponent=" + exponent + "]"; - } - } -} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pow.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pow.java index 48db81fefbc98..9e160e7c2f15f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pow.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pow.java @@ -25,7 +25,6 @@ import java.util.Objects; import java.util.function.Function; -import static org.elasticsearch.xpack.esql.expression.function.scalar.math.Cast.cast; import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.ql.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isNumeric; @@ -37,13 +36,13 @@ public class Pow extends ScalarFunction implements OptionalArgument, EvaluatorMa public Pow( Source source, - @Param(name = "base", type = { "integer", "long", "double" }) Expression base, - @Param(name = "exponent", type = { "integer", "double" }) Expression exponent + @Param(name = "base", type = { "integer", "unsigned_long", "long", "double" }) Expression base, + @Param(name = "exponent", type = { "integer", "unsigned_long", "long", "double" }) Expression exponent ) { super(source, Arrays.asList(base, exponent)); this.base = base; this.exponent = exponent; - this.dataType = determineDataType(base, exponent); + this.dataType = DataTypes.DOUBLE; } @Override @@ -70,65 +69,19 @@ public Object fold() { return EvaluatorMapper.super.fold(); } - @Evaluator(extraName = "Double", warnExceptions = { ArithmeticException.class }) + @Evaluator(warnExceptions = { ArithmeticException.class }) static double process(double base, double exponent) { return validateAsDouble(base, exponent); } - @Evaluator(extraName = "Long", warnExceptions = { ArithmeticException.class }) - static long processLong(double base, double exponent) { - return exponent == 1 ? validateAsLong(base) : validateAsLong(base, exponent); - } - - @Evaluator(extraName = "Int", warnExceptions = { ArithmeticException.class }) - static int processInt(double base, double exponent) { - return exponent == 1 ? validateAsInt(base) : validateAsInt(base, exponent); - } - private static double validateAsDouble(double base, double exponent) { double result = Math.pow(base, exponent); - if (Double.isNaN(result)) { - throw new ArithmeticException("invalid result: pow(" + base + ", " + exponent + ")"); + if (Double.isNaN(result) || Double.isInfinite(result)) { + throw new ArithmeticException("invalid result when computing pow"); } return result; } - private static long validateAsLong(double base, double exponent) { - double result = Math.pow(base, exponent); - if (Double.isNaN(result)) { - throw new ArithmeticException("invalid result: pow(" + base + ", " + exponent + ")"); - } - return validateAsLong(result); - } - - private static long validateAsLong(double value) { - if (Double.compare(value, Long.MAX_VALUE) > 0) { - throw new ArithmeticException("long overflow"); - } - if (Double.compare(value, Long.MIN_VALUE) < 0) { - throw new ArithmeticException("long overflow"); - } - return (long) value; - } - - private static int validateAsInt(double base, double exponent) { - double result = Math.pow(base, exponent); - if (Double.isNaN(result)) { - throw new ArithmeticException("invalid result: pow(" + base + ", " + exponent + ")"); - } - return validateAsInt(result); - } - - private static int validateAsInt(double value) { - if (Double.compare(value, Integer.MAX_VALUE) > 0) { - throw new ArithmeticException("integer overflow"); - } - if (Double.compare(value, Integer.MIN_VALUE) < 0) { - throw new ArithmeticException("integer overflow"); - } - return (int) value; - } - @Override public final Expression replaceChildren(List newChildren) { return new Pow(source(), newChildren.get(0), newChildren.get(1)); @@ -152,16 +105,6 @@ public DataType dataType() { return dataType; } - private static DataType determineDataType(Expression base, Expression exponent) { - if (base.dataType().isRational() || exponent.dataType().isRational()) { - return DataTypes.DOUBLE; - } - if (base.dataType().size() == Long.BYTES || exponent.dataType().size() == Long.BYTES) { - return DataTypes.LONG; - } - return DataTypes.INTEGER; - } - @Override public ScriptTemplate asScript() { throw new UnsupportedOperationException("functions do not support scripting"); @@ -169,27 +112,9 @@ public ScriptTemplate asScript() { @Override public ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { - var baseEvaluator = toEvaluator.apply(base); - var exponentEvaluator = toEvaluator.apply(exponent); - if (dataType == DataTypes.DOUBLE) { - return new PowDoubleEvaluator.Factory( - source(), - cast(base.dataType(), DataTypes.DOUBLE, baseEvaluator), - cast(exponent.dataType(), DataTypes.DOUBLE, exponentEvaluator) - ); - } else if (dataType == DataTypes.LONG) { - return new PowLongEvaluator.Factory( - source(), - cast(base.dataType(), DataTypes.DOUBLE, baseEvaluator), - cast(exponent.dataType(), DataTypes.DOUBLE, exponentEvaluator) - ); - } else { - return new PowIntEvaluator.Factory( - source(), - cast(base.dataType(), DataTypes.DOUBLE, baseEvaluator), - cast(exponent.dataType(), DataTypes.DOUBLE, exponentEvaluator) - ); - } + var baseEval = Cast.cast(base.dataType(), DataTypes.DOUBLE, toEvaluator.apply(base)); + var expEval = Cast.cast(exponent.dataType(), DataTypes.DOUBLE, toEvaluator.apply(exponent)); + return new PowEvaluator.Factory(source(), baseEval, expEval); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index e49776db1edea..8603cea9e873c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -26,7 +26,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.function.DoubleBinaryOperator; +import java.util.function.BinaryOperator; import java.util.function.DoubleFunction; import java.util.function.Function; import java.util.function.IntFunction; @@ -138,27 +138,28 @@ public static List forBinaryCastingToDouble( String name, String lhsName, String rhsName, - DoubleBinaryOperator expected, + BinaryOperator expected, Double lhsMin, Double lhsMax, Double rhsMin, Double rhsMax, List warnings ) { - List suppliers = new ArrayList<>(); - List lhsSuppliers = new ArrayList<>(); - List rhsSuppliers = new ArrayList<>(); - - lhsSuppliers.addAll(intCases(lhsMin.intValue(), lhsMax.intValue())); - lhsSuppliers.addAll(longCases(lhsMin.longValue(), lhsMax.longValue())); - lhsSuppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(lhsMin)), BigInteger.valueOf((long) Math.floor(lhsMax)))); - lhsSuppliers.addAll(doubleCases(lhsMin, lhsMax)); - - rhsSuppliers.addAll(intCases(rhsMin.intValue(), rhsMax.intValue())); - rhsSuppliers.addAll(longCases(rhsMin.longValue(), rhsMax.longValue())); - rhsSuppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(rhsMin)), BigInteger.valueOf((long) Math.floor(rhsMax)))); - rhsSuppliers.addAll(doubleCases(rhsMin, rhsMax)); + List lhsSuppliers = castToDoubleSuppliersFromRange(lhsMin, lhsMax); + List rhsSuppliers = castToDoubleSuppliersFromRange(rhsMin, rhsMax); + return forBinaryCastingToDouble(name, lhsName, rhsName, expected, lhsSuppliers, rhsSuppliers, warnings); + } + public static List forBinaryCastingToDouble( + String name, + String lhsName, + String rhsName, + BinaryOperator expected, + List lhsSuppliers, + List rhsSuppliers, + List warnings + ) { + List suppliers = new ArrayList<>(); for (TypedDataSupplier lhsSupplier : lhsSuppliers) { for (TypedDataSupplier rhsSupplier : rhsSuppliers) { String caseName = lhsSupplier.name() + ", " + rhsSupplier.name(); @@ -182,7 +183,7 @@ public static List forBinaryCastingToDouble( List.of(lhsTyped, rhsTyped), name + "[" + lhsName + "=" + lhsEvalName + ", " + rhsName + "=" + rhsEvalName + "]", DataTypes.DOUBLE, - equalTo(expected.applyAsDouble(lhs.doubleValue(), rhs.doubleValue())) + equalTo(expected.apply(lhs.doubleValue(), rhs.doubleValue())) ); for (String warning : warnings) { testCase = testCase.withWarning(warning); @@ -195,6 +196,15 @@ public static List forBinaryCastingToDouble( return suppliers; } + public static List castToDoubleSuppliersFromRange(Double Min, Double Max) { + List suppliers = new ArrayList<>(); + suppliers.addAll(intCases(Min.intValue(), Max.intValue())); + suppliers.addAll(longCases(Min.longValue(), Max.longValue())); + suppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(Min)), BigInteger.valueOf((long) Math.floor(Max)))); + suppliers.addAll(doubleCases(Min, Max)); + return suppliers; + } + /** * Generate positive test cases for a unary function operating on an {@link DataTypes#INTEGER}. */ diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java index 58f56e54c7245..c8b316d8e6bfb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java @@ -10,7 +10,6 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractScalarFunctionTestCase; import org.elasticsearch.xpack.ql.expression.Expression; @@ -21,8 +20,6 @@ import java.util.List; import java.util.function.Supplier; -import static org.hamcrest.Matchers.equalTo; - public class PowTests extends AbstractScalarFunctionTestCase { public PowTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); @@ -30,415 +27,63 @@ public PowTests(@Name("TestCase") Supplier testCaseSu @ParametersFactory public static Iterable parameters() { - return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("pow(, )", () -> { - double base = 1 / randomDouble(); - int exponent = between(-30, 30); - return new TestCaseSupplier.TestCase( + // Positive number bases + List suppliers = TestCaseSupplier.forBinaryCastingToDouble( + "PowEvaluator", + "base", + "exponent", + Math::pow, + // 143^143 is still representable, but 144^144 is infinite + 1d, + 143d, + -143d, + 143d, + List.of() + ); + // Anything to 0 is 1 + suppliers.addAll( + TestCaseSupplier.forBinaryCastingToDouble( + "PowEvaluator", + "base", + "exponent", + (b, e) -> 1d, + // 143^143 is still representable, but 144^144 is infinite + TestCaseSupplier.castToDoubleSuppliersFromRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY), List.of( - new TestCaseSupplier.TypedData(base, DataTypes.DOUBLE, "arg"), - new TestCaseSupplier.TypedData(exponent, DataTypes.INTEGER, "exp") + new TestCaseSupplier.TypedDataSupplier("<0 double>", () -> 0d, DataTypes.DOUBLE), + new TestCaseSupplier.TypedDataSupplier("<-0 double>", () -> -0d, DataTypes.DOUBLE) ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(Math.pow(base, exponent)) - ); - }), - new TestCaseSupplier( - "pow(NaN, 1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(Double.NaN, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(1.0d, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(null) - ).withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line -1:-1: java.lang.ArithmeticException: invalid result: pow(NaN, 1.0)") - ), - new TestCaseSupplier( - "pow(1, NaN)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(1.0d, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(Double.NaN, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(null) - ).withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line -1:-1: java.lang.ArithmeticException: invalid result: pow(1.0, NaN)") - ), - new TestCaseSupplier( - "pow(NaN, 0)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(Double.NaN, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(0d, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(1d) - ) - ), - new TestCaseSupplier( - "pow(0, 0)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(0d, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(0d, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(1d) - ) - ), - new TestCaseSupplier( - "pow(1, 1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "base") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo(1) - ) - ), - new TestCaseSupplier( - "pow(integer, 0)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(randomValueOtherThan(0, ESTestCase::randomInt), DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(0, DataTypes.INTEGER, "exp") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo(1) - ) - ), - new TestCaseSupplier("pow(integer, 2)", () -> { - int base = randomIntBetween(-1000, 1000); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "exp") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo((int) Math.pow(base, 2)) - ); - }), - new TestCaseSupplier( - "integer overflow case", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(Integer.MAX_VALUE, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "exp") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo(null) - ).withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line -1:-1: java.lang.ArithmeticException: integer overflow") - ), - new TestCaseSupplier( - "long overflow case", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(Long.MAX_VALUE, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(null) - ).withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line -1:-1: java.lang.ArithmeticException: long overflow") - ), - new TestCaseSupplier( - "pow(2, 0.5) == sqrt(2)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(0.5, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(Math.sqrt(2)) - ) - ), - new TestCaseSupplier( - "pow(2.0, 0.5) == sqrt(2)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(2d, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(0.5, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(Math.sqrt(2)) - ) - ), - new TestCaseSupplier("pow(integer, double)", () -> { - // Positive numbers to a non-integer power - int base = randomIntBetween(1, 1000); - double exp = randomDoubleBetween(-10.0, 10.0, true); - double expected = Math.pow(base, exp); - TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(exp, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(expected) - ); - return testCase; - }), - new TestCaseSupplier("fractional power of negative integer is null", () -> { - // Negative numbers to a non-integer power are NaN - int base = randomIntBetween(-1000, -1); - double exp = randomDouble(); // between 0 and 1 - TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(exp, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(null) - ).withWarning("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.") - .withWarning("Line -1:-1: java.lang.ArithmeticException: invalid result: pow(" + (double) base + ", " + exp + ")"); - return testCase; - }), - new TestCaseSupplier( - "pow(123, -1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(123, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(-1, DataTypes.INTEGER, "exp") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo(0) - ) - ), - new TestCaseSupplier( - "pow(123L, -1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(123L, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(-1, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(0L) - ) - ), - new TestCaseSupplier( - "pow(123D, -1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(123.0, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(-1, DataTypes.INTEGER, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(1D / 123D) - ) - ), - new TestCaseSupplier("pow(integer, 1)", () -> { - int base = randomInt(); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.INTEGER, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowIntEvaluator[base=CastIntToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.INTEGER, - equalTo(base) - ); - }), - new TestCaseSupplier( - "pow(1L, 1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(1L, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(1L) - ) - ), - new TestCaseSupplier("pow(long, 1)", () -> { - // Avoid double precision loss - long base = randomLongBetween(-1L << 51, 1L << 51); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(base) - ); - }), - new TestCaseSupplier("long-double overflow", () -> { - long base = 4339622345450989181L; // Not exactly representable as a double - long expected = 4339622345450989056L; - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(expected) - ); - }), - new TestCaseSupplier("pow(long, 0)", () -> { - long base = randomLong(); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(0, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo(1L) - ); - }), - new TestCaseSupplier("pow(long, 2)", () -> { - long base = randomLongBetween(-1000, 1000); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "exp") - ), - "PowLongEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], " - + "exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.LONG, - equalTo((long) Math.pow(base, 2)) - ); - }), - new TestCaseSupplier("pow(long, double)", () -> { - // Negative numbers to non-integer power are NaN - long base = randomLongBetween(1, 1000); - double exp = randomDoubleBetween(-10.0, 10.0, true); - double expected = Math.pow(base, exp); - TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.LONG, "base"), - new TestCaseSupplier.TypedData(exp, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=CastLongToDoubleEvaluator[v=Attribute[channel=0]], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(expected) - ); - return testCase; - }), - new TestCaseSupplier( - "pow(1D, 1)", - () -> new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(1D, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(1D) + List.of() + ) + ); + + // Add null cases before the rest of the error cases, so messages are correct. + suppliers = anyNullIsNull(true, suppliers); + + // Overflow should be null + suppliers.addAll( + TestCaseSupplier.forBinaryCastingToDouble( + "PowEvaluator", + "base", + "exponent", + (b, e) -> null, + // 143^143 is still representable, but 144^144 is infinite + 144d, + Double.POSITIVE_INFINITY, + 144d, + Double.POSITIVE_INFINITY, + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: invalid result when computing pow" ) - ), - new TestCaseSupplier("pow(double, 1)", () -> { - double base; - if (randomBoolean()) { - base = randomDouble(); - } else { - // Sometimes pick a large number - base = 1 / randomDouble(); - } - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(1, DataTypes.INTEGER, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(base) - ); - }), - new TestCaseSupplier("pow(double, 0)", () -> { - double base; - if (randomBoolean()) { - base = randomDouble(); - } else { - // Sometimes pick a large number - base = 1 / randomDouble(); - } - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(0, DataTypes.INTEGER, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(1D) - ); - }), - new TestCaseSupplier("pow(double, 2)", () -> { - double base = randomDoubleBetween(-1000, 1000, true); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(2, DataTypes.INTEGER, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=CastIntToDoubleEvaluator[v=Attribute[channel=1]]]", - DataTypes.DOUBLE, - equalTo(Math.pow(base, 2)) - ); - }), - new TestCaseSupplier("pow(double, double)", () -> { - // Negative numbers to a non-integer power are NaN - double base = randomDoubleBetween(0, 1000, true); - double exp = randomDoubleBetween(-10.0, 10.0, true); - TestCaseSupplier.TestCase testCase = new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(base, DataTypes.DOUBLE, "base"), - new TestCaseSupplier.TypedData(exp, DataTypes.DOUBLE, "exp") - ), - "PowDoubleEvaluator[base=Attribute[channel=0], exponent=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(Math.pow(base, exp)) - ); - return testCase; - }) - )); + ) + ); + return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers)); } @Override protected DataType expectedType(List argTypes) { - var base = argTypes.get(0); - var exp = argTypes.get(1); - if (base.isRational() || exp.isRational()) { - return DataTypes.DOUBLE; - } else if (base.size() == Long.BYTES || exp.size() == Long.BYTES) { - return DataTypes.LONG; - } else { - return DataTypes.INTEGER; - } + return DataTypes.DOUBLE; } @Override