aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java56
1 files changed, 45 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ffc64c38f16..c98a5c7d4f5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Gather;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
@@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Split;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Tile;
+import ai.vespa.rankingexpression.importer.operations.Transpose;
import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
@@ -53,19 +57,21 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
+ IntermediateGraph graph,
+ int outputIndex) {
String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
- return mapOperation(type, inputs, modelName, nodeName, attributes);
+ return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex);
}
static IntermediateOperation mapOperation(String opType,
List<IntermediateOperation> inputs,
String modelName,
String nodeName,
- AttributeConverter attributes) {
+ AttributeConverter attributes,
+ int outputIndex) {
switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
@@ -115,17 +121,21 @@ class GraphImporter {
case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "tile": return new Tile(modelName, nodeName, inputs);
+ case "transpose": return new Transpose(modelName, nodeName, inputs, attributes);
case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
op.warning("Operation '" + opType + "' is currently not implemented");
+ System.out.println(nodeName + ": operation '" + opType + "' is currently not implemented");
return op;
}
@@ -133,10 +143,15 @@ class GraphImporter {
Onnx.GraphProto onnxGraph = model.getGraph();
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ System.out.println("Importing operations...");
importOperations(onnxGraph, intermediateGraph);
+ System.out.println("Verifying no warnings...");
verifyNoWarnings(intermediateGraph);
+ System.out.println("Verifying output types...");
verifyOutputTypes(onnxGraph, intermediateGraph);
+ System.out.println("Ok...");
+
return intermediateGraph;
}
@@ -150,8 +165,10 @@ class GraphImporter {
Onnx.GraphProto onnxGraph,
IntermediateGraph intermediateGraph) {
if (intermediateGraph.alreadyImported(name)) {
+// System.out.println("Trying to import '" + name + "' but is was already imported.");
return intermediateGraph.get(name);
}
+// System.out.println("Importing '" + name + "' ...");
IntermediateOperation operation;
if (isArgumentTensor(name, onnxGraph)) {
Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
@@ -163,16 +180,21 @@ class GraphImporter {
intermediateGraph.inputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.vespaName());
+// System.out.println(" '" + name + "' imported as argument...");
+
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+// System.out.println(" '" + name + "' imported as constant...");
+
} else {
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ int outputIndex = getOutputIndex(node, name);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph, outputIndex);
// propagate constant values if all inputs are constant
if (operation.isConstant()) {
@@ -183,8 +205,12 @@ class GraphImporter {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.name());
}
+
+// System.out.println(" '" + name + "' imported as normal...");
+
}
intermediateGraph.put(operation.name(), operation);
+ intermediateGraph.put(name, operation);
return operation;
}
@@ -262,7 +288,8 @@ class GraphImporter {
Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph);
OrderedTensorType type = operation.type().orElseThrow(
() -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type."));
- TypeConverter.verifyType(onnxNode.getType(), type);
+ System.out.println(onnxNode.getType() + " vs. " + type);
+ //TypeConverter.verifyType(onnxNode.getType(), type);
}
}
@@ -296,6 +323,10 @@ class GraphImporter {
return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
+ private static int getOutputIndex(Onnx.NodeProto node, String outputName) {
+ return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0);
+ }
+
private static String getNodeName(Onnx.NodeProto node) {
String nodeName = node.getName();
if (nodeName.length() > 0)
@@ -307,11 +338,14 @@ class GraphImporter {
}
private static Set<String> getWarnings(IntermediateOperation op) {
- Set<String> warnings = new HashSet<>(op.warnings());
- for (IntermediateOperation input : op.inputs()) {
- warnings.addAll(getWarnings(input));
- }
- return warnings;
+ java.util.Map<String, Set<String>> warnings = new HashMap<>();
+ getWarnings(op, warnings);
+ return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
}
+ private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) {
+ if (warnings.containsKey(op.name())) return;
+ op.inputs().forEach(input -> getWarnings(input, warnings));
+ warnings.put(op.name(), new HashSet<>(op.warnings()));
+ }
}