diff options
author | Lester Solbakken <lesters@oath.com> | 2018-05-28 11:33:17 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-05-28 11:33:17 +0200 |
commit | 88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch) | |
tree | d2a871f2e6870daadf674fca0f350692cbdc42a3 /searchlib | |
parent | 3c1334090cef6fb0891515040ad900702275ccea (diff) |
Add ONNX pseudo ranking feature
Diffstat (limited to 'searchlib')
5 files changed, 177 insertions, 52 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java index 047d1b187f5..295f9228316 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java @@ -13,8 +13,10 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; import onnx.Onnx; +import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.Collection; @@ -22,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -31,48 +34,64 @@ import java.util.stream.Collectors; */ public class OnnxImporter { - public OnnxModel importModel(String modelPath, String outputNode) { + private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); + + public OnnxModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + public OnnxModel importModel(String modelName, String modelPath) { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - return importModel(model, outputNode); + return importModel(modelName, model); } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } } - public OnnxModel importModel(Onnx.ModelProto model, String outputNode) { - return importGraph(model.getGraph(), outputNode); + public OnnxModel importModel(String modelName, Onnx.ModelProto model) { + return importGraph(modelName, model.getGraph()); } - private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) { - OnnxModel model = new OnnxModel(outputNode); + private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) { + OnnxModel model = new OnnxModel(modelName); OperationIndex index = new OperationIndex(); - OnnxOperation output = importNode(outputNode, graph, index); - output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type.")) - .verifyType(getOutputNode(outputNode, graph).getType()); + importNodes(graph, model, index); + verifyOutputTypes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index); - findDimensionNames(output); - importExpressions(output, model); + reportWarnings(model, index); return model; } - private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) { - if (index.alreadyImported(nodeName)) { - return index.get(nodeName); + private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + importNode(valueInfo.getName(), graph, model, index); + } + } + + private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + if (index.alreadyImported(name)) { + return index.get(name); } OnnxOperation operation; - if (isArgumentTensor(nodeName, graph)) { - operation = new Argument(getArgumentTensor(nodeName, graph)); - } else if (isConstantTensor(nodeName, graph)) { - operation = new Constant(getConstantTensor(nodeName, graph)); + if (isArgumentTensor(name, graph)) { + operation = new Argument(getArgumentTensor(name, graph)); + model.input(OnnxOperation.namePartOf(name), operation.vespaName()); + } else if (isConstantTensor(name, graph)) { + operation = new Constant(model.name(), getConstantTensor(name, graph)); } else { - Onnx.NodeProto node = getNodeFromGraph(nodeName, graph); - List<OnnxOperation> inputs = importNodeInputs(node, graph, index); + Onnx.NodeProto node = getNodeFromGraph(name, graph); + List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index); operation = OperationMapper.get(node, inputs); + if (isOutputNode(name, graph)) { + model.output(OnnxOperation.namePartOf(name), operation.vespaName()); + } } - index.put(nodeName, operation); + index.put(operation.vespaName(), operation); return operation; } @@ -113,8 +132,11 @@ public class OnnxImporter { private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph); - if (node.getName().equals(name)) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + String nodeName = OnnxOperation.namePartOf(valueInfo.getName()); + if (nodeName.equals(name)) { return valueInfo; } } @@ -123,18 +145,34 @@ public class OnnxImporter { private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, Onnx.GraphProto graph, + OnnxModel model, OperationIndex index) { return node.getInputList().stream() - .map(nodeName -> importNode(nodeName, graph, index)) + .map(nodeName -> importNode(nodeName, graph, model, index)) .collect(Collectors.toList()); } + private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + OnnxOperation operation = index.get(outputName); + Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph); + operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")) + .verifyType(onnxNode.getType()); + } + } + + /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(OnnxOperation output) { + private static void findDimensionNames(OnnxModel model, OperationIndex index) { DimensionRenamer renamer = new DimensionRenamer(); - addDimensionNameConstraints(output, renamer); + for (String output : model.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } renamer.solve(); - renameDimensions(output, renamer); + for (String output : model.outputs().values()) { + renameDimensions(index.get(output), renamer); + } } private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { @@ -151,10 +189,17 @@ public class OnnxImporter { } } - private static void importExpressions(OnnxOperation output, OnnxModel model) { - Optional<TensorFunction> function = importExpression(output, model); - if (!function.isPresent()) { - throw new IllegalArgumentException("No valid output function could be found."); + private static void importExpressions(OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(index.get(outputName), model); + if (!function.isPresent()) { + model.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + model.skippedOutput(outputName, Exceptions.toMessageString(e)); + } } } @@ -167,7 +212,7 @@ public class OnnxImporter { } importInputExpressions(operation, model); importRankingExpression(operation, model); - importInputExpression(operation, model); + importArgumentExpression(operation, model); return operation.function(); } @@ -204,7 +249,7 @@ public class OnnxImporter { if (!model.expressions().containsKey(name)) { TensorFunction function = operation.function().get(); - if (name.equals(model.output())) { + if (model.outputs().containsKey(name)) { OrderedTensorType operationType = operation.type().get(); OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); if ( ! operationType.equals(standardNamingType)) { @@ -228,7 +273,7 @@ public class OnnxImporter { } } - private static void importInputExpression(OnnxOperation operation, OnnxModel model) { + private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) { if (operation.isInput()) { // All inputs must have dimensions with standard naming convention: d0, d1, ... OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); @@ -237,6 +282,20 @@ public class OnnxImporter { } } + private static void reportWarnings(OnnxModel model, OperationIndex index) { + for (String output : model.outputs().values()) { + reportWarnings(model, index.get(output)); + } + } + + private static void reportWarnings(OnnxModel model, OnnxOperation operation) { + for (String warning : operation.warnings()) { + model.importWarning(warning); + } + for (OnnxOperation input : operation.inputs()) { + reportWarnings(model, input); + } + } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { boolean hasPortNumber = nodeName.contains(":"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java index df108fcbbe7..027c1d7ff9d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java @@ -14,29 +14,73 @@ import java.util.regex.Pattern; /** * The result of importing an ONNX model into Vespa. * + * @author bratseth * @author lesters */ public class OnnxModel { - public OnnxModel(String outputNode) { - this.output = outputNode; + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + + private final String name; + + public OnnxModel(String name) { + if ( ! nameRegexp.matcher(name).matches()) + throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); + this.name = name; } - private final String output; + /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + public String name() { return name; } + + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); + private final List<String> importWarnings = new ArrayList<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); private final Map<String, Tensor> smallConstants = new HashMap<>(); private final Map<String, Tensor> largeConstants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> macros = new HashMap<>(); private final Map<String, TensorType> requiredMacros = new HashMap<>(); + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - /** Return the name of the output node for this model */ - public String output() { return output; } + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } + + /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); } /** Returns an immutable map of the arguments (inputs) of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } @@ -57,6 +101,9 @@ public class OnnxModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + /** Returns an immutable map of macros that are part of this model */ + public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + /** Returns an immutable map of the macros that must be provided by the environment running this model */ public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java index ab650bf8d77..b5494477227 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java @@ -15,18 +15,19 @@ import java.util.Optional; public class Constant extends OnnxOperation { + final String modelName; final Onnx.TensorProto tensorProto; - public Constant(Onnx.TensorProto tensorProto) { + public Constant(String modelName, Onnx.TensorProto tensorProto) { super(null, Collections.emptyList()); + this.modelName = modelName; this.tensorProto = tensorProto; } /** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { -// return modelName() + "_" + super.vespaName(); - return vespaName(tensorProto.getName()); + return modelName + "_" + vespaName(tensorProto.getName()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java index 2c8003f5951..3c9f01c5e1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java @@ -92,7 +92,7 @@ public abstract class OnnxOperation { /** Retrieve the valid Vespa name of this node */ public String vespaName() { return vespaName(node.getName()); } - public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } @@ -116,4 +116,22 @@ public abstract class OnnxOperation { return verifyInputs(expected, OnnxOperation::function); } + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java index e118c2b885a..4b68cd40a08 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add"); + OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("Variable_0"); + Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("Variable_1_0"); + Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); @@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase { // Check required macros (inputs) assertEquals(1, model.requiredMacros().size()); - assertTrue(model.requiredMacros().containsKey("Placeholder_0")); + assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder_0")); + model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.expressions().get("add"); + RankingExpression output = model.outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.getRoot().toString()); } @@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor argument = placeholderArgument(); Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); - Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add"); + Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add"); assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); } @@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - OnnxModel model = new OnnxImporter().importModel(path, output); + OnnxModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } |