diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/RankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/RankProfile.java | 120 |
1 files changed, 107 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index dafe0b48698..0cfcddc6c57 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; @@ -22,6 +22,7 @@ import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -30,6 +31,7 @@ import java.io.IOException; import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -626,8 +628,10 @@ public class RankProfile implements Cloneable { } private void addImplicitMatchFeatures(List<FeatureList> list) { - if (matchFeatures == null) - matchFeatures = new LinkedHashSet<>(); + if (matchFeatures == null) { + var inherited = getMatchFeatures(); + matchFeatures = new LinkedHashSet<>(inherited); + } if (hiddenMatchFeatures == null) hiddenMatchFeatures = new LinkedHashSet<>(); for (var features : list) { @@ -1058,21 +1062,45 @@ public class RankProfile implements Cloneable { functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); allFunctionsCached = null; + var context = new RankProfileTransformContext(this, + queryProfiles, + featureTypes, + importedModels, + constants(), + inlineFunctions); + var allNormalizers = getFeatureNormalizers(); + verifyNoNormalizers("first-phase expression", firstPhaseRanking, allNormalizers, context); + verifyNoNormalizers("second-phase expression", secondPhaseRanking, allNormalizers, context); + for (ReferenceNode mf : getMatchFeatures()) { + verifyNoNormalizers("match-feature " + mf, mf, allNormalizers, context); + } + for (ReferenceNode sf : getSummaryFeatures()) { + verifyNoNormalizers("summary-feature " + sf, sf, allNormalizers, context); + } if (globalPhaseRanking != null) { - var context = new RankProfileTransformContext(this, - queryProfiles, - featureTypes, - importedModels, - constants(), - inlineFunctions); var needInputs = new HashSet<String>(); + Set<String> userDeclaredMatchFeatures = new HashSet<>(); + for (ReferenceNode mf : getMatchFeatures()) { + userDeclaredMatchFeatures.add(mf.toString()); + } var recorder = new InputRecorder(needInputs); - if (matchFeatures != null) { - for (ReferenceNode mf : matchFeatures) { - recorder.alreadyHandled(mf.toString()); + recorder.alreadyMatchFeatures(userDeclaredMatchFeatures); + recorder.addKnownNormalizers(allNormalizers.keySet()); + recorder.process(globalPhaseRanking.function().getBody(), context); + for (var normalizerName : recorder.normalizersUsed()) { + var normalizer = allNormalizers.get(normalizerName); + var func = functions.get(normalizer.input()); + if (func != null) { + verifyNoNormalizers("normalizer input " + normalizer.input(), func, allNormalizers, context); + if (! userDeclaredMatchFeatures.contains(normalizer.input())) { + var subRecorder = new InputRecorder(needInputs); + subRecorder.alreadyMatchFeatures(userDeclaredMatchFeatures); + subRecorder.process(func.function().getBody(), context); + } + } else { + needInputs.add(normalizer.input()); } } - recorder.process(globalPhaseRanking.function().getBody(), context); List<FeatureList> addIfMissing = new ArrayList<>(); for (String input : needInputs) { if (input.startsWith("constant(") || input.startsWith("query(")) { @@ -1630,4 +1658,70 @@ public class RankProfile implements Cloneable { } + public static record RankFeatureNormalizer(Reference original, String name, String input, String algo, double kparam) { + @Override + public String toString() { + return "normalizer{name=" + name + ",input=" + input + ",algo=" + algo + ",k=" + kparam + "}"; + } + private static long hash(String s) { + int bob = com.yahoo.collections.BobHash.hash(s); + return bob + 0x100000000L; + } + public static RankFeatureNormalizer linear(Reference original, Reference inputRef) { + long h = hash(original.toString()); + String name = "normalize@" + h + "@linear"; + return new RankFeatureNormalizer(original, name, inputRef.toString(), "LINEAR", 0.0); + } + public static RankFeatureNormalizer rrank(Reference original, Reference inputRef, double k) { + long h = hash(original.toString()); + String name = "normalize@" + h + "@rrank"; + return new RankFeatureNormalizer(original, name, inputRef.toString(), "RRANK", k); + } + } + + private List<RankFeatureNormalizer> featureNormalizers = new ArrayList<>(); + + public Map<String, RankFeatureNormalizer> getFeatureNormalizers() { + Map<String, RankFeatureNormalizer> all = new LinkedHashMap<>(); + for (var inheritedProfile : inherited()) { + all.putAll(inheritedProfile.getFeatureNormalizers()); + } + for (var n : featureNormalizers) { + all.put(n.name(), n); + } + return all; + } + + public void addFeatureNormalizer(RankFeatureNormalizer n) { + if (functions.get(n.name()) != null) { + throw new IllegalArgumentException("cannot use name '" + name + "' for both function and normalizer"); + } + featureNormalizers.add(n); + } + + private void verifyNoNormalizers(String where, RankingExpressionFunction f, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) { + if (f == null) return; + verifyNoNormalizers(where, f.function(), allNormalizers, context); + } + + private void verifyNoNormalizers(String where, ExpressionFunction func, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) { + if (func == null) return; + var body = func.getBody(); + if (body == null) return; + verifyNoNormalizers(where, body.getRoot(), allNormalizers, context); + } + + private void verifyNoNormalizers(String where, ExpressionNode node, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) { + var needInputs = new HashSet<String>(); + var recorder = new InputRecorder(needInputs); + recorder.process(node, context); + for (var input : needInputs) { + var normalizer = allNormalizers.get(input); + if (normalizer != null) { + throw new IllegalArgumentException("Cannot use " + normalizer.original() + " from " + where + ", only valid in global-phase expression"); + } + } + } + + } |