From 28ac3f42805d540b87a0ff0e61434830809e0282 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 2 Feb 2023 15:07:40 +0000 Subject: wire global-phase through config-model exactly like second-phase --- .../main/java/com/yahoo/schema/RankProfile.java | 35 ++++++++++++++++++++++ .../com/yahoo/schema/derived/RawRankProfile.java | 19 +++++++++++- .../yahoo/schema/parser/ConvertParsedRanking.java | 5 ++++ .../processing/OnnxModelConfigGenerator.java | 3 ++ .../processing/RankingExpressionTypeResolver.java | 1 + .../model/container/search/ContainerSearch.java | 34 ++++++++++----------- .../derived/rankingexpression/rank-profiles.cfg | 17 +++++++++++ .../derived/rankingexpression/rankexpression.sd | 16 ++++++++++ 8 files changed, 112 insertions(+), 18 deletions(-) (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index b6f003b2ce1..c1c0ad4f044 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -57,6 +57,7 @@ public class RankProfile implements Cloneable { public final static String FIRST_PHASE = "firstphase"; public final static String SECOND_PHASE = "secondphase"; + public final static String GLOBAL_PHASE = "globalphase"; /** The schema-unique name of this rank profile */ private final String name; @@ -79,9 +80,15 @@ public class RankProfile implements Cloneable { /** The ranking expression to be used for second phase */ private RankingExpressionFunction secondPhaseRanking = null; + /** The ranking expression to be used for global-phase */ + private RankingExpressionFunction globalPhaseRanking = null; + /** Number of hits to be reranked in second phase, -1 means use default */ private int rerankCount = -1; + /** Number of hits to be reranked in global-phase, -1 means use default */ + private int globalPhaseRerankCount = -1; + /** Mysterious attribute */ private int keepRankCount = -1; @@ -509,6 +516,26 @@ public class RankProfile implements Cloneable { } } + public RankingExpression getGlobalPhaseRanking() { + RankingExpressionFunction function = getGlobalPhase(); + if (function == null) return null; + return function.function().getBody(); + } + + public RankingExpressionFunction getGlobalPhase() { + if (globalPhaseRanking != null) return globalPhaseRanking; + return uniquelyInherited(p -> p.getGlobalPhase(), "global-phase expression").orElse(null); + } + + public void setGlobalPhaseRanking(String expression) { + try { + globalPhaseRanking = new RankingExpressionFunction(parseRankingExpression(GLOBAL_PHASE, Collections.emptyList(), expression), false); + } + catch (ParseException e) { + throw new IllegalArgumentException("Illegal global-phase ranking function", e); + } + } + // TODO: Below we have duplicate methods for summary and match features: Encapsulate this in a single parametrized // class instead (and probably make rank features work the same). @@ -667,6 +694,13 @@ public class RankProfile implements Cloneable { return uniquelyInherited(p -> p.getRerankCount(), c -> c >= 0, "rerank-count").orElse(-1); } + public void setGlobalPhaseRerankCount(int count) { this.globalPhaseRerankCount = count; } + + public int getGlobalPhaseRerankCount() { + if (globalPhaseRerankCount >= 0) return globalPhaseRerankCount; + return uniquelyInherited(p -> p.getGlobalPhaseRerankCount(), c -> c >= 0, "global-phase rerank-count").orElse(-1); + } + public void setNumThreadsPerSearch(int numThreads) { this.numThreadsPerSearch = numThreads; } public int getNumThreadsPerSearch() { @@ -966,6 +1000,7 @@ public class RankProfile implements Cloneable { firstPhaseRanking = compile(this.getFirstPhase(), queryProfiles, featureTypes, importedModels, constants(), inlineFunctions, expressionTransforms); secondPhaseRanking = compile(this.getSecondPhase(), queryProfiles, featureTypes, importedModels, constants(), inlineFunctions, expressionTransforms); + globalPhaseRanking = compile(this.getGlobalPhase(), queryProfiles, featureTypes, importedModels, constants(), inlineFunctions, expressionTransforms); // Function compiling second pass: compile all functions and insert previously compiled inline functions // TODO: This merges all functions from inherited profiles too and erases inheritance information. Not good. diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 0231e0e3c3f..31a38752bec 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -164,6 +164,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private RankingExpression firstPhaseRanking; private RankingExpression secondPhaseRanking; + private RankingExpression globalPhaseRanking; + private final int globalPhaseRerankCount; /** * Creates a raw rank profile from the given rank profile @@ -177,10 +179,12 @@ public class RawRankProfile implements RankProfilesConfig.Producer { inputs = compiled.inputs(); firstPhaseRanking = compiled.getFirstPhaseRanking(); secondPhaseRanking = compiled.getSecondPhaseRanking(); + globalPhaseRanking = compiled.getGlobalPhaseRanking(); summaryFeatures = new LinkedHashSet<>(compiled.getSummaryFeatures()); matchFeatures = new LinkedHashSet<>(compiled.getMatchFeatures()); rankFeatures = compiled.getRankFeatures(); rerankCount = compiled.getRerankCount(); + globalPhaseRerankCount = compiled.getGlobalPhaseRerankCount(); matchPhaseSettings = compiled.getMatchPhaseSettings(); numThreadsPerSearch = compiled.getNumThreadsPerSearch(); minHitsPerThread = compiled.getMinHitsPerThread(); @@ -206,7 +210,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (secondPhaseRanking != null) { functionProperties.putAll(secondPhaseRanking.getRankProperties(functionSerializationContext)); } - + if (globalPhaseRanking != null) { + functionProperties.putAll(globalPhaseRanking.getRankProperties(functionSerializationContext)); + } derivePropertiesAndFeaturesFromFunctions(functions, functionProperties, functionSerializationContext); deriveOnnxModelFunctionsAndFeatures(compiled); @@ -360,12 +366,20 @@ public class RawRankProfile implements RankProfilesConfig.Producer { throw new IllegalArgumentException("Could not parse second phase expression", e); } } + else if (RankingExpression.propertyName(RankProfile.GLOBAL_PHASE).equals(property.getName())) { + try { + globalPhaseRanking = new RankingExpression(property.getValue()); + } catch (ParseException e) { + throw new IllegalArgumentException("Could not parse global-phase expression", e); + } + } else { properties.add(new Pair<>(property.getName(), property.getValue())); } } properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, RankProfile.FIRST_PHASE)); properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, RankProfile.SECOND_PHASE)); + properties.addAll(deriveRankingPhaseRankProperties(globalPhaseRanking, RankProfile.GLOBAL_PHASE)); for (FieldRankSettings settings : fieldRankSettings.values()) { properties.addAll(settings.deriveRankProperties()); } @@ -424,6 +438,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (keepRankCount > -1) { properties.add(new Pair<>("vespa.hitcollector.arraysize", keepRankCount + "")); } + if (globalPhaseRerankCount > -1) { + properties.add(new Pair<>("vespa.globalphase.rerankcount", globalPhaseRerankCount + "")); + } if (rankScoreDropLimit > -Double.MAX_VALUE) { properties.add(new Pair<>("vespa.hitcollector.rankscoredroplimit", rankScoreDropLimit + "")); } diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java b/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java index bd628779b24..bdecf6332a0 100644 --- a/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java +++ b/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java @@ -84,6 +84,11 @@ public class ConvertParsedRanking { parsed.getSecondPhaseExpression().ifPresent (value -> profile.setSecondPhaseRanking(value)); + parsed.getGlobalPhaseExpression().ifPresent + (value -> profile.setGlobalPhaseRanking(value)); + parsed.getGlobalPhaseRerankCount().ifPresent + (value -> profile.setGlobalPhaseRerankCount(value)); + for (var value : parsed.getMatchFeatures()) { profile.addMatchFeatures(value); } diff --git a/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java index ce56a4320d3..338977fa679 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/OnnxModelConfigGenerator.java @@ -50,6 +50,9 @@ public class OnnxModelConfigGenerator extends Processor { if (profile.getSecondPhaseRanking() != null) { process(profile.getSecondPhaseRanking().getRoot(), profile); } + if (profile.getGlobalPhaseRanking() != null) { + process(profile.getGlobalPhaseRanking().getRoot(), profile); + } for (Map.Entry function : profile.getFunctions().entrySet()) { process(function.getValue().function().getBody().getRoot(), profile); } diff --git a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java index 871b79a7737..88b304e31c4 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/RankingExpressionTypeResolver.java @@ -86,6 +86,7 @@ public class RankingExpressionTypeResolver extends Processor { profile.getSummaryFeatures().forEach(f -> resolveType(f, "summary feature " + f, context)); ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + ensureValidDouble(profile.getGlobalPhaseRanking(), "global-phase expression", context); if ( ( context.tensorsAreUsed() || profile.isStrict()) && ! context.queryFeaturesNotDeclared().isEmpty() && ! warnedAbout.containsAll(context.queryFeaturesNotDeclared())) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index f949f2d5cfc..c738905eecd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -30,12 +30,12 @@ import static com.yahoo.vespa.model.container.PlatformBundles.SEARCH_AND_DOCPROC */ public class ContainerSearch extends ContainerSubsystem implements - IndexInfoConfig.Producer, - IlscriptsConfig.Producer, - QrSearchersConfig.Producer, - QueryProfilesConfig.Producer, + IndexInfoConfig.Producer, + IlscriptsConfig.Producer, + QrSearchersConfig.Producer, + QueryProfilesConfig.Producer, SemanticRulesConfig.Producer, - PageTemplatesConfig.Producer, + PageTemplatesConfig.Producer, SchemaInfoConfig.Producer { public static final String QUERY_PROFILE_REGISTRY_CLASS = "com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry"; @@ -132,20 +132,20 @@ public class ContainerSearch extends ContainerSubsystem @Override public void getConfig(QrSearchersConfig.Builder builder) { for (int i = 0; i < searchClusters.size(); i++) { - SearchCluster sys = findClusterWithId(searchClusters, i); - QrSearchersConfig.Searchcluster.Builder scB = new QrSearchersConfig.Searchcluster.Builder(). - name(sys.getClusterName()); - for (SchemaInfo spec : sys.schemas().values()) { - scB.searchdef(spec.fullSchema().getName()); - } - scB.rankprofiles(new QrSearchersConfig.Searchcluster.Rankprofiles.Builder().configid(sys.getConfigId())); - scB.indexingmode(QrSearchersConfig.Searchcluster.Indexingmode.Enum.valueOf(sys.getIndexingModeName())); - if ( ! (sys instanceof IndexedSearchCluster)) { - scB.storagecluster(new QrSearchersConfig.Searchcluster.Storagecluster.Builder(). - routespec(((StreamingSearchCluster)sys).getStorageRouteSpec())); + SearchCluster sys = findClusterWithId(searchClusters, i); + QrSearchersConfig.Searchcluster.Builder scB = new QrSearchersConfig.Searchcluster.Builder(). + name(sys.getClusterName()); + for (SchemaInfo spec : sys.schemas().values()) { + scB.searchdef(spec.fullSchema().getName()); + } + scB.rankprofiles(new QrSearchersConfig.Searchcluster.Rankprofiles.Builder().configid(sys.getConfigId())); + scB.indexingmode(QrSearchersConfig.Searchcluster.Indexingmode.Enum.valueOf(sys.getIndexingModeName())); + if ( ! (sys instanceof IndexedSearchCluster)) { + scB.storagecluster(new QrSearchersConfig.Searchcluster.Storagecluster.Builder(). + routespec(((StreamingSearchCluster)sys).getStorageRouteSpec())); } builder.searchcluster(scB); - } + } } private static SearchCluster findClusterWithId(List clusters, int index) { diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg index 9291a690f7c..6ff063c785c 100644 --- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg @@ -396,3 +396,20 @@ rankprofile[].fef.property[].name "rankingExpression(mybooleanexpression).rankin rankprofile[].fef.property[].value "5.0" rankprofile[].fef.property[].name "vespa.type.attribute.t1" rankprofile[].fef.property[].value "tensor(m{},v[3])" +rankprofile[].name "withglobalphase" +rankprofile[].fef.property[].name "rankingExpression(myplus).rankingScript" +rankprofile[].fef.property[].value "attribute(foo1) + attribute(foo2)" +rankprofile[].fef.property[].name "rankingExpression(mymul).rankingScript" +rankprofile[].fef.property[].value "attribute(t1) * query(fromq)" +rankprofile[].fef.property[].name "rankingExpression(mymul).type" +rankprofile[].fef.property[].value "tensor(m{},v[3])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "rankingExpression(myplus) + reduce(rankingExpression(mymul), sum) + firstPhase" +rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" +rankprofile[].fef.property[].value "42" +rankprofile[].fef.property[].name "vespa.type.attribute.t1" +rankprofile[].fef.property[].value "tensor(m{},v[3])" diff --git a/config-model/src/test/derived/rankingexpression/rankexpression.sd b/config-model/src/test/derived/rankingexpression/rankexpression.sd index 7d8c79da5fb..015767e3070 100644 --- a/config-model/src/test/derived/rankingexpression/rankexpression.sd +++ b/config-model/src/test/derived/rankingexpression/rankexpression.sd @@ -369,4 +369,20 @@ schema rankexpression { } } + rank-profile withglobalphase { + function myplus() { + expression: attribute(foo1)+attribute(foo2) + } + function mymul() { + expression: attribute(t1)*query(fromq) + } + first-phase { + expression: attribute(foo1) + } + global-phase { + expression: myplus()+sum(mymul())+firstPhase + rerank-count: 42 + } + } + } -- cgit v1.2.3