Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog/unreleased/SOLR-18245-rerank-cutoff-value.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
title: Optional reRankCutoff field returned in responseHeader. Used to return main sort value(s) from the lowest ranked document eligible for inclusion in rerank.
type: added
authors:
- name: Darren Shaw
nick: shawdm
links:
- name: SOLR-18245
url: https://issues.apache.org/jira/browse/SOLR-18245
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.schema.SortableTextField;
import org.apache.solr.search.AbstractReRankQuery;
import org.apache.solr.search.CursorMark;
import org.apache.solr.search.DocIterator;
import org.apache.solr.search.DocList;
Expand Down Expand Up @@ -1028,6 +1029,8 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) {
boolean maxHitsTerminatedEarly = false;
long approximateTotalHits = 0;
int failedShardCount = 0;
int failedShardCountForReRankCutoff = 0;
NamedList<Object> reRankCutoffByShard = null;
for (ShardResponse srsp : sreq.responses) {
SolrDocumentList docs = null;
NamedList<?> responseHeader = null;
Expand Down Expand Up @@ -1115,6 +1118,23 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) {
rb, srsp, "responseHeader", false));
}

final Object shardReRankCutoffValue =
responseHeader.get(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY);
if (shardReRankCutoffValue != null) {
if (reRankCutoffByShard == null) {
reRankCutoffByShard = new SimpleOrderedMap<>();
}
String shard = srsp.getShard();
if (StrUtils.isNullOrEmpty(shard)) {
shard = srsp.getShardAddress();
}
if (StrUtils.isNullOrEmpty(shard)) {
failedShardCountForReRankCutoff += 1;
shard = "unknown_shard_" + failedShardCountForReRankCutoff;
}
reRankCutoffByShard.add(shard, shardReRankCutoffValue);
}

final boolean thisResponseIsPartial;
thisResponseIsPartial =
Boolean.TRUE.equals(
Expand Down Expand Up @@ -1256,6 +1276,12 @@ protected void mergeIds(ResponseBuilder rb, ShardRequest sreq) {
SolrQueryResponse.RESPONSE_HEADER_APPROXIMATE_TOTAL_HITS_KEY, approximateTotalHits);
}
}

if (reRankCutoffByShard != null) {
rb.rsp
.getResponseHeader()
.add(AbstractReRankQuery.RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY, reRankCutoffByShard);
}
}

protected void setResultIdsAndResponseDocs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ public QueryCommand createQueryCommand() {
cmd.setQuery(wrap(getQuery()))
.setFilterList(getFilters())
.setSort(getSortSpec().getSort())
.setSortSchemaFields(getSortSpec().getSchemaFields())
.setOffset(getSortSpec().getOffset())
.setLen(getSortSpec().getCount())
.setFlags(getFieldFlags())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
import org.apache.solr.request.SolrRequestInfo;

public abstract class AbstractReRankQuery extends RankQuery {
public static final String RERANK_CUTOFF_RESPONSE_HEADER_KEY = "reRankCutoff";
public static final String RERANK_CUTOFF_BY_SHARD_RESPONSE_HEADER_KEY = "reRankCutoffByShard";
public static final String RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY =
"solr.rerank.echoReRankCutoff";

protected Query mainQuery;
protected final int reRankDocs;
protected final Rescorer reRankQueryRescorer;
Expand Down
11 changes: 11 additions & 0 deletions solr/core/src/java/org/apache/solr/search/QueryCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.solr.schema.SchemaField;

/**
* A query request command to avoid having to change the method signatures if we want to pass
Expand All @@ -33,6 +34,7 @@ public class QueryCommand {
private boolean isQueryCancellable;
private List<Query> filterList;
private Sort sort;
private List<SchemaField> sortSchemaFields;
private int offset;
private int len;
private int supersetMaxDoc;
Expand Down Expand Up @@ -108,6 +110,15 @@ public QueryCommand setSort(Sort sort) {
return this;
}

public List<SchemaField> getSortSchemaFields() {
return sortSchemaFields;
}

public QueryCommand setSortSchemaFields(List<SchemaField> sortSchemaFields) {
this.sortSchemaFields = sortSchemaFields;
return this;
}

public int getOffset() {
return offset;
}
Expand Down
78 changes: 78 additions & 0 deletions solr/core/src/java/org/apache/solr/search/ReRankCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import com.carrotsearch.hppc.IntFloatMap;
import com.carrotsearch.hppc.IntIntHashMap;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
Expand All @@ -39,8 +42,10 @@
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.handler.component.QueryElevationComponent;
import org.apache.solr.request.SolrRequestInfo;
import org.apache.solr.schema.SchemaField;

/* A TopDocsCollector used by reranking queries. */
public class ReRankCollector extends TopDocsCollector<ScoreDoc> {
Expand All @@ -52,6 +57,7 @@ public class ReRankCollector extends TopDocsCollector<ScoreDoc> {
private final Set<BytesRef> boostedPriority; // order is the "priority"
private final Rescorer reRankQueryRescorer;
private final Sort sort;
private final List<SchemaField> sortSchemaFields;
private final Query query;
private ReRankScaler reRankScaler;
private ReRankOperator reRankOperator;
Expand Down Expand Up @@ -85,6 +91,7 @@ public ReRankCollector(
this.boostedPriority = boostedPriority;
this.query = cmd.getQuery();
Sort sort = cmd.getSort();
this.sortSchemaFields = cmd.getSortSchemaFields();
int maxDoc = searcher.getIndexReader().maxDoc();
int numHits = Math.min(Math.max(this.reRankDocs, length), maxDoc);
if (sort == null) {
Expand Down Expand Up @@ -132,6 +139,7 @@ public TopDocs topDocs(int start, int howMany) {
}

ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
final Object reRankCutoffValue = getReRankCutoffValue(mainScoreDocs);
boolean zeroOutScores = reRankScaler != null && reRankScaler.scaleScores();
IntFloatMap docToOriginalScore = new IntFloatHashMap();
ScoreDoc[] mainScoreDocsClone = deepClone(mainScoreDocs, docToOriginalScore, zeroOutScores);
Expand Down Expand Up @@ -174,6 +182,8 @@ public TopDocs topDocs(int start, int howMany) {
rescoredDocs.scoreDocs, new BoostedComp(boostedDocs, mainDocs.scoreDocs, maxScore));
}

updateResponseHeader(reRankCutoffValue);

if (howMany == rescoredDocs.scoreDocs.length) {
if (reRankScaler != null && reRankScaler.scaleScores()) {
rescoredDocs.scoreDocs =
Expand Down Expand Up @@ -224,6 +234,74 @@ private TopDocs toRescoredDocs(TopDocs topDocs, IntFloatMap originalScores) {
return new TopDocs(topDocs.totalHits, rescoredDocs);
}

private Object getReRankCutoffValue(ScoreDoc[] mainScoreDocs) {
final ScoreDoc cutoffDoc = mainScoreDocs[Math.min(mainScoreDocs.length, reRankDocs) - 1];
if (sort == null) {
return cutoffDoc.score;
}

if (!(cutoffDoc instanceof FieldDoc fieldDoc)
|| fieldDoc.fields == null
|| fieldDoc.fields.length == 0) {
return cutoffDoc.score;
}

if (fieldDoc.fields.length == 1) {
return marshalSortValue(fieldDoc.fields[0], 0);
}

final List<Object> cutoffSortValues = new ArrayList<>(fieldDoc.fields.length);
for (int i = 0; i < fieldDoc.fields.length; i++) {
cutoffSortValues.add(marshalSortValue(fieldDoc.fields[i], i));
}
return cutoffSortValues;
}

private Object marshalSortValue(Object sortValue, int index) {
if (sortValue == null) {
return null;
}

if (sortSchemaFields != null && index < sortSchemaFields.size()) {
final SchemaField schemaField = sortSchemaFields.get(index);
if (schemaField != null) {
return schemaField.getType().marshalSortValue(sortValue);
}
}

if (sortValue instanceof BytesRef bytesRef) {
return bytesRef.utf8ToString();
}

return sortValue;
}

private void updateResponseHeader(Object reRankCutoffValue) {
SolrRequestInfo info = SolrRequestInfo.getRequestInfo();
if (info == null || info.getReq() == null || info.getRsp() == null) {
return;
}

Map<Object, Object> requestContext = info.getReq().getContext();
if (!Boolean.TRUE.equals(
requestContext.get(AbstractReRankQuery.RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY))) {
return;
}

final NamedList<Object> responseHeader = info.getRsp().getResponseHeader();
if (responseHeader == null) {
return;
}

final int existingIndex =
responseHeader.indexOf(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY, 0);
if (existingIndex >= 0) {
responseHeader.setVal(existingIndex, reRankCutoffValue);
} else {
responseHeader.add(AbstractReRankQuery.RERANK_CUTOFF_RESPONSE_HEADER_KEY, reRankCutoffValue);
}
}

private ScoreDoc[] deepClone(
ScoreDoc[] scoreDocs, IntFloatMap originalScoreMap, boolean zeroOut) {
ScoreDoc[] scoreDocs1 = new ScoreDoc[scoreDocs.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.search.AbstractReRankQuery;
import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin;
import org.apache.solr.search.SyntaxError;
Expand Down Expand Up @@ -75,6 +76,9 @@ public class LTRQParserPlugin extends QParserPlugin
/** query parser plugin:the param that will select how the number of document to rerank */
public static final String RERANK_DOCS = "reRankDocs";

/** query parser plugin: include the candidate cutoff in the response header */
public static final String ECHO_RERANK_CUTOFF = "echoReRankCutoff";

/** query parser plugin: default interleaving algorithm */
public static final String DEFAULT_INTERLEAVING_ALGORITHM = Interleaving.TEAM_DRAFT;

Expand Down Expand Up @@ -229,6 +233,12 @@ public Query parse() throws SyntaxError {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST, "Must rerank at least 1 document");
}

req.getContext()
.put(
AbstractReRankQuery.RERANK_CUTOFF_ECHO_REQUEST_CONTEXT_KEY,
localParams.getBool(ECHO_RERANK_CUTOFF, false));

if (!isInterleaving) {
SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] {rerankingQuery});
return new LTRQuery(rerankingQuery, reRankDocs);
Expand Down
Loading