diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java | 156 |
1 files changed, 69 insertions, 87 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 41ac1e17d93..6b589a22de5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -55,15 +55,18 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields, ModelContext.Properties deployProperties) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, + AttributeFields attributeFields, ModelContext.Properties deployProperties) { this.name = rankProfile.getName(); - compressedProperties = compress(new Deriver(rankProfile, queryProfiles, importedModels, attributeFields, deployProperties).derive()); + compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels), + attributeFields, deployProperties).derive()); } /** * Only for testing */ - public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields) { + public RawRankProfile(RankProfile rankProfile, QueryProfileRegistry queryProfiles, + ImportedMlModels importedModels, AttributeFields attributeFields) { this(rankProfile, queryProfiles, importedModels, attributeFields, new TestProperties()); } @@ -120,61 +123,74 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private static class Deriver { - /** - * The field rank settings of this profile - */ - private Map<String, FieldRankSettings> fieldRankSettings = new java.util.LinkedHashMap<>(); - - private final RankProfile rankProfile; - private RankingExpression firstPhaseRanking = null; - private RankingExpression secondPhaseRanking = null; - - private Set<ReferenceNode> summaryFeatures = new LinkedHashSet<>(); - - private Set<ReferenceNode> rankFeatures = new LinkedHashSet<>(); - - private List<RankProfile.RankProperty> rankProperties = new ArrayList<>(); + private final Map<String, FieldRankSettings> fieldRankSettings = new java.util.LinkedHashMap<>(); + private final Set<ReferenceNode> summaryFeatures; + private final Set<ReferenceNode> rankFeatures; + private final List<RankProfile.RankProperty> rankProperties; /** * Rank properties for weight settings to make these available to feature executors */ - private List<RankProfile.RankProperty> boostAndWeightRankProperties = new ArrayList<>(); - - private boolean ignoreDefaultRankFeatures = false; - - private RankProfile.MatchPhaseSettings matchPhaseSettings = null; - - private int rerankCount = -1; - private int keepRankCount = -1; - private int numThreadsPerSearch = -1; - private int minHitsPerThread = -1; - private int numSearchPartitions = -1; - private double termwiseLimit = 1.0; - private double rankScoreDropLimit = -Double.MAX_VALUE; + private final List<RankProfile.RankProperty> boostAndWeightRankProperties = new ArrayList<>(); + + private final boolean ignoreDefaultRankFeatures; + private final RankProfile.MatchPhaseSettings matchPhaseSettings; + private final int rerankCount; + private final int keepRankCount; + private final int numThreadsPerSearch; + private final int minHitsPerThread; + private final int numSearchPartitions; + private final double termwiseLimit; + private final double rankScoreDropLimit; /** * The rank type definitions used to derive settings for the native rank features */ private final NativeRankTypeDefinitionSet nativeRankTypeDefinitions = new NativeRankTypeDefinitionSet("default"); - private final Map<String, String> attributeTypes; private final Map<String, String> queryFeatureTypes; - private final boolean useExternalExpressionFiles; + private final Set<String> filterFields = new java.util.LinkedHashSet<>(); - private Set<String> filterFields = new java.util.LinkedHashSet<>(); + private RankingExpression firstPhaseRanking; + private RankingExpression secondPhaseRanking; /** * Creates a raw rank profile from the given rank profile */ - Deriver(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, - AttributeFields attributeFields, ModelContext.Properties deployProperties) + Deriver(RankProfile compiled, AttributeFields attributeFields, ModelContext.Properties deployProperties) { - this.rankProfile = rankProfile; - RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); - useExternalExpressionFiles = deployProperties.featureFlags().useExternalRankExpressions(); - deriveRankingFeatures(compiled, deployProperties); + firstPhaseRanking = compiled.getFirstPhaseRanking(); + secondPhaseRanking = compiled.getSecondPhaseRanking(); + summaryFeatures = new LinkedHashSet<>(compiled.getSummaryFeatures()); + rankFeatures = compiled.getRankFeatures(); + rerankCount = compiled.getRerankCount(); + matchPhaseSettings = compiled.getMatchPhaseSettings(); + numThreadsPerSearch = compiled.getNumThreadsPerSearch(); + minHitsPerThread = compiled.getMinHitsPerThread(); + numSearchPartitions = compiled.getNumSearchPartitions(); + termwiseLimit = compiled.getTermwiseLimit().orElse(deployProperties.featureFlags().defaultTermwiseLimit()); + keepRankCount = compiled.getKeepRankCount(); + rankScoreDropLimit = compiled.getRankScoreDropLimit(); + ignoreDefaultRankFeatures = compiled.getIgnoreDefaultRankFeatures(); + rankProperties = new ArrayList<>(compiled.getRankProperties()); + + Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions(); + List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); + Map<String, String> functionProperties = new LinkedHashMap<>(); + SerializationContext functionSerializationContext = new SerializationContext(functionExpressions); + + if (firstPhaseRanking != null) { + functionProperties.putAll(firstPhaseRanking.getRankProperties(functionSerializationContext)); + } + if (secondPhaseRanking != null) { + functionProperties.putAll(secondPhaseRanking.getRankProperties(functionSerializationContext)); + } + + derivePropertiesAndSummaryFeaturesFromFunctions(functions, functionProperties, functionSerializationContext); + deriveOnnxModelFunctionsAndSummaryFeatures(compiled); + deriveRankTypeSetting(compiled, attributeFields); deriveFilterFields(compiled); deriveWeightProperties(compiled); @@ -184,44 +200,16 @@ public class RawRankProfile implements RankProfilesConfig.Producer { filterFields.addAll(rp.allFilterFields()); } - private void deriveRankingFeatures(RankProfile rankProfile, ModelContext.Properties deployProperties) { - firstPhaseRanking = rankProfile.getFirstPhaseRanking(); - secondPhaseRanking = rankProfile.getSecondPhaseRanking(); - summaryFeatures = new LinkedHashSet<>(rankProfile.getSummaryFeatures()); - rankFeatures = rankProfile.getRankFeatures(); - rerankCount = rankProfile.getRerankCount(); - matchPhaseSettings = rankProfile.getMatchPhaseSettings(); - numThreadsPerSearch = rankProfile.getNumThreadsPerSearch(); - minHitsPerThread = rankProfile.getMinHitsPerThread(); - numSearchPartitions = rankProfile.getNumSearchPartitions(); - termwiseLimit = rankProfile.getTermwiseLimit().orElse(deployProperties.featureFlags().defaultTermwiseLimit()); - keepRankCount = rankProfile.getKeepRankCount(); - rankScoreDropLimit = rankProfile.getRankScoreDropLimit(); - ignoreDefaultRankFeatures = rankProfile.getIgnoreDefaultRankFeatures(); - rankProperties = new ArrayList<>(rankProfile.getRankProperties()); - derivePropertiesAndSummaryFeaturesFromFunctions(rankProfile.getFunctions()); - deriveOnnxModelFunctionsAndSummaryFeatures(rankProfile); - } - - private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions) { + private void derivePropertiesAndSummaryFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions, + Map<String, String> functionProperties, + SerializationContext functionContext) { if (functions.isEmpty()) return; - List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); - Map<String, String> functionProperties = new LinkedHashMap<>(); - - if (firstPhaseRanking != null) { - functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); - } - if (secondPhaseRanking != null) { - functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); - } - - SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); - replaceFunctionSummaryFeatures(context); + replaceFunctionSummaryFeatures(functionContext); // First phase, second phase and summary features should add all required functions to the context. // However, we need to add any functions not referenced in those anyway for model-evaluation. - deriveFunctionProperties(functions, functionExpressions, functionProperties); + deriveFunctionProperties(functions, functionProperties, functionContext); for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); @@ -229,15 +217,13 @@ public class RawRankProfile implements RankProfilesConfig.Producer { } private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, - List<ExpressionFunction> functionExpressions, - Map<String, String> functionProperties) { - SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); + Map<String, String> functionProperties, + SerializationContext context) { for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { - if (useExternalExpressionFiles && rankProfile.getExpressionFile(e.getKey()) != null) continue; String propertyName = RankingExpression.propertyName(e.getKey()); if (context.serializedFunctions().containsKey(propertyName)) continue; - String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) @@ -259,7 +245,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { ExpressionFunction function = context.getFunction(referenceNode.getName()); if (function != null) { String propertyName = RankingExpression.propertyName(referenceNode.getName()); - String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + String expressionString = function.getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode); @@ -355,8 +341,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { properties.add(new Pair<>(property.getName(), property.getValue())); } } - properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, rankProfile.getFirstPhaseFile(), RankProfile.FIRST_PHASE)); - properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, rankProfile.getSecondPhaseFile(), RankProfile.SECOND_PHASE)); + properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, RankProfile.FIRST_PHASE)); + properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, RankProfile.SECOND_PHASE)); for (FieldRankSettings settings : fieldRankSettings.values()) { properties.addAll(settings.deriveRankProperties()); } @@ -408,9 +394,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (ignoreDefaultRankFeatures) { properties.add(new Pair<>("vespa.dump.ignoredefaultfeatures", String.valueOf(true))); } - Iterator filterFieldsIterator = filterFields.iterator(); - while (filterFieldsIterator.hasNext()) { - String fieldName = (String) filterFieldsIterator.next(); + for (String fieldName : filterFields) { properties.add(new Pair<>("vespa.isfilterfield." + fieldName, String.valueOf(true))); } for (Map.Entry<String, String> attributeType : attributeTypes.entrySet()) { @@ -423,7 +407,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { return properties; } - private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String fileName, String phase) { + private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String phase) { List<Pair<String, String>> properties = new ArrayList<>(); if (expression == null) return properties; @@ -431,9 +415,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if ("".equals(name)) name = phase; - if (useExternalExpressionFiles && (fileName != null)) { - properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + rankProfile.getUniqueExpressionName(name) + ")")); - } else if (expression.getRoot() instanceof ReferenceNode) { + if (expression.getRoot() instanceof ReferenceNode) { properties.add(new Pair<>("vespa.rank." + phase, expression.getRoot().toString())); } else { properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + name + ")")); |