diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java | 524 |
1 files changed, 524 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java new file mode 100644 index 00000000000..a8a9b4c8755 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -0,0 +1,524 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.derived; + +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import com.google.common.collect.ImmutableList; +import com.yahoo.collections.Pair; +import com.yahoo.compress.Compressor; +import com.yahoo.config.model.api.ModelContext; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.schema.FeatureNames; +import com.yahoo.schema.OnnxModel; +import com.yahoo.schema.LargeRankExpressions; +import com.yahoo.schema.RankExpressionBody; +import com.yahoo.schema.document.RankType; +import com.yahoo.schema.RankProfile; +import com.yahoo.schema.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.OptionalDouble; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * A rank profile derived from a search definition, containing exactly the features available natively in the server + * + * @author bratseth + */ +public class RawRankProfile implements RankProfilesConfig.Producer { + + /** A reusable compressor with default settings */ + private static final Compressor compressor = new Compressor(); + + private static final String keyEndMarker = "\r="; + private static final String valueEndMarker = "\r\n"; + + private final String name; + private final Compressor.Compression compressedProperties; + + /** The compiled profile this is created from. */ + private final RankProfile compiled; + + /** Creates a raw rank profile from the given rank profile. */ + public RawRankProfile(RankProfile rankProfile, LargeRankExpressions largeExpressions, + QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, + AttributeFields attributeFields, ModelContext.Properties deployProperties) { + this.name = rankProfile.name(); + compiled = rankProfile.compile(queryProfiles, importedModels); + compressedProperties = compress(new Deriver(compiled, attributeFields, deployProperties, queryProfiles) + .derive(largeExpressions)); + } + + public RankProfile compiled() { return compiled; } + + private Compressor.Compression compress(List<Pair<String, String>> properties) { + StringBuilder b = new StringBuilder(); + for (Pair<String, String> property : properties) + b.append(property.getFirst()).append(keyEndMarker).append(property.getSecond()).append(valueEndMarker); + return compressor.compress(b.toString().getBytes(StandardCharsets.UTF_8)); + } + + private List<Pair<String, String>> decompress(Compressor.Compression compression) { + String propertiesString = new String(compressor.decompress(compression), StandardCharsets.UTF_8); + if (propertiesString.isEmpty()) return ImmutableList.of(); + + ImmutableList.Builder<Pair<String, String>> properties = new ImmutableList.Builder<>(); + for (int pos = 0; pos < propertiesString.length();) { + int keyEndPos = propertiesString.indexOf(keyEndMarker, pos); + String key = propertiesString.substring(pos, keyEndPos); + pos = keyEndPos + keyEndMarker.length(); + int valueEndPos = propertiesString.indexOf(valueEndMarker, pos); + String value = propertiesString.substring(pos, valueEndPos); + pos = valueEndPos + valueEndMarker.length(); + properties.add(new Pair<>(key, value)); + } + return properties.build(); + } + + public String getName() { return name; } + + private void getRankProperties(RankProfilesConfig.Rankprofile.Builder b) { + RankProfilesConfig.Rankprofile.Fef.Builder fefB = new RankProfilesConfig.Rankprofile.Fef.Builder(); + for (Pair<String, String> p : decompress(compressedProperties)) + fefB.property(new RankProfilesConfig.Rankprofile.Fef.Property.Builder().name(p.getFirst()).value(p.getSecond())); + b.fef(fefB); + } + + /** + * Returns the properties of this as an unmodifiable list. + * Note: This method is expensive. + */ + public List<Pair<String, String>> configProperties() { return decompress(compressedProperties); } + + @Override + public void getConfig(RankProfilesConfig.Builder builder) { + RankProfilesConfig.Rankprofile.Builder b = new RankProfilesConfig.Rankprofile.Builder().name(getName()); + getRankProperties(b); + builder.rankprofile(b); + } + + @Override + public String toString() { + return " rank profile " + name; + } + + private static class Deriver { + + private final Map<String, FieldRankSettings> fieldRankSettings = new java.util.LinkedHashMap<>(); + private final Set<ReferenceNode> summaryFeatures; + private final Set<ReferenceNode> matchFeatures; + private final Set<ReferenceNode> rankFeatures; + private final Map<String, String> featureRenames = new java.util.LinkedHashMap<>(); + private final List<RankProfile.RankProperty> rankProperties; + + /** + * Rank properties for weight settings to make these available to feature executors + */ + private final List<RankProfile.RankProperty> boostAndWeightRankProperties = new ArrayList<>(); + + private final boolean ignoreDefaultRankFeatures; + private final RankProfile.MatchPhaseSettings matchPhaseSettings; + private final int rerankCount; + private final int keepRankCount; + private final int numThreadsPerSearch; + private final int minHitsPerThread; + private final int numSearchPartitions; + private final double termwiseLimit; + private final OptionalDouble postFilterThreshold; + private final OptionalDouble approximateThreshold; + private final double rankScoreDropLimit; + private final boolean mapBackRankingExpressionFeatures; + + /** + * The rank type definitions used to derive settings for the native rank features + */ + private final NativeRankTypeDefinitionSet nativeRankTypeDefinitions = new NativeRankTypeDefinitionSet("default"); + private final Map<String, String> attributeTypes; + private final Map<Reference, RankProfile.Input> inputs; + private final Set<String> filterFields = new java.util.LinkedHashSet<>(); + private final String rankprofileName; + + private RankingExpression firstPhaseRanking; + private RankingExpression secondPhaseRanking; + + /** + * Creates a raw rank profile from the given rank profile + */ + Deriver(RankProfile compiled, + AttributeFields attributeFields, + ModelContext.Properties deployProperties, + QueryProfileRegistry queryProfiles) { + rankprofileName = compiled.name(); + attributeTypes = compiled.getAttributeTypes(); + inputs = compiled.inputs(); + firstPhaseRanking = compiled.getFirstPhaseRanking(); + secondPhaseRanking = compiled.getSecondPhaseRanking(); + summaryFeatures = new LinkedHashSet<>(compiled.getSummaryFeatures()); + matchFeatures = new LinkedHashSet<>(compiled.getMatchFeatures()); + rankFeatures = compiled.getRankFeatures(); + rerankCount = compiled.getRerankCount(); + matchPhaseSettings = compiled.getMatchPhaseSettings(); + numThreadsPerSearch = compiled.getNumThreadsPerSearch(); + minHitsPerThread = compiled.getMinHitsPerThread(); + numSearchPartitions = compiled.getNumSearchPartitions(); + termwiseLimit = compiled.getTermwiseLimit().orElse(deployProperties.featureFlags().defaultTermwiseLimit()); + postFilterThreshold = compiled.getPostFilterThreshold(); + approximateThreshold = compiled.getApproximateThreshold(); + keepRankCount = compiled.getKeepRankCount(); + rankScoreDropLimit = compiled.getRankScoreDropLimit(); + mapBackRankingExpressionFeatures = deployProperties.featureFlags().avoidRenamingSummaryFeatures(); + ignoreDefaultRankFeatures = compiled.getIgnoreDefaultRankFeatures(); + rankProperties = new ArrayList<>(compiled.getRankProperties()); + + Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions(); + List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); + Map<String, String> functionProperties = new LinkedHashMap<>(); + SerializationContext functionSerializationContext = new SerializationContext(functionExpressions, + Map.of(), + compiled.typeContext(queryProfiles)); + + if (firstPhaseRanking != null) { + functionProperties.putAll(firstPhaseRanking.getRankProperties(functionSerializationContext)); + } + if (secondPhaseRanking != null) { + functionProperties.putAll(secondPhaseRanking.getRankProperties(functionSerializationContext)); + } + + derivePropertiesAndFeaturesFromFunctions(functions, functionProperties, functionSerializationContext); + deriveOnnxModelFunctionsAndFeatures(compiled); + + deriveRankTypeSetting(compiled, attributeFields); + deriveFilterFields(compiled); + deriveWeightProperties(compiled); + } + + private void deriveFilterFields(RankProfile rp) { + filterFields.addAll(rp.allFilterFields()); + } + + private void derivePropertiesAndFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions, + Map<String, String> functionProperties, + SerializationContext functionContext) { + if (functions.isEmpty()) return; + + replaceFunctionFeatures(summaryFeatures, functionContext); + replaceFunctionFeatures(matchFeatures, functionContext); + + // First phase, second phase and summary features should add all required functions to the context. + // However, we need to add any functions not referenced in those anyway for model-evaluation. + deriveFunctionProperties(functions, functionProperties, functionContext); + + for (Map.Entry<String, String> e : functionProperties.entrySet()) { + rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); + } + } + + private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + Map<String, String> functionProperties, + SerializationContext context) { + for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String propertyName = RankingExpression.propertyName(e.getKey()); + if (context.serializedFunctions().containsKey(propertyName)) continue; + + String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); + + context.addFunctionSerialization(propertyName, expressionString); + for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) + context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); + if (e.getValue().function().returnType().isPresent()) + context.addFunctionTypeSerialization(e.getKey(), e.getValue().function().returnType().get()); + // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types + // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); + } + functionProperties.putAll(context.serializedFunctions()); + } + + private void replaceFunctionFeatures(Set<ReferenceNode> features, SerializationContext context) { + if (features == null) return; + Map<String, ReferenceNode> functionFeatures = new LinkedHashMap<>(); + for (Iterator<ReferenceNode> i = features.iterator(); i.hasNext(); ) { + ReferenceNode referenceNode = i.next(); + // Is the feature a function? + ExpressionFunction function = context.getFunction(referenceNode.getName()); + if (function != null) { + String propertyName = RankingExpression.propertyName(referenceNode.getName()); + String expressionString = function.getBody().getRoot().toString(context).toString(); + context.addFunctionSerialization(propertyName, expressionString); + ReferenceNode backendReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", + referenceNode.getArguments().expressions(), + referenceNode.getOutput()); + if (mapBackRankingExpressionFeatures) { + // tell backend to map back to the name the user expects: + featureRenames.put(backendReferenceNode.toString(), referenceNode.toString()); + } + functionFeatures.put(referenceNode.getName(), backendReferenceNode); + i.remove(); // Will add the expanded one in next block + } + } + // Then, replace the features that were functions + for (Map.Entry<String, ReferenceNode> e : functionFeatures.entrySet()) { + features.add(e.getValue()); + } + } + + private void deriveWeightProperties(RankProfile rankProfile) { + + for (RankProfile.RankSetting setting : rankProfile.rankSettings()) { + if (setting.getType() != RankProfile.RankSetting.Type.WEIGHT) continue; + boostAndWeightRankProperties.add(new RankProfile.RankProperty("vespa.fieldweight." + setting.getFieldName(), + String.valueOf(setting.getIntValue()))); + } + } + + /** + * Adds the type boosts from a rank profile + */ + private void deriveRankTypeSetting(RankProfile rankProfile, AttributeFields attributeFields) { + for (Iterator<RankProfile.RankSetting> i = rankProfile.rankSettingIterator(); i.hasNext(); ) { + RankProfile.RankSetting setting = i.next(); + if (setting.getType() != RankProfile.RankSetting.Type.RANKTYPE) continue; + + deriveNativeRankTypeSetting(setting.getFieldName(), (RankType) setting.getValue(), attributeFields, + hasDefaultRankTypeSetting(rankProfile, setting.getFieldName())); + } + } + + private void deriveNativeRankTypeSetting(String fieldName, RankType rankType, AttributeFields attributeFields, + boolean isDefaultSetting) { + if (isDefaultSetting) return; + + NativeRankTypeDefinition definition = nativeRankTypeDefinitions.getRankTypeDefinition(rankType); + if (definition == null) throw new IllegalArgumentException("In field '" + fieldName + "': " + + rankType + " is known but has no implementation. " + + "Supported rank types: " + + nativeRankTypeDefinitions.types().keySet()); + + FieldRankSettings settings = deriveFieldRankSettings(fieldName); + for (Iterator<NativeTable> i = definition.rankSettingIterator(); i.hasNext(); ) { + NativeTable table = i.next(); + // only add index field tables if we are processing an index field and + // only add attribute field tables if we are processing an attribute field + if ((FieldRankSettings.isIndexFieldTable(table) && attributeFields.getAttribute(fieldName) == null) || + (FieldRankSettings.isAttributeFieldTable(table) && attributeFields.getAttribute(fieldName) != null)) { + settings.addTable(table); + } + } + } + + private boolean hasDefaultRankTypeSetting(RankProfile rankProfile, String fieldName) { + RankProfile.RankSetting setting = + rankProfile.getRankSetting(fieldName, RankProfile.RankSetting.Type.RANKTYPE); + return setting != null && setting.getValue().equals(RankType.DEFAULT); + } + + private FieldRankSettings deriveFieldRankSettings(String fieldName) { + FieldRankSettings settings = fieldRankSettings.get(fieldName); + if (settings == null) { + settings = new FieldRankSettings(fieldName); + fieldRankSettings.put(fieldName, settings); + } + return settings; + } + + /** Derives the properties this produces */ + public List<Pair<String, String>> derive(LargeRankExpressions largeRankExpressions) { + List<Pair<String, String>> properties = new ArrayList<>(); + for (RankProfile.RankProperty property : rankProperties) { + if (RankingExpression.propertyName(RankProfile.FIRST_PHASE).equals(property.getName())) { + // Could have been set by function expansion. Set expressions, then skip this property. + try { + firstPhaseRanking = new RankingExpression(property.getValue()); + } catch (ParseException e) { + throw new IllegalArgumentException("Could not parse first phase expression", e); + } + } + else if (RankingExpression.propertyName(RankProfile.SECOND_PHASE).equals(property.getName())) { + try { + secondPhaseRanking = new RankingExpression(property.getValue()); + } catch (ParseException e) { + throw new IllegalArgumentException("Could not parse second phase expression", e); + } + } + else { + properties.add(new Pair<>(property.getName(), property.getValue())); + } + } + properties.addAll(deriveRankingPhaseRankProperties(firstPhaseRanking, RankProfile.FIRST_PHASE)); + properties.addAll(deriveRankingPhaseRankProperties(secondPhaseRanking, RankProfile.SECOND_PHASE)); + for (FieldRankSettings settings : fieldRankSettings.values()) { + properties.addAll(settings.deriveRankProperties()); + } + for (RankProfile.RankProperty property : boostAndWeightRankProperties) { + properties.add(new Pair<>(property.getName(), property.getValue())); + } + for (ReferenceNode feature : summaryFeatures) { + properties.add(new Pair<>("vespa.summary.feature", feature.toString())); + } + for (ReferenceNode feature : matchFeatures) { + properties.add(new Pair<>("vespa.match.feature", feature.toString())); + } + for (ReferenceNode feature : rankFeatures) { + properties.add(new Pair<>("vespa.dump.feature", feature.toString())); + } + for (var entry : featureRenames.entrySet()) { + properties.add(new Pair<>("vespa.feature.rename", entry.getKey())); + properties.add(new Pair<>("vespa.feature.rename", entry.getValue())); + } + if (numThreadsPerSearch > 0) { + properties.add(new Pair<>("vespa.matching.numthreadspersearch", numThreadsPerSearch + "")); + } + if (minHitsPerThread > 0) { + properties.add(new Pair<>("vespa.matching.minhitsperthread", minHitsPerThread + "")); + } + if (numSearchPartitions >= 0) { + properties.add(new Pair<>("vespa.matching.numsearchpartitions", numSearchPartitions + "")); + } + if (termwiseLimit < 1.0) { + properties.add(new Pair<>("vespa.matching.termwise_limit", termwiseLimit + "")); + } + if (postFilterThreshold.isPresent()) { + properties.add(new Pair<>("vespa.matching.global_filter.upper_limit", String.valueOf(postFilterThreshold.getAsDouble()))); + } + if (approximateThreshold.isPresent()) { + properties.add(new Pair<>("vespa.matching.global_filter.lower_limit", String.valueOf(approximateThreshold.getAsDouble()))); + } + if (matchPhaseSettings != null) { + properties.add(new Pair<>("vespa.matchphase.degradation.attribute", matchPhaseSettings.getAttribute())); + properties.add(new Pair<>("vespa.matchphase.degradation.ascendingorder", matchPhaseSettings.getAscending() + "")); + properties.add(new Pair<>("vespa.matchphase.degradation.maxhits", matchPhaseSettings.getMaxHits() + "")); + properties.add(new Pair<>("vespa.matchphase.degradation.maxfiltercoverage", matchPhaseSettings.getMaxFilterCoverage() + "")); + properties.add(new Pair<>("vespa.matchphase.degradation.samplepercentage", matchPhaseSettings.getEvaluationPoint() + "")); + properties.add(new Pair<>("vespa.matchphase.degradation.postfiltermultiplier", matchPhaseSettings.getPrePostFilterTippingPoint() + "")); + RankProfile.DiversitySettings diversitySettings = matchPhaseSettings.getDiversity(); + if (diversitySettings != null) { + properties.add(new Pair<>("vespa.matchphase.diversity.attribute", diversitySettings.getAttribute())); + properties.add(new Pair<>("vespa.matchphase.diversity.mingroups", String.valueOf(diversitySettings.getMinGroups()))); + properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.factor", String.valueOf(diversitySettings.getCutoffFactor()))); + properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.strategy", String.valueOf(diversitySettings.getCutoffStrategy()))); + } + } + if (rerankCount > -1) { + properties.add(new Pair<>("vespa.hitcollector.heapsize", rerankCount + "")); + } + if (keepRankCount > -1) { + properties.add(new Pair<>("vespa.hitcollector.arraysize", keepRankCount + "")); + } + if (rankScoreDropLimit > -Double.MAX_VALUE) { + properties.add(new Pair<>("vespa.hitcollector.rankscoredroplimit", rankScoreDropLimit + "")); + } + if (ignoreDefaultRankFeatures) { + properties.add(new Pair<>("vespa.dump.ignoredefaultfeatures", String.valueOf(true))); + } + for (String fieldName : filterFields) { + properties.add(new Pair<>("vespa.isfilterfield." + fieldName, String.valueOf(true))); + } + for (Map.Entry<String, String> attributeType : attributeTypes.entrySet()) { + properties.add(new Pair<>("vespa.type.attribute." + attributeType.getKey(), attributeType.getValue())); + } + + for (var input : inputs.values()) { + if (FeatureNames.isQueryFeature(input.name())) { + if (input.type().rank() > 0) // Proton does not like representing the double type as a rank 0 tensor + properties.add(new Pair<>("vespa.type.query." + input.name().arguments().expressions().get(0), + input.type().toString())); + if (input.defaultValue().isPresent()) { + properties.add(new Pair<>(input.name().toString(), + input.type().rank() == 0 ? + String.valueOf(input.defaultValue().get().asDouble()) : + input.defaultValue().get().toString(true, false))); + } + } + } + if (properties.size() >= 1000000) throw new IllegalArgumentException("Too many rank properties"); + distributeLargeExpressionsAsFiles(properties, largeRankExpressions); + return properties; + } + + private void distributeLargeExpressionsAsFiles(List<Pair<String, String>> properties, LargeRankExpressions largeRankExpressions) { + for (ListIterator<Pair<String, String>> iter = properties.listIterator(); iter.hasNext();) { + Pair<String, String> property = iter.next(); + String expression = property.getSecond(); + if (expression.length() > largeRankExpressions.limit()) { + String propertyName = property.getFirst(); + String functionName = RankingExpression.extractScriptName(propertyName); + if (functionName != null) { + String mangledName = rankprofileName + "." + functionName; + largeRankExpressions.add(new RankExpressionBody(mangledName, ByteBuffer.wrap(expression.getBytes(StandardCharsets.UTF_8)))); + iter.set(new Pair<>(RankingExpression.propertyExpressionName(functionName), mangledName)); + } + } + } + } + + private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String phase) { + List<Pair<String, String>> properties = new ArrayList<>(); + if (expression == null) return properties; + + String name = expression.getName(); + if ("".equals(name)) + name = phase; + + if (expression.getRoot() instanceof ReferenceNode) { + properties.add(new Pair<>("vespa.rank." + phase, expression.getRoot().toString())); + } else { + properties.add(new Pair<>("vespa.rank." + phase, "rankingExpression(" + name + ")")); + properties.add(new Pair<>(RankingExpression.propertyName(name), expression.getRoot().toString())); + } + return properties; + } + + private void deriveOnnxModelFunctionsAndFeatures(RankProfile rankProfile) { + if (rankProfile.schema() == null) return; + if (rankProfile.onnxModels().isEmpty()) return; + replaceOnnxFunctionInputs(rankProfile); + replaceImplicitOnnxConfigFeatures(summaryFeatures, rankProfile); + replaceImplicitOnnxConfigFeatures(matchFeatures, rankProfile); + } + + private void replaceOnnxFunctionInputs(RankProfile rankProfile) { + Set<String> functionNames = rankProfile.getFunctions().keySet(); + if (functionNames.isEmpty()) return; + for (OnnxModel onnxModel: rankProfile.onnxModels().values()) { + for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { + String source = mapping.getValue(); + if (functionNames.contains(source)) { + onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); + } + } + } + } + + private void replaceImplicitOnnxConfigFeatures(Set<ReferenceNode> features, RankProfile rankProfile) { + if (features == null || features.isEmpty()) return; + Set<ReferenceNode> replacedFeatures = new HashSet<>(); + for (Iterator<ReferenceNode> i = features.iterator(); i.hasNext(); ) { + ReferenceNode referenceNode = i.next(); + ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); + if (referenceNode != replacedNode) { + replacedFeatures.add(replacedNode); + i.remove(); + } + } + features.addAll(replacedFeatures); + } + + } + +} |