summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java40
1 files changed, 28 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ffc64c38f16..d14ad033a69 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Gather;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
@@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Split;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Tile;
+import ai.vespa.rankingexpression.importer.operations.Transpose;
import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
@@ -53,19 +57,21 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
+ IntermediateGraph graph,
+ int outputIndex) {
String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
- return mapOperation(type, inputs, modelName, nodeName, attributes);
+ return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex);
}
static IntermediateOperation mapOperation(String opType,
List<IntermediateOperation> inputs,
String modelName,
String nodeName,
- AttributeConverter attributes) {
+ AttributeConverter attributes,
+ int outputIndex) {
switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
@@ -115,12 +121,15 @@ class GraphImporter {
case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "tile": return new Tile(modelName, nodeName, inputs);
+ case "transpose": return new Transpose(modelName, nodeName, inputs, attributes);
case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
@@ -168,11 +177,11 @@ class GraphImporter {
OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
-
} else {
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ int outputIndex = getOutputIndex(node, name);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph, outputIndex);
// propagate constant values if all inputs are constant
if (operation.isConstant()) {
@@ -185,7 +194,7 @@ class GraphImporter {
}
}
intermediateGraph.put(operation.name(), operation);
-
+ intermediateGraph.put(name, operation);
return operation;
}
@@ -296,6 +305,10 @@ class GraphImporter {
return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
+ private static int getOutputIndex(Onnx.NodeProto node, String outputName) {
+ return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0);
+ }
+
private static String getNodeName(Onnx.NodeProto node) {
String nodeName = node.getName();
if (nodeName.length() > 0)
@@ -307,11 +320,14 @@ class GraphImporter {
}
private static Set<String> getWarnings(IntermediateOperation op) {
- Set<String> warnings = new HashSet<>(op.warnings());
- for (IntermediateOperation input : op.inputs()) {
- warnings.addAll(getWarnings(input));
- }
- return warnings;
+ java.util.Map<String, Set<String>> warnings = new HashMap<>();
+ getWarnings(op, warnings);
+ return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
}
+ private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) {
+ if (warnings.containsKey(op.name())) return;
+ op.inputs().forEach(input -> getWarnings(input, warnings));
+ warnings.put(op.name(), new HashSet<>(op.warnings()));
+ }
}