diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java | 91 |
1 files changed, 54 insertions, 37 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 714953fbd45..280fe354149 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,11 +2,16 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.Gemm; +import ai.vespa.rankingexpression.importer.operations.OnnxConcat; +import ai.vespa.rankingexpression.importer.operations.Reduce; +import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Squeeze; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Argument; -import ai.vespa.rankingexpression.importer.operations.ConcatV2; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.Identity; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; @@ -36,6 +41,7 @@ class GraphImporter { IntermediateGraph graph) { String modelName = graph.name(); String nodeName = getNodeName(node); + AttributeConverter attributes = AttributeConverter.convert(node); switch (node.getOpType().toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); @@ -44,13 +50,14 @@ class GraphImporter { case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); - case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); @@ -63,15 +70,21 @@ class GraphImporter { case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu()); case "shape": return new Shape(modelName, nodeName, inputs); - case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "softmax": return new Softmax(modelName, nodeName, inputs); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); + case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } @@ -125,16 +138,25 @@ class GraphImporter { List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); operation = mapOperation(node, inputs, intermediateGraph); + // propagate constant values if all inputs are constant + if (operation.isConstant()) { + operation.setConstantValueFunction(operation::evaluateAsConstant); + } + if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + .put(IntermediateOperation.namePartOf(name), operation.name()); } } - intermediateGraph.put(operation.vespaName(), operation); + intermediateGraph.put(operation.name(), operation); return operation; } + // Rules for initializers in ONNX: + // When an initializer has the same name as a graph input, it specifies a default value for that input. + // When an initializer has a name different from all graph inputs, it specifies a constant value. + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); @@ -142,9 +164,7 @@ class GraphImporter { } private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; + return getConstantTensor(name, graph) != null; } private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { @@ -191,46 +211,43 @@ class GraphImporter { } private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { - for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) { - IntermediateOperation operation = intermediateGraph.get(outputName); - Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph); + for (java.util.Map.Entry<String, String> output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) { + IntermediateOperation operation = intermediateGraph.get(output.getValue()); + Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph); OrderedTensorType type = operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); + () -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type.")); TypeConverter.verifyType(onnxNode.getType(), type); } } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - Optional<Onnx.NodeProto> node; - if (nodeName.contains(":")) { - node = getNodeFromGraphOutputs(nodeName, graph); - } else { - node = getNodeFromGraphNames(nodeName, graph); - if (node.isEmpty()) { - node = getNodeFromGraphOutputs(nodeName, graph); - } - } - return node.orElseThrow(() -> new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph")); + Optional<Onnx.NodeProto> node = getNodeFromGraphNames(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphOutputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + node = getNodeFromGraphInputs(nodeName, graph); + if (node.isPresent()) + return node.get(); + + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); } private static Optional<Onnx.NodeProto> getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return Optional.of(node); - } - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> + node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); + } + + private static Optional<Onnx.NodeProto> getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) { + return graph.getNodeList().stream().filter(node -> + node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); } private static Optional<Onnx.NodeProto> getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) { - for (Onnx.NodeProto node : graph.getNodeList()) { - if (node.getName().equals(nodeName)) { - return Optional.of(node); - } - } - return Optional.empty(); + return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } private static String getNodeName(Onnx.NodeProto node) { |