From 0c55dc92a3bf889c67fac1ca855e6e33e1994904 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 9 Oct 2023 09:44:29 +0200 Subject: Update copyright --- config-model/src/test/derived/rankingexpression/rankexpression.sd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'config-model/src/test/derived/rankingexpression/rankexpression.sd') diff --git a/config-model/src/test/derived/rankingexpression/rankexpression.sd b/config-model/src/test/derived/rankingexpression/rankexpression.sd index b0de2c60299..16dff61b63a 100644 --- a/config-model/src/test/derived/rankingexpression/rankexpression.sd +++ b/config-model/src/test/derived/rankingexpression/rankexpression.sd @@ -1,4 +1,4 @@ -# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. schema rankexpression { document rankexpression { -- cgit v1.2.3 From 5467c074d4922bdd3ed330a9079b4e3291f82861 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 12 Oct 2023 11:46:18 +0000 Subject: allow configuring normalizers --- .../main/java/com/yahoo/schema/RankProfile.java | 112 ++++++++++++-- .../com/yahoo/schema/derived/RawRankProfile.java | 21 ++- .../expressiontransforms/ExpressionTransforms.java | 3 +- .../schema/expressiontransforms/InputRecorder.java | 37 ++++- .../NormalizerFunctionExpander.java | 134 ++++++++++++++++ .../derived/rankingexpression/rank-profiles.cfg | 62 ++++++++ .../derived/rankingexpression/rankexpression.sd | 28 ++++ .../com/yahoo/schema/NoNormalizersTestCase.java | 170 +++++++++++++++++++++ 8 files changed, 549 insertions(+), 18 deletions(-) create mode 100644 config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java create mode 100644 config-model/src/test/java/com/yahoo/schema/NoNormalizersTestCase.java (limited to 'config-model/src/test/derived/rankingexpression/rankexpression.sd') diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 6007a1cf4b1..e2577f4f834 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -22,6 +22,7 @@ import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -30,6 +31,7 @@ import java.io.IOException; import java.io.Reader; import java.io.StringReader; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -1058,21 +1060,45 @@ public class RankProfile implements Cloneable { functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); allFunctionsCached = null; + var context = new RankProfileTransformContext(this, + queryProfiles, + featureTypes, + importedModels, + constants(), + inlineFunctions); + var allNormalizers = getFeatureNormalizers(); + verifyNoNormalizers("first-phase expression", firstPhaseRanking, allNormalizers, context); + verifyNoNormalizers("second-phase expression", secondPhaseRanking, allNormalizers, context); + for (ReferenceNode mf : getMatchFeatures()) { + verifyNoNormalizers("match-feature " + mf, mf, allNormalizers, context); + } + for (ReferenceNode sf : getSummaryFeatures()) { + verifyNoNormalizers("summary-feature " + sf, sf, allNormalizers, context); + } if (globalPhaseRanking != null) { - var context = new RankProfileTransformContext(this, - queryProfiles, - featureTypes, - importedModels, - constants(), - inlineFunctions); var needInputs = new HashSet(); + Set userDeclaredMatchFeatures = new HashSet<>(); + for (ReferenceNode mf : getMatchFeatures()) { + userDeclaredMatchFeatures.add(mf.toString()); + } var recorder = new InputRecorder(needInputs); - if (matchFeatures != null) { - for (ReferenceNode mf : matchFeatures) { - recorder.alreadyHandled(mf.toString()); + recorder.alreadyMatchFeatures(userDeclaredMatchFeatures); + recorder.addKnownNormalizers(allNormalizers.keySet()); + recorder.process(globalPhaseRanking.function().getBody(), context); + for (var normalizerName : recorder.normalizersUsed()) { + var normalizer = allNormalizers.get(normalizerName); + var func = functions.get(normalizer.input()); + if (func != null) { + verifyNoNormalizers("normalizer input " + normalizer.input(), func, allNormalizers, context); + if (! userDeclaredMatchFeatures.contains(normalizer.input())) { + var subRecorder = new InputRecorder(needInputs); + subRecorder.alreadyMatchFeatures(userDeclaredMatchFeatures); + subRecorder.process(func.function().getBody(), context); + } + } else { + needInputs.add(normalizer.input()); } } - recorder.process(globalPhaseRanking.function().getBody(), context); List addIfMissing = new ArrayList<>(); for (String input : needInputs) { if (input.startsWith("constant(") || input.startsWith("query(")) { @@ -1630,4 +1656,70 @@ public class RankProfile implements Cloneable { } + public static record RankFeatureNormalizer(Reference original, String name, String input, String algo, double kparam) { + @Override + public String toString() { + return "normalizer{name=" + name + ",input=" + input + ",algo=" + algo + ",k=" + kparam + "}"; + } + private static long hash(String s) { + int bob = com.yahoo.collections.BobHash.hash(s); + return bob + 0x100000000L; + } + public static RankFeatureNormalizer linear(Reference original, Reference inputRef) { + long h = hash(original.toString()); + String name = "normalize@" + h + "@linear"; + return new RankFeatureNormalizer(original, name, inputRef.toString(), "LINEAR", 0.0); + } + public static RankFeatureNormalizer rrank(Reference original, Reference inputRef, double k) { + long h = hash(original.toString()); + String name = "normalize@" + h + "@rrank"; + return new RankFeatureNormalizer(original, name, inputRef.toString(), "RRANK", k); + } + } + + private List featureNormalizers = new ArrayList<>(); + + public Map getFeatureNormalizers() { + Map all = new LinkedHashMap<>(); + for (var inheritedProfile : inherited()) { + all.putAll(inheritedProfile.getFeatureNormalizers()); + } + for (var n : featureNormalizers) { + all.put(n.name(), n); + } + return all; + } + + public void addFeatureNormalizer(RankFeatureNormalizer n) { + if (functions.get(n.name()) != null) { + throw new IllegalArgumentException("cannot use name '" + name + "' for both function and normalizer"); + } + featureNormalizers.add(n); + } + + private void verifyNoNormalizers(String where, RankingExpressionFunction f, Map allNormalizers, RankProfileTransformContext context) { + if (f == null) return; + verifyNoNormalizers(where, f.function(), allNormalizers, context); + } + + private void verifyNoNormalizers(String where, ExpressionFunction func, Map allNormalizers, RankProfileTransformContext context) { + if (func == null) return; + var body = func.getBody(); + if (body == null) return; + verifyNoNormalizers(where, body.getRoot(), allNormalizers, context); + } + + private void verifyNoNormalizers(String where, ExpressionNode node, Map allNormalizers, RankProfileTransformContext context) { + var needInputs = new HashSet(); + var recorder = new InputRecorder(needInputs); + recorder.process(node, context); + for (var input : needInputs) { + var normalizer = allNormalizers.get(input); + if (normalizer != null) { + throw new IllegalArgumentException("Cannot use " + normalizer.original() + " from " + where + ", only valid in global-phase expression"); + } + } + } + + } diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 05e5f17ea3d..eb9f7d44c91 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -54,6 +54,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { private final String name; private final Compressor.Compression compressedProperties; + private final Map featureNormalizers; /** The compiled profile this is created from. */ private final Collection constants; @@ -66,13 +67,14 @@ public class RawRankProfile implements RankProfilesConfig.Producer { this.name = rankProfile.name(); /* * Forget the RankProfiles as soon as possible. They can become very large and memory hungry - * Especially do not refer then through any member variables due to the RawRankProfile living forever. + * Especially do not refer them through any member variables due to the RawRankProfile living forever. */ RankProfile compiled = rankProfile.compile(queryProfiles, importedModels); constants = compiled.constants().values(); onnxModels = compiled.onnxModels().values(); - compressedProperties = compress(new Deriver(compiled, attributeFields, deployProperties, queryProfiles) - .derive(largeExpressions)); + var deriver = new Deriver(compiled, attributeFields, deployProperties, queryProfiles); + compressedProperties = compress(deriver.derive(largeExpressions)); + this.featureNormalizers = compiled.getFeatureNormalizers(); } public Collection constants() { return constants; } @@ -111,6 +113,18 @@ public class RawRankProfile implements RankProfilesConfig.Producer { b.fef(fefB); } + private void buildNormalizers(RankProfilesConfig.Rankprofile.Builder b) { + for (var normalizer : featureNormalizers.values()) { + var nBuilder = new RankProfilesConfig.Rankprofile.Normalizer.Builder(); + nBuilder.name(normalizer.name()); + nBuilder.input(normalizer.input()); + var algo = RankProfilesConfig.Rankprofile.Normalizer.Algo.Enum.valueOf(normalizer.algo()); + nBuilder.algo(algo); + nBuilder.kparam(normalizer.kparam()); + b.normalizer(nBuilder); + } + } + /** * Returns the properties of this as an unmodifiable list. * Note: This method is expensive. @@ -121,6 +135,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { public void getConfig(RankProfilesConfig.Builder builder) { RankProfilesConfig.Rankprofile.Builder b = new RankProfilesConfig.Rankprofile.Builder().name(getName()); getRankProperties(b); + buildNormalizers(b); builder.rankprofile(b); } diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java index cf46bedf223..42c8147b3dc 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java @@ -35,7 +35,8 @@ public class ExpressionTransforms { new FunctionShadower(), new TensorMaxMinTransformer(), new Simplifier(), - new BooleanExpressionTransformer()); + new BooleanExpressionTransformer(), + new NormalizerFunctionExpander()); } public RankingExpression transform(RankingExpression expression, RankProfileTransformContext context) { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 1128aaf3681..ab18f9c83db 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.functions.Generate; import java.io.StringReader; +import java.util.Collection; import java.util.HashSet; import java.util.Set; import java.util.logging.Logger; @@ -29,19 +30,35 @@ public class InputRecorder extends ExpressionTransformer { private final Set neededInputs; private final Set handled = new HashSet<>(); + private final Set availableNormalizers = new HashSet<>(); + private final Set usedNormalizers = new HashSet<>(); public InputRecorder(Set target) { this.neededInputs = target; } public void process(RankingExpression expression, RankProfileTransformContext context) { - transform(expression.getRoot(), new InputRecorderContext(context)); + process(expression.getRoot(), context); } - public void alreadyHandled(String name) { - handled.add(name); + public void process(ExpressionNode node, RankProfileTransformContext context) { + transform(node, new InputRecorderContext(context)); } + public void alreadyMatchFeatures(Collection matchFeatures) { + for (String mf : matchFeatures) { + handled.add(mf); + } + } + + public void addKnownNormalizers(Collection names) { + for (String name : names) { + availableNormalizers.add(name); + } + } + + public Set normalizersUsed() { return this.usedNormalizers; } + @Override public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { @@ -77,6 +94,10 @@ public class InputRecorder extends ExpressionTransformer { if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { return; } + if (simpleFunctionOrIdentifier && availableNormalizers.contains(name)) { + usedNormalizers.add(name); + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; @@ -113,12 +134,20 @@ public class InputRecorder extends ExpressionTransformer { } } if ("onnx".equals(name)) { - if (args.size() != 1) { + if (args.size() < 1) { throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature); } var arg = args.expressions().get(0); var models = context.rankProfile().onnxModels(); var model = models.get(arg.toString()); + if (model == null) { + var tmp = OnnxModelTransformer.transformFeature(feature, context.rankProfile()); + if (tmp instanceof ReferenceNode newRefNode) { + args = newRefNode.getArguments(); + arg = args.expressions().get(0); + model = models.get(arg.toString()); + } + } if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java new file mode 100644 index 00000000000..a8fee966656 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java @@ -0,0 +1,134 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.FeatureNames; +import com.yahoo.schema.RankProfile.RankFeatureNormalizer; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; + +import java.io.StringReader; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.ArrayList; + +/** + * Recognizes pseudo-functions and creates global-phase normalizers + * @author arnej + */ +public class NormalizerFunctionExpander extends ExpressionTransformer { + + public final static String NORMALIZE_LINEAR = "normalize_linear"; + public final static String RECIPROCAL_RANK = "reciprocal_rank"; + public final static String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion"; + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode r) { + node = transformReference(r, context); + } + if (node instanceof CompositeNode composite) { + node = transformChildren(composite, context); + } + return node; + } + + private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) { + Reference ref = node.reference(); + String name = ref.name(); + if (ref.output() != null) { + return node; + } + var f = context.rankProfile().getFunctions().get(name); + if (f != null) { + // never transform declared functions + return node; + } + return switch(name) { + case RECIPROCAL_RANK_FUSION -> transform(expandRRF(ref), context); + case NORMALIZE_LINEAR -> transformNormLin(ref, context); + case RECIPROCAL_RANK -> transformRRank(ref, context); + default -> node; + }; + } + + private ExpressionNode expandRRF(Reference ref) { + var args = ref.arguments(); + if (args.size() < 2) { + throw new IllegalArgumentException("must have at least 2 arguments: " + ref); + } + List children = new ArrayList<>(); + List operators = new ArrayList<>(); + for (var arg : args.expressions()) { + if (! children.isEmpty()) operators.add(Operator.plus); + children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null)); + } + // must be further transformed (see above) + return new OperationNode(children, operators); + } + + private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) { + var args = ref.arguments(); + if (args.size() != 1) { + throw new IllegalArgumentException("must have exactly 1 argument: " + ref); + } + var input = args.expressions().get(0); + if (input instanceof ReferenceNode inputRefNode) { + var inputRef = inputRefNode.reference(); + RankFeatureNormalizer normalizer = RankFeatureNormalizer.linear(ref, inputRef); + context.rankProfile().addFeatureNormalizer(normalizer); + var newRef = Reference.fromIdentifier(normalizer.name()); + return new ReferenceNode(newRef); + } else { + throw new IllegalArgumentException("the first argument must be a simple feature: " + ref + " => " + input.getClass()); + } + } + + private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) { + var args = ref.arguments(); + if (args.size() < 1 || args.size() > 2) { + throw new IllegalArgumentException("must have 1 or 2 arguments: " + ref); + } + double k = 60.0; + if (args.size() == 2) { + var kArg = args.expressions().get(1); + if (kArg instanceof ConstantNode kNode) { + k = kNode.getValue().asDouble(); + } else { + throw new IllegalArgumentException("the second argument (k) must be a constant in: " + ref); + } + } + var input = args.expressions().get(0); + if (input instanceof ReferenceNode inputRefNode) { + var inputRef = inputRefNode.reference(); + RankFeatureNormalizer normalizer = RankFeatureNormalizer.rrank(ref, inputRef, k); + context.rankProfile().addFeatureNormalizer(normalizer); + var newRef = Reference.fromIdentifier(normalizer.name()); + return new ReferenceNode(newRef); + } else { + throw new IllegalArgumentException("the first argument must be a simple feature: " + ref); + } + } +} diff --git a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg index b0f7d0f2477..b3257c962dd 100644 --- a/config-model/src/test/derived/rankingexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankingexpression/rank-profiles.cfg @@ -520,3 +520,65 @@ rankprofile[].fef.property[].name "vespa.type.attribute.t1" rankprofile[].fef.property[].value "tensor(m{},v[3])" rankprofile[].fef.property[].name "vespa.type.query.v" rankprofile[].fef.property[].value "tensor(v[3])" +rankprofile[].name "withnorm" +rankprofile[].fef.property[].name "rankingExpression(normBar).rankingScript" +rankprofile[].fef.property[].value "attribute(foo1) + attribute(year)" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "normalize@3551296680@linear + normalize@2879443254@rrank" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "nativeRank" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(year)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "attribute(year)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" +rankprofile[].fef.property[].value "123" +rankprofile[].fef.property[].name "vespa.type.attribute.t1" +rankprofile[].fef.property[].value "tensor(m{},v[3])" +rankprofile[].normalizer[].name "normalize@3551296680@linear" +rankprofile[].normalizer[].input "nativeRank" +rankprofile[].normalizer[].algo LINEAR +rankprofile[].normalizer[].kparam 0.0 +rankprofile[].normalizer[].name "normalize@2879443254@rrank" +rankprofile[].normalizer[].input "normBar" +rankprofile[].normalizer[].algo RRANK +rankprofile[].normalizer[].kparam 42.0 +rankprofile[].name "withfusion" +rankprofile[].fef.property[].name "rankingExpression(normBar).rankingScript" +rankprofile[].fef.property[].value "attribute(foo1) + attribute(year)" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.rank.globalphase" +rankprofile[].fef.property[].value "rankingExpression(globalphase)" +rankprofile[].fef.property[].name "rankingExpression(globalphase).rankingScript" +rankprofile[].fef.property[].value "normalize@5385018767@rrank + normalize@3221316369@rrank" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "nativeRank" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(year)" +rankprofile[].fef.property[].name "vespa.match.feature" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "attribute(year)" +rankprofile[].fef.property[].name "vespa.hidden.matchfeature" +rankprofile[].fef.property[].value "attribute(foo1)" +rankprofile[].fef.property[].name "vespa.globalphase.rerankcount" +rankprofile[].fef.property[].value "456" +rankprofile[].fef.property[].name "vespa.type.attribute.t1" +rankprofile[].fef.property[].value "tensor(m{},v[3])" +rankprofile[].normalizer[].name "normalize@5385018767@rrank" +rankprofile[].normalizer[].input "normBar" +rankprofile[].normalizer[].algo RRANK +rankprofile[].normalizer[].kparam 60.0 +rankprofile[].normalizer[].name "normalize@3221316369@rrank" +rankprofile[].normalizer[].input "nativeRank" +rankprofile[].normalizer[].algo RRANK +rankprofile[].normalizer[].kparam 60.0 diff --git a/config-model/src/test/derived/rankingexpression/rankexpression.sd b/config-model/src/test/derived/rankingexpression/rankexpression.sd index 16dff61b63a..15537f1f9d0 100644 --- a/config-model/src/test/derived/rankingexpression/rankexpression.sd +++ b/config-model/src/test/derived/rankingexpression/rankexpression.sd @@ -441,4 +441,32 @@ schema rankexpression { } } + rank-profile withnorm { + first-phase { + expression: attribute(foo1) + } + function normBar() { + expression: attribute(foo1) + attribute(year) + } + global-phase { + expression: normalize_linear(nativeRank) + reciprocal_rank(normBar(), 42.0) + rerank-count: 123 + } + match-features: nativeRank + } + + rank-profile withfusion { + first-phase { + expression: attribute(foo1) + } + function normBar() { + expression: attribute(foo1) + attribute(year) + } + global-phase { + expression: reciprocal_rank_fusion(normBar, nativeRank) + rerank-count: 456 + } + match-features: nativeRank + } + } diff --git a/config-model/src/test/java/com/yahoo/schema/NoNormalizersTestCase.java b/config-model/src/test/java/com/yahoo/schema/NoNormalizersTestCase.java new file mode 100644 index 00000000000..f1620f7415c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/schema/NoNormalizersTestCase.java @@ -0,0 +1,170 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema; + +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.schema.parser.ParseException; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests rank profiles with normalizers in bad places + * + * @author arnej + */ +public class NoNormalizersTestCase extends AbstractSchemaTestCase { + + static String wrapError(String core) { + return "Cannot use " + core + ", only valid in global-phase expression"; + } + + void compileSchema(String schema) throws ParseException { + RankProfileRegistry registry = new RankProfileRegistry(); + var qp = new QueryProfileRegistry(); + ApplicationBuilder builder = new ApplicationBuilder(registry, qp); + builder.addSchema(schema); + builder.build(true); + for (RankProfile rp : registry.all()) { + rp.compile(qp, new ImportedMlModels()); + } + } + + @Test + void requireThatNormalizerInFirstPhaseIsChecked() throws ParseException { + try { + compileSchema(""" + search test { + document test { } + rank-profile p1 { + first-phase { + expression: normalize_linear(nativeRank) + } + } + } + """); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Rank profile 'p1' is invalid", e.getMessage()); + assertEquals(wrapError("normalize_linear(nativeRank) from first-phase expression"), e.getCause().getMessage()); + } + } + + @Test + void requireThatNormalizerInSecondPhaseIsChecked() throws ParseException { + try { + compileSchema(""" + search test { + document test { + field title type string { + indexing: index + } + } + rank-profile p2 { + function foobar() { + expression: 42 + reciprocal_rank(whatever, 1.0) + } + function whatever() { + expression: fieldMatch(title) + } + first-phase { + expression: nativeRank + } + second-phase { + expression: foobar + } + } + } + """); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Rank profile 'p2' is invalid", e.getMessage()); + assertEquals(wrapError("reciprocal_rank(whatever,1.0) from second-phase expression"), e.getCause().getMessage()); + } + } + + @Test + void requireThatNormalizerInMatchFeatureIsChecked() throws ParseException { + try { + compileSchema(""" + search test { + document test { } + rank-profile p3 { + function foobar() { + expression: normalize_linear(nativeRank) + } + first-phase { + expression: nativeRank + } + match-features { + nativeRank + foobar + } + } + } + """); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Rank profile 'p3' is invalid", e.getMessage()); + assertEquals(wrapError("normalize_linear(nativeRank) from match-feature foobar"), e.getCause().getMessage()); + } + } + + @Test + void requireThatNormalizerInSummaryFeatureIsChecked() throws ParseException { + try { + compileSchema(""" + search test { + document test { } + rank-profile p4 { + function foobar() { + expression: normalize_linear(nativeRank) + } + first-phase { + expression: nativeRank + } + summary-features { + nativeRank + foobar + } + } + } + """); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Rank profile 'p4' is invalid", e.getMessage()); + assertEquals(wrapError("normalize_linear(nativeRank) from summary-feature foobar"), e.getCause().getMessage()); + } + } + + @Test + void requireThatNormalizerInNormalizerIsChecked() throws ParseException { + try { + compileSchema(""" + search test { + document test { + field title type string { + indexing: index + } + } + rank-profile p5 { + function foobar() { + expression: reciprocal_rank(nativeRank) + } + first-phase { + expression: nativeRank + } + global-phase { + expression: normalize_linear(fieldMatch(title)) + normalize_linear(foobar) + } + } + } + """); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Rank profile 'p5' is invalid", e.getMessage()); + assertEquals(wrapError("reciprocal_rank(nativeRank) from normalizer input foobar"), e.getCause().getMessage()); + } + } +} -- cgit v1.2.3