summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java156
1 files changed, 69 insertions, 87 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 41ac1e17d93..6b589a22de5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -55,15 +55,18 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields, ModelContext.Properties deployProperties) {
+ public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels,
+ AttributeFields attributeFields, ModelContext.Properties deployProperties) {
this.name = rankProfile.getName();
- compressedProperties = compress(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields, deployProperties).derive());
+ compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels),
+ attributeFields, deployProperties).derive());
}
/**
* Only for testing
*/
- public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) {
+ public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles,
+ ImportedMlModels importedModels, AttributeFields attributeFields) {
this(rankProfile, queryProfiles, importedModels, attributeFields, new TestProperties());
}
@@ -120,61 +123,74 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
private static class Deriver {
- /**
- * The field rank settings of this profile
- */
- private Map<String, FieldRankSettings> fieldRankSettings = new java.util.LinkedHashMap<>();
-
- private final RankProfile rankProfile;
- private RankingExpression firstPhaseRanking = null;
- private RankingExpression secondPhaseRanking = null;
-
- private Set<ReferenceNode> summaryFeatures = new LinkedHashSet<>();
-
- private Set<ReferenceNode> rankFeatures = new LinkedHashSet<>();
-
- private List<RankProfile.RankProperty> rankProperties = new ArrayList<>();
+ private final Map<String, FieldRankSettings> fieldRankSettings = new java.util.LinkedHashMap<>();
+ private final Set<ReferenceNode> summaryFeatures;
+ private final Set<ReferenceNode> rankFeatures;
+ private final List<RankProfile.RankProperty> rankProperties;
/**
* Rank properties for weight settings to make these available to feature executors
*/
- private List<RankProfile.RankProperty> boostAndWeightRankProperties = new ArrayList<>();
-
- private boolean ignoreDefaultRankFeatures = false;
-
- private RankProfile.MatchPhaseSettings matchPhaseSettings = null;
-
- private int rerankCount = -1;
- private int keepRankCount = -1;
- private int numThreadsPerSearch = -1;
- private int minHitsPerThread = -1;
- private int numSearchPartitions = -1;
- private double termwiseLimit = 1.0;
- private double rankScoreDropLimit = -Double.MAX_VALUE;
+ private final List<RankProfile.RankProperty> boostAndWeightRankProperties = new ArrayList<>();
+
+ private final boolean ignoreDefaultRankFeatures;
+ private final RankProfile.MatchPhaseSettings matchPhaseSettings;
+ private final int rerankCount;
+ private final int keepRankCount;
+ private final int numThreadsPerSearch;
+ private final int minHitsPerThread;
+ private final int numSearchPartitions;
+ private final double termwiseLimit;
+ private final double rankScoreDropLimit;
/**
* The rank type definitions used to derive settings for the native rank features
*/
private final NativeRankTypeDefinitionSet nativeRankTypeDefinitions = new NativeRankTypeDefinitionSet("default");
-
private final Map<String, String> attributeTypes;
private final Map<String, String> queryFeatureTypes;
- private final boolean useExternalExpressionFiles;
+ private final Set<String> filterFields = new java.util.LinkedHashSet<>();
- private Set<String> filterFields = new java.util.LinkedHashSet<>();
+ private RankingExpression firstPhaseRanking;
+ private RankingExpression secondPhaseRanking;
/**
* Creates a raw rank profile from the given rank profile
*/
- Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels,
- AttributeFields attributeFields, ModelContext.Properties deployProperties)
+ Deriver(RankProfile compiled, AttributeFields attributeFields, ModelContext.Properties deployProperties)
{
- this.rankProfile = rankProfile;
- RankProfile compiled = rankProfile.compile(queryProfiles, importedModels);
attributeTypes = compiled.getAttributeTypes();
queryFeatureTypes = compiled.getQueryFeatureTypes();
- useExternalExpressionFiles = deployProperties.featureFlags().useExternalRankExpressions();
- deriveRankingFeatures(compiled, deployProperties);
+ firstPhaseRanking = compiled.getFirstPhaseRanking();
+ secondPhaseRanking = compiled.getSecondPhaseRanking();
+ summaryFeatures = new LinkedHashSet<>(compiled.getSummaryFeatures());
+ rankFeatures = compiled.getRankFeatures();
+ rerankCount = compiled.getRerankCount();
+ matchPhaseSettings = compiled.getMatchPhaseSettings();
+ numThreadsPerSearch = compiled.getNumThreadsPerSearch();
+ minHitsPerThread = compiled.getMinHitsPerThread();
+ numSearchPartitions = compiled.getNumSearchPartitions();
+ termwiseLimit = compiled.getTermwiseLimit().orElse(deployProperties.featureFlags().defaultTermwiseLimit());
+ keepRankCount = compiled.getKeepRankCount();
+ rankScoreDropLimit = compiled.getRankScoreDropLimit();
+ ignoreDefaultRankFeatures = compiled.getIgnoreDefaultRankFeatures();
+ rankProperties = new ArrayList<>(compiled.getRankProperties());
+
+ Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions();
+ List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
+ Map<String, String> functionProperties = new LinkedHashMap<>();
+ SerializationContext functionSerializationContext = new SerializationContext(functionExpressions);
+
+ if (firstPhaseRanking != null) {
+ functionProperties.putAll(firstPhaseRanking.getRankProperties(functionSerializationContext));
+ }
+ if (secondPhaseRanking != null) {
+ functionProperties.putAll(secondPhaseRanking.getRankProperties(functionSerializationContext));
+ }
+
+ derivePropertiesAndSummaryFeaturesFromFunctions(functions, functionProperties, functionSerializationContext);
+ deriveOnnxModelFunctionsAndSummaryFeatures(compiled);
+
deriveRankTypeSetting(compiled, attributeFields);
deriveFilterFields(compiled);
deriveWeightProperties(compiled);
@@ -184,44 +200,16 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
filterFields.addAll(rp.allFilterFields());
}
- private void deriveRankingFeatures(RankProfile rankProfile, ModelContext.Properties deployProperties) {
- firstPhaseRanking = rankProfile.getFirstPhaseRanking();
- secondPhaseRanking = rankProfile.getSecondPhaseRanking();
- summaryFeatures = new LinkedHashSet<>(rankProfile.getSummaryFeatures());
- rankFeatures = rankProfile.getRankFeatures();
- rerankCount = rankProfile.getRerankCount();
- matchPhaseSettings = rankProfile.getMatchPhaseSettings();
- numThreadsPerSearch = rankProfile.getNumThreadsPerSearch();
- minHitsPerThread = rankProfile.getMinHitsPerThread();
- numSearchPartitions = rankProfile.getNumSearchPartitions();
- termwiseLimit = rankProfile.getTermwiseLimit().orElse(deployProperties.featureFlags().defaultTermwiseLimit());
- keepRankCount = rankProfile.getKeepRankCount();
- rankScoreDropLimit = rankProfile.getRankScoreDropLimit();
- ignoreDefaultRankFeatures = rankProfile.getIgnoreDefaultRankFeatures();
- rankProperties = new ArrayList<>(rankProfile.getRankProperties());
- derivePropertiesAndSummaryFeaturesFromFunctions(rankProfile.getFunctions());
- deriveOnnxModelFunctionsAndSummaryFeatures(rankProfile);
- }
-
- private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) {
+ private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions,
+ Map<String, String> functionProperties,
+ SerializationContext functionContext) {
if (functions.isEmpty()) return;
- List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
- Map<String, String> functionProperties = new LinkedHashMap<>();
-
- if (firstPhaseRanking != null) {
- functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions));
- }
- if (secondPhaseRanking != null) {
- functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions));
- }
-
- SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
- replaceFunctionSummaryFeatures(context);
+ replaceFunctionSummaryFeatures(functionContext);
// First phase, second phase and summary features should add all required functions to the context.
// However, we need to add any functions not referenced in those anyway for model-evaluation.
- deriveFunctionProperties(functions, functionExpressions, functionProperties);
+ deriveFunctionProperties(functions, functionProperties, functionContext);
for (Map.Entry<String, String> e : functionProperties.entrySet()) {
rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue()));
@@ -229,15 +217,13 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
}
private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
- List<ExpressionFunction> functionExpressions,
- Map<String, String> functionProperties) {
- SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
+ Map<String, String> functionProperties,
+ SerializationContext context) {
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
- if (useExternalExpressionFiles && rankProfile.getExpressionFile(e.getKey()) != null) continue;
String propertyName = RankingExpression.propertyName(e.getKey());
if (context.serializedFunctions().containsKey(propertyName)) continue;
- String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
@@ -259,7 +245,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
ExpressionFunction function = context.getFunction(referenceNode.getName());
if (function != null) {
String propertyName = RankingExpression.propertyName(referenceNode.getName());
- String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ String expressionString = function.getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);
ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput());
functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode);
@@ -355,8 +341,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
properties.add(new Pair<>(property.getName(), property.getValue()));
}
}
- properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, rankProfile.getFirstPhaseFile(), RankProfile.FIRST_PHASE));
- properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, rankProfile.getSecondPhaseFile(), RankProfile.SECOND_PHASE));
+ properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, RankProfile.FIRST_PHASE));
+ properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, RankProfile.SECOND_PHASE));
for (FieldRankSettings settings : fieldRankSettings.values()) {
properties.addAll(settings.deriveRankProperties());
}
@@ -408,9 +394,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if (ignoreDefaultRankFeatures) {
properties.add(new Pair<>("vespa.dump.ignoredefaultfeatures", String.valueOf(true)));
}
- Iterator filterFieldsIterator = filterFields.iterator();
- while (filterFieldsIterator.hasNext()) {
- String fieldName = (String) filterFieldsIterator.next();
+ for (String fieldName : filterFields) {
properties.add(new Pair<>("vespa.isfilterfield." + fieldName, String.valueOf(true)));
}
for (Map.Entry<String, String> attributeType : attributeTypes.entrySet()) {
@@ -423,7 +407,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
return properties;
}
- private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String fileName, String phase) {
+ private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String phase) {
List<Pair<String, String>> properties = new ArrayList<>();
if (expression == null) return properties;
@@ -431,9 +415,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if ("".equals(name))
name = phase;
- if (useExternalExpressionFiles && (fileName != null)) {
- properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + rankProfile.getUniqueExpressionName(name) + ")"));
- } else if (expression.getRoot() instanceof ReferenceNode) {
+ if (expression.getRoot() instanceof ReferenceNode) {
properties.add(new Pair<>("vespa.rank." + phase, expression.getRoot().toString()));
} else {
properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + name + ")"));