diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java | 56 |
1 files changed, 45 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index ffc64c38f16..c98a5c7d4f5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer.onnx; -import ai.vespa.rankingexpression.importer.operations.ExpandDims; import ai.vespa.rankingexpression.importer.operations.Gather; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; @@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Split; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Tile; +import ai.vespa.rankingexpression.importer.operations.Transpose; import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -53,19 +57,21 @@ class GraphImporter { private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, - IntermediateGraph graph) { + IntermediateGraph graph, + int outputIndex) { String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); - return mapOperation(type, inputs, modelName, nodeName, attributes); + return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex); } static IntermediateOperation mapOperation(String opType, List<IntermediateOperation> inputs, String modelName, String nodeName, - AttributeConverter attributes) { + AttributeConverter attributes, + int outputIndex) { switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); @@ -115,17 +121,21 @@ class GraphImporter { case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "tile": return new Tile(modelName, nodeName, inputs); + case "transpose": return new Transpose(modelName, nodeName, inputs, attributes); case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); op.warning("Operation '" + opType + "' is currently not implemented"); + System.out.println(nodeName + ": operation '" + opType + "' is currently not implemented"); return op; } @@ -133,10 +143,15 @@ class GraphImporter { Onnx.GraphProto onnxGraph = model.getGraph(); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + System.out.println("Importing operations..."); importOperations(onnxGraph, intermediateGraph); + System.out.println("Verifying no warnings..."); verifyNoWarnings(intermediateGraph); + System.out.println("Verifying output types..."); verifyOutputTypes(onnxGraph, intermediateGraph); + System.out.println("Ok..."); + return intermediateGraph; } @@ -150,8 +165,10 @@ class GraphImporter { Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { if (intermediateGraph.alreadyImported(name)) { +// System.out.println("Trying to import '" + name + "' but is was already imported."); return intermediateGraph.get(name); } +// System.out.println("Importing '" + name + "' ..."); IntermediateOperation operation; if (isArgumentTensor(name, onnxGraph)) { Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); @@ -163,16 +180,21 @@ class GraphImporter { intermediateGraph.inputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.vespaName()); +// System.out.println(" '" + name + "' imported as argument..."); + } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); +// System.out.println(" '" + name + "' imported as constant..."); + } else { Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + int outputIndex = getOutputIndex(node, name); List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph, outputIndex); // propagate constant values if all inputs are constant if (operation.isConstant()) { @@ -183,8 +205,12 @@ class GraphImporter { intermediateGraph.outputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.name()); } + +// System.out.println(" '" + name + "' imported as normal..."); + } intermediateGraph.put(operation.name(), operation); + intermediateGraph.put(name, operation); return operation; } @@ -262,7 +288,8 @@ class GraphImporter { Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph); OrderedTensorType type = operation.type().orElseThrow( () -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type.")); - TypeConverter.verifyType(onnxNode.getType(), type); + System.out.println(onnxNode.getType() + " vs. " + type); + //TypeConverter.verifyType(onnxNode.getType(), type); } } @@ -296,6 +323,10 @@ class GraphImporter { return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } + private static int getOutputIndex(Onnx.NodeProto node, String outputName) { + return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0); + } + private static String getNodeName(Onnx.NodeProto node) { String nodeName = node.getName(); if (nodeName.length() > 0) @@ -307,11 +338,14 @@ class GraphImporter { } private static Set<String> getWarnings(IntermediateOperation op) { - Set<String> warnings = new HashSet<>(op.warnings()); - for (IntermediateOperation input : op.inputs()) { - warnings.addAll(getWarnings(input)); - } - return warnings; + java.util.Map<String, Set<String>> warnings = new HashMap<>(); + getWarnings(op, warnings); + return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); } + private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) { + if (warnings.containsKey(op.name())) return; + op.inputs().forEach(input -> getWarnings(input, warnings)); + warnings.put(op.name(), new HashSet<>(op.warnings())); + } } |