aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java25
1 files changed, 21 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 3774e64c886..7fad077ceb2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -3,11 +3,14 @@ package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -15,9 +18,11 @@ import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -122,8 +127,16 @@ public abstract class ModelImporter implements MlModelImporter {
return operation.function();
}
+ private static boolean isImported(IntermediateOperation operation, ImportedModel model) {
+ return model.expressions().containsKey(operation.name()); // test for others?
+ }
+
private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
+ operation.inputs().forEach(input -> {
+ if ( ! isImported(operation, model)) {
+ importExpression(input, model);
+ }
+ });
}
private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
@@ -206,18 +219,22 @@ public abstract class ModelImporter implements MlModelImporter {
private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
+ reportWarnings(graph.get(outputName), model, new HashSet<String>());
}
}
}
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> reported) {
+ if (reported.contains(operation.name())) {
+ return;
+ }
for (String warning : operation.warnings()) {
// If we want to report warnings, that code goes here
}
for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
+ reportWarnings(input, model, reported);
}
+ reported.add(operation.name());
}
/**