diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 3774e64c886..7fad077ceb2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -3,11 +3,14 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; @@ -15,9 +18,11 @@ import com.yahoo.text.ExpressionFormatter; import com.yahoo.yolean.Exceptions; import java.io.File; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -122,8 +127,16 @@ public abstract class ModelImporter implements MlModelImporter { return operation.function(); } + private static boolean isImported(IntermediateOperation operation, ImportedModel model) { + return model.expressions().containsKey(operation.name()); // test for others? + } + private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); + operation.inputs().forEach(input -> { + if ( ! isImported(operation, model)) { + importExpression(input, model); + } + }); } private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { @@ -206,18 +219,22 @@ public abstract class ModelImporter implements MlModelImporter { private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { - reportWarnings(graph.get(outputName), model); + reportWarnings(graph.get(outputName), model, new HashSet<String>()); } } } - private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> reported) { + if (reported.contains(operation.name())) { + return; + } for (String warning : operation.warnings()) { // If we want to report warnings, that code goes here } for (IntermediateOperation input : operation.inputs()) { - reportWarnings(input, model); + reportWarnings(input, model, reported); } + reported.add(operation.name()); } /** |