diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java | 40 |
1 files changed, 28 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index ffc64c38f16..d14ad033a69 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer.onnx; -import ai.vespa.rankingexpression.importer.operations.ExpandDims; import ai.vespa.rankingexpression.importer.operations.Gather; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; @@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Split; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Tile; +import ai.vespa.rankingexpression.importer.operations.Transpose; import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -53,19 +57,21 @@ class GraphImporter { private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, - IntermediateGraph graph) { + IntermediateGraph graph, + int outputIndex) { String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); - return mapOperation(type, inputs, modelName, nodeName, attributes); + return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex); } static IntermediateOperation mapOperation(String opType, List<IntermediateOperation> inputs, String modelName, String nodeName, - AttributeConverter attributes) { + AttributeConverter attributes, + int outputIndex) { switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); @@ -115,12 +121,15 @@ class GraphImporter { case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "tile": return new Tile(modelName, nodeName, inputs); + case "transpose": return new Transpose(modelName, nodeName, inputs, attributes); case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } @@ -168,11 +177,11 @@ class GraphImporter { OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); - } else { Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + int outputIndex = getOutputIndex(node, name); List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph, outputIndex); // propagate constant values if all inputs are constant if (operation.isConstant()) { @@ -185,7 +194,7 @@ class GraphImporter { } } intermediateGraph.put(operation.name(), operation); - + intermediateGraph.put(name, operation); return operation; } @@ -296,6 +305,10 @@ class GraphImporter { return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } + private static int getOutputIndex(Onnx.NodeProto node, String outputName) { + return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0); + } + private static String getNodeName(Onnx.NodeProto node) { String nodeName = node.getName(); if (nodeName.length() > 0) @@ -307,11 +320,14 @@ class GraphImporter { } private static Set<String> getWarnings(IntermediateOperation op) { - Set<String> warnings = new HashSet<>(op.warnings()); - for (IntermediateOperation input : op.inputs()) { - warnings.addAll(getWarnings(input)); - } - return warnings; + java.util.Map<String, Set<String>> warnings = new HashMap<>(); + getWarnings(op, warnings); + return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); } + private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) { + if (warnings.containsKey(op.name())) return; + op.inputs().forEach(input -> getWarnings(input, warnings)); + warnings.put(op.name(), new HashSet<>(op.warnings())); + } } |