diff options
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java | 59 |
1 files changed, 29 insertions, 30 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index c60a9b85d10..b670eca9183 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -112,7 +112,7 @@ class GraphImporter { operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + .put(IntermediateOperation.namePartOf(name), operation.name()); } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); @@ -132,14 +132,18 @@ class GraphImporter { if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + .put(IntermediateOperation.namePartOf(name), operation.name()); } } - intermediateGraph.put(operation.vespaName(), operation); + intermediateGraph.put(operation.name(), operation); return operation; } + // Rules for initializers in ONNX: + // When an initializer has the same name as a graph input, it specifies a default value for that input. + // When an initializer has a name different from all graph inputs, it specifies a constant value. + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); @@ -147,9 +151,7 @@ class GraphImporter { } private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; + return getConstantTensor(name, graph) != null; } private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { @@ -206,36 +208,33 @@ class GraphImporter { } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - Optional<Onnx.NodeProto> node; - if (nodeName.contains(":")) { - node = getNodeFromGraphOutputs(nodeName, graph); - } else { - node = getNodeFromGraphNames(nodeName, graph); - if (node.isEmpty()) { - node = getNodeFromGraphOutputs(nodeName, graph); - } - } - return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph")); + Optional<Onnx.NodeProto> node = getNodeFromGraphNames(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphOutputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphInputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); } private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return Optional.of(node); - } - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> + node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); + } + + private static Optional<Onnx.NodeProto> getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) { + return graph.getNodeList().stream().filter(node -> + node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); } private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - if (node.getName().equals(nodeName)) { - return Optional.of(node); - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } private static String getNodeName(Onnx.NodeProto node) { |