diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 39 |
1 files changed, 34 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index b30873fabee..ab18f9c83db 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; @@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.functions.Generate; import java.io.StringReader; +import java.util.Collection; import java.util.HashSet; import java.util.Set; import java.util.logging.Logger; @@ -29,19 +30,35 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); + private final Set<String> availableNormalizers = new HashSet<>(); + private final Set<String> usedNormalizers = new HashSet<>(); public InputRecorder(Set<String> target) { this.neededInputs = target; } public void process(RankingExpression expression, RankProfileTransformContext context) { - transform(expression.getRoot(), new InputRecorderContext(context)); + process(expression.getRoot(), context); } - public void alreadyHandled(String name) { - handled.add(name); + public void process(ExpressionNode node, RankProfileTransformContext context) { + transform(node, new InputRecorderContext(context)); } + public void alreadyMatchFeatures(Collection<String> matchFeatures) { + for (String mf : matchFeatures) { + handled.add(mf); + } + } + + public void addKnownNormalizers(Collection<String> names) { + for (String name : names) { + availableNormalizers.add(name); + } + } + + public Set<String> normalizersUsed() { return this.usedNormalizers; } + @Override public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { @@ -77,6 +94,10 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { return; } + if (simpleFunctionOrIdentifier && availableNormalizers.contains(name)) { + usedNormalizers.add(name); + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; @@ -113,13 +134,21 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } } if ("onnx".equals(name)) { - if (args.size() != 1) { + if (args.size() < 1) { throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature); } var arg = args.expressions().get(0); var models = context.rankProfile().onnxModels(); var model = models.get(arg.toString()); if (model == null) { + var tmp = OnnxModelTransformer.transformFeature(feature, context.rankProfile()); + if (tmp instanceof ReferenceNode newRefNode) { + args = newRefNode.getArguments(); + arg = args.expressions().get(0); + model = models.get(arg.toString()); + } + } + if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } model.getInputMap().forEach((__, onnxInput) -> { |