summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-05-07 14:55:45 +0200
committerLester Solbakken <lesters@oath.com>2019-05-07 14:55:45 +0200
commit35995299077105e10f08187c508d5ffafaef50b6 (patch)
treef9666671404895421bbd0b5ed38ecad6541acace /model-integration
parent50ccc80a1242a1013c0510e11ff3f8fc18f2ded0 (diff)
ONNX: use node output as name if not explicitly set
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java52
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java41
-rw-r--r--model-integration/src/test/models/onnx/simple/simple.onnx23
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/simple.py32
4 files changed, 137 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index a469e666d93..419bc7ddf28 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
import java.util.List;
+import java.util.Optional;
import java.util.stream.Collectors;
/**
@@ -33,8 +34,8 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
- String nodeName = node.getName();
String modelName = graph.name();
+ String nodeName = getNodeName(node);
switch (node.getOpType().toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
@@ -74,7 +75,7 @@ class GraphImporter {
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
return op;
}
@@ -199,18 +200,47 @@ class GraphImporter {
}
private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
+ Optional<Onnx.NodeProto> node;
+ if (nodeName.contains(":")) {
+ node = getNodeFromGraphOutputs(nodeName, graph);
+ } else {
+ node = getNodeFromGraphNames(nodeName, graph);
+ if (node.isEmpty()) {
+ node = getNodeFromGraphOutputs(nodeName, graph);
+ }
+ }
+ return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"));
+ }
+
+ private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) {
for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return Optional.of(node);
}
- } else if (node.getName().equals(nodeName)) {
- return node;
}
}
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ return Optional.empty();
+ }
+
+ private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) {
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (node.getName().equals(nodeName)) {
+ return Optional.of(node);
+ }
+ }
+ return Optional.empty();
+ }
+
+ private static String getNodeName(Onnx.NodeProto node) {
+ String nodeName = node.getName();
+ if (nodeName.length() > 0)
+ return nodeName;
+ if (node.getOutputCount() == 1)
+ return node.getOutput(0);
+ throw new IllegalArgumentException("Unable to find a suitable name for node '" + node.toString() + "'. " +
+ "Either no explicit name given or no single output name.");
}
+
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
new file mode 100644
index 00000000000..d1dea730da5
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -0,0 +1,41 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class SimpleImportTestCase {
+
+ @Test
+ public void testSimpleOnnxModelImport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
+
+ MapContext context = new MapContext();
+ context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")).
+ cell(0.1, 0, 0).
+ cell(0.2, 0, 1).
+ cell(0.3, 0, 2).
+ cell(0.4, 0, 3).build()));
+ context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")).
+ cell(0.1, 0, 0).
+ cell(0.2, 1, 0).
+ cell(0.3, 2, 0).
+ cell(0.4, 3, 0).build()));
+ context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")).
+ cell(1.0, 0, 0).build()));
+
+ Tensor result = model.expressions().get("output").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}"));
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx
new file mode 100644
index 00000000000..1c746c90efa
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/simple.onnx
@@ -0,0 +1,23 @@
+ simple.py:ã
+0
+ query_tensor
+attribute_tensormatmul"MatMul
+"
+matmul
+ bias_tensoroutput"addsimple_scoringZ
+ query_tensor
+ 
+
+Z"
+attribute_tensor
+ 
+
+Z
+ bias_tensor
+ 
+
+b
+output
+ 
+
+B
diff --git a/model-integration/src/test/models/onnx/simple/simple.py b/model-integration/src/test/models/onnx/simple/simple.py
new file mode 100755
index 00000000000..4471ed812b8
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/simple.py
@@ -0,0 +1,32 @@
+import onnx
+from onnx import helper, TensorProto
+
+QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, [1, 4])
+ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1])
+BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, [1, 1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['query_tensor', 'attribute_tensor'],
+ ['matmul'],
+ ),
+ helper.make_node(
+ 'add',
+ ['matmul', 'bias_tensor'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'simple_scoring',
+ [
+ QUERY_TENSOR,
+ ATTRIBUTE_TENSOR,
+ BIAS_TENSOR,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='simple.py')
+onnx.save(model_def, 'simple.onnx')