diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms')
15 files changed, 183 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 3d5d29fb3c7..9aaa8e38b80 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java index a9eea3d2ead..1c7843113b3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java index 42e99f6aa45..42c8147b3dc 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ExpressionTransforms.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.RankingExpression; @@ -35,7 +35,8 @@ public class ExpressionTransforms { new FunctionShadower(), new TensorMaxMinTransformer(), new Simplifier(), - new BooleanExpressionTransformer()); + new BooleanExpressionTransformer(), + new NormalizerFunctionExpander()); } public RankingExpression transform(RankingExpression expression, RankProfileTransformContext context) { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionInliner.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionInliner.java index 382d51747bb..2d64073a46c 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionInliner.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionInliner.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.RankProfile; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionShadower.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionShadower.java index 702e4ea220e..f3bfdf4cf04 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionShadower.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/FunctionShadower.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.RankProfile; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index b30873fabee..ab18f9c83db 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; @@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.functions.Generate; import java.io.StringReader; +import java.util.Collection; import java.util.HashSet; import java.util.Set; import java.util.logging.Logger; @@ -29,19 +30,35 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); + private final Set<String> availableNormalizers = new HashSet<>(); + private final Set<String> usedNormalizers = new HashSet<>(); public InputRecorder(Set<String> target) { this.neededInputs = target; } public void process(RankingExpression expression, RankProfileTransformContext context) { - transform(expression.getRoot(), new InputRecorderContext(context)); + process(expression.getRoot(), context); } - public void alreadyHandled(String name) { - handled.add(name); + public void process(ExpressionNode node, RankProfileTransformContext context) { + transform(node, new InputRecorderContext(context)); } + public void alreadyMatchFeatures(Collection<String> matchFeatures) { + for (String mf : matchFeatures) { + handled.add(mf); + } + } + + public void addKnownNormalizers(Collection<String> names) { + for (String name : names) { + availableNormalizers.add(name); + } + } + + public Set<String> normalizersUsed() { return this.usedNormalizers; } + @Override public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { @@ -77,6 +94,10 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { return; } + if (simpleFunctionOrIdentifier && availableNormalizers.contains(name)) { + usedNormalizers.add(name); + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; @@ -113,13 +134,21 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } } if ("onnx".equals(name)) { - if (args.size() != 1) { + if (args.size() < 1) { throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature); } var arg = args.expressions().get(0); var models = context.rankProfile().onnxModels(); var model = models.get(arg.toString()); if (model == null) { + var tmp = OnnxModelTransformer.transformFeature(feature, context.rankProfile()); + if (tmp instanceof ReferenceNode newRefNode) { + args = newRefNode.getArguments(); + arg = args.expressions().get(0); + model = models.get(arg.toString()); + } + } + if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } model.getInputMap().forEach((__, onnxInput) -> { diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java index 54617374b67..5fc51f76669 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.RankProfile; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/LightGBMFeatureConverter.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/LightGBMFeatureConverter.java index af5fa5ebeab..24e0252555c 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/LightGBMFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/LightGBMFeatureConverter.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.path.Path; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java new file mode 100644 index 00000000000..a8fee966656 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java @@ -0,0 +1,134 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.FeatureNames; +import com.yahoo.schema.RankProfile.RankFeatureNormalizer; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; + +import java.io.StringReader; +import java.util.HashSet; +import java.util.Set; +import java.util.logging.Logger; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.ArrayList; + +/** + * Recognizes pseudo-functions and creates global-phase normalizers + * @author arnej + */ +public class NormalizerFunctionExpander extends ExpressionTransformer<RankProfileTransformContext> { + + public final static String NORMALIZE_LINEAR = "normalize_linear"; + public final static String RECIPROCAL_RANK = "reciprocal_rank"; + public final static String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion"; + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode r) { + node = transformReference(r, context); + } + if (node instanceof CompositeNode composite) { + node = transformChildren(composite, context); + } + return node; + } + + private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) { + Reference ref = node.reference(); + String name = ref.name(); + if (ref.output() != null) { + return node; + } + var f = context.rankProfile().getFunctions().get(name); + if (f != null) { + // never transform declared functions + return node; + } + return switch(name) { + case RECIPROCAL_RANK_FUSION -> transform(expandRRF(ref), context); + case NORMALIZE_LINEAR -> transformNormLin(ref, context); + case RECIPROCAL_RANK -> transformRRank(ref, context); + default -> node; + }; + } + + private ExpressionNode expandRRF(Reference ref) { + var args = ref.arguments(); + if (args.size() < 2) { + throw new IllegalArgumentException("must have at least 2 arguments: " + ref); + } + List<ExpressionNode> children = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); + for (var arg : args.expressions()) { + if (! children.isEmpty()) operators.add(Operator.plus); + children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null)); + } + // must be further transformed (see above) + return new OperationNode(children, operators); + } + + private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) { + var args = ref.arguments(); + if (args.size() != 1) { + throw new IllegalArgumentException("must have exactly 1 argument: " + ref); + } + var input = args.expressions().get(0); + if (input instanceof ReferenceNode inputRefNode) { + var inputRef = inputRefNode.reference(); + RankFeatureNormalizer normalizer = RankFeatureNormalizer.linear(ref, inputRef); + context.rankProfile().addFeatureNormalizer(normalizer); + var newRef = Reference.fromIdentifier(normalizer.name()); + return new ReferenceNode(newRef); + } else { + throw new IllegalArgumentException("the first argument must be a simple feature: " + ref + " => " + input.getClass()); + } + } + + private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) { + var args = ref.arguments(); + if (args.size() < 1 || args.size() > 2) { + throw new IllegalArgumentException("must have 1 or 2 arguments: " + ref); + } + double k = 60.0; + if (args.size() == 2) { + var kArg = args.expressions().get(1); + if (kArg instanceof ConstantNode kNode) { + k = kNode.getValue().asDouble(); + } else { + throw new IllegalArgumentException("the second argument (k) must be a constant in: " + ref); + } + } + var input = args.expressions().get(0); + if (input instanceof ReferenceNode inputRefNode) { + var inputRef = inputRefNode.reference(); + RankFeatureNormalizer normalizer = RankFeatureNormalizer.rrank(ref, inputRef, k); + context.rankProfile().addFeatureNormalizer(normalizer); + var newRef = Reference.fromIdentifier(normalizer.name()); + return new ReferenceNode(newRef); + } else { + throw new IllegalArgumentException("the first argument must be a simple feature: " + ref); + } + } +} diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java index 2277491cd47..ff9367c36a9 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxFeatureConverter.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.path.Path; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java index 8797deefcb6..4af4b31a2f8 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.path.Path; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/RankProfileTransformContext.java index 890fa5e7a10..43f02ab238b 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/RankProfileTransformContext.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TensorFlowFeatureConverter.java index e2bef0063ea..186452dd3b6 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 04a31a47190..c044b91bb7e 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/XgboostFeatureConverter.java index b05f9ba9166..e71d97b4fcc 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/XgboostFeatureConverter.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.path.Path; |