// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.searcher; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.Hit; import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.*; /** *

Groups hits according to sddocname.

* *

For each group, the desired number of hits can be specified.

* * @author tonytv */ public class MultipleResultsSearcher extends Searcher { private final static String propertyPrefix = "multipleresultsets."; private static final CompoundName additionalHitsFactorName=new CompoundName(propertyPrefix + "additionalHitsFactor"); private static final CompoundName maxTimesRetrieveHeterogeneousHitsName=new CompoundName(propertyPrefix + "maxTimesRetrieveHeterogeneousHits"); private static final CompoundName numHits=new CompoundName(propertyPrefix + "numHits"); @Override public Result search(Query query, Execution e) { try { Parameters parameters = new Parameters(query); query.trace("MultipleResultsSearcher: " + parameters, false, 2); HitsRetriever hitsRetriever = new HitsRetriever(query,e,parameters); for (DocumentGroup documentGroup : parameters.documentGroups) { if ( hitsRetriever.numHits(documentGroup) < documentGroup.targetNumberOfDocuments) { hitsRetriever.retrieveMoreHits(documentGroup); } } return hitsRetriever.createMultipleResultSets(); } catch(ParameterException exception) { Result result = new Result(query); result.hits().addError(ErrorMessage.createInvalidQueryParameter(exception.msg)); return result; } } private class HitsRetriever { PartitionedResult partitionedResult; private int numRetrieveMoreHitsCalls = 0; private int nextOffset; private Query query; private final Parameters parameters; private final int hits; private final int offset; private Execution execution; private Result initialResult; HitsRetriever(Query query, Execution execution, Parameters parameters) throws ParameterException { this.offset=query.getOffset(); this.hits=query.getHits(); this.nextOffset = query.getOffset() + query.getHits(); this.query = query; this.parameters = parameters; this.execution = execution; initialResult = retrieveHits(); partitionedResult = new PartitionedResult(parameters.documentGroups, initialResult); this.query = query; } void retrieveMoreHits(DocumentGroup documentGroup) { if ( ++numRetrieveMoreHitsCalls < parameters.maxTimesRetrieveHeterogeneousHits) { retrieveHeterogenousHits(); if (numHits(documentGroup) < documentGroup.targetNumberOfDocuments) { retrieveMoreHits(documentGroup); } } else { retrieveRemainingHitsForGroup(documentGroup); } } void retrieveHeterogenousHits() { int numHitsToRetrieve = (int)(hits * parameters.additionalHitsFactor); final int maxNumHitsToRetrieve = 1000; numHitsToRetrieve = Math.min(numHitsToRetrieve,maxNumHitsToRetrieve); try { query.setWindow(nextOffset,numHitsToRetrieve); partitionedResult.addHits(retrieveHits()); } finally { restoreWindow(); nextOffset += numHitsToRetrieve; } } private void restoreWindow() { query.setWindow(offset,hits); } void retrieveRemainingHitsForGroup(DocumentGroup documentGroup) { Set oldRestrictList = query.getModel().getRestrict(); try { int numMissingHits = documentGroup.targetNumberOfDocuments - numHits(documentGroup); int offset = numHits(documentGroup); query.getModel().getRestrict().clear(); query.getModel().getRestrict().add(documentGroup.documentName); query.setWindow(offset, numMissingHits); partitionedResult.addHits(retrieveHits()); } finally { restoreWindow(); query.getModel().getRestrict().clear(); query.getModel().getRestrict().addAll(oldRestrictList); } } int numHits(DocumentGroup documentGroup) { return partitionedResult.numHits(documentGroup.documentName); } Result createMultipleResultSets() { Iterator i = initialResult.hits().iterator(); while (i.hasNext()) { i.next(); i.remove(); } for (DocumentGroup group: parameters.documentGroups) { partitionedResult.cropResultSet(group.documentName,group.targetNumberOfDocuments); } partitionedResult.insertInto(initialResult.hits()); return initialResult; } private Result retrieveHits() { Result result = execution.search(query); // ensure that field sddocname is available execution.fill(result); // TODO: Suffices to fill attributes if (result.hits().getErrorHit() != null) initialResult.hits().getErrorHit().addErrors( result.hits().getErrorHit()); return result; } } // Assumes that field sddocname is available private static class PartitionedResult { private Map resultSets = new HashMap<>(); private List otherHits = new ArrayList<>(); PartitionedResult(List documentGroups,Result result) throws ParameterException { for (DocumentGroup group : documentGroups) addGroup(group); addHits(result, true); } void addHits(Result result, boolean addOtherHits) { Iterator i = result.hits().iterator(); while (i.hasNext()) { add(i.next(), addOtherHits); } } void addHits(Result result) { addHits(result, false); } void add(Hit hit, boolean addOtherHits) { String documentName = (String)hit.getField(Hit.SDDOCNAME_FIELD); if (documentName != null) { HitGroup resultSet = resultSets.get(documentName); if (resultSet != null) { resultSet.add(hit); return; } } if (addOtherHits) { otherHits.add(hit); } } int numHits(String documentName) { return resultSets.get(documentName).size(); } void insertInto(HitGroup group) { for (Hit hit: otherHits) { group.add(hit); } for (HitGroup hit: resultSets.values() ) { hit.copyOrdering(group); group.add(hit); } } void cropResultSet(String documentName, int numDocuments) { resultSets.get(documentName).trim(0, numDocuments); } private void addGroup(DocumentGroup group) throws ParameterException { final String documentName = group.documentName; if ( resultSets.put(group.documentName, new HitGroup(documentName) { /** * */ private static final long serialVersionUID = 5732822886080288688L; }) != null ) { throw new ParameterException("Document name " + group.documentName + "mentioned multiple times"); } } } //examples: //multipleresultsets.numhits=music:10,movies:20 //multipleresultsets.additionalhitsFactor=0.8 //multipleresultsets.maxtimesretrieveheterogeneoushits=2 private static class Parameters { Parameters(Query query) throws ParameterException { readNumHitsSpecification(query); readMaxTimesRetrieveHeterogeneousHits(query); readAdditionalHitsFactor(query); } List documentGroups = new ArrayList<>(); double additionalHitsFactor = 0.8; int maxTimesRetrieveHeterogeneousHits = 2; private void readAdditionalHitsFactor(Query query) throws ParameterException { String additionalHitsFactorStr = query.properties().getString(additionalHitsFactorName); if (additionalHitsFactorStr == null) return; try { additionalHitsFactor = Double.parseDouble(additionalHitsFactorStr); } catch (NumberFormatException e) { throw new ParameterException( "Expected floating point number, got '" + additionalHitsFactorStr + "'."); } } private void readMaxTimesRetrieveHeterogeneousHits(Query query) { maxTimesRetrieveHeterogeneousHits = query.properties().getInteger( maxTimesRetrieveHeterogeneousHitsName, maxTimesRetrieveHeterogeneousHits); } private void readNumHitsSpecification(Query query) throws ParameterException { //example numHitsSpecification: "music:10,movies:20" String numHitsSpecification = query.properties().getString(numHits); if (numHitsSpecification == null) return; String[] numHitsForDocumentNames = numHitsSpecification.split(","); for (String s:numHitsForDocumentNames) { handleDocumentNameWithNumberOfHits(s); } } public String toString() { String s = "additionalHitsFactor=" + additionalHitsFactor + ", maxTimesRetrieveHeterogeneousHits=" + maxTimesRetrieveHeterogeneousHits + ", numHitsSpecification='"; for (DocumentGroup group : documentGroups) { s += group.documentName + ":" + group.targetNumberOfDocuments + ", "; } s += "'"; return s; } //example input: music:10 private void handleDocumentNameWithNumberOfHits(String s) throws ParameterException { String[] documentNameWithNumberOfHits = s.split(":"); if (documentNameWithNumberOfHits.length != 2) { String msg = "Expected a single ':' in '" + s + "'."; if (documentNameWithNumberOfHits.length > 2) msg += " Please check for missing commas."; throw new ParameterException(msg); } else { String documentName = documentNameWithNumberOfHits[0].trim(); try { int numHits = Integer.parseInt( documentNameWithNumberOfHits[1].trim()); numRequestedHits(documentName, numHits); } catch (NumberFormatException e) { throw new ParameterException( "Excpected an integer but got '" + documentNameWithNumberOfHits[1] + "'"); } } } private void numRequestedHits(String documentName, int numHits) { documentGroups.add(new DocumentGroup(documentName, numHits)); } } private static class DocumentGroup { String documentName; int targetNumberOfDocuments; DocumentGroup(String documentName, int targetNumberOfDocuments) { this.documentName = documentName; this.targetNumberOfDocuments = targetNumberOfDocuments; } } @SuppressWarnings("serial") private static class ParameterException extends Exception { String msg; ParameterException(String msg) { this.msg = msg; } } }