aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java91
1 files changed, 54 insertions, 37 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 714953fbd45..280fe354149 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,11 +2,16 @@
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.operations.Gemm;
+import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
+import ai.vespa.rankingexpression.importer.operations.Reduce;
+import ai.vespa.rankingexpression.importer.operations.Select;
+import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Squeeze;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.Argument;
-import ai.vespa.rankingexpression.importer.operations.ConcatV2;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.Identity;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
@@ -36,6 +41,7 @@ class GraphImporter {
IntermediateGraph graph) {
String modelName = graph.name();
String nodeName = getNodeName(node);
+ AttributeConverter attributes = AttributeConverter.convert(node);
switch (node.getOpType().toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
@@ -44,13 +50,14 @@ class GraphImporter {
case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
- case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "gemm": return new Gemm(modelName, nodeName, inputs, attributes);
case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
case "identity": return new Identity(modelName, nodeName, inputs);
case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
@@ -63,15 +70,21 @@ class GraphImporter {
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu());
case "shape": return new Shape(modelName, nodeName, inputs);
- case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "softmax": return new Softmax(modelName, nodeName, inputs);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
+ case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
@@ -125,16 +138,25 @@ class GraphImporter {
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
operation = mapOperation(node, inputs, intermediateGraph);
+ // propagate constant values if all inputs are constant
+ if (operation.isConstant()) {
+ operation.setConstantValueFunction(operation::evaluateAsConstant);
+ }
+
if (isOutputNode(name, onnxGraph)) {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ .put(IntermediateOperation.namePartOf(name), operation.name());
}
}
- intermediateGraph.put(operation.vespaName(), operation);
+ intermediateGraph.put(operation.name(), operation);
return operation;
}
+ // Rules for initializers in ONNX:
+ // When an initializer has the same name as a graph input, it specifies a default value for that input.
+ // When an initializer has a name different from all graph inputs, it specifies a constant value.
+
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
@@ -142,9 +164,7 @@ class GraphImporter {
}
private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
+ return getConstantTensor(name, graph) != null;
}
private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
@@ -191,46 +211,43 @@ class GraphImporter {
}
private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
- for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
+ for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) {
+ IntermediateOperation operation = intermediateGraph.get(output.getValue());
+ Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph);
OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ () -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type."));
TypeConverter.verifyType(onnxNode.getType(), type);
}
}
private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- Optional<Onnx.NodeProto> node;
- if (nodeName.contains(":")) {
- node = getNodeFromGraphOutputs(nodeName, graph);
- } else {
- node = getNodeFromGraphNames(nodeName, graph);
- if (node.isEmpty()) {
- node = getNodeFromGraphOutputs(nodeName, graph);
- }
- }
- return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"));
+ Optional<Onnx.NodeProto> node = getNodeFromGraphNames(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ node = getNodeFromGraphOutputs(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ node = getNodeFromGraphInputs(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
}
private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) {
- for (Onnx.NodeProto node : graph.getNodeList()) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return Optional.of(node);
- }
- }
- }
- return Optional.empty();
+ return graph.getNodeList().stream().filter(node ->
+ node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst();
+ }
+
+ private static Optional<Onnx.NodeProto> getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) {
+ return graph.getNodeList().stream().filter(node ->
+ node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst();
}
private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) {
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (node.getName().equals(nodeName)) {
- return Optional.of(node);
- }
- }
- return Optional.empty();
+ return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
private static String getNodeName(Onnx.NodeProto node) {