// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.operations.ConstantOfShape; import ai.vespa.rankingexpression.importer.operations.Expand; import ai.vespa.rankingexpression.importer.operations.Gather; import ai.vespa.rankingexpression.importer.operations.OnnxConstant; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Range; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Split; import ai.vespa.rankingexpression.importer.operations.Squeeze; import ai.vespa.rankingexpression.importer.operations.Tile; import ai.vespa.rankingexpression.importer.operations.Transpose; import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Argument; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.Identity; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import ai.vespa.rankingexpression.importer.operations.Join; import ai.vespa.rankingexpression.importer.operations.Map; import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; /** * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis * for generating Vespa ranking expressions. * * @author lesters */ class GraphImporter { private static final Value eluAlpha = DoubleValue.frozen(1.0); private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717); private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946); private static final Value leakyReluAlpha = DoubleValue.frozen(0.01); private static IntermediateOperation mapOperation(Onnx.NodeProto node, List inputs, IntermediateGraph graph, int outputIndex) { String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex); } static IntermediateOperation mapOperation(String opType, List inputs, String modelName, String nodeName, AttributeConverter attributes, int outputIndex) { switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "constant": return new OnnxConstant(modelName, nodeName, inputs, attributes); case "constantofshape": return new ConstantOfShape(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "expand": return new Expand(modelName, nodeName, inputs); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); case "gather": return new Gather(modelName, nodeName, inputs, attributes); case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max); case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min); case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); case "range": return new Range(modelName, nodeName, inputs); case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log()); case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log()); case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max); case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min); case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod); case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble())); case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble())); case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); case "tile": return new Tile(modelName, nodeName, inputs); case "transpose": return new Transpose(modelName, nodeName, inputs, attributes); case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); op.warning("Operation '" + opType + "' is currently not implemented"); return op; } static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { Onnx.GraphProto onnxGraph = model.getGraph(); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); importOperations(onnxGraph, intermediateGraph); verifyNoWarnings(intermediateGraph); verifyOutputTypes(onnxGraph, intermediateGraph); return intermediateGraph; } private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) { importOperation(valueInfo.getName(), onnxGraph, intermediateGraph); } } private static IntermediateOperation importOperation(String name, Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { if (intermediateGraph.alreadyImported(name)) { return intermediateGraph.get(name); } IntermediateOperation operation; if (isArgumentTensor(name, onnxGraph)) { Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); if (valueInfoProto == null) throw new IllegalArgumentException("Could not find argument tensor '" + name + "'"); OrderedTensorType type = TypeConverter.typeFrom(valueInfoProto.getType()); operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.vespaName()); } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); } else { Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); int outputIndex = getOutputIndex(node, name); List inputs = importOperationInputs(node, onnxGraph, intermediateGraph); operation = mapOperation(node, inputs, intermediateGraph, outputIndex); // propagate constant values if all inputs are constant if (operation.isConstant()) { operation.setConstantValueFunction(operation::evaluateAsConstant); } if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.name()); } } intermediateGraph.put(operation.name(), operation); intermediateGraph.put(name, operation); return operation; } // Rules for initializers in ONNX: // When an initializer has the same name as a graph input, it specifies a default value for that input. // When an initializer has a name different from all graph inputs, it specifies a constant value. private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); return value != null && tensor == null; } private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { return getConstantTensor(name, graph) != null; } private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { if (valueInfo.getName().equals(name)) { return valueInfo; } } return null; } private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { if (tensorProto.getName().equals(name)) { return tensorProto; } } return null; } private static boolean isOutputNode(String name, Onnx.GraphProto graph) { return getOutputNode(name, graph) != null; } private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { if (valueInfo.getName().equals(name)) { return valueInfo; } String nodeName = IntermediateOperation.namePartOf(valueInfo.getName()); if (nodeName.equals(name)) { return valueInfo; } } return null; } private static List importOperationInputs(Onnx.NodeProto node, Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { return node.getInputList().stream() .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph)) .toList(); } private static void verifyNoWarnings(IntermediateGraph intermediateGraph) { for (java.util.Map.Entry output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) { IntermediateOperation operation = intermediateGraph.get(output.getValue()); Set warnings = getWarnings(operation); if (warnings.size() > 0) { throw new IllegalArgumentException("Could not import " + intermediateGraph.name() + ": " + String.join("\n", warnings)); } } } private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { for (java.util.Map.Entry output : intermediateGraph.outputs(intermediateGraph.defaultSignature()).entrySet()) { IntermediateOperation operation = intermediateGraph.get(output.getValue()); Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph); OrderedTensorType type = operation.type().orElseThrow( () -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type.")); TypeConverter.verifyType(onnxNode.getType(), type); } } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { Optional node = getNodeFromGraphNames(nodeName, graph); if (node.isPresent()) return node.get(); node = getNodeFromGraphOutputs(nodeName, graph); if (node.isPresent()) return node.get(); node = getNodeFromGraphInputs(nodeName, graph); if (node.isPresent()) return node.get(); throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); } private static Optional getNodeFromGraphOutputs(String nodeName, Onnx.GraphProto graph) { return graph.getNodeList().stream().filter(node -> node.getOutputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); } private static Optional getNodeFromGraphInputs(String nodeName, Onnx.GraphProto graph) { return graph.getNodeList().stream().filter(node -> node.getInputList().stream().anyMatch(name -> name.equals(nodeName))).findFirst(); } private static Optional getNodeFromGraphNames(String nodeName, Onnx.GraphProto graph) { return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } private static int getOutputIndex(Onnx.NodeProto node, String outputName) { return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0); } private static String getNodeName(Onnx.NodeProto node) { String nodeName = node.getName(); if (nodeName.length() > 0) return nodeName; if (node.getOutputCount() == 1) return node.getOutput(0); throw new IllegalArgumentException("Unable to find a suitable name for node '" + node.toString() + "'. " + "Either no explicit name given or no single output name."); } private static Set getWarnings(IntermediateOperation op) { java.util.Map> warnings = new HashMap<>(); getWarnings(op, warnings); return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); } private static void getWarnings(IntermediateOperation op, java.util.Map> warnings) { if (warnings.containsKey(op.name())) return; op.inputs().forEach(input -> getWarnings(input, warnings)); warnings.put(op.name(), new HashSet<>(op.warnings())); } }