diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-06-06 13:38:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-06 13:38:37 +0200 |
commit | 240176d60c44507f4e6733c7512620e80554c8de (patch) | |
tree | 7b1f54a9acd169da88a524f4899ddcf76d02db28 /config-model | |
parent | e4f626c587cf1cc4d5c05da5e15523f4162107f0 (diff) | |
parent | e4626398c7e9c1b4b0fa5dbd974e1696c377dd77 (diff) |
Merge pull request #6046 from vespa-engine/lesters/refactor-onnx-tensorflow-import
Refactor ONNX and TF import to use same code base
Diffstat (limited to 'config-model')
5 files changed, 718 insertions, 1317 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java new file mode 100644 index 00000000000..effa261be3b --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java @@ -0,0 +1,674 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Base class for replacing instances of a pseudofeature for imported ML + * ranking models with native Vespa ranking expressions. + * + * @author bratseth + * @author lesters + */ +abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + + ExpressionNode transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); + if (signature.skippedOutputs().containsKey(output)) { + String message = "Could not import model output '" + output + "'"; + if (!signature.skippedOutputs().get(output).isEmpty()) { + message += ": " + signature.skippedOutputs().get(output); + } + if (!signature.importWarnings().isEmpty()) { + message += ": " + String.join(", ", signature.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair<String, RankingExpression> macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing signature, or the only signature if none is specified. + * Throws IllegalArgumentException in all other cases. + */ + private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { + if ( ! signatureName.isPresent()) { + if (importResult.signatures().size() == 0) + throw new IllegalArgumentException("No signatures are available"); + if (importResult.signatures().size() > 1) + throw new IllegalArgumentException("Model has multiple signatures (" + + Joiner.on(", ").join(importResult.signatures().keySet()) + + "), one must be specified " + + "as a second argument to tensorflow()"); + return importResult.signatures().values().stream().findFirst().get(); + } + else { + ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); + if (signature == null) + throw new IllegalArgumentException("Model does not have the specified signature '" + + signatureName.get() + "'"); + return signature; + } + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (signature.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); + if (signature.outputs().size() > 1) + throw new IllegalArgumentException(signature + " has multiple outputs (" + + Joiner.on(", ").join(signature.outputs().keySet()) + + "), one must be specified " + + "as a third argument to tensorflow()"); + return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); + } + else { + String output = signature.outputs().get(outputName.get()); + if (output == null) { + if (signature.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + signature.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } + } + + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + typeMismatchExplanation(constantValue.type(), macroType)); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(ImportedModel.Signature signature) { + if (signature.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers input '" + macroName + "'. " + + typeMismatchExplanation(requiredType, actualType)); + } + } + + private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + return "The required type of this is " + requiredType + ", but this macro returns " + actualType + + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + + "in query profile types - see the documentation." + : ""); + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(ImportedModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, + TypeContext<Reference> typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + ModelStore(ApplicationPackage application, FeatureArguments arguments) { + this.application = application; + this.arguments = arguments; + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File modelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + List<Pair<String, RankingExpression>> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + List<RankingConstant> readLargeConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the arguments to the import feature */ + static abstract class FeatureArguments { + + Path modelPath; + + /** Optional arguments */ + Optional<String> signature, output; + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + } +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 1c41ad8284e..44eeb364603 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -2,58 +2,20 @@ package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter; -import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; /** * Replaces instances of the onnx(model-path, output) @@ -63,12 +25,12 @@ import java.util.stream.Collectors; * @author bratseth * @author lesters */ -public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { +public class OnnxFeatureConverter extends MLImportFeatureConverter { private final OnnxImporter onnxImporter = new OnnxImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, OnnxModel> importedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -84,7 +46,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); else @@ -98,597 +61,24 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFromOnnxModel(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles) { - OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> onnxImporter.importModel(store.arguments().modelName(), - store.onnxModelDir())); - - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - String output = chooseOutput(model, store.arguments().output()); - if (model.skippedOutputs().containsKey(output)) { - String message = "Could not import Onnx model output '" + output + "'"; - if (!model.skippedOutputs().get(output).isEmpty()) { - message += ": " + model.skippedOutputs().get(output); - } - if (!model.importWarnings().isEmpty()) { - message += ": " + String.join(", ", model.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(OnnxModel model, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (model.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model)); - if (model.outputs().size() > 1) - throw new IllegalArgumentException("Onnx model has multiple outputs (" + - Joiner.on(", ").join(model.outputs().keySet()) + - "), one must be specified " + - "as a second argument to onnx()"); - return model.outputs().get(model.outputs().keySet().stream().findFirst().get()); - } - else { - String output = model.outputs().get(outputName.get()); - if (output == null) { - if (model.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - model.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } + store.modelDir())); + return transformFromImportedModel(model, store, profile, queryProfiles); } - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - "The required type of this is " + constantValue.type() + - ", but the macro returns " + macroType); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, RankProfile profile, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(OnnxModel model) { - if (model.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, OnnxModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers Placeholder '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro produces type " + actualType); - } - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(OnnxModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, OnnxModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - private static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - public ModelStore(ApplicationPackage application, Arguments arguments) { - this.application = application; - this.arguments = new FeatureArguments(arguments); - } - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which contains the source model to use for these arguments - */ - public File onnxModelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - public void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - public RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - public void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - public List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - public List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - public Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */ - private static class FeatureArguments { - - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> output; - - public FeatureArguments(Arguments arguments) { + static class OnnxFeatureArguments extends FeatureArguments { + public OnnxFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + - "the onnx model directory under [application]/models"); + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); modelPath = Path.fromString(asString(arguments.expressions().get(0))); output = optionalArgument(1, arguments); + signature = Optional.of("default"); } - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 41da32f64c3..27e1ad51b33 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,59 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -62,12 +22,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { +public class TensorFlowFeatureConverter extends MLImportFeatureConverter { private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -83,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); + FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); else @@ -95,565 +56,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private ExpressionNode transformFromTensorFlowModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> tensorFlowImporter.importModel(store.arguments().modelName(), - store.tensorFlowModelDir())); - - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import TensorFlow model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); - } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } - } - - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - typeMismatchExplanation(constantValue.type(), macroType)); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(TensorFlowModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); + RankProfile profile, + QueryProfileRegistry queryProfiles) { + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.arguments().modelName(), + store.modelDir())); + return transformFromImportedModel(model, store, profile, queryProfiles); } - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " + - typeMismatchExplanation(requiredType, actualType)); - } - } - - private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { - return "The required type of this is " + requiredType + ", but this macro returns " + actualType + - (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + - "in query profile types - see the documentation." - : ""); - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - private static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - public ModelStore(ApplicationPackage application, Arguments arguments) { - this.application = application; - this.arguments = new FeatureArguments(arguments); - } - - - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which (if hasTensorFlowModels is true) - * contains the source model to use for these arguments - */ - public File tensorFlowModelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - public void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - public RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - public void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - public List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - public List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - public Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ - private static class FeatureArguments { - - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { + static class TensorFlowFeatureArguments extends FeatureArguments { + public TensorFlowFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); @@ -661,68 +76,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil signature = optionalArgument(1, arguments); output = optionalArgument(2, arguments); } - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 1c54d12d8b3..d9beab6e2f2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReference() throws ParseException { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); - assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); - } - - @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", @@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase { } @Test - public void testOnnxReferenceSpecifyingOutput() { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'add')"); - search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - } - - @Test public void testOnnxReferenceMissingMacro() throws ParseException { try { RankProfileSearchFixture search = new RankProfileSearchFixture( @@ -145,7 +129,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -163,8 +147,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + - "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", + "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index d288a396732..7228af2b0de 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } @@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", @@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase { "input", new StoringApplicationPackage(applicationDir)); search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase { application); search.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase { } - public static class StoringApplicationPackageFile extends ApplicationFile { + static class StoringApplicationPackageFile extends ApplicationFile { /** The path to the application package root */ private final Path root; |