summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java14
1 files changed, 9 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index af072c5b59a..7f578f07fe3 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -2,7 +2,6 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
-import com.yahoo.schema.RankProfile;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.functions.DynamicTensor;
import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Slice;
import java.io.StringReader;
import java.util.HashSet;
import java.util.Set;
+import java.util.logging.Logger;
/**
* Analyzes expression to figure out what inputs it needs
@@ -27,6 +25,8 @@ import java.util.Set;
*/
public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
+ private static final Logger log = Logger.getLogger(InputRecorder.class.getName());
+
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
- for (String onnxInput : model.getInputMap().values()) {
+ model.getInputMap().forEach((onnxName, onnxInput) -> {
+ if (model.getInitializers().contains(onnxName)) {
+ log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName));
+ return;
+ }
var reader = new StringReader(onnxInput);
try {
var asExpression = new RankingExpression(reader);
@@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
} catch (ParseException e) {
throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage());
}
- }
+ });
return;
}
neededInputs.add(feature.toString());