From 88d06ec474f727d41963b6aa65c2382ccc01c3f5 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 28 May 2018 11:33:17 +0200 Subject: Add ONNX pseudo ranking feature --- .../expressiontransforms/ExpressionTransforms.java | 1 + .../expressiontransforms/OnnxFeatureConverter.java | 693 +++++++++++++++++++++ .../integration/onnx/models/mnist_softmax.onnx | Bin 0 -> 31758 bytes .../RankingExpressionWithOnnxTestCase.java | 333 ++++++++++ .../RankingExpressionWithTensorFlowTestCase.java | 4 +- .../integration/onnx/OnnxImporter.java | 127 +++- .../integration/onnx/OnnxModel.java | 57 +- .../onnx/importer/operations/Constant.java | 7 +- .../onnx/importer/operations/OnnxOperation.java | 20 +- .../onnx/OnnxMnistSoftmaxImportTestCase.java | 18 +- 10 files changed, 1206 insertions(+), 54 deletions(-) create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java create mode 100644 config-model/src/test/integration/onnx/models/mnist_softmax.onnx create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 67d60b08ab0..6ca16c1559d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -21,6 +21,7 @@ public class ExpressionTransforms { private final List transforms = ImmutableList.of(new TensorFlowFeatureConverter(), + new OnnxFeatureConverter(), new ConstantDereferencer(), new ConstantTensorTransformer(), new MacroInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java new file mode 100644 index 00000000000..6d44b130996 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -0,0 +1,693 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Replaces instances of the onnx(model-path, output) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author bratseth + * @author lesters + */ +public class OnnxFeatureConverter extends ExpressionTransformer { + + private final OnnxImporter onnxImporter = new OnnxImporter(); + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map importedModels = new HashMap<>(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if ( ! feature.getName().equals("onnx")) return feature; + + try { + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files + return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); + else + return transformFromStoredModel(store, context.rankProfile()); + } + catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); + } + } + + private ExpressionNode transformFromOnnxModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { + OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> onnxImporter.importModel(store.arguments().modelName(), + store.onnxModelDir())); + + // Add constants + Set constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + String output = chooseOutput(model, store.arguments().output()); + if (model.skippedOutputs().containsKey(output)) { + String message = "Could not import Onnx model output '" + output + "'"; + if (!model.skippedOutputs().get(output).isEmpty()) { + message += ": " + model.skippedOutputs().get(output); + } + if (!model.importWarnings().isEmpty()) { + message += ": " + String.join(", ", model.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(OnnxModel model, Optional outputName) { + if ( ! outputName.isPresent()) { + if (model.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model)); + if (model.outputs().size() > 1) + throw new IllegalArgumentException("Onnx model has multiple outputs (" + + Joiner.on(", ").join(model.outputs().keySet()) + + "), one must be specified " + + "as a second argument to onnx()"); + return model.outputs().get(model.outputs().keySet().stream().findFirst().get()); + } + else { + String output = model.outputs().get(outputName.get()); + if (output == null) { + if (model.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + model.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } + } + + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, RankProfile profile, + Set constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(OnnxModel model) { + if (model.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType); + } + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(OnnxModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model, + TypeContext typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List children = ((CompositeNode)node).children(); + List transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set names, OnnxModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File onnxModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + public void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + public List> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List readLargeConstants() { + try { + List constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */ + private static class FeatureArguments { + + private final Path modelPath; + + /** Optional arguments */ + private final Optional output; + + public FeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An onnx node must take an argument pointing to " + + "the onnx model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); + + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + output = optionalArgument(1, arguments); + } + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + private Optional optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + } + +} diff --git a/config-model/src/test/integration/onnx/models/mnist_softmax.onnx b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx new file mode 100644 index 00000000000..a86019bf53a Binary files /dev/null and b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx differ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java new file mode 100644 index 00000000000..5f2f9e9ffaa --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -0,0 +1,333 @@ +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.yolean.Exceptions; +import org.junit.After; +import org.junit.Test; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + +import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RankingExpressionWithOnnxTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/onnx/"); + private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_onnx_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))"; + + @After + public void removeGeneratedConstantTensorFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testOnnxReference() throws ParseException { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithConstantFeature() { + RankProfileSearchFixture search = fixtureWith("constant(mytensor)", + "onnx('mnist_softmax.onnx')", + "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", + null); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithQueryFeature() { + String queryProfile = ""; + String queryProfileType = "" + + " " + + ""; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, + queryProfile, + queryProfileType); + RankProfileSearchFixture search = fixtureWith("query(mytensor)", + "onnx('mnist_softmax.onnx')", + null, + null, + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithDocumentFeature() { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", + "onnx('mnist_softmax.onnx')", + null, + "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + + @Test + public void testOnnxReferenceWithFeatureCombination() { + String queryProfile = ""; + String queryProfileType = "" + + " " + + ""; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, + queryProfile, + queryProfileType); + RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", + "onnx('mnist_softmax.onnx')", + "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", + "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + + @Test + public void testNestedOnnxReference() { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "5 + sum(onnx('mnist_softmax.onnx'))"); + search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceSpecifyingOutput() { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'add')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testOnnxReferenceMissingMacro() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + new StoringApplicationPackage(applicationDir), + new QueryProfileRegistry(), + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: onnx('mnist_softmax.onnx')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx'): " + + "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "not present in rank profile 'my_profile'", + Exceptions.toMessageString(expected)); + } + } + + + @Test + public void testOnnxReferenceWithWrongMacroType() { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx'): " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + + "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testOnnxReferenceSpecifyingNonExistingOutput() { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'y')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx','y'): " + + "Model does not have the specified output 'y'", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testImportingFromStoredExpressions() throws IOException { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')", + null, + null, + "Placeholder", + storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + // Verify that the constants exists, but don't verify the content as we are not + // simulating file distribution in this test + assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.empty()); + assertLargeConstant("mnist_softmax_onnx_Variable", searchFromStored, Optional.empty()); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + @Test + public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { + String rankProfile = + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: tensor(d0[2],d1[784])(0.0)\n" + + " }\n" + + " macro mnist_softmax_onnx_Variable() {\n" + + " expression: tensor(d1[10],d2[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx('mnist_softmax.onnx')" + + " }\n" + + " }"; + + + String vespaExpressionWithoutConstant = + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))"; + RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + + assertNull("Constant overridden by macro is not added", + search.search().getRankingConstants().get("mnist_softmax_onnx_Variable")); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + assertNull("Constant overridden by macro is not added", + searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable")); + assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.of(10L)); + } finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional expectedSize) { + try { + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax.onnx/constants").append(name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { + return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, + String constant, String field) { + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String macroExpression, + String firstPhaseExpression, + String constant, + String field, + String macroName, + StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " macro " + macroName + "() {\n" + + " expression: " + macroExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 55754605843..d288a396732 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -439,7 +439,7 @@ public class RankingExpressionWithTensorFlowTestCase { } } - private static class StoringApplicationPackage extends MockApplicationPackage { + static class StoringApplicationPackage extends MockApplicationPackage { private final File root; @@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase { } - private static class StoringApplicationPackageFile extends ApplicationFile { + public static class StoringApplicationPackageFile extends ApplicationFile { /** The path to the application package root */ private final Path root; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java index 047d1b187f5..295f9228316 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java @@ -13,8 +13,10 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; import onnx.Onnx; +import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.Collection; @@ -22,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -31,48 +34,64 @@ import java.util.stream.Collectors; */ public class OnnxImporter { - public OnnxModel importModel(String modelPath, String outputNode) { + private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); + + public OnnxModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + public OnnxModel importModel(String modelName, String modelPath) { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - return importModel(model, outputNode); + return importModel(modelName, model); } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } } - public OnnxModel importModel(Onnx.ModelProto model, String outputNode) { - return importGraph(model.getGraph(), outputNode); + public OnnxModel importModel(String modelName, Onnx.ModelProto model) { + return importGraph(modelName, model.getGraph()); } - private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) { - OnnxModel model = new OnnxModel(outputNode); + private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) { + OnnxModel model = new OnnxModel(modelName); OperationIndex index = new OperationIndex(); - OnnxOperation output = importNode(outputNode, graph, index); - output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type.")) - .verifyType(getOutputNode(outputNode, graph).getType()); + importNodes(graph, model, index); + verifyOutputTypes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index); - findDimensionNames(output); - importExpressions(output, model); + reportWarnings(model, index); return model; } - private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) { - if (index.alreadyImported(nodeName)) { - return index.get(nodeName); + private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + importNode(valueInfo.getName(), graph, model, index); + } + } + + private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + if (index.alreadyImported(name)) { + return index.get(name); } OnnxOperation operation; - if (isArgumentTensor(nodeName, graph)) { - operation = new Argument(getArgumentTensor(nodeName, graph)); - } else if (isConstantTensor(nodeName, graph)) { - operation = new Constant(getConstantTensor(nodeName, graph)); + if (isArgumentTensor(name, graph)) { + operation = new Argument(getArgumentTensor(name, graph)); + model.input(OnnxOperation.namePartOf(name), operation.vespaName()); + } else if (isConstantTensor(name, graph)) { + operation = new Constant(model.name(), getConstantTensor(name, graph)); } else { - Onnx.NodeProto node = getNodeFromGraph(nodeName, graph); - List inputs = importNodeInputs(node, graph, index); + Onnx.NodeProto node = getNodeFromGraph(name, graph); + List inputs = importNodeInputs(node, graph, model, index); operation = OperationMapper.get(node, inputs); + if (isOutputNode(name, graph)) { + model.output(OnnxOperation.namePartOf(name), operation.vespaName()); + } } - index.put(nodeName, operation); + index.put(operation.vespaName(), operation); return operation; } @@ -113,8 +132,11 @@ public class OnnxImporter { private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph); - if (node.getName().equals(name)) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + String nodeName = OnnxOperation.namePartOf(valueInfo.getName()); + if (nodeName.equals(name)) { return valueInfo; } } @@ -123,18 +145,34 @@ public class OnnxImporter { private static List importNodeInputs(Onnx.NodeProto node, Onnx.GraphProto graph, + OnnxModel model, OperationIndex index) { return node.getInputList().stream() - .map(nodeName -> importNode(nodeName, graph, index)) + .map(nodeName -> importNode(nodeName, graph, model, index)) .collect(Collectors.toList()); } + private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + OnnxOperation operation = index.get(outputName); + Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph); + operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")) + .verifyType(onnxNode.getType()); + } + } + + /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(OnnxOperation output) { + private static void findDimensionNames(OnnxModel model, OperationIndex index) { DimensionRenamer renamer = new DimensionRenamer(); - addDimensionNameConstraints(output, renamer); + for (String output : model.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } renamer.solve(); - renameDimensions(output, renamer); + for (String output : model.outputs().values()) { + renameDimensions(index.get(output), renamer); + } } private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { @@ -151,10 +189,17 @@ public class OnnxImporter { } } - private static void importExpressions(OnnxOperation output, OnnxModel model) { - Optional function = importExpression(output, model); - if (!function.isPresent()) { - throw new IllegalArgumentException("No valid output function could be found."); + private static void importExpressions(OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + try { + Optional function = importExpression(index.get(outputName), model); + if (!function.isPresent()) { + model.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + model.skippedOutput(outputName, Exceptions.toMessageString(e)); + } } } @@ -167,7 +212,7 @@ public class OnnxImporter { } importInputExpressions(operation, model); importRankingExpression(operation, model); - importInputExpression(operation, model); + importArgumentExpression(operation, model); return operation.function(); } @@ -204,7 +249,7 @@ public class OnnxImporter { if (!model.expressions().containsKey(name)) { TensorFunction function = operation.function().get(); - if (name.equals(model.output())) { + if (model.outputs().containsKey(name)) { OrderedTensorType operationType = operation.type().get(); OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); if ( ! operationType.equals(standardNamingType)) { @@ -228,7 +273,7 @@ public class OnnxImporter { } } - private static void importInputExpression(OnnxOperation operation, OnnxModel model) { + private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) { if (operation.isInput()) { // All inputs must have dimensions with standard naming convention: d0, d1, ... OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); @@ -237,6 +282,20 @@ public class OnnxImporter { } } + private static void reportWarnings(OnnxModel model, OperationIndex index) { + for (String output : model.outputs().values()) { + reportWarnings(model, index.get(output)); + } + } + + private static void reportWarnings(OnnxModel model, OnnxOperation operation) { + for (String warning : operation.warnings()) { + model.importWarning(warning); + } + for (OnnxOperation input : operation.inputs()) { + reportWarnings(model, input); + } + } private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { boolean hasPortNumber = nodeName.contains(":"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java index df108fcbbe7..027c1d7ff9d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java @@ -14,29 +14,73 @@ import java.util.regex.Pattern; /** * The result of importing an ONNX model into Vespa. * + * @author bratseth * @author lesters */ public class OnnxModel { - public OnnxModel(String outputNode) { - this.output = outputNode; + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + + private final String name; + + public OnnxModel(String name) { + if ( ! nameRegexp.matcher(name).matches()) + throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); + this.name = name; } - private final String output; + /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + public String name() { return name; } + + private final Map inputs = new HashMap<>(); + private final Map outputs = new HashMap<>(); + private final Map skippedOutputs = new HashMap<>(); + private final List importWarnings = new ArrayList<>(); + private final Map arguments = new HashMap<>(); private final Map smallConstants = new HashMap<>(); private final Map largeConstants = new HashMap<>(); private final Map expressions = new HashMap<>(); + private final Map macros = new HashMap<>(); private final Map requiredMacros = new HashMap<>(); + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - /** Return the name of the output node for this model */ - public String output() { return output; } + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map outputs() { return Collections.unmodifiableMap(outputs); } + + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List importWarnings() { return Collections.unmodifiableList(importWarnings); } + + /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); } /** Returns an immutable map of the arguments (inputs) of this */ public Map arguments() { return Collections.unmodifiableMap(arguments); } @@ -57,6 +101,9 @@ public class OnnxModel { */ public Map expressions() { return Collections.unmodifiableMap(expressions); } + /** Returns an immutable map of macros that are part of this model */ + public Map macros() { return Collections.unmodifiableMap(macros); } + /** Returns an immutable map of the macros that must be provided by the environment running this model */ public Map requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java index ab650bf8d77..b5494477227 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java @@ -15,18 +15,19 @@ import java.util.Optional; public class Constant extends OnnxOperation { + final String modelName; final Onnx.TensorProto tensorProto; - public Constant(Onnx.TensorProto tensorProto) { + public Constant(String modelName, Onnx.TensorProto tensorProto) { super(null, Collections.emptyList()); + this.modelName = modelName; this.tensorProto = tensorProto; } /** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { -// return modelName() + "_" + super.vespaName(); - return vespaName(tensorProto.getName()); + return modelName + "_" + vespaName(tensorProto.getName()); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java index 2c8003f5951..3c9f01c5e1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java @@ -92,7 +92,7 @@ public abstract class OnnxOperation { /** Retrieve the valid Vespa name of this node */ public String vespaName() { return vespaName(node.getName()); } - public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } /** Retrieve the list of warnings produced during its lifetime */ public List warnings() { return Collections.unmodifiableList(importWarnings); } @@ -116,4 +116,22 @@ public abstract class OnnxOperation { return verifyInputs(expected, OnnxOperation::function); } + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java index e118c2b885a..4b68cd40a08 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add"); + OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("Variable_0"); + Tensor constant0 = model.largeConstants().get("test_Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("Variable_1_0"); + Tensor constant1 = model.largeConstants().get("test_Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); @@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase { // Check required macros (inputs) assertEquals(1, model.requiredMacros().size()); - assertTrue(model.requiredMacros().containsKey("Placeholder_0")); + assertTrue(model.requiredMacros().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder_0")); + model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.expressions().get("add"); + RankingExpression output = model.outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.getRoot().toString()); } @@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor argument = placeholderArgument(); Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); - Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add"); + Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add"); assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); } @@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - OnnxModel model = new OnnxImporter().importModel(path, output); + OnnxModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } -- cgit v1.2.3