From 35995299077105e10f08187c508d5ffafaef50b6 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 7 May 2019 14:55:45 +0200 Subject: ONNX: use node output as name if not explicitly set --- .../importer/onnx/GraphImporter.java | 52 +++++++++++++++++----- .../importer/onnx/SimpleImportTestCase.java | 41 +++++++++++++++++ .../src/test/models/onnx/simple/simple.onnx | 23 ++++++++++ .../src/test/models/onnx/simple/simple.py | 32 +++++++++++++ 4 files changed, 137 insertions(+), 11 deletions(-) create mode 100644 model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java create mode 100644 model-integration/src/test/models/onnx/simple/simple.onnx create mode 100755 model-integration/src/test/models/onnx/simple/simple.py (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index a469e666d93..419bc7ddf28 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -33,8 +34,8 @@ class GraphImporter { private static IntermediateOperation mapOperation(Onnx.NodeProto node, List inputs, IntermediateGraph graph) { - String nodeName = node.getName(); String modelName = graph.name(); + String nodeName = getNodeName(node); switch (node.getOpType().toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); @@ -74,7 +75,7 @@ class GraphImporter { case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } - IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); + IntermediateOperation op = new NoOp(modelName, nodeName, inputs); op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); return op; } @@ -199,18 +200,47 @@ class GraphImporter { } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - boolean hasPortNumber = nodeName.contains(":"); + Optional node; + if (nodeName.contains(":")) { + node = getNodeFromGraphOutputs(nodeName, graph); + } else { + node = getNodeFromGraphNames(nodeName, graph); + if (node.isEmpty()) { + node = getNodeFromGraphOutputs(nodeName, graph); + } + } + return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph")); + } + + private static Optional getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) { for (Onnx.NodeProto node : graph.getNodeList()) { - if (hasPortNumber) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return node; - } + for (String outputName : node.getOutputList()) { + if (outputName.equals(nodeName)) { + return Optional.of(node); } - } else if (node.getName().equals(nodeName)) { - return node; } } - throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); + return Optional.empty(); + } + + private static Optional getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) { + for (Onnx.NodeProto node : graph.getNodeList()) { + if (node.getName().equals(nodeName)) { + return Optional.of(node); + } + } + return Optional.empty(); + } + + private static String getNodeName(Onnx.NodeProto node) { + String nodeName = node.getName(); + if (nodeName.length() > 0) + return nodeName; + if (node.getOutputCount() == 1) + return node.getOutput(0); + throw new IllegalArgumentException("Unable to find a suitable name for node '" + node.toString() + "'. " + + "Either no explicit name given or no single output name."); } + + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java new file mode 100644 index 00000000000..d1dea730da5 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -0,0 +1,41 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class SimpleImportTestCase { + + @Test + public void testSimpleOnnxModelImport() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); + + MapContext context = new MapContext(); + context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")). + cell(0.1, 0, 0). + cell(0.2, 0, 1). + cell(0.3, 0, 2). + cell(0.4, 0, 3).build())); + context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")). + cell(0.1, 0, 0). + cell(0.2, 1, 0). + cell(0.3, 2, 0). + cell(0.4, 3, 0).build())); + context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")). + cell(1.0, 0, 0).build())); + + Tensor result = model.expressions().get("output").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}")); + } + +} diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx new file mode 100644 index 00000000000..1c746c90efa --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/simple.onnx @@ -0,0 +1,23 @@ + simple.py:ã +0 + query_tensor +attribute_tensormatmul"MatMul +" +matmul + bias_tensoroutput"addsimple_scoringZ + query_tensor +  + +Z" +attribute_tensor +  + +Z + bias_tensor +  + +b +output +  + +B diff --git a/model-integration/src/test/models/onnx/simple/simple.py b/model-integration/src/test/models/onnx/simple/simple.py new file mode 100755 index 00000000000..4471ed812b8 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/simple.py @@ -0,0 +1,32 @@ +import onnx +from onnx import helper, TensorProto + +QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, [1, 4]) +ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1]) +BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, [1, 1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1]) + +nodes = [ + helper.make_node( + 'MatMul', + ['query_tensor', 'attribute_tensor'], + ['matmul'], + ), + helper.make_node( + 'add', + ['matmul', 'bias_tensor'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'simple_scoring', + [ + QUERY_TENSOR, + ATTRIBUTE_TENSOR, + BIAS_TENSOR, + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='simple.py') +onnx.save(model_def, 'simple.onnx') -- cgit v1.2.3