aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/RankProfile.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/RankProfile.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java120
1 files changed, 107 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index dafe0b48698..0cfcddc6c57 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1,4 +1,4 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
@@ -22,6 +22,7 @@ import com.yahoo.searchlib.rankingexpression.FeatureList;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -30,6 +31,7 @@ import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -626,8 +628,10 @@ public class RankProfile implements Cloneable {
}
private void addImplicitMatchFeatures(List<FeatureList> list) {
- if (matchFeatures == null)
- matchFeatures = new LinkedHashSet<>();
+ if (matchFeatures == null) {
+ var inherited = getMatchFeatures();
+ matchFeatures = new LinkedHashSet<>(inherited);
+ }
if (hiddenMatchFeatures == null)
hiddenMatchFeatures = new LinkedHashSet<>();
for (var features : list) {
@@ -1058,21 +1062,45 @@ public class RankProfile implements Cloneable {
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
allFunctionsCached = null;
+ var context = new RankProfileTransformContext(this,
+ queryProfiles,
+ featureTypes,
+ importedModels,
+ constants(),
+ inlineFunctions);
+ var allNormalizers = getFeatureNormalizers();
+ verifyNoNormalizers("first-phase expression", firstPhaseRanking, allNormalizers, context);
+ verifyNoNormalizers("second-phase expression", secondPhaseRanking, allNormalizers, context);
+ for (ReferenceNode mf : getMatchFeatures()) {
+ verifyNoNormalizers("match-feature " + mf, mf, allNormalizers, context);
+ }
+ for (ReferenceNode sf : getSummaryFeatures()) {
+ verifyNoNormalizers("summary-feature " + sf, sf, allNormalizers, context);
+ }
if (globalPhaseRanking != null) {
- var context = new RankProfileTransformContext(this,
- queryProfiles,
- featureTypes,
- importedModels,
- constants(),
- inlineFunctions);
var needInputs = new HashSet<String>();
+ Set<String> userDeclaredMatchFeatures = new HashSet<>();
+ for (ReferenceNode mf : getMatchFeatures()) {
+ userDeclaredMatchFeatures.add(mf.toString());
+ }
var recorder = new InputRecorder(needInputs);
- if (matchFeatures != null) {
- for (ReferenceNode mf : matchFeatures) {
- recorder.alreadyHandled(mf.toString());
+ recorder.alreadyMatchFeatures(userDeclaredMatchFeatures);
+ recorder.addKnownNormalizers(allNormalizers.keySet());
+ recorder.process(globalPhaseRanking.function().getBody(), context);
+ for (var normalizerName : recorder.normalizersUsed()) {
+ var normalizer = allNormalizers.get(normalizerName);
+ var func = functions.get(normalizer.input());
+ if (func != null) {
+ verifyNoNormalizers("normalizer input " + normalizer.input(), func, allNormalizers, context);
+ if (! userDeclaredMatchFeatures.contains(normalizer.input())) {
+ var subRecorder = new InputRecorder(needInputs);
+ subRecorder.alreadyMatchFeatures(userDeclaredMatchFeatures);
+ subRecorder.process(func.function().getBody(), context);
+ }
+ } else {
+ needInputs.add(normalizer.input());
}
}
- recorder.process(globalPhaseRanking.function().getBody(), context);
List<FeatureList> addIfMissing = new ArrayList<>();
for (String input : needInputs) {
if (input.startsWith("constant(") || input.startsWith("query(")) {
@@ -1630,4 +1658,70 @@ public class RankProfile implements Cloneable {
}
+ public static record RankFeatureNormalizer(Reference original, String name, String input, String algo, double kparam) {
+ @Override
+ public String toString() {
+ return "normalizer{name=" + name + ",input=" + input + ",algo=" + algo + ",k=" + kparam + "}";
+ }
+ private static long hash(String s) {
+ int bob = com.yahoo.collections.BobHash.hash(s);
+ return bob + 0x100000000L;
+ }
+ public static RankFeatureNormalizer linear(Reference original, Reference inputRef) {
+ long h = hash(original.toString());
+ String name = "normalize@" + h + "@linear";
+ return new RankFeatureNormalizer(original, name, inputRef.toString(), "LINEAR", 0.0);
+ }
+ public static RankFeatureNormalizer rrank(Reference original, Reference inputRef, double k) {
+ long h = hash(original.toString());
+ String name = "normalize@" + h + "@rrank";
+ return new RankFeatureNormalizer(original, name, inputRef.toString(), "RRANK", k);
+ }
+ }
+
+ private List<RankFeatureNormalizer> featureNormalizers = new ArrayList<>();
+
+ public Map<String, RankFeatureNormalizer> getFeatureNormalizers() {
+ Map<String, RankFeatureNormalizer> all = new LinkedHashMap<>();
+ for (var inheritedProfile : inherited()) {
+ all.putAll(inheritedProfile.getFeatureNormalizers());
+ }
+ for (var n : featureNormalizers) {
+ all.put(n.name(), n);
+ }
+ return all;
+ }
+
+ public void addFeatureNormalizer(RankFeatureNormalizer n) {
+ if (functions.get(n.name()) != null) {
+ throw new IllegalArgumentException("cannot use name '" + name + "' for both function and normalizer");
+ }
+ featureNormalizers.add(n);
+ }
+
+ private void verifyNoNormalizers(String where, RankingExpressionFunction f, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) {
+ if (f == null) return;
+ verifyNoNormalizers(where, f.function(), allNormalizers, context);
+ }
+
+ private void verifyNoNormalizers(String where, ExpressionFunction func, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) {
+ if (func == null) return;
+ var body = func.getBody();
+ if (body == null) return;
+ verifyNoNormalizers(where, body.getRoot(), allNormalizers, context);
+ }
+
+ private void verifyNoNormalizers(String where, ExpressionNode node, Map<String, RankFeatureNormalizer> allNormalizers, RankProfileTransformContext context) {
+ var needInputs = new HashSet<String>();
+ var recorder = new InputRecorder(needInputs);
+ recorder.process(node, context);
+ for (var input : needInputs) {
+ var normalizer = allNormalizers.get(input);
+ if (normalizer != null) {
+ throw new IllegalArgumentException("Cannot use " + normalizer.original() + " from " + where + ", only valid in global-phase expression");
+ }
+ }
+ }
+
+
}