aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 10:58:41 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 10:58:41 +0100
commit2a48b55bed1ce90bcf33e04579d00d7b8c993d5e (patch)
treec4f0bb1a16e33ad85645261fe8eb4e116929a4aa /model-integration
parentf3e934cdeae3fceb6bf952dde2f5b0b90b02bfa7 (diff)
Insert correct names in intermediate graph for ONNX to avoid re-importing
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java59
1 files changed, 29 insertions, 30 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index c60a9b85d10..b670eca9183 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -112,7 +112,7 @@ class GraphImporter {
operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
intermediateGraph.inputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ .put(IntermediateOperation.namePartOf(name), operation.name());
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
@@ -132,14 +132,18 @@ class GraphImporter {
if (isOutputNode(name, onnxGraph)) {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ .put(IntermediateOperation.namePartOf(name), operation.name());
}
}
- intermediateGraph.put(operation.vespaName(), operation);
+ intermediateGraph.put(operation.name(), operation);
return operation;
}
+ // Rules for initializers in ONNX:
+ // When an initializer has the same name as a graph input, it specifies a default value for that input.
+ // When an initializer has a name different from all graph inputs, it specifies a constant value.
+
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
@@ -147,9 +151,7 @@ class GraphImporter {
}
private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
+ return getConstantTensor(name, graph) != null;
}
private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
@@ -206,36 +208,33 @@ class GraphImporter {
}
private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- Optional<Onnx.NodeProto> node;
- if (nodeName.contains(":")) {
- node = getNodeFromGraphOutputs(nodeName, graph);
- } else {
- node = getNodeFromGraphNames(nodeName, graph);
- if (node.isEmpty()) {
- node = getNodeFromGraphOutputs(nodeName, graph);
- }
- }
- return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"));
+ Optional<Onnx.NodeProto> node = getNodeFromGraphNames(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ node = getNodeFromGraphOutputs(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ node = getNodeFromGraphInputs(nodeName, graph);
+ if (node.isPresent())
+ return node.get();
+
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
}
private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) {
- for (Onnx.NodeProto node : graph.getNodeList()) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return Optional.of(node);
- }
- }
- }
- return Optional.empty();
+ return graph.getNodeList().stream().filter(node ->
+ node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst();
+ }
+
+ private static Optional<Onnx.NodeProto> getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) {
+ return graph.getNodeList().stream().filter(node ->
+ node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst();
}
private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) {
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (node.getName().equals(nodeName)) {
- return Optional.of(node);
- }
- }
- return Optional.empty();
+ return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
private static String getNodeName(Onnx.NodeProto node) {