aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java39
1 files changed, 34 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index b30873fabee..ab18f9c83db 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -1,4 +1,4 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
@@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.functions.Generate;
import java.io.StringReader;
+import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;
@@ -29,19 +30,35 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
+ private final Set<String> availableNormalizers = new HashSet<>();
+ private final Set<String> usedNormalizers = new HashSet<>();
public InputRecorder(Set<String> target) {
this.neededInputs = target;
}
public void process(RankingExpression expression, RankProfileTransformContext context) {
- transform(expression.getRoot(), new InputRecorderContext(context));
+ process(expression.getRoot(), context);
}
- public void alreadyHandled(String name) {
- handled.add(name);
+ public void process(ExpressionNode node, RankProfileTransformContext context) {
+ transform(node, new InputRecorderContext(context));
}
+ public void alreadyMatchFeatures(Collection<String> matchFeatures) {
+ for (String mf : matchFeatures) {
+ handled.add(mf);
+ }
+ }
+
+ public void addKnownNormalizers(Collection<String> names) {
+ for (String name : names) {
+ availableNormalizers.add(name);
+ }
+ }
+
+ public Set<String> normalizersUsed() { return this.usedNormalizers; }
+
@Override
public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) {
if (node instanceof ReferenceNode r) {
@@ -77,6 +94,10 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) {
return;
}
+ if (simpleFunctionOrIdentifier && availableNormalizers.contains(name)) {
+ usedNormalizers.add(name);
+ return;
+ }
if (ref.isSimpleRankingExpressionWrapper()) {
name = ref.simpleArgument().get();
simpleFunctionOrIdentifier = true;
@@ -113,13 +134,21 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
}
}
if ("onnx".equals(name)) {
- if (args.size() != 1) {
+ if (args.size() < 1) {
throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature);
}
var arg = args.expressions().get(0);
var models = context.rankProfile().onnxModels();
var model = models.get(arg.toString());
if (model == null) {
+ var tmp = OnnxModelTransformer.transformFeature(feature, context.rankProfile());
+ if (tmp instanceof ReferenceNode newRefNode) {
+ args = newRefNode.getArguments();
+ arg = args.expressions().get(0);
+ model = models.get(arg.toString());
+ }
+ }
+ if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
model.getInputMap().forEach((__, onnxInput) -> {