diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 10:58:41 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 10:58:41 +0100 |
commit | 2a48b55bed1ce90bcf33e04579d00d7b8c993d5e (patch) | |
tree | c4f0bb1a16e33ad85645261fe8eb4e116929a4aa /model-integration | |
parent | f3e934cdeae3fceb6bf952dde2f5b0b90b02bfa7 (diff) |
Insert correct names in intermediate graph for ONNX to avoid re-importing
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java | 59 |
1 files changed, 29 insertions, 30 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index c60a9b85d10..b670eca9183 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -112,7 +112,7 @@ class GraphImporter { operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + .put(IntermediateOperation.namePartOf(name), operation.name()); } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); @@ -132,14 +132,18 @@ class GraphImporter { if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + .put(IntermediateOperation.namePartOf(name), operation.name()); } } - intermediateGraph.put(operation.vespaName(), operation); + intermediateGraph.put(operation.name(), operation); return operation; } + // Rules for initializers in ONNX: + // When an initializer has the same name as a graph input, it specifies a default value for that input. + // When an initializer has a name different from all graph inputs, it specifies a constant value. + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); @@ -147,9 +151,7 @@ class GraphImporter { } private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; + return getConstantTensor(name, graph) != null; } private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { @@ -206,36 +208,33 @@ class GraphImporter { } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - Optional<Onnx.NodeProto> node; - if (nodeName.contains(":")) { - node = getNodeFromGraphOutputs(nodeName, graph); - } else { - node = getNodeFromGraphNames(nodeName, graph); - if (node.isEmpty()) { - node = getNodeFromGraphOutputs(nodeName, graph); - } - } - return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph")); + Optional<Onnx.NodeProto> node = getNodeFromGraphNames(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphOutputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphInputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); } private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return Optional.of(node); - } - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> + node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); + } + + private static Optional<Onnx.NodeProto> getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) { + return graph.getNodeList().stream().filter(node -> + node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); } private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - if (node.getName().equals(nodeName)) { - return Optional.of(node); - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } private static String getNodeName(Onnx.NodeProto node) { |