From 88d06ec474f727d41963b6aa65c2382ccc01c3f5 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 28 May 2018 11:33:17 +0200 Subject: Add ONNX pseudo ranking feature --- .../expressiontransforms/ExpressionTransforms.java | 1 + .../expressiontransforms/OnnxFeatureConverter.java | 693 +++++++++++++++++++++ .../integration/onnx/models/mnist_softmax.onnx | Bin 0 -> 31758 bytes .../RankingExpressionWithOnnxTestCase.java | 333 ++++++++++ .../RankingExpressionWithTensorFlowTestCase.java | 4 +- 5 files changed, 1029 insertions(+), 2 deletions(-) create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java create mode 100644 config-model/src/test/integration/onnx/models/mnist_softmax.onnx create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java (limited to 'config-model/src') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 67d60b08ab0..6ca16c1559d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -21,6 +21,7 @@ public class ExpressionTransforms { private final List transforms = ImmutableList.of(new TensorFlowFeatureConverter(), + new OnnxFeatureConverter(), new ConstantDereferencer(), new ConstantTensorTransformer(), new MacroInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java new file mode 100644 index 00000000000..6d44b130996 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -0,0 +1,693 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Replaces instances of the onnx(model-path, output) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author bratseth + * @author lesters + */ +public class OnnxFeatureConverter extends ExpressionTransformer { + + private final OnnxImporter onnxImporter = new OnnxImporter(); + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map importedModels = new HashMap<>(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if ( ! feature.getName().equals("onnx")) return feature; + + try { + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files + return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); + else + return transformFromStoredModel(store, context.rankProfile()); + } + catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use Onnx model from " + feature, e); + } + } + + private ExpressionNode transformFromOnnxModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { + OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> onnxImporter.importModel(store.arguments().modelName(), + store.onnxModelDir())); + + // Add constants + Set constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + String output = chooseOutput(model, store.arguments().output()); + if (model.skippedOutputs().containsKey(output)) { + String message = "Could not import Onnx model output '" + output + "'"; + if (!model.skippedOutputs().get(output).isEmpty()) { + message += ": " + model.skippedOutputs().get(output); + } + if (!model.importWarnings().isEmpty()) { + message += ": " + String.join(", ", model.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(OnnxModel model, Optional outputName) { + if ( ! outputName.isPresent()) { + if (model.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model)); + if (model.outputs().size() > 1) + throw new IllegalArgumentException("Onnx model has multiple outputs (" + + Joiner.on(", ").join(model.outputs().keySet()) + + "), one must be specified " + + "as a second argument to onnx()"); + return model.outputs().get(model.outputs().keySet().stream().findFirst().get()); + } + else { + String output = model.outputs().get(outputName.get()); + if (output == null) { + if (model.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + model.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } + } + + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, RankProfile profile, + Set constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(OnnxModel model) { + if (model.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType); + } + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(OnnxModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model, + TypeContext typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List children = ((CompositeNode)node).children(); + List transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set names, OnnxModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File onnxModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + public void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + public List> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List readLargeConstants() { + try { + List constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */ + private static class FeatureArguments { + + private final Path modelPath; + + /** Optional arguments */ + private final Optional output; + + public FeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An onnx node must take an argument pointing to " + + "the onnx model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); + + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + output = optionalArgument(1, arguments); + } + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + private Optional optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + } + +} diff --git a/config-model/src/test/integration/onnx/models/mnist_softmax.onnx b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx new file mode 100644 index 00000000000..a86019bf53a Binary files /dev/null and b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx differ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java new file mode 100644 index 00000000000..5f2f9e9ffaa --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -0,0 +1,333 @@ +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.yolean.Exceptions; +import org.junit.After; +import org.junit.Test; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Optional; + +import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RankingExpressionWithOnnxTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/onnx/"); + private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_onnx_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))"; + + @After + public void removeGeneratedConstantTensorFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testOnnxReference() throws ParseException { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithConstantFeature() { + RankProfileSearchFixture search = fixtureWith("constant(mytensor)", + "onnx('mnist_softmax.onnx')", + "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", + null); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithQueryFeature() { + String queryProfile = ""; + String queryProfileType = "" + + " " + + ""; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, + queryProfile, + queryProfileType); + RankProfileSearchFixture search = fixtureWith("query(mytensor)", + "onnx('mnist_softmax.onnx')", + null, + null, + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceWithDocumentFeature() { + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); + RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", + "onnx('mnist_softmax.onnx')", + null, + "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + + @Test + public void testOnnxReferenceWithFeatureCombination() { + String queryProfile = ""; + String queryProfileType = "" + + " " + + ""; + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, + queryProfile, + queryProfileType); + RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", + "onnx('mnist_softmax.onnx')", + "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", + "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "Placeholder", + application); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + + @Test + public void testNestedOnnxReference() { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "5 + sum(onnx('mnist_softmax.onnx'))"); + search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test + public void testOnnxReferenceSpecifyingOutput() { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'add')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testOnnxReferenceMissingMacro() throws ParseException { + try { + RankProfileSearchFixture search = new RankProfileSearchFixture( + new StoringApplicationPackage(applicationDir), + new QueryProfileRegistry(), + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: onnx('mnist_softmax.onnx')" + + " }\n" + + " }"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx'): " + + "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "not present in rank profile 'my_profile'", + Exceptions.toMessageString(expected)); + } + } + + + @Test + public void testOnnxReferenceWithWrongMacroType() { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx'): " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + + "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testOnnxReferenceSpecifyingNonExistingOutput() { + try { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'y')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + fail("Expecting exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + + "onnx('mnist_softmax.onnx','y'): " + + "Model does not have the specified output 'y'", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testImportingFromStoredExpressions() throws IOException { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')", + null, + null, + "Placeholder", + storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + // Verify that the constants exists, but don't verify the content as we are not + // simulating file distribution in this test + assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.empty()); + assertLargeConstant("mnist_softmax_onnx_Variable", searchFromStored, Optional.empty()); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + @Test + public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { + String rankProfile = + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: tensor(d0[2],d1[784])(0.0)\n" + + " }\n" + + " macro mnist_softmax_onnx_Variable() {\n" + + " expression: tensor(d1[10],d2[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx('mnist_softmax.onnx')" + + " }\n" + + " }"; + + + String vespaExpressionWithoutConstant = + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))"; + RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + + assertNull("Constant overridden by macro is not added", + search.search().getRankingConstants().get("mnist_softmax_onnx_Variable")); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + assertNull("Constant overridden by macro is not added", + searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable")); + assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.of(10L)); + } finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional expectedSize) { + try { + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax.onnx/constants").append(name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString())); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); + assertTrue("Constant file '" + constantPath + "' has been written", + constantPath.toFile().exists()); + Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); + } + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { + return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, + String constant, String field) { + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String macroExpression, + String firstPhaseExpression, + String constant, + String field, + String macroName, + StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " macro " + macroName + "() {\n" + + " expression: " + macroExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 55754605843..d288a396732 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -439,7 +439,7 @@ public class RankingExpressionWithTensorFlowTestCase { } } - private static class StoringApplicationPackage extends MockApplicationPackage { + static class StoringApplicationPackage extends MockApplicationPackage { private final File root; @@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase { } - private static class StoringApplicationPackageFile extends ApplicationFile { + public static class StoringApplicationPackageFile extends ApplicationFile { /** The path to the application package root */ private final Path root; -- cgit v1.2.3