diff --git a/changelog/unreleased/SOLR-18245-rerank-cutoff-value.yml b/changelog/unreleased/SOLR-18245-rerank-cutoff-value.yml new file mode 100644 index 000000000000..ad1bf506b78b --- /dev/null +++ b/changelog/unreleased/SOLR-18245-rerank-cutoff-value.yml @@ -0,0 +1,8 @@ +title: Optional reRankCutoff field returned in responseHeader. Used to return main sort value(s) from the lowest ranked document eligible for inclusion in rerank. +type: added +authors: + - name: Darren Shaw + nick: shawdm +links: + - name: SOLR-18245 + url: https://issues.apache.org/jira/browse/SOLR-18245 diff --git a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java index 9bc219e6d538..3c8c2a7a5bdd 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java +++ b/solr/core/src/java/org/apache/solr/handler/component/QueryComponent.java @@ -80,6 +80,7 @@ import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; import org.apache.solr.schema.SortableTextField; +import org.apache.solr.search.AbstractReRankQuery; import org.apache.solr.search.CursorMark; import org.apache.solr.search.DocIterator; import org.apache.solr.search.DocList; @@ -1028,6 +1029,8 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) { boolean maxHitsTerminatedEarly = false; long approximateTotalHits = 0; int failedShardCount = 0; + int failedShardCountForReRankCutoff = 0; + NamedList reRankCutoffByShard = null; for (ShardResponse srsp : sreq.responses) { SolrDocumentList docs = null; NamedList responseHeader = null; @@ -1115,6 +1118,23 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) { rb, srsp, "responseHeader", false)); } + final Object shardReRankCutoffValue = + responseHeader.get(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY); + if (shardReRankCutoffValue != null) { + if (reRankCutoffByShard == null) { + reRankCutoffByShard = new SimpleOrderedMap<>(); + } + String shard = srsp.getShard(); + if (StrUtils.isNullOrEmpty(shard)) { + shard = srsp.getShardAddress(); + } + if (StrUtils.isNullOrEmpty(shard)) { + failedShardCountForReRankCutoff += 1; + shard = "unknown_shard_" + failedShardCountForReRankCutoff; + } + reRankCutoffByShard.add(shard, shardReRankCutoffValue); + } + final boolean thisResponseIsPartial; thisResponseIsPartial = Boolean.TRUE.equals( @@ -1256,6 +1276,12 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) { SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY, approximateTotalHits); } } + + if (reRankCutoffByShard != null) { + rb.rsp + .getResponseHeader() + .add(AbstractReRankQuery.RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY, reRankCutoffByShard); + } } protected void setResultIdsAndResponseDocs( diff --git a/solr/core/src/java/org/apache/solr/handler/component/ResponseBuilder.java b/solr/core/src/java/org/apache/solr/handler/component/ResponseBuilder.java index 413a962683b8..aa85f9a69b44 100644 --- a/solr/core/src/java/org/apache/solr/handler/component/ResponseBuilder.java +++ b/solr/core/src/java/org/apache/solr/handler/component/ResponseBuilder.java @@ -462,6 +462,7 @@ public QueryCommand createQueryCommand() { cmd.setQuery(wrap(getQuery())) .setFilterList(getFilters()) .setSort(getSortSpec().getSort()) + .setSortSchemaFields(getSortSpec().getSchemaFields()) .setOffset(getSortSpec().getOffset()) .setLen(getSortSpec().getCount()) .setFlags(getFieldFlags()) diff --git a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java index 5c026e557620..2b8f7d879093 100644 --- a/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java +++ b/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java @@ -34,6 +34,11 @@ import org.apache.solr.request.SolrRequestInfo; public abstract class AbstractReRankQuery extends RankQuery { + public static final String RERANK_CUTOFF_RESPONSE_HEADER_KEY = "reRankCutoff"; + public static final String RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY = "reRankCutoffByShard"; + public static final String RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY = + "solr.rerank.echoReRankCutoff"; + protected Query mainQuery; protected final int reRankDocs; protected final Rescorer reRankQueryRescorer; diff --git a/solr/core/src/java/org/apache/solr/search/QueryCommand.java b/solr/core/src/java/org/apache/solr/search/QueryCommand.java index 81857a7b568f..8442217464b7 100755 --- a/solr/core/src/java/org/apache/solr/search/QueryCommand.java +++ b/solr/core/src/java/org/apache/solr/search/QueryCommand.java @@ -21,6 +21,7 @@ import java.util.List; import org.apache.lucene.search.Query; import org.apache.lucene.search.Sort; +import org.apache.solr.schema.SchemaField; /** * A query request command to avoid having to change the method signatures if we want to pass @@ -33,6 +34,7 @@ public class QueryCommand { private boolean isQueryCancellable; private List filterList; private Sort sort; + private List sortSchemaFields; private int offset; private int len; private int supersetMaxDoc; @@ -108,6 +110,15 @@ public QueryCommand setSort(Sort sort) { return this; } + public List getSortSchemaFields() { + return sortSchemaFields; + } + + public QueryCommand setSortSchemaFields(List sortSchemaFields) { + this.sortSchemaFields = sortSchemaFields; + return this; + } + public int getOffset() { return offset; } diff --git a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java index ed97f783ae8c..44d89d8ebfd8 100644 --- a/solr/core/src/java/org/apache/solr/search/ReRankCollector.java +++ b/solr/core/src/java/org/apache/solr/search/ReRankCollector.java @@ -20,11 +20,14 @@ import com.carrotsearch.hppc.IntFloatMap; import com.carrotsearch.hppc.IntIntHashMap; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import java.util.Map; import java.util.Set; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; @@ -39,8 +42,10 @@ import org.apache.lucene.search.TopScoreDocCollectorManager; import org.apache.lucene.util.BytesRef; import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; import org.apache.solr.handler.component.QueryElevationComponent; import org.apache.solr.request.SolrRequestInfo; +import org.apache.solr.schema.SchemaField; /* A TopDocsCollector used by reranking queries. */ public class ReRankCollector extends TopDocsCollector { @@ -52,6 +57,7 @@ public class ReRankCollector extends TopDocsCollector { private final Set boostedPriority; // order is the "priority" private final Rescorer reRankQueryRescorer; private final Sort sort; + private final List sortSchemaFields; private final Query query; private ReRankScaler reRankScaler; private ReRankOperator reRankOperator; @@ -85,6 +91,7 @@ public ReRankCollector( this.boostedPriority = boostedPriority; this.query = cmd.getQuery(); Sort sort = cmd.getSort(); + this.sortSchemaFields = cmd.getSortSchemaFields(); int maxDoc = searcher.getIndexReader().maxDoc(); int numHits = Math.min(Math.max(this.reRankDocs, length), maxDoc); if (sort == null) { @@ -132,6 +139,7 @@ public TopDocs topDocs(int start, int howMany) { } ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; + final Object reRankCutoffValue = getReRankCutoffValue(mainScoreDocs); boolean zeroOutScores = reRankScaler != null && reRankScaler.scaleScores(); IntFloatMap docToOriginalScore = new IntFloatHashMap(); ScoreDoc[] mainScoreDocsClone = deepClone(mainScoreDocs, docToOriginalScore, zeroOutScores); @@ -174,6 +182,8 @@ public TopDocs topDocs(int start, int howMany) { rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, maxScore)); } + updateResponseHeader(reRankCutoffValue); + if (howMany == rescoredDocs.scoreDocs.length) { if (reRankScaler != null && reRankScaler.scaleScores()) { rescoredDocs.scoreDocs = @@ -224,6 +234,74 @@ private TopDocs toRescoredDocs(TopDocs topDocs, IntFloatMap originalScores) { return new TopDocs(topDocs.totalHits, rescoredDocs); } + private Object getReRankCutoffValue(ScoreDoc[] mainScoreDocs) { + final ScoreDoc cutoffDoc = mainScoreDocs[Math.min(mainScoreDocs.length, reRankDocs) - 1]; + if (sort == null) { + return cutoffDoc.score; + } + + if (!(cutoffDoc instanceof FieldDoc fieldDoc) + || fieldDoc.fields == null + || fieldDoc.fields.length == 0) { + return cutoffDoc.score; + } + + if (fieldDoc.fields.length == 1) { + return marshalSortValue(fieldDoc.fields[0], 0); + } + + final List cutoffSortValues = new ArrayList<>(fieldDoc.fields.length); + for (int i = 0; i < fieldDoc.fields.length; i++) { + cutoffSortValues.add(marshalSortValue(fieldDoc.fields[i], i)); + } + return cutoffSortValues; + } + + private Object marshalSortValue(Object sortValue, int index) { + if (sortValue == null) { + return null; + } + + if (sortSchemaFields != null && index < sortSchemaFields.size()) { + final SchemaField schemaField = sortSchemaFields.get(index); + if (schemaField != null) { + return schemaField.getType().marshalSortValue(sortValue); + } + } + + if (sortValue instanceof BytesRef bytesRef) { + return bytesRef.utf8ToString(); + } + + return sortValue; + } + + private void updateResponseHeader(Object reRankCutoffValue) { + SolrRequestInfo info = SolrRequestInfo.getRequestInfo(); + if (info == null || info.getReq() == null || info.getRsp() == null) { + return; + } + + Map requestContext = info.getReq().getContext(); + if (!Boolean.TRUE.equals( + requestContext.get(AbstractReRankQuery.RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY))) { + return; + } + + final NamedList responseHeader = info.getRsp().getResponseHeader(); + if (responseHeader == null) { + return; + } + + final int existingIndex = + responseHeader.indexOf(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY, 0); + if (existingIndex >= 0) { + responseHeader.setVal(existingIndex, reRankCutoffValue); + } else { + responseHeader.add(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY, reRankCutoffValue); + } + } + private ScoreDoc[] deepClone( ScoreDoc[] scoreDocs, IntFloatMap originalScoreMap, boolean zeroOut) { ScoreDoc[] scoreDocs1 = new ScoreDoc[scoreDocs.length]; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java index e3ba6cffe1d9..ee9c54d2d7f9 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java @@ -41,6 +41,7 @@ import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.rest.ManagedResource; import org.apache.solr.rest.ManagedResourceObserver; +import org.apache.solr.search.AbstractReRankQuery; import org.apache.solr.search.QParser; import org.apache.solr.search.QParserPlugin; import org.apache.solr.search.SyntaxError; @@ -75,6 +76,9 @@ public class LTRQParserPlugin extends QParserPlugin /** query parser plugin:the param that will select how the number of document to rerank */ public static final String RERANK_DOCS = "reRankDocs"; + /** query parser plugin: include the candidate cutoff in the response header */ + public static final String ECHO_RERANK_CUTOFF = "echoReRankCutoff"; + /** query parser plugin: default interleaving algorithm */ public static final String DEFAULT_INTERLEAVING_ALGORITHM = Interleaving.TEAM_DRAFT; @@ -229,6 +233,12 @@ public Query parse() throws SyntaxError { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, "Must rerank at least 1 document"); } + + req.getContext() + .put( + AbstractReRankQuery.RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY, + localParams.getBool(ECHO_RERANK_CUTOFF, false)); + if (!isInterleaving) { SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] {rerankingQuery}); return new LTRQuery(rerankingQuery, reRankDocs); diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/AbstractLTRSolrCloudTestBase.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/AbstractLTRSolrCloudTestBase.java new file mode 100644 index 000000000000..1cbba4225268 --- /dev/null +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/AbstractLTRSolrCloudTestBase.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import static java.util.stream.Collectors.toList; + +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.stream.IntStream; +import org.apache.commons.io.file.PathUtils; +import org.apache.solr.client.solrj.request.CollectionAdminRequest; +import org.apache.solr.client.solrj.response.CollectionAdminResponse; +import org.apache.solr.cloud.MiniSolrCloudCluster; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.embedded.JettyConfig; +import org.apache.solr.embedded.JettySolrRunner; +import org.apache.solr.ltr.feature.FieldValueFeature; +import org.apache.solr.ltr.feature.OriginalScoreFeature; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; + +public abstract class AbstractLTRSolrCloudTestBase extends TestRerankBase { + + protected MiniSolrCloudCluster solrCluster; + + protected abstract int numberOfShards(); + + protected abstract int numberOfReplicas(); + + @Override + public void setUp() throws Exception { + super.setUp(); + setupTestInit("solrconfig-ltr.xml", "schema.xml", true); + System.setProperty("solr.index.updatelog.enabled", "true"); + + final int numShards = numberOfShards(); + final int numReplicas = numberOfReplicas(); + setupSolrCluster(numShards, numReplicas, numShards * numReplicas); + } + + @Override + public void tearDown() throws Exception { + restTestHarness = null; + if (solrCluster != null) { + solrCluster.shutdown(); + solrCluster = null; + } + super.tearDown(); + } + + protected void setupSolrCluster(int numShards, int numReplicas, int numServers) throws Exception { + solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome, JettyConfig.builder().build()); + Path configDir = tmpSolrHome.resolve("collection1/conf"); + solrCluster.uploadConfigSet(configDir, "conf1"); + + createCollection(COLLECTION, "conf1", numShards, numReplicas); + indexDocuments(COLLECTION); + for (JettySolrRunner solrRunner : solrCluster.getJettySolrRunners()) { + if (!solrRunner.getCoreContainer().getCores().isEmpty()) { + String coreName = solrRunner.getCoreContainer().getCores().iterator().next().getName(); + restTestHarness = solrRunner.getRestClient(coreName); + break; + } + } + loadModelsAndFeatures(); + } + + private void createCollection(String name, String config, int numShards, int numReplicas) + throws Exception { + CollectionAdminResponse response; + CollectionAdminRequest.Create create = + CollectionAdminRequest.createCollection(name, config, numShards, numReplicas); + response = create.process(solrCluster.getSolrClient()); + + if (response.getStatus() != 0 || response.getErrorMessages() != null) { + fail("Could not create collection. Response" + response); + } + solrCluster.waitForActiveCollection(name, numShards, numShards * numReplicas); + } + + private void indexDocument( + String collection, String id, String title, String description, int popularity) + throws Exception { + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", id); + doc.setField("title", title); + doc.setField("description", description); + doc.setField("popularity", popularity); + if (popularity != 1) { + // check that empty values will be read as default + doc.setField("dvIntField", popularity); + doc.setField("dvLongField", popularity); + doc.setField("dvFloatField", ((float) popularity) / 10); + doc.setField("dvDoubleField", ((double) popularity) / 10); + doc.setField("dvStrNumField", popularity % 2 == 0 ? "F" : "T"); + doc.setField("dvStrBoolField", popularity % 2 == 0 ? "T" : "F"); + } + solrCluster.getSolrClient().add(collection, doc); + } + + private void indexDocuments(final String collection) throws Exception { + final int collectionSize = 8; + // put documents in random order to check that advanceExact is working correctly + List docIds = IntStream.rangeClosed(1, collectionSize).boxed().collect(toList()); + Collections.shuffle(docIds, random()); + + int docCounter = 1; + for (int docId : docIds) { + final int popularity = docId; + indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity); + // maybe commit in the middle in order to check that everything works fine for multi-segment + // case + if (docCounter == collectionSize / 2 && random().nextBoolean()) { + solrCluster.getSolrClient().commit(collection); + } + docCounter++; + } + solrCluster.getSolrClient().commit(collection, true, true); + } + + private void loadModelsAndFeatures() throws Exception { + final String featureStore = "test"; + final String[] featureNames = + new String[] { + "powpularityS", + "c3", + "original", + "dvIntFieldFeature", + "dvLongFieldFeature", + "dvFloatFieldFeature", + "dvDoubleFieldFeature", + "dvStrNumFieldFeature", + "dvStrBoolFieldFeature" + }; + final String jsonModelParams = + "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0,\"original\":0.1," + + "\"dvIntFieldFeature\":0.1,\"dvLongFieldFeature\":0.1," + + "\"dvFloatFieldFeature\":0.1,\"dvDoubleFieldFeature\":0.1,\"dvStrNumFieldFeature\":0.1,\"dvStrBoolFieldFeature\":0.1}}"; + + loadFeature( + featureNames[0], + SolrFeature.class.getName(), + featureStore, + "{\"q\":\"{!func}pow(popularity,2)\"}"); + loadFeature(featureNames[1], ValueFeature.class.getName(), featureStore, "{\"value\":2}"); + loadFeature(featureNames[2], OriginalScoreFeature.class.getName(), featureStore, null); + loadFeature( + featureNames[3], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvIntField\"}"); + loadFeature( + featureNames[4], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvLongField\"}"); + loadFeature( + featureNames[5], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvFloatField\"}"); + loadFeature( + featureNames[6], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvDoubleField\",\"defaultValue\":-4.0}"); + loadFeature( + featureNames[7], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvStrNumField\",\"defaultValue\":-5}"); + loadFeature( + featureNames[8], + FieldValueFeature.class.getName(), + featureStore, + "{\"field\":\"dvStrBoolField\"}"); + + loadModel( + "powpularityS-model", + LinearModel.class.getName(), + featureNames, + featureStore, + jsonModelParams); + reloadCollection(COLLECTION); + } + + private void reloadCollection(String collection) throws Exception { + CollectionAdminRequest.Reload reloadRequest = + CollectionAdminRequest.reloadCollection(collection); + CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient()); + assertEquals(0, response.getStatus()); + assertTrue(response.isSuccess()); + } + + @AfterClass + public static void after() throws Exception { + if (null != tmpSolrHome) { + PathUtils.deleteDirectory(tmpSolrHome); + tmpSolrHome = null; + } + } +} diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java index 31a70b9652fc..df7dba8a8306 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java @@ -15,56 +15,22 @@ */ package org.apache.solr.ltr; -import static java.util.stream.Collectors.toList; - import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; -import java.nio.file.Path; -import java.util.Collections; -import java.util.List; -import java.util.stream.IntStream; -import org.apache.commons.io.file.PathUtils; -import org.apache.solr.client.solrj.request.CollectionAdminRequest; import org.apache.solr.client.solrj.request.SolrQuery; -import org.apache.solr.client.solrj.response.CollectionAdminResponse; import org.apache.solr.client.solrj.response.QueryResponse; -import org.apache.solr.cloud.MiniSolrCloudCluster; -import org.apache.solr.common.SolrInputDocument; -import org.apache.solr.embedded.JettyConfig; -import org.apache.solr.embedded.JettySolrRunner; -import org.apache.solr.ltr.feature.FieldValueFeature; -import org.apache.solr.ltr.feature.OriginalScoreFeature; -import org.apache.solr.ltr.feature.SolrFeature; -import org.apache.solr.ltr.feature.ValueFeature; -import org.apache.solr.ltr.model.LinearModel; -import org.junit.AfterClass; import org.junit.Test; @ThreadLeakLingering(linger = 10) -public class TestLTROnSolrCloud extends TestRerankBase { - - private MiniSolrCloudCluster solrCluster; - String solrconfig = "solrconfig-ltr.xml"; - String schema = "schema.xml"; +public class TestLTROnSolrCloud extends AbstractLTRSolrCloudTestBase { @Override - public void setUp() throws Exception { - super.setUp(); - setupTestInit(solrconfig, schema, true); - System.setProperty("solr.index.updatelog.enabled", "true"); - - int numberOfShards = random().nextInt(4) + 1; - int numberOfReplicas = random().nextInt(2) + 1; - - int numberOfNodes = numberOfShards * numberOfReplicas; - - setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes); + protected int numberOfShards() { + return random().nextInt(4) + 1; } @Override - public void tearDown() throws Exception { - restTestHarness = null; - solrCluster.shutdown(); - super.tearDown(); + protected int numberOfReplicas() { + return random().nextInt(2) + 1; } @Test @@ -319,155 +285,4 @@ public void testSimpleQuery() throws Exception { assertEquals("1", queryResponse.getResults().get(7).get("id").toString()); assertEquals(result7_features, queryResponse.getResults().get(7).get("features").toString()); } - - private void setupSolrCluster(int numShards, int numReplicas, int numServers) throws Exception { - solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome, JettyConfig.builder().build()); - Path configDir = tmpSolrHome.resolve("collection1/conf"); - solrCluster.uploadConfigSet(configDir, "conf1"); - - createCollection(COLLECTION, "conf1", numShards, numReplicas); - indexDocuments(COLLECTION); - for (JettySolrRunner solrRunner : solrCluster.getJettySolrRunners()) { - if (!solrRunner.getCoreContainer().getCores().isEmpty()) { - String coreName = solrRunner.getCoreContainer().getCores().iterator().next().getName(); - restTestHarness = solrRunner.getRestClient(coreName); - break; - } - } - loadModelsAndFeatures(); - } - - private void createCollection(String name, String config, int numShards, int numReplicas) - throws Exception { - CollectionAdminResponse response; - CollectionAdminRequest.Create create = - CollectionAdminRequest.createCollection(name, config, numShards, numReplicas); - response = create.process(solrCluster.getSolrClient()); - - if (response.getStatus() != 0 || response.getErrorMessages() != null) { - fail("Could not create collection. Response" + response); - } - solrCluster.waitForActiveCollection(name, numShards, numShards * numReplicas); - } - - void indexDocument(String collection, String id, String title, String description, int popularity) - throws Exception { - SolrInputDocument doc = new SolrInputDocument(); - doc.setField("id", id); - doc.setField("title", title); - doc.setField("description", description); - doc.setField("popularity", popularity); - if (popularity != 1) { - // check that empty values will be read as default - doc.setField("dvIntField", popularity); - doc.setField("dvLongField", popularity); - doc.setField("dvFloatField", ((float) popularity) / 10); - doc.setField("dvDoubleField", ((double) popularity) / 10); - doc.setField("dvStrNumField", popularity % 2 == 0 ? "F" : "T"); - doc.setField("dvStrBoolField", popularity % 2 == 0 ? "T" : "F"); - } - solrCluster.getSolrClient().add(collection, doc); - } - - private void indexDocuments(final String collection) throws Exception { - final int collectionSize = 8; - // put documents in random order to check that advanceExact is working correctly - List docIds = IntStream.rangeClosed(1, collectionSize).boxed().collect(toList()); - Collections.shuffle(docIds, random()); - - int docCounter = 1; - for (int docId : docIds) { - final int popularity = docId; - indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity); - // maybe commit in the middle in order to check that everything works fine for multi-segment - // case - if (docCounter == collectionSize / 2 && random().nextBoolean()) { - solrCluster.getSolrClient().commit(collection); - } - docCounter++; - } - solrCluster.getSolrClient().commit(collection, true, true); - } - - private void loadModelsAndFeatures() throws Exception { - final String featureStore = "test"; - final String[] featureNames = - new String[] { - "powpularityS", - "c3", - "original", - "dvIntFieldFeature", - "dvLongFieldFeature", - "dvFloatFieldFeature", - "dvDoubleFieldFeature", - "dvStrNumFieldFeature", - "dvStrBoolFieldFeature" - }; - final String jsonModelParams = - "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0,\"original\":0.1," - + "\"dvIntFieldFeature\":0.1,\"dvLongFieldFeature\":0.1," - + "\"dvFloatFieldFeature\":0.1,\"dvDoubleFieldFeature\":0.1,\"dvStrNumFieldFeature\":0.1,\"dvStrBoolFieldFeature\":0.1}}"; - - loadFeature( - featureNames[0], - SolrFeature.class.getName(), - featureStore, - "{\"q\":\"{!func}pow(popularity,2)\"}"); - loadFeature(featureNames[1], ValueFeature.class.getName(), featureStore, "{\"value\":2}"); - loadFeature(featureNames[2], OriginalScoreFeature.class.getName(), featureStore, null); - loadFeature( - featureNames[3], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvIntField\"}"); - loadFeature( - featureNames[4], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvLongField\"}"); - loadFeature( - featureNames[5], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvFloatField\"}"); - loadFeature( - featureNames[6], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvDoubleField\",\"defaultValue\":-4.0}"); - loadFeature( - featureNames[7], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvStrNumField\",\"defaultValue\":-5}"); - loadFeature( - featureNames[8], - FieldValueFeature.class.getName(), - featureStore, - "{\"field\":\"dvStrBoolField\"}"); - - loadModel( - "powpularityS-model", - LinearModel.class.getName(), - featureNames, - featureStore, - jsonModelParams); - reloadCollection(COLLECTION); - } - - private void reloadCollection(String collection) throws Exception { - CollectionAdminRequest.Reload reloadRequest = - CollectionAdminRequest.reloadCollection(collection); - CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient()); - assertEquals(0, response.getStatus()); - assertTrue(response.isSuccess()); - } - - @AfterClass - public static void after() throws Exception { - if (null != tmpSolrHome) { - PathUtils.deleteDirectory(tmpSolrHome); - tmpSolrHome = null; - } - } } diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java index 8f793845e3a4..8f5f140f4fdf 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java @@ -231,4 +231,84 @@ public void ltr_expensiveFeatureRescoringWithinTimeAllowed_shouldReturnRerankedR // original score for the 4th document due to reRankDocs=3 limit "/response/docs/[3]/score==1.0"); } + + @Test + public void ltrEchoReRankCutoff_shouldAddHeaderField() throws Exception { + final String solrQuery = "_query_:{!edismax qf='id' v='8^=10 9^=5 7^=3 6^=1'}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.setFields("id", "score"); + query.setRows(2); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3 echoReRankCutoff=true}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/reRankCutoff==3.0"); + } + + @Test + public void ltrEchoReRankCutoff_defaultShouldNotAddHeaderField() throws Exception { + final String solrQuery = "_query_:{!edismax qf='id' v='8^=10 9^=5 7^=3 6^=1'}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.setFields("id", "score"); + query.setRows(2); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3}"); + + assertJQ("/query" + query.toQueryString(), "!/responseHeader/reRankCutoff=="); + } + + @Test + public void ltrEchoReRankCutoff_falseShouldNotAddHeaderField() throws Exception { + final String solrQuery = "_query_:{!edismax qf='id' v='8^=10 9^=5 7^=3 6^=1'}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.setFields("id", "score"); + query.setRows(2); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3 echoReRankCutoff=false}"); + + assertJQ("/query" + query.toQueryString(), "!/responseHeader/reRankCutoff=="); + } + + @Test + public void ltrEchoReRankCutoff_withNoResults_shouldNotAddHeaderField() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("title:title_that_does_not_exist"); + query.add("fl", "*,[fv]"); + query.add("rows", "3"); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3 echoReRankCutoff=true}"); + + assertJQ( + "/query" + query.toQueryString(), + "/response/numFound/==0", + "!/responseHeader/reRankCutoff=="); + } + + @Test + public void ltrEchoReRankCutoff_withFieldSortShouldAddSortCutoffValue() throws Exception { + final String solrQuery = "_query_:{!edismax qf='id' v='8^=10 9^=5 7^=3 6^=1'}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.setFields("id", "score"); + query.setRows(2); + query.setSort("id", SolrQuery.ORDER.asc); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3 echoReRankCutoff=true}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/reRankCutoff=='8'"); + } + + @Test + public void ltrEchoReRankCutoff_withMixedScoreAndFieldSort_shouldAddBothValuesInOrder() + throws Exception { + final String solrQuery = "_query_:{!edismax qf='id' v='8^=10 9^=5 7^=3 6^=1'}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.setFields("id", "score"); + query.setRows(2); + query.add("sort", "score desc,id asc"); + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3 echoReRankCutoff=true}"); + + assertJQ( + "/query" + query.toQueryString(), + "/responseHeader/reRankCutoff/[0]==3.0", + "/responseHeader/reRankCutoff/[1]=='7'"); + } } diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRReRankCutoffOnSolrCloud.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRReRankCutoffOnSolrCloud.java new file mode 100644 index 000000000000..71367bef5422 --- /dev/null +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRReRankCutoffOnSolrCloud.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; +import java.util.List; +import org.apache.solr.client.solrj.request.QueryRequest; +import org.apache.solr.client.solrj.request.SolrQuery; +import org.apache.solr.client.solrj.response.QueryResponse; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.search.AbstractReRankQuery; +import org.junit.Test; + +@ThreadLeakLingering(linger = 10) +public class TestLTRReRankCutoffOnSolrCloud extends AbstractLTRSolrCloudTestBase { + + private static final int NUM_SHARDS = 3; + private static final int NUM_REPLICAS = 2; + + @Override + protected int numberOfShards() { + return NUM_SHARDS; + } + + @Override + protected int numberOfReplicas() { + return NUM_REPLICAS; + } + + @Test + public void distributedRerankCutoffScore_shouldBeReturnedForShardsWithHits() throws Exception { + final SolrQuery query = newBaseRerankQuery(); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=3 echoReRankCutoff=true}"); + + final QueryRequest queryRequest = new QueryRequest(query); + queryRequest.setPath("/query"); + final QueryResponse queryResponse = + queryRequest.process(solrCluster.getSolrClient(), COLLECTION); + + @SuppressWarnings("unchecked") + final NamedList perShardCutoff = + (NamedList) + queryResponse + .getResponseHeader() + .get(AbstractReRankQuery.RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY); + assertNotNull(perShardCutoff); + + for (int i = 0; i < perShardCutoff.size(); i++) { + assertNotNull(perShardCutoff.getName(i)); + assertNotNull(perShardCutoff.getVal(i)); + assertTrue(perShardCutoff.getVal(i) instanceof List); + assertEquals(2, ((List) perShardCutoff.getVal(i)).size()); + } + } + + @Test + public void distributedRerankCutoffScore_defaultShouldNotBeReturnedPerShard() throws Exception { + final SolrQuery query = newBaseRerankQuery(); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=3}"); + + final QueryRequest queryRequest = new QueryRequest(query); + queryRequest.setPath("/query"); + final QueryResponse queryResponse = + queryRequest.process(solrCluster.getSolrClient(), COLLECTION); + assertNull( + queryResponse + .getResponseHeader() + .get(AbstractReRankQuery.RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY)); + } + + @Test + public void distributedRerankCutoffScore_falseShouldNotBeReturnedPerShard() throws Exception { + final SolrQuery query = newBaseRerankQuery(); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=3 echoReRankCutoff=false}"); + + final QueryRequest queryRequest = new QueryRequest(query); + queryRequest.setPath("/query"); + final QueryResponse queryResponse = + queryRequest.process(solrCluster.getSolrClient(), COLLECTION); + assertNull( + queryResponse + .getResponseHeader() + .get(AbstractReRankQuery.RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY)); + } + + private SolrQuery newBaseRerankQuery() { + final SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))"); + query.setFields("id", "score"); + query.setRows(4); + query.setParam("sort", "score desc,id asc"); + return query; + } +} diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java index af8664aff24f..52647fa5959a 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java @@ -95,6 +95,51 @@ public void testRankingSolrSort() throws Exception { assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0"); } + @Test + public void ltrEchoReRankCutoff_withFunctionSortShouldAddFunctionValue() throws Exception { + loadFeature( + "powpularityS", SolrFeature.class.getName(), "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadModel( + "powpularityS-model", + LinearModel.class.getName(), + new String[] {"powpularityS"}, + "{\"weights\":{\"powpularityS\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("fl", "id,score"); + query.add("rows", "4"); + query.add("sort", "pow(popularity,2) desc"); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4 echoReRankCutoff=true}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/reRankCutoff==25.0"); + } + + @Test + public void ltrEchoReRankCutoff_withMultipleFunctionSortsShouldAddAllValues() throws Exception { + loadFeature( + "powpularityS", SolrFeature.class.getName(), "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadModel( + "powpularityS-model", + LinearModel.class.getName(), + new String[] {"powpularityS"}, + "{\"weights\":{\"powpularityS\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("fl", "id,score"); + query.add("rows", "4"); + query.add("sort", "pow(popularity,2) desc,sum(popularity,1) desc"); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4 echoReRankCutoff=true}"); + + assertJQ( + "/query" + query.toQueryString(), + "/responseHeader/reRankCutoff/[0]==25.0", + "/responseHeader/reRankCutoff/[1]==6.0"); + } + @Test public void interleavingTwoModelsWithSort_shouldInterleave() throws Exception { TeamDraftInterleaving.setRANDOM( diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc index cf13add48d88..b98af2b6fbca 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc @@ -518,6 +518,58 @@ The output will include feature values as a comma-separated list, resembling the }} ---- +==== Rerank Cutoff Score +The rerank cutoff score is the first-pass ranking value of the lowest ranked document eligible for reranking. It tells you the score or sort value that a document would have needed to beat to be included in the reranking stage. To obtain the rerank cutoff score, add the `echoReRankCutoff` parameter to the ltr query. + +[source,text] +---- +http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModel reRankDocs=100 echoReRankCutoff=true}&fl=id,score +---- + +When set to `true`, Solr adds a `reRankCutoff` to the responseHeader. For the default first-pass relevance ranking, the value is the relevance score of the last document eligible for reranking. When an explicit sort is used, `reRankCutoff` reflects the sort value of that last eligible document. +[source,json] +---- +"responseHeader": { + "status": 0, + "QTime": 1, + "reRankCutoff": 12.345 +} +---- + +When multiple sorts are used (e.g. `sort=price desc,discount asc`), the response header contains the cutoff value for each sort in order: + +[source,json] +---- +"responseHeader": { + "status": 0, + "QTime": 1, + "reRankCutoff": [ + 11.99, + 2.20 + ] +} +---- + +For distributed queries, when `echoReRankCutoff` is `true`, Solr returns shard-local cutoff values in `reRankCutoffByShard`: + +[source,json] +---- +"responseHeader": { + "status": 0, + "QTime": 3, + "reRankCutoffByShard": { + "shard1": [ + 3.0, + 2.0 + ], + "shard2": [ + 5.0, + 2.3 + ] + } +} +---- + === Running a Rerank Query and Query Limits Apache Solr allows to define Query Limits to interrupt particularly expensive queries (xref:query-guide:common-query-parameters.adoc#timeallowed-parameter[Time Allowed], xref:query-guide:common-query-parameters.adoc#cpuallowed-parameter[Cpu Allowed]).