diff options
author | HÃ¥kon Hallingstad <hakon@oath.com> | 2018-06-06 14:25:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-06 14:25:23 +0200 |
commit | 62ae46a58d9501ad60431634f374b3cfa2856a48 (patch) | |
tree | 12ea280192a44f26b9718018c7cfb39b0c4c4735 | |
parent | 240176d60c44507f4e6733c7512620e80554c8de (diff) | |
parent | 681963959794b47102d1a1cf72f215c72b0e2b51 (diff) |
Merge pull request #6106 from vespa-engine/revert-6046-lesters/refactor-onnx-tensorflow-import
Revert "Refactor ONNX and TF import to use same code base"
64 files changed, 3726 insertions, 2365 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java deleted file mode 100644 index effa261be3b..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java +++ /dev/null @@ -1,674 +0,0 @@ -package com.yahoo.searchdefinition.expressiontransforms; - -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; -import com.yahoo.path.Path; -import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; - -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Base class for replacing instances of a pseudofeature for imported ML - * ranking models with native Vespa ranking expressions. - * - * @author bratseth - * @author lesters - */ -abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - - ExpressionNode transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); - } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - ImportedModel.Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } - } - - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - typeMismatchExplanation(constantValue.type(), macroType)); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(ImportedModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers input '" + macroName + "'. " + - typeMismatchExplanation(requiredType, actualType)); - } - } - - private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { - return "The required type of this is " + requiredType + ", but this macro returns " + actualType + - (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + - "in query profile types - see the documentation." - : ""); - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(ImportedModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - ModelStore(ApplicationPackage application, FeatureArguments arguments) { - this.application = application; - this.arguments = arguments; - } - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which contains the source model to use for these arguments - */ - public File modelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the arguments to the import feature */ - static abstract class FeatureArguments { - - Path modelPath; - - /** Optional arguments */ - Optional<String> signature, output; - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - - } -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 44eeb364603..1c41ad8284e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -2,20 +2,58 @@ package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter; +import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; /** * Replaces instances of the onnx(model-path, output) @@ -25,12 +63,12 @@ import java.util.Optional; * @author bratseth * @author lesters */ -public class OnnxFeatureConverter extends MLImportFeatureConverter { +public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { private final OnnxImporter onnxImporter = new OnnxImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ImportedModel> importedModels = new HashMap<>(); + private final Map<Path, OnnxModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -46,8 +84,7 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter { if ( ! feature.getName().equals("onnx")) return feature; try { - FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments()); - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles()); else @@ -61,24 +98,597 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter { private ExpressionNode transformFromOnnxModel(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles) { - ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> onnxImporter.importModel(store.arguments().modelName(), - store.modelDir())); - return transformFromImportedModel(model, store, profile, queryProfiles); + store.onnxModelDir())); + + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + String output = chooseOutput(model, store.arguments().output()); + if (model.skippedOutputs().containsKey(output)) { + String message = "Could not import Onnx model output '" + output + "'"; + if (!model.skippedOutputs().get(output).isEmpty()) { + message += ": " + model.skippedOutputs().get(output); + } + if (!model.importWarnings().isEmpty()) { + message += ": " + String.join(", ", model.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair<String, RankingExpression> macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(OnnxModel model, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (model.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model)); + if (model.outputs().size() > 1) + throw new IllegalArgumentException("Onnx model has multiple outputs (" + + Joiner.on(", ").join(model.outputs().keySet()) + + "), one must be specified " + + "as a second argument to onnx()"); + return model.outputs().get(model.outputs().keySet().stream().findFirst().get()); + } + else { + String output = model.outputs().get(outputName.get()); + if (output == null) { + if (model.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + model.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } } - static class OnnxFeatureArguments extends FeatureArguments { - public OnnxFeatureArguments(Arguments arguments) { + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, RankProfile profile, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(OnnxModel model) { + if (model.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); + } + + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers input '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType); + } + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(OnnxModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, OnnxModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model, + TypeContext<Reference> typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which contains the source model to use for these arguments + */ + public File onnxModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + public void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + public List<Pair<String, RankingExpression>> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List<RankingConstant> readLargeConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */ + private static class FeatureArguments { + + private final Path modelPath; + + /** Optional arguments */ + private final Optional<String> output; + + public FeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("An onnx node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + "the onnx model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); modelPath = Path.fromString(asString(arguments.expressions().get(0))); output = optionalArgument(1, arguments); - signature = Optional.of("default"); } + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 27e1ad51b33..41da32f64c3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,19 +1,59 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -22,12 +62,12 @@ import java.util.Map; * * @author bratseth */ -public class TensorFlowFeatureConverter extends MLImportFeatureConverter { +public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ImportedModel> importedModels = new HashMap<>(); + private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -43,8 +83,7 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { if ( ! feature.getName().equals("tensorflow")) return feature; try { - FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); else @@ -56,19 +95,565 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { } private ExpressionNode transformFromTensorFlowModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> tensorFlowImporter.importModel(store.arguments().modelName(), - store.modelDir())); - return transformFromImportedModel(model, store, profile, queryProfiles); + RankProfile profile, + QueryProfileRegistry queryProfiles) { + TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.arguments().modelName(), + store.tensorFlowModelDir())); + + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + + // Find the specified expression + Signature signature = chooseSignature(model, store.arguments().signature()); + String output = chooseOutput(signature, store.arguments().output()); + if (signature.skippedOutputs().containsKey(output)) { + String message = "Could not import TensorFlow model output '" + output + "'"; + if (!signature.skippedOutputs().get(output).isEmpty()) { + message += ": " + signature.skippedOutputs().get(output); + } + if (!signature.importWarnings().isEmpty()) { + message += ": " + String.join(", ", signature.importWarnings()); + } + throw new IllegalArgumentException(message); + } + + RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); + + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); + + store.writeConverted(expression); + return expression.getRoot(); + } + + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { + if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + + for (Pair<String, RankingExpression> macro : store.readMacros()) { + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + + return store.readConverted().getRoot(); + } + + /** + * Returns the specified, existing signature, or the only signature if none is specified. + * Throws IllegalArgumentException in all other cases. + */ + private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { + if ( ! signatureName.isPresent()) { + if (importResult.signatures().size() == 0) + throw new IllegalArgumentException("No signatures are available"); + if (importResult.signatures().size() > 1) + throw new IllegalArgumentException("Model has multiple signatures (" + + Joiner.on(", ").join(importResult.signatures().keySet()) + + "), one must be specified " + + "as a second argument to tensorflow()"); + return importResult.signatures().values().stream().findFirst().get(); + } + else { + Signature signature = importResult.signatures().get(signatureName.get()); + if (signature == null) + throw new IllegalArgumentException("Model does not have the specified signature '" + + signatureName.get() + "'"); + return signature; + } + } + + /** + * Returns the specified, existing output expression, or the only output expression if no output name is specified. + * Throws IllegalArgumentException in all other cases. + */ + private String chooseOutput(Signature signature, Optional<String> outputName) { + if ( ! outputName.isPresent()) { + if (signature.outputs().size() == 0) + throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); + if (signature.outputs().size() > 1) + throw new IllegalArgumentException(signature + " has multiple outputs (" + + Joiner.on(", ").join(signature.outputs().keySet()) + + "), one must be specified " + + "as a third argument to tensorflow()"); + return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); + } + else { + String output = signature.outputs().get(outputName.get()); + if (output == null) { + if (signature.skippedOutputs().containsKey(outputName.get())) + throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + + signature.skippedOutputs().get(outputName.get())); + else + throw new IllegalArgumentException("Model does not have the specified output '" + + outputName.get() + "'"); + } + return output; + } + } + + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + typeMismatchExplanation(constantValue.type(), macroType)); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + Path constantPath = store.writeLargeConstant(constantName, constantValue); + if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } + } + } + + private void transformGeneratedMacro(ModelStore store, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + store.writeMacro(macroName, expression); + } + + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { + if (signature.skippedOutputs().isEmpty()) return ""; + StringBuilder b = new StringBuilder(": "); + signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); + return b.toString(); } - static class TensorFlowFeatureArguments extends FeatureArguments { - public TensorFlowFeatureArguments(Arguments arguments) { + /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + TensorType requiredType = model.requiredMacros().get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + if ( actualType == null) + throw new IllegalArgumentException("Model refers placeholder '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro references a feature which is not declared"); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " + + typeMismatchExplanation(requiredType, actualType)); + } + } + + private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { + return "The required type of this is " + requiredType + ", but this macro returns " + actualType + + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + + "in query profile types - see the documentation." + : ""); + } + + /** + * Add the generated macros to the rank profile + */ + private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; + } + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); + } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, + TypeContext<Reference> typeContext) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); + } + return node; + } + + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + + private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names, model); + } + } + + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + + /** + * Provides read/write access to the correct directories of the application package given by the feature arguments + */ + private static class ModelStore { + + private final ApplicationPackage application; + private final FeatureArguments arguments; + + public ModelStore(ApplicationPackage application, Arguments arguments) { + this.application = application; + this.arguments = new FeatureArguments(arguments); + } + + + + public FeatureArguments arguments() { return arguments; } + + public boolean hasStoredModel() { + try { + return application.getFile(arguments.expressionPath()).exists(); + } + catch (UnsupportedOperationException e) { + return false; + } + } + + /** + * Returns the directory which (if hasTensorFlowModels is true) + * contains the source model to use for these arguments + */ + public File tensorFlowModelDir() { + return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); + } + + /** + * Adds this expression to the application package, such that it can be read later. + */ + public void writeConverted(RankingExpression expression) { + application.getFile(arguments.expressionPath()) + .writeFile(new StringReader(expression.getRoot().toString())); + } + + /** Reads the previously stored ranking expression for these arguments */ + public RankingExpression readConverted() { + try { + return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + + /** Adds this macro expression to the application package to it can be read later. */ + public void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + public List<Pair<String, RankingExpression>> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Reads the information about all the large (aka ranking) constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List<RankingConstant> readLargeConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Adds this constant to the application package as a file, + * such that it can be distributed using file distribution. + * + * @return the path to the stored constant, relative to the application package root + */ + public Path writeLargeConstant(String name, Tensor constant) { + Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); + + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: + Path constantPath = constantsPath.append(name + ".tbf"); + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); + IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } + } + + private void createIfNeeded(Path path) { + File dir = application.getFileReference(path); + if ( ! dir.exists()) { + if (!dir.mkdirs()) + throw new IllegalStateException("Could not create " + dir); + } + } + + } + + /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ + private static class FeatureArguments { + + private final Path modelPath; + + /** Optional arguments */ + private final Optional<String> signature, output; + + public FeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + "the tensorflow model directory under [application]/models"); if (arguments.expressions().size() > 3) throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); @@ -76,6 +661,68 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { signature = optionalArgument(1, arguments); output = optionalArgument(2, arguments); } + + /** Returns modelPath with slashes replaced by underscores */ + public String modelName() { return modelPath.toString().replace('/', '_'); } + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); + } + + public Path expressionPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + .append(modelPath).append("expressions").append(expressionFileName()); + } + + private String expressionFileName() { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + return fileName.toString(); + } + + private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index d9beab6e2f2..1c54d12d8b3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -37,6 +37,15 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testOnnxReference() throws ParseException { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L)); + assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L)); + } + + @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", @@ -113,6 +122,13 @@ public class RankingExpressionWithOnnxTestCase { } @Test + public void testOnnxReferenceSpecifyingOutput() { + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "onnx('mnist_softmax.onnx', 'add')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test public void testOnnxReferenceMissingMacro() throws ParseException { try { RankProfileSearchFixture search = new RankProfileSearchFixture( @@ -129,7 +145,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -147,8 +163,8 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + - "but this macro returns tensor(d0[2],d5[10])", + "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " + + "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7228af2b0de..d288a396732 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + + "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " + "tensorflow('mnist_softmax/saved'): " + - "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + "but this macro returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } @@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", @@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase { "input", new StoringApplicationPackage(applicationDir)); search.assertFirstPhaseExpression(expression, "my_profile"); - search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase { application); search.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase { storedApplication); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase { } - static class StoringApplicationPackageFile extends ApplicationFile { + public static class StoringApplicationPackageFile extends ApplicationFile { /** The path to the application package root */ private final Path root; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java deleted file mode 100644 index a658833b426..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ /dev/null @@ -1,242 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; - -import java.io.File; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; - -/** - * Base class for importing ML models (ONNX/TensorFlow) as native Vespa - * ranking expressions. The general mechanism for import is for the - * specific ML platform import implementations to create an - * IntermediateGraph. This class offers common code to convert the - * IntermediateGraph to Vespa ranking expressions and macros. - * - * @author lesters - */ -public abstract class ModelImporter { - - private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); - - /** - * The main import function. - */ - public abstract ImportedModel importModel(String modelName, String modelPath); - - public ImportedModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); - } - - /** - * Takes an IntermediateGraph and converts it to a ImportedModel containing - * the actual Vespa ranking expressions. - */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { - ImportedModel model = new ImportedModel(graph.name()); - - graph.optimize(); - - importSignatures(graph, model); - importExpressions(graph, model); - reportWarnings(graph, model); - logVariableTypes(graph); - - return model; - } - - private static void importSignatures(IntermediateGraph graph, ImportedModel model) { - for (String signatureName : graph.signatures()) { - ImportedModel.Signature signature = model.signature(signatureName); - for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) { - signature.input(input.getKey(), input.getValue()); - } - for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { - signature.output(output.getKey(), output.getValue()); - } - } - } - - private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String inputName : signature.inputs().values()) { - if (inputName.equals(operation.name())) { - return true; - } - } - } - return false; - } - - private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - if (outputName.equals(operation.name())) { - return true; - } - } - } - return false; - } - - /** - * Convert intermediate representation to Vespa ranking expressions. - */ - static void importExpressions(IntermediateGraph graph, ImportedModel model) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(graph.get(outputName), model); - if (!function.isPresent()) { - signature.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - } - - private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(operation, model); - } - importExpressionInputs(operation, model); - importRankingExpression(operation, model); - importArgumentExpression(operation, model); - importMacroExpression(operation, model); - - return operation.function(); - } - - private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); - } - - private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Value value = operation.getConstantValue().orElseThrow(() -> - new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + - "is constant but does not have a value.")); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - - Tensor tensor = value.asTensor(); - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.function().isPresent()) { - String name = operation.name(); - if (!model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); - - if (isSignatureOutput(model, operation)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced from the output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Imported function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.isInput()) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.argument(operation.vespaName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.macro().isPresent()) { - TensorFunction function = operation.macro().get(); - try { - model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - - /** - * Add any import warnings to the signature in the ImportedModel. - */ - private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - reportWarnings(graph.get(outputName), model); - } - } - } - - private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { - for (String warning : operation.warnings()) { - model.defaultSignature().importWarning(warning); - } - for (IntermediateOperation input : operation.inputs()) { - reportWarnings(input, model); - } - } - - /** - * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. - * This allows users to learn the exact types (including dimension order after renaming) of the Variables - * such that these can be converted and fed to a parent document independently of the rest of the model - * for fast model weight updates. - */ - private static void logVariableTypes(IntermediateGraph graph) { - for (IntermediateOperation operation : graph.operations()) { - if ( ! (operation instanceof Constant)) continue; - if ( ! operation.type().isPresent()) continue; // will not happen - log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + - " of type " + operation.type().get()); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java deleted file mode 100644 index d3dd2a1d418..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; -import onnx.Onnx; - -import java.io.FileInputStream; -import java.io.IOException; - -/** - * Converts a ONNX model into a ranking expression and set of constants. - * - * @author lesters - */ -public class OnnxImporter extends ModelImporter { - - @Override - public ImportedModel importModel(String modelName, String modelPath) { - try (FileInputStream inputStream = new FileInputStream(modelPath)) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); - } catch (IOException e) { - throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java deleted file mode 100644 index ff584559a83..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; -import org.tensorflow.SavedModelBundle; - -import java.io.IOException; - -/** - * Converts a saved TensorFlow model into a ranking expression and set of constants. - * - * @author bratseth - * @author lesters - */ -public class TensorFlowImporter extends ModelImporter { - - /** - * Imports a saved TensorFlow model from a directory. - * The model should be saved as a .pbtxt or .pb file. - * The name of the model is taken as the db/pbtxt file name (not including the file ending). - * - * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] - * @param modelDir the directory containing the TensorFlow model files to import - */ - public ImportedModel importModel(String modelName, String modelDir) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(modelName, model); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); - } - } - - /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, SavedModelBundle model) { - try { - IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); - } - catch (IOException e) { - throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); - } - } - - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java deleted file mode 100644 index 39a8b211d09..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; - -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Holds an intermediate representation of an imported ONNX or TensorFlow - * graph. After this intermediate representation is constructed, it is used to - * simplify and optimize the computational graph and then converted into the - * final ImportedModel that holds the Vespa ranking expressions for the model. - * - * @author lesters - */ -public class IntermediateGraph { - - private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); - private final Map<String, GraphSignature> signatures = new HashMap<>(); - - private static class GraphSignature { - final Map<String, String> inputs = new HashMap<>(); - final Map<String, String> outputs = new HashMap<>(); - } - - public IntermediateGraph(String modelName) { - this.modelName = modelName; - } - - public String name() { - return modelName; - } - - public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); - } - - public IntermediateOperation get(String key) { - return index.get(key); - } - - public Set<String> signatures() { - return signatures.keySet(); - } - - public Map<String, String> inputs(String signature) { - return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs; - } - - public Map<String, String> outputs(String signature) { - return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs; - } - - public String defaultSignature() { - return "default"; - } - - public boolean alreadyImported(String key) { - return index.containsKey(key); - } - - public Collection<IntermediateOperation> operations() { - return index.values(); - } - - public void optimize() { - renameDimensions(); - } - - /** - * Find dimension names to avoid excessive renaming while evaluating the model. - */ - private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); - for (String signature : signatures()) { - for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - } - renamer.solve(); - for (String signature : signatures()) { - for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); - } - } - } - - private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java deleted file mode 100644 index 3fe92440cae..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; - -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; -import com.yahoo.tensor.functions.ScalarFunctions; -import onnx.Onnx; - -import java.util.List; -import java.util.stream.Collectors; - -/** - * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis - * for generating Vespa ranking expressions. - * - * @author lesters - */ -public class GraphImporter { - - public static IntermediateOperation mapOperation(Onnx.NodeProto node, - List<IntermediateOperation> inputs, - IntermediateGraph graph) { - String nodeName = node.getName(); - String modelName = graph.name(); - - switch (node.getOpType().toLowerCase()) { - case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); - case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); - case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); - case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); - case "concat": return new ConcatV2(modelName, nodeName, inputs); - case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); - case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); - case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); - case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); - case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); - case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); - case "identity": return new Identity(modelName, nodeName, inputs); - case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); - case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); - case "matmul": return new MatMul(modelName, nodeName, inputs); - case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); - case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); - case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); - case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); - case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); - case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "shape": return new Shape(modelName, nodeName, inputs); - case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); - case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); - case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); - case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); - } - - IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); - return op; - } - - public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { - Onnx.GraphProto onnxGraph = model.getGraph(); - - IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); - importOperations(onnxGraph, intermediateGraph); - verifyOutputTypes(onnxGraph, intermediateGraph); - - return intermediateGraph; - } - - private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { - for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) { - importOperation(valueInfo.getName(), onnxGraph, intermediateGraph); - } - } - - private static IntermediateOperation importOperation(String name, - Onnx.GraphProto onnxGraph, - IntermediateGraph intermediateGraph) { - if (intermediateGraph.alreadyImported(name)) { - return intermediateGraph.get(name); - } - IntermediateOperation operation; - if (isArgumentTensor(name, onnxGraph)) { - Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); - if (valueInfoProto == null) - throw new IllegalArgumentException("Could not find argument tensor: " + name); - OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); - operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); - - intermediateGraph.inputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); - - } else if (isConstantTensor(name, onnxGraph)) { - Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); - operation = new Constant(intermediateGraph.name(), name, defaultType); - operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); - - } else { - Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); - List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs, intermediateGraph); - - if (isOutputNode(name, onnxGraph)) { - intermediateGraph.outputs(intermediateGraph.defaultSignature()) - .put(IntermediateOperation.namePartOf(name), operation.vespaName()); - } - } - intermediateGraph.put(operation.vespaName(), operation); - - return operation; - } - - private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor == null; - } - - private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; - } - - private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - } - return null; - } - - private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { - for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { - if (tensorProto.getName().equals(name)) { - return tensorProto; - } - } - return null; - } - - private static boolean isOutputNode(String name, Onnx.GraphProto graph) { - return getOutputNode(name, graph) != null; - } - - private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - String nodeName = IntermediateOperation.namePartOf(valueInfo.getName()); - if (nodeName.equals(name)) { - return valueInfo; - } - } - return null; - } - - private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node, - Onnx.GraphProto onnxGraph, - IntermediateGraph intermediateGraph) { - return node.getInputList().stream() - .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph)) - .collect(Collectors.toList()); - } - - private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { - for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) { - IntermediateOperation operation = intermediateGraph.get(outputName); - Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph); - OrderedTensorType type = operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); - TypeConverter.verifyType(onnxNode.getType(), type); - } - } - - private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - boolean hasPortNumber = nodeName.contains(":"); - for (Onnx.NodeProto node : graph.getNodeList()) { - if (hasPortNumber) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return node; - } - } - } else if (node.getName().equals(nodeName)) { - return node; - } - } - throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java deleted file mode 100644 index 715c55d8323..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import onnx.Onnx; - -/** - * Converts and verifies ONNX tensor types into Vespa tensor types. - * - * @author lesters - */ -public class TypeConverter { - - public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { - Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); - } - for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) { - int vespaIndex = type.dimensionMap(onnxIndex); - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); - TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); - if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions"); - } - } - } - } - - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { - return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - if (onnxDimension.getDimValue() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java deleted file mode 100644 index 19ba146492c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends IntermediateOperation { - - public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java deleted file mode 100644 index a815cbc3944..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java +++ /dev/null @@ -1,85 +0,0 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; - -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; - -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Converts TensorFlow node attributes to Vespa attribute values. - * - * @author lesters - */ -public class AttributeConverter implements IntermediateOperation.AttributeMap { - - private final Map<String, AttrValue> attributeMap; - - public AttributeConverter(NodeDef node) { - attributeMap = node.getAttrMap(); - } - - public static AttributeConverter convert(NodeDef node) { - return new AttributeConverter(node); - } - - @Override - public Optional<Value> get(String key) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return Optional.empty(); // requires type - } - if (attrValue.getValueCase() == AttrValue.ValueCase.B) { - return Optional.of(new BooleanValue(attrValue.getB())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.I) { - return Optional.of(new DoubleValue(attrValue.getI())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.F) { - return Optional.of(new DoubleValue(attrValue.getF())); - } - } - return Optional.empty(); - } - - @Override - public Optional<Value> get(String key, OrderedTensorType type) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type()))); - } - } - return get(key); - } - - @Override - public Optional<List<Value>> getList(String key) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) { - AttrValue.ListValue listValue = attrValue.getList(); - if ( ! listValue.getBList().isEmpty()) { - return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList())); - } - if ( ! listValue.getIList().isEmpty()) { - return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList())); - } - if ( ! listValue.getFList().isEmpty()) { - return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList())); - } - // add the rest - } - } - return Optional.empty(); - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java deleted file mode 100644 index e1b292f9e61..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; - -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch; -import com.yahoo.tensor.functions.ScalarFunctions; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; - -import java.io.IOException; -import java.util.List; -import java.util.stream.Collectors; - -/** - * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis - * for generating Vespa ranking expressions. - * - * @author lesters - */ -public class GraphImporter { - - public static IntermediateOperation mapOperation(NodeDef node, - List<IntermediateOperation> inputs, - IntermediateGraph graph) { - String nodeName = node.getName(); - String modelName = graph.name(); - int nodePort = IntermediateOperation.indexPartOf(nodeName); - OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node); - AttributeConverter attributes = AttributeConverter.convert(node); - - switch (node.getOp().toLowerCase()) { - // array ops - case "concatv2": return new ConcatV2(modelName, nodeName, inputs); - case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); - case "expanddims": return new ExpandDims(modelName, nodeName, inputs); - case "identity": return new Identity(modelName, nodeName, inputs); - case "placeholder": return new Argument(modelName, nodeName, nodeType); - case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "shape": return new Shape(modelName, nodeName, inputs); - case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); - - // control flow - case "merge": return new Merge(modelName, nodeName, inputs); - case "switch": return new Switch(modelName, nodeName, inputs, nodePort); - - // math ops - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); - case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); - case "matmul": return new MatMul(modelName, nodeName, inputs); - case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "mean": return new Mean(modelName, nodeName, inputs, attributes); - case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); - case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); - case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); - case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); - case "select": return new Select(modelName, nodeName, inputs); - case "where3": return new Select(modelName, nodeName, inputs); - case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); - case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - - // nn ops - case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - - // state ops - case "variable": return new Constant(modelName, nodeName, nodeType); - case "variablev2": return new Constant(modelName, nodeName, nodeType); - - // evaluation no-ops - case "stopgradient":return new Identity(modelName, nodeName, inputs); - case "noop": return new NoOp(modelName, nodeName, inputs); - - } - - IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); - op.warning("Operation '" + node.getOp() + "' is currently not implemented"); - return op; - } - - public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { - MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); - - IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); - importSignatures(tfGraph, intermediateGraph); - importOperations(tfGraph, intermediateGraph, bundle); - verifyOutputTypes(tfGraph, intermediateGraph); - - return intermediateGraph; - } - - private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { - for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) { - String signatureName = signatureEntry.getKey(); - java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); - for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { - String inputName = input.getKey(); - String nodeName = input.getValue().getName(); - intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName)); - } - java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); - for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { - String outputName = output.getKey(); - String nodeName = output.getValue().getName(); - intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName)); - } - } - } - - private static void importOperations(MetaGraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - for (String signatureName : intermediateGraph.signatures()) { - for (String outputName : intermediateGraph.outputs(signatureName).values()) { - importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle); - } - } - } - - private static IntermediateOperation importOperation(String nodeName, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - if (intermediateGraph.alreadyImported(nodeName)) { - return intermediateGraph.get(nodeName); - } - NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph); - List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle); - IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph); - intermediateGraph.put(nodeName, operation); - - List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle); - if (controlInputs.size() > 0) { - operation.setControlInputs(controlInputs); - } - - if (operation.isConstant()) { - operation.setConstantValueFunction( - type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type))); - } - - return operation; - } - - private static List<IntermediateOperation> importOperationInputs(NodeDef node, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - return node.getInputList().stream() - .filter(name -> ! isControlDependency(name)) - .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) - .collect(Collectors.toList()); - } - - private static List<IntermediateOperation> importControlInputs(NodeDef node, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - return node.getInputList().stream() - .filter(nodeName -> isControlDependency(nodeName)) - .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) - .collect(Collectors.toList()); - } - - private static boolean isControlDependency(String name) { - return name.startsWith("^"); - } - - private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) { - for (NodeDef node : tfGraph.getNodeList()) { - if (node.getName().equals(name)) { - return node; - } - } - throw new IllegalArgumentException("Could not find node '" + name + "'"); - } - - public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { - Session.Runner fetched = bundle.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + - ", but got " + importedTensors.size()); - return importedTensors.get(0); - } - - private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { - for (String signatureName : intermediateGraph.signatures()) { - for (String outputName : intermediateGraph.outputs(signatureName).values()) { - IntermediateOperation operation = intermediateGraph.get(outputName); - NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef()); - OrderedTensorType type = operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); - TypeConverter.verifyType(node, type); - } - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java deleted file mode 100644 index 67ad1edc312..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; - -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorShapeProto; - -import java.util.List; - -/** - * Converts and verifies TensorFlow tensor types into Vespa tensor types. - * - * @author lesters - */ -public class TypeConverter { - - public static void verifyType(NodeDef node, OrderedTensorType type) { - TensorShapeProto shape = tensorFlowShape(node); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); - } - for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { - int vespaIndex = type.dimensionMap(tensorFlowIndex); - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); - TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); - if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); - } - } - } - } - - private static TensorShapeProto tensorFlowShape(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node) { - return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - TensorShapeProto shape = tensorFlowShape(node); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); - if (tensorFlowDimension.getSize() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java new file mode 100644 index 00000000000..fa1f929cc80 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java @@ -0,0 +1,326 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.onnx; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; +import onnx.Onnx; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Converts a ONNX model into a ranking expression and set of constants. + * + * @author lesters + */ +public class OnnxImporter { + + private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); + + public OnnxModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + public OnnxModel importModel(String modelName, String modelPath) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + return importModel(modelName, model); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + + public OnnxModel importModel(String modelName, Onnx.ModelProto model) { + return importGraph(modelName, model.getGraph()); + } + + private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) { + OnnxModel model = new OnnxModel(modelName); + OperationIndex index = new OperationIndex(); + + importNodes(graph, model, index); + verifyOutputTypes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index); + + reportWarnings(model, index); + + return model; + } + + private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + importNode(valueInfo.getName(), graph, model, index); + } + } + + private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + if (index.alreadyImported(name)) { + return index.get(name); + } + OnnxOperation operation; + if (isArgumentTensor(name, graph)) { + operation = new Argument(getArgumentTensor(name, graph)); + model.input(OnnxOperation.namePartOf(name), operation.vespaName()); + } else if (isConstantTensor(name, graph)) { + operation = new Constant(model.name(), getConstantTensor(name, graph)); + } else { + Onnx.NodeProto node = getNodeFromGraph(name, graph); + List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index); + operation = OperationMapper.get(node, inputs); + if (isOutputNode(name, graph)) { + model.output(OnnxOperation.namePartOf(name), operation.vespaName()); + } + } + index.put(operation.vespaName(), operation); + + return operation; + } + + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor == null; + } + + private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor != null; + } + + private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + } + return null; + } + + private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { + for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { + if (tensorProto.getName().equals(name)) { + return tensorProto; + } + } + return null; + } + + private static boolean isOutputNode(String name, Onnx.GraphProto graph) { + return getOutputNode(name, graph) != null; + } + + private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + String nodeName = OnnxOperation.namePartOf(valueInfo.getName()); + if (nodeName.equals(name)) { + return valueInfo; + } + } + return null; + } + + private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, + Onnx.GraphProto graph, + OnnxModel model, + OperationIndex index) { + return node.getInputList().stream() + .map(nodeName -> importNode(nodeName, graph, model, index)) + .collect(Collectors.toList()); + } + + private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + OnnxOperation operation = index.get(outputName); + Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph); + operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")) + .verifyType(onnxNode.getType()); + } + } + + + /** Find dimension names to avoid excessive renaming while evaluating the model. */ + private static void findDimensionNames(OnnxModel model, OperationIndex index) { + DimensionRenamer renamer = new DimensionRenamer(); + for (String output : model.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + renamer.solve(); + for (String output : model.outputs().values()) { + renameDimensions(index.get(output), renamer); + } + } + + private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + + private static void importExpressions(OnnxModel model, OperationIndex index) { + for (String outputName : model.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(index.get(outputName), model); + if (!function.isPresent()) { + model.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + model.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } + + private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(operation, model); + } + importInputExpressions(operation, model); + importRankingExpression(operation, model); + importArgumentExpression(operation, model); + + return operation.function(); + } + + private static void importInputExpressions(OnnxOperation operation, OnnxModel model) { + operation.inputs().forEach(input -> importExpression(input, model)); + } + + private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Value value = operation.getConstantValue().orElseThrow(() -> + new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + + "is constant but does not have a value.")); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + + Tensor tensor = value.asTensor(); + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + private static void importRankingExpression(OnnxOperation operation, OnnxModel model) { + if (operation.function().isPresent()) { + String name = operation.vespaName(); + if (!model.expressions().containsKey(name)) { + TensorFunction function = operation.function().get(); + + if (model.outputs().containsKey(name)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced from the output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) { + if (operation.isInput()) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); + model.argument(operation.vespaName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + } + } + + private static void reportWarnings(OnnxModel model, OperationIndex index) { + for (String output : model.outputs().values()) { + reportWarnings(model, index.get(output)); + } + } + + private static void reportWarnings(OnnxModel model, OnnxOperation operation) { + for (String warning : operation.warnings()) { + model.importWarning(warning); + } + for (OnnxOperation input : operation.inputs()) { + reportWarnings(model, input); + } + } + + private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { + boolean hasPortNumber = nodeName.contains(":"); + for (Onnx.NodeProto node : graph.getNodeList()) { + if (hasPortNumber) { + for (String outputName : node.getOutputList()) { + if (outputName.equals(nodeName)) { + return node; + } + } + } else if (node.getName().equals(nodeName)) { + return node; + } + } + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); + } + + private static class OperationIndex { + private final Map<String, OnnxOperation> index = new HashMap<>(); + public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); } + public OnnxOperation get(String key) { return index.get(key); } + public boolean alreadyImported(String key) { return index.containsKey(key); } + public Collection<OnnxOperation> operations() { return index.values(); } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java new file mode 100644 index 00000000000..bd53afefc3f --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java @@ -0,0 +1,112 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.onnx; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * The result of importing an ONNX model into Vespa. + * + * @author bratseth + * @author lesters + */ +public class OnnxModel { + + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + + private final String name; + + public OnnxModel(String name) { + if ( ! nameRegexp.matcher(name).matches()) + throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); + this.name = name; + } + + /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + public String name() { return name; } + + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); + private final List<String> importWarnings = new ArrayList<>(); + + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } + + /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */ + public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); } + + /** Returns an immutable map of the arguments (inputs) of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + + /** + * Returns an immutable map of the small constants of this. + */ + public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + + /** + * Returns an immutable map of the large constants of this. + */ + public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + + /** + * Returns an immutable map of the expressions of this - corresponding to ONNX nodes + * which are not inputs or constants. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + + /** Returns an immutable map of macros that are part of this model */ + public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + + /** Returns an immutable map of the macros that must be provided by the environment running this model */ + public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java index 38f1d2329e2..2524417cee0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; import java.util.ArrayDeque; import java.util.ArrayList; @@ -47,7 +47,7 @@ public class DimensionRenamer { /** * Add a constraint between dimension names. */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { + public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) { Arc arc = new Arc(from, to, operation); Arc opposite = arc.opposite(); constraints.put(arc, pred); @@ -175,9 +175,9 @@ public class DimensionRenamer { private final String from; private final String to; - private final IntermediateOperation operation; + private final OnnxOperation operation; - Arc(String from, String to, IntermediateOperation operation) { + Arc(String from, String to, OnnxOperation operation) { this.from = from; this.to = to; this.operation = operation; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java new file mode 100644 index 00000000000..12090145d3a --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java @@ -0,0 +1,26 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.tensor.functions.ScalarFunctions; +import onnx.Onnx; + +import java.util.List; + +public class OperationMapper { + + public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) { + switch (node.getOpType().toLowerCase()) { + case "add": return new Join(node, inputs, ScalarFunctions.add()); + case "matmul": return new MatMul(node, inputs); + } + + OnnxOperation op = new NoOp(node, inputs); + op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + return op; + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java index 209d73a9f38..812e9b8d678 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.TensorTypeParser; +import onnx.Onnx; import java.util.ArrayList; import java.util.Collections; @@ -13,9 +13,9 @@ import java.util.stream.Collectors; /** * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. Imported tensors have an explicit ordering of their dimensions. + * names. ONNX tensors have an explicit ordering of their dimensions. * During import, we need to track the Vespa dimension that matches the - * corresponding imported dimension as the ordering can change after + * corresponding ONNX dimension as the ordering can change after * dimension renaming. That is the purpose of this class. * * @author lesters @@ -25,14 +25,14 @@ public class OrderedTensorType { private final TensorType type; private final List<TensorType.Dimension> dimensions; - private final long[] innerSizesOriginal; + private final long[] innerSizesOnnx; private final long[] innerSizesVespa; private final int[] dimensionMap; private OrderedTensorType(List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesOriginal = new long[dimensions.size()]; + this.innerSizesOnnx = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); } @@ -54,10 +54,10 @@ public class OrderedTensorType { if (numDimensions == 0) { return null; } - innerSizesOriginal[numDimensions - 1] = 1; + innerSizesOnnx[numDimensions - 1] = 1; innerSizesVespa[numDimensions - 1] = 1; for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1]; + innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1]; innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; } int[] mapping = new int[numDimensions]; @@ -74,15 +74,11 @@ public class OrderedTensorType { return mapping; } - public int dimensionMap(int originalIndex) { - return dimensionMap[originalIndex]; - } - /** - * When dimension ordering between Vespa and imported differs, i.e. + * When dimension ordering between Vespa and Onnx differs, i.e. * after dimension renaming, use the dimension map to read in values * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors. + * Used when importing tensors from Onnx. */ public int toDirectIndex(int index) { if (dimensions.size() == 0) { @@ -94,9 +90,9 @@ public class OrderedTensorType { int directIndex = 0; long rest = index; for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesOriginal[i]; + long address = rest / innerSizesOnnx[i]; directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesOriginal[i]; + rest %= innerSizesOnnx[i]; } return directIndex; } @@ -120,6 +116,22 @@ public class OrderedTensorType { return true; } + public void verifyType(Onnx.TypeProto typeProto) { + Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); + } + for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) { + int vespaIndex = dimensionMap[onnxIndex]; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); + TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); + if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions"); + } + } + } + } public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); for (TensorType.Dimension dimension : dimensions) { @@ -139,13 +151,18 @@ public class OrderedTensorType { return new OrderedTensorType(renamedDimensions); } - public OrderedTensorType rename(String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dimensions.size(); ++ i) { + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + Builder builder = new Builder(shape); + for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; - Optional<Long> dimSize = dimensions.get(i).size(); - if (dimSize.isPresent() && dimSize.get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get())); + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + if (onnxDimension.getDimValue() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -153,13 +170,13 @@ public class OrderedTensorType { return builder.build(); } - public static OrderedTensorType standardType(OrderedTensorType type) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < type.dimensions().size(); ++ i) { - TensorType.Dimension dim = type.dimensions().get(i); - String dimensionName = "d" + i; - if (dim.size().isPresent() && dim.size().get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); + public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) { + Builder builder = new Builder(); + for (int i = 0; i < dims.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Long dimSize = dims.get(i); + if (dimSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -167,46 +184,13 @@ public class OrderedTensorType { return builder.build(); } - public static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - public static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - /** - * Returns a string representation of this: A standard tensor type string where dimensions - * are listed in the order of this rather than in the natural order of their names. - */ - @Override - public String toString() { - return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; - } - - /** - * Creates an instance from the string representation of this: A standard tensor type string - * where dimensions are listed in the order of this rather than the natural order of their names. - */ - public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); - } - - public static OrderedTensorType fromDimensionList(List<Long> dims) { - return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dims.size(); ++ i) { - String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); - if (dimSize >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); + public static OrderedTensorType standardType(OrderedTensorType type) { + Builder builder = new Builder(); + for (int i = 0; i < type.dimensions().size(); ++ i) { + TensorType.Dimension dim = type.dimensions().get(i); + String dimensionName = "d" + i; + if (dim.size().isPresent() && dim.size().get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -216,13 +200,45 @@ public class OrderedTensorType { public static class Builder { + private final Onnx.TensorShapeProto shape; private final List<TensorType.Dimension> dimensions; + public Builder(Onnx.TensorShapeProto shape) { + this.shape = shape; + this.dimensions = new ArrayList<>(shape.getDimCount()); + } + public Builder() { + this.shape = null; this.dimensions = new ArrayList<>(); } public Builder add(TensorType.Dimension vespaDimension) { + if (shape != null) { + int index = dimensions.size(); + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index); + long size = onnxDimension.getDimValue(); + if (size >= 0) { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension types"); + } + if (!vespaDimension.size().isPresent()) { + throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + + "not have a size"); + } + if (vespaDimension.size().get() != size) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension sizes. TensorFlow: " + size + " Vespa: " + + vespaDimension.size().get()); + } + } else { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension types"); + } + } + } this.dimensions.add(vespaDimension); return this; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java index 18856d4a25f..2912db03b5f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java @@ -1,16 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; import com.google.protobuf.ByteString; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import onnx.Onnx; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.util.List; /** * Converts Onnx tensors into Vespa tensors. @@ -28,6 +29,7 @@ public class TensorConverter { return builder.build(); } + /* todo: support more types */ private static Values readValuesOf(Onnx.TensorProto tensorProto) { if (tensorProto.hasRawData()) { switch (tensorProto.getDataType()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java new file mode 100644 index 00000000000..a8d8d63daf4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java @@ -0,0 +1,64 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.List; + +public class Argument extends OnnxOperation { + + private Onnx.ValueInfoProto valueInfo; + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... + + public Argument(Onnx.ValueInfoProto valueInfoProto) { + super(null, Collections.emptyList()); + valueInfo = valueInfoProto; + standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType()); + } + + @Override + public String vespaName() { + return vespaName(valueInfo.getName()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + if (!standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isInput() { + return true; + } + + @Override + public boolean isConstant() { + return false; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java index 5e4abeaa234..13043a61a8e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java @@ -1,34 +1,38 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; import java.util.Collections; import java.util.Optional; -public class Constant extends IntermediateOperation { +public class Constant extends OnnxOperation { - private final String modelName; + final String modelName; + final Onnx.TensorProto tensorProto; - public Constant(String modelName, String nodeName, OrderedTensorType type) { - super(modelName, nodeName, Collections.emptyList()); + public Constant(String modelName, Onnx.TensorProto tensorProto) { + super(null, Collections.emptyList()); this.modelName = modelName; - this.type = type.rename(vespaName() + "_"); + this.tensorProto = tensorProto; } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + vespaName(name); + return modelName + "_" + vespaName(tensorProto.getName()); } @Override protected OrderedTensorType lazyGetType() { - return type; + return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_"); } @Override @@ -36,14 +40,9 @@ public class Constant extends IntermediateOperation { return null; // will be added by function() since this is constant. } - /** - * Constant values are sent in via the constantValueFunction, as the - * dimension names and thus the data layout depends on the dimension - * renaming which happens after the conversion to intermediate graph. - */ @Override public Optional<Value> getConstantValue() { - return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type)); + return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java index 8413ed74118..fe2004a528d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java @@ -1,22 +1,24 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.function.DoubleBinaryOperator; -public class Join extends IntermediateOperation { +public class Join extends OnnxOperation { private final DoubleBinaryOperator operator; - public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { - super(modelName, nodeName, inputs); + public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) { + super(node, inputs); this.operator = operator; } @@ -59,8 +61,8 @@ public class Join extends IntermediateOperation { return null; } - IntermediateOperation a = largestInput(); - IntermediateOperation b = smallestInput(); + OnnxOperation a = largestInput(); + OnnxOperation b = smallestInput(); List<String> aDimensionsToReduce = new ArrayList<>(); List<String> bDimensionsToReduce = new ArrayList<>(); @@ -105,13 +107,13 @@ public class Join extends IntermediateOperation { } } - private IntermediateOperation largestInput() { + private OnnxOperation largestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); } - private IntermediateOperation smallestInput() { + private OnnxOperation smallestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java index 52e223f9518..1b388e2ae89 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java @@ -1,18 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.function.DoubleBinaryOperator; -public class MatMul extends IntermediateOperation { +public class MatMul extends OnnxOperation { - public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) { + super(node, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java new file mode 100644 index 00000000000..b1136a0ce0a --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends OnnxOperation { + + public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) { + super(node, Collections.emptyList()); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java new file mode 100644 index 00000000000..30f7b4f4711 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java @@ -0,0 +1,139 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Wraps an ONNX node and produces the respective Vespa tensor operation. + * During import, a graph of these operations are constructed. Then, the + * types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper + * Vespa expressions can be extracted. + * + * @author lesters + */ +public abstract class OnnxOperation { + + protected final Onnx.NodeProto node; // can be null for onnx inputs and constants + protected final List<OnnxOperation> inputs; + protected final List<OnnxOperation> outputs = new ArrayList<>(); + protected final List<String> importWarnings = new ArrayList<>(); + + protected OrderedTensorType type; + protected TensorFunction function; + protected Value constantValue = null; + + OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) { + this.node = node; + this.inputs = Collections.unmodifiableList(inputs); + this.inputs.forEach(i -> i.outputs.add(this)); + } + + protected abstract OrderedTensorType lazyGetType(); + protected abstract TensorFunction lazyGetFunction(); + + /** Returns the Vespa tensor type of this operation if it exists */ + public Optional<OrderedTensorType> type() { + if (type == null) { + type = lazyGetType(); + } + return Optional.ofNullable(type); + } + + /** Returns the Vespa tensor function implementing all operations from this node with inputs */ + public Optional<TensorFunction> function() { + if (function == null) { + if (isConstant()) { + ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); + function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else { + function = lazyGetFunction(); + } + } + return Optional.ofNullable(function); + } + + /** Return Onnx node */ + public Onnx.NodeProto node() { return node; } + + /** Return unmodifiable list of inputs */ + public List<OnnxOperation> inputs() { return inputs; } + + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ + public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); } + + /** Add dimension name constraints for this operation */ + public void addDimensionNameConstraints(DimensionRenamer renamer) { } + + /** Performs dimension rename for this operation */ + public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + + /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ + public boolean isInput() { return false; } + + /** Return true if this node is constant */ + public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); } + + /** Gets the constant value if it exists */ + public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + + /** Retrieve the valid Vespa name of this node */ + public String vespaName() { return vespaName(node.getName()); } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } + + /** Retrieve the list of warnings produced during its lifetime */ + public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + + /** Set an input warning */ + public void warning(String warning) { importWarnings.add(warning); } + + boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) { + if (inputs.size() != expected) { + throw new IllegalArgumentException("Expected " + expected + " inputs " + + "for '" + node.getName() + "', got " + inputs.size()); + } + return inputs.stream().map(func).allMatch(Optional::isPresent); + } + + boolean allInputTypesPresent(int expected) { + return verifyInputs(expected, OnnxOperation::type); + } + + boolean allInputFunctionsPresent(int expected) { + return verifyInputs(expected, OnnxOperation::function); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java index 1530754cc43..5cff8b03d40 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java @@ -3,6 +3,6 @@ * ONNX integration */ @ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.onnx; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java new file mode 100644 index 00000000000..e3c72830095 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -0,0 +1,411 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.MetaGraphDef; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.SignatureDef; +import org.tensorflow.framework.TensorInfo; + +import java.io.File; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Converts a saved TensorFlow model into a ranking expression and set of constants. + * + * @author bratseth + * @author lesters + */ +public class TensorFlowImporter { + + private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); + + /** + * Imports a saved TensorFlow model from a directory. + * The model should be saved as a .pbtxt or .pb file. + * The name of the model is taken as the db/pbtxt file name (not including the file ending). + * + * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] + * @param modelDir the directory containing the TensorFlow model files to import + */ + public TensorFlowModel importModel(String modelName, String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + + return importModel(modelName, model); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } + } + + public TensorFlowModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + /** Imports a TensorFlow model */ + public TensorFlowModel importModel(String modelName, SavedModelBundle model) { + try { + return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); + } + } + + /** + * Imports the TensorFlow graph by first importing the tensor types, then + * finding a suitable set of dimensions names for each + * placeholder/constant/variable, then importing the expressions. + */ + private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) { + TensorFlowModel model = new TensorFlowModel(modelName); + OperationIndex index = new OperationIndex(); + + importSignatures(graph, model); + importNodes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index, bundle); + + reportWarnings(model, index); + logVariableTypes(index); + + return model; + } + + private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) { + for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { + String signatureName = signatureEntry.getKey(); + TensorFlowModel.Signature signature = model.signature(signatureName); + + Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); + for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { + String inputName = input.getKey(); + signature.input(inputName, namePartOf(input.getValue().getName())); + } + + Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); + for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { + String outputName = output.getKey(); + signature.output(outputName, namePartOf(output.getValue().getName())); + } + } + } + + private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String inputName : signature.inputs().values()) { + if (inputName.equals(operation.node().getName())) { + return true; + } + } + } + return false; + } + + private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.node().getName())) { + return true; + } + } + } + return false; + } + + private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + importNode(model.name(), outputName, graph.getGraphDef(), index); + } + } + } + + private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) { + if (index.alreadyImported(nodeName)) { + return index.get(nodeName); + } + NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph); + List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index); + TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName)); + index.put(nodeName, operation); + + List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index); + if (controlInputs.size() > 0) { + operation.setControlInputs(controlInputs); + } + + return operation; + } + + private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(name -> ! isControlDependency(name)) + .map(nodeName -> importNode(modelName, nodeName, graph, index)) + .collect(Collectors.toList()); + } + + private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(nodeName -> isControlDependency(nodeName)) + .map(nodeName -> importNode(modelName, nodeName, graph, index)) + .collect(Collectors.toList()); + } + + private static boolean isControlDependency(String name) { + return name.startsWith("^"); + } + + /** Find dimension names to avoid excessive renaming while evaluating the model. */ + private static void findDimensionNames(TensorFlowModel model, OperationIndex index) { + DimensionRenamer renamer = new DimensionRenamer(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + renameDimensions(index.get(output), renamer); + } + } + } + + private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + + private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } + } + + private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(model, operation, bundle); + } + + importInputExpressions(operation, model, bundle); + importRankingExpression(model, operation); + importInputExpression(model, operation); + importMacroExpression(model, operation); + + return operation.function(); + } + + private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, + SavedModelBundle bundle) { + operation.inputs().forEach(input -> importExpression(input, model, bundle)); + } + + private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.macro().isPresent()) { + TensorFunction function = operation.macro().get(); + try { + model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + + private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, + SavedModelBundle bundle) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Tensor tensor; + if (operation.getConstantValue().isPresent()) { + Value value = operation.getConstantValue().get(); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + tensor = value.asTensor(); + } else { + // Here we use the type from the operation, which will have correct dimension names after name resolving + tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), + operation.type().get()); + operation.setConstantValue(new TensorValue(tensor)); + } + + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + Session.Runner fetched = bundle.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + + ", but got " + importedTensors.size()); + return importedTensors.get(0); + } + + private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.function().isPresent()) { + String name = operation.node().getName(); + if (!model.expressions().containsKey(operation.node().getName())) { + TensorFunction function = operation.function().get(); + + // Make sure output adheres to standard naming convention + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced in a signature output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.isInput() && isSignatureInput(model, operation)) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); + model.argument(operation.node().getName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + } + } + + private static void reportWarnings(TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + reportWarnings(index.get(output), signature); + } + } + } + + /** + * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ + private static void logVariableTypes(OperationIndex index) { + for (TensorFlowOperation operation : index.operations()) { + if ( ! (operation instanceof Variable)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + + log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } + } + + private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { + for (String warning : operation.warnings()) { + signature.importWarning(warning); + } + for (TensorFlowOperation input : operation.inputs()) { + reportWarnings(input, signature); + } + } + + private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { + for (NodeDef node : graph.getNodeList()) { + if (node.getName().equals(name)) { + return node; + } + } + throw new IllegalArgumentException("Could not find node '" + name + "'"); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + private static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output port part. Indexes are used for nodes with + * multiple outputs. + */ + private static int portPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + private static class OperationIndex { + + private final Map<String, TensorFlowOperation> index = new HashMap<>(); + public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } + public TensorFlowOperation get(String key) { return index.get(key); } + public boolean alreadyImported(String key) { return index.containsKey(key); } + public Collection<TensorFlowOperation> operations() { return index.values(); } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 4b49f17f74e..721214f9e94 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -1,4 +1,5 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -12,61 +13,76 @@ import java.util.Map; import java.util.regex.Pattern; /** - * The result of importing a model (TensorFlow or ONNX) into Vespa. + * The result of importing a TensorFlow model into Vespa. + * - A set of signatures which are named collections of inputs and outputs. + * - A set of named constant tensors represented by Variable nodes in TensorFlow. + * - A list of warning messages. * * @author bratseth */ -public class ImportedModel { - - private static final String defaultSignatureName = "default"; +// This object can be built incrementally within this package, but is immutable when observed from outside the package +public class TensorFlowModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); - private final String name; - private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); + private final String name; /** - * Creates a new imported model. + * Creates a TensorFlow model * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] */ - public ImportedModel(String name) { + public TensorFlowModel(String name) { if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + - name + "'"); + throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); this.name = name; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } + private final Map<String, Signature> signatures = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); + + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, Signature::new); + } + /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** * Returns an immutable map of the small constants of this. * These should have sizes up to a few kb at most, and correspond to constant - * values given in the TensorFlow or ONNX source. + * values given in the TensorFlow source. */ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } /** * Returns an immutable map of the large constants of this. - * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. - * For TensorFlow this corresponds to Variable files stored separately. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration, + * and correspond to Variable files stored separately in TensorFlow. */ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } /** - * Returns an immutable map of the expressions of this - corresponding to graph nodes - * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants). - * Note that only nodes recursively referenced by a placeholder/input are added. + * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes + * which are not Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder are added. */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } @@ -79,26 +95,9 @@ public class ImportedModel { /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } - /** Returns the given signature. If it does not already exist it is added to this. */ - Signature signature(String name) { - return signatures.computeIfAbsent(name, Signature::new); - } - - /** Convenience method for returning a default signature */ - Signature defaultSignature() { return signature(defaultSignatureName); } - - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - /** - * A signature is a set of named inputs and outputs, where the inputs maps to argument - * ("placeholder") names+types, and outputs maps to expressions nodes. - * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit - * concept of signatures. For now, we handle ONNX models as having a single signature. + * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, + * and outputs maps to expressions nodes. */ public class Signature { @@ -108,14 +107,19 @@ public class ImportedModel { private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); - public Signature(String name) { + Signature(String name) { this.name = name; } + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } + public String name() { return name; } /** Returns the result this is part of */ - public ImportedModel owner() { return ImportedModel.this; } + TensorFlowModel owner() { return TensorFlowModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name @@ -123,7 +127,7 @@ public class ImportedModel { */ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - /** Returns the type of the argument this input references */ + /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ @@ -140,17 +144,12 @@ public class ImportedModel { */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - /** Returns the expression this output references */ + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } @Override public String toString() { return "signature '" + name + "'"; } - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java index e1294ec3e01..c5ac7ace0fc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java @@ -1,8 +1,7 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; @@ -25,7 +24,7 @@ public class VariableConverter { */ public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { - return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, + return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName, bundle), OrderedTensorType.fromSpec(orderedTypeSpec))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java new file mode 100644 index 00000000000..c1665d066a4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java @@ -0,0 +1,210 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * A constraint satisfier to find suitable dimension names to reduce the + * amount of necessary renaming during evaluation of an imported model. + * + * @author lesters + */ +public class DimensionRenamer { + + private final String dimensionPrefix; + private final Map<String, List<Integer>> variables = new HashMap<>(); + private final Map<Arc, Constraint> constraints = new HashMap<>(); + private final Map<String, Integer> renames = new HashMap<>(); + + private int iterations = 0; + + public DimensionRenamer() { + this("d"); + } + + public DimensionRenamer(String dimensionPrefix) { + this.dimensionPrefix = dimensionPrefix; + } + + /** + * Add a dimension name variable. + */ + public void addDimension(String name) { + variables.computeIfAbsent(name, d -> new ArrayList<>()); + } + + /** + * Add a constraint between dimension names. + */ + public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) { + Arc arc = new Arc(from, to, operation); + Arc opposite = arc.opposite(); + constraints.put(arc, pred); + constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + } + + /** + * Retrieve resulting name of dimension after solving for constraints. + */ + public Optional<String> dimensionNameOf(String name) { + if (!renames.containsKey(name)) { + return Optional.empty(); + } + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + } + + /** + * Perform iterative arc consistency until we have found a solution. After + * an initial iteration, the variables (dimensions) will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equals, lesserThan and greaterThan do that, but adding notEquals does + * not typically result in a guaranteed ordering. If that is needed, the + * algorithm below needs to be adapted with a backtracking (tree) search + * to find solutions. + */ + public void solve(int maxIterations) { + initialize(); + + // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if (!ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + } + values.sort(Integer::compare); + variables.put(dimension, Collections.singletonList(values.get(0))); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + + maxIterations + " iterations"); + } + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + + public void solve() { + solve(100000); + } + + private void initialize() { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); + while (!workList.isEmpty()) { + Arc arc = workList.pop(); + iterations += 1; + if (revise(arc)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private boolean revise(Arc arc) { + boolean revised = false; + for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).test(from, to)) { + satisfied = true; + } + } + if (!satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public interface Constraint { + boolean test(Integer x, Integer y); + } + + public static boolean equals(Integer x, Integer y) { + return Objects.equals(x, y); + } + + public static boolean lesserThan(Integer x, Integer y) { + return x < y; + } + + public static boolean greaterThan(Integer x, Integer y) { + return x > y; + } + + private static class Arc { + + private final String from; + private final String to; + private final TensorFlowOperation operation; + + Arc(String from, String to, TensorFlowOperation operation) { + this.from = from; + this.to = to; + this.operation = operation; + } + + Arc opposite() { + return new Arc(to, from, operation); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof Arc)) { + return false; + } + Arc other = (Arc) obj; + return Objects.equals(from, other.from) && Objects.equals(to, other.to); + } + + @Override + public String toString() { + return String.format("%s -> %s", from, to); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java new file mode 100644 index 00000000000..b665413a6b2 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -0,0 +1,97 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; +import com.yahoo.tensor.functions.ScalarFunctions; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +/** + * Maps from TensorFlow operations to Vespa operations. + * + * @author bratseth + * @author lesters + */ +public class OperationMapper { + + public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + switch (node.getOp().toLowerCase()) { + // array ops + case "concatv2": return new ConcatV2(modelName, node, inputs, port); + case "const": return new Const(modelName, node, inputs, port); + case "expanddims": return new ExpandDims(modelName, node, inputs, port); + case "identity": return new Identity(modelName, node, inputs, port); + case "placeholder": return new Placeholder(modelName, node, inputs, port); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port); + case "reshape": return new Reshape(modelName, node, inputs, port); + case "shape": return new Shape(modelName, node, inputs, port); + case "squeeze": return new Squeeze(modelName, node, inputs, port); + + // control flow + case "merge": return new Merge(modelName, node, inputs, port); + case "switch": return new Switch(modelName, node, inputs, port); + + // math ops + case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos()); + case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); + case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor()); + case "matmul": return new Matmul(modelName, node, inputs, port); + case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max()); + case "mean": return new Mean(modelName, node, inputs, port); + case "reducemean": return new Mean(modelName, node, inputs, port); + case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, node, inputs, port); + case "where3": return new Select(modelName, node, inputs, port); + case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); + + // nn ops + case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); + case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu()); + case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu()); + case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu()); + + // state ops + case "variable": return new Variable(modelName, node, inputs, port); + case "variablev2": return new Variable(modelName, node, inputs, port); + + // evaluation no-ops + case "stopgradient":return new Identity(modelName, node, inputs, port); + case "noop": return new NoOp(modelName, node, inputs, port); + } + + TensorFlowOperation op = new NoOp(modelName, node, inputs, port); + op.warning("Operation '" + node.getOp() + "' is currently not implemented"); + return op; + } + +} + + + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java new file mode 100644 index 00000000000..03a65333192 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -0,0 +1,255 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorTypeParser; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorShapeProto; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A Vespa tensor type is ordered by the lexicographical ordering of dimension + * names. TensorFlow tensors have an explicit ordering of their dimensions. + * During import, we need to track the Vespa dimension that matches the + * corresponding TensorFlow dimension as the ordering can change after + * dimension renaming. That is the purpose of this class. + * + * @author lesters + */ +public class OrderedTensorType { + + private final TensorType type; + private final List<TensorType.Dimension> dimensions; + + private final long[] innerSizesTensorFlow; + private final long[] innerSizesVespa; + private final int[] dimensionMap; + + private OrderedTensorType(List<TensorType.Dimension> dimensions) { + this.dimensions = Collections.unmodifiableList(dimensions); + this.type = new TensorType.Builder(dimensions).build(); + this.innerSizesTensorFlow = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + public TensorType type() { + return this.type; + } + + public int rank() { return dimensions.size(); } + + public List<TensorType.Dimension> dimensions() { + return dimensions; + } + + public List<String> dimensionNames() { + return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); + } + + private int[] createDimensionMap() { + int numDimensions = dimensions.size(); + if (numDimensions == 0) { + return null; + } + innerSizesTensorFlow[numDimensions - 1] = 1; + innerSizesVespa[numDimensions - 1] = 1; + for (int i = numDimensions - 1; --i >= 0; ) { + innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1]; + innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; + } + int[] mapping = new int[numDimensions]; + for (int i = 0; i < numDimensions; ++i) { + TensorType.Dimension dim1 = dimensions().get(i); + for (int j = 0; j < numDimensions; ++j) { + TensorType.Dimension dim2 = type.dimensions().get(j); + if (dim1.equals(dim2)) { + mapping[i] = j; + break; + } + } + } + return mapping; + } + + /** + * When dimension ordering between Vespa and TensorFlow differs, i.e. + * after dimension renaming, use the dimension map to read in values + * so that they are correctly laid out in memory for Vespa. + * Used when importing tensors from TensorFlow. + */ + public int toDirectIndex(int index) { + if (dimensions.size() == 0) { + return 0; + } + if (dimensionMap == null) { + throw new IllegalArgumentException("Dimension map is not available"); + } + int directIndex = 0; + long rest = index; + for (int i = 0; i < dimensions.size(); ++i) { + long address = rest / innerSizesTensorFlow[i]; + directIndex += innerSizesVespa[dimensionMap[i]] * address; + rest %= innerSizesTensorFlow[i]; + } + return directIndex; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof OrderedTensorType)) { + return false; + } + OrderedTensorType other = (OrderedTensorType) obj; + if (dimensions.size() != dimensions.size()) { + return false; + } + List<TensorType.Dimension> thisDimensions = this.dimensions(); + List<TensorType.Dimension> otherDimensions = other.dimensions(); + for (int i = 0; i < thisDimensions.size(); ++i) { + if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { + return false; + } + } + return true; + } + + public void verifyType(NodeDef node) { + TensorShapeProto shape = tensorFlowShape(node); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + + "does not match Vespa shape"); + } + for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) { + int vespaIndex = dimensionMap[tensorFlowIndex]; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); + TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); + if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + + "does not match Vespa dimensions"); + } + } + } + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + } + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + } + List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); + return shapeList.get(0); // support multiple outputs? + } + + public OrderedTensorType rename(DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + for (TensorType.Dimension dimension : dimensions) { + String oldName = dimension.name(); + Optional<String> newName = renamer.dimensionNameOf(oldName); + if (!newName.isPresent()) + return this; // presumably, already renamed + TensorType.Dimension.Type dimensionType = dimension.type(); + if (dimensionType == TensorType.Dimension.Type.indexedBound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); + } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); + } else if (dimensionType == TensorType.Dimension.Type.mapped) { + renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); + } + } + return new OrderedTensorType(renamedDimensions); + } + + /** + * Returns a string representation of this: A standard tensor type string where dimensions + * are listed in the order of this rather than in the natural order of their names. + */ + @Override + public String toString() { + return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; + } + + /** + * Creates an instance from the string representation of this: A standard tensor type string + * where dimensions are listed in the order of this rather than the natural order of their names. + */ + public static OrderedTensorType fromSpec(String typeSpec) { + return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node) { + return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + Builder builder = new Builder(node); + TensorShapeProto shape = tensorFlowShape(node); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static class Builder { + + private final TensorShapeProto shape; + private final List<TensorType.Dimension> dimensions; + + public Builder(NodeDef node) { + this.shape = tensorFlowShape(node); + this.dimensions = new ArrayList<>(shape.getDimCount()); + } + + public Builder add(TensorType.Dimension vespaDimension) { + int index = dimensions.size(); + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index); + long size = tensorFlowDimension.getSize(); + if (size >= 0) { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + if (!vespaDimension.size().isPresent()) { + throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + + "not have a size"); + } + if (vespaDimension.size().get() != size) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension sizes. TensorFlow: " + size + " Vespa: " + + vespaDimension.size().get()); + } + } else { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + } + this.dimensions.add(vespaDimension); + return this; + } + + public OrderedTensorType build() { + return new OrderedTensorType(dimensions); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java index d2d0acfc964..3f55e622fdf 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java index 1b8c62fe0e9..4f5d61d75f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java @@ -1,37 +1,38 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class ConcatV2 extends IntermediateOperation { +public class ConcatV2 extends TensorFlowOperation { private String concatDimensionName; - public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { return null; } - IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input + TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + "concat dimension must be a constant."); } Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + "concat dimension must be a scalar."); } @@ -43,7 +44,7 @@ public class ConcatV2 extends IntermediateOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + "inputs must have save rank."); } for (int j = 0; j < aType.rank(); ++j) { @@ -52,13 +53,13 @@ public class ConcatV2 extends IntermediateOperation { if (j == concatDim) { concatDimSize += dimSizeB; } else if (dimSizeA != dimSizeB) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + "input dimension " + j + " differs in input tensors."); } } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDim) { @@ -74,7 +75,7 @@ public class ConcatV2 extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) { return null; } TensorFunction result = inputs.get(0).function().get(); @@ -87,7 +88,7 @@ public class ConcatV2 extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { return; } OrderedTensorType a = inputs.get(0).type().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java index 3c0f8569c47..718e2a4b3c2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -1,38 +1,36 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Const extends IntermediateOperation { +public class Const extends TensorFlowOperation { - private final AttributeMap attributeMap; - - public Const(String modelName, - String nodeName, - List<IntermediateOperation> inputs, - AttributeMap attributeMap, - OrderedTensorType type) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; - this.type = type.rename(vespaName() + "_"); + public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); setConstantValue(value()); } @Override protected OrderedTensorType lazyGetType() { - return type; + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); } @Override @@ -57,7 +55,7 @@ public class Const extends IntermediateOperation { /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + super.vespaName(); + return modelName() + "_" + super.vespaName(); } @Override @@ -79,11 +77,24 @@ public class Const extends IntermediateOperation { } private Value value() { - Optional<Value> value = attributeMap.get("value", type); - if ( ! value.isPresent()) { - throw new IllegalArgumentException("Node '" + name + "' of type " + - "const has missing or non-recognized 'value' attribute"); + if ( ! node.getAttrMap().containsKey("value")) { + throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + + "const has missing 'value' attribute"); + } + AttrValue attrValue = node.getAttrMap().get("value"); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.B) { + return new BooleanValue(attrValue.getB()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.I) { + return new DoubleValue(attrValue.getI()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.F) { + return new DoubleValue(attrValue.getF()); } - return value.get(); + throw new IllegalArgumentException("Requesting value of constant in " + + node.getName() + " but type is not recognized."); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java index 742ed8b89ab..2d0f4c7042b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -12,17 +12,18 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; -public class ExpandDims extends IntermediateOperation { +public class ExpandDims extends TensorFlowOperation { private List<String> expandDimensions; - public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -31,14 +32,14 @@ public class ExpandDims extends IntermediateOperation { return null; } - IntermediateOperation axisOperation = inputs().get(1); + TensorFlowOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + "axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + "axis argument must be a scalar."); } @@ -48,7 +49,7 @@ public class ExpandDims extends IntermediateOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java index d29bd4b7a9e..1408e7e04f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java @@ -1,21 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; -public class Identity extends IntermediateOperation { +public class Identity extends TensorFlowOperation { - public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + super.vespaName(); + return modelName() + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java new file mode 100644 index 00000000000..6cbfe0dfb05 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -0,0 +1,145 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleBinaryOperator; + +public class Join extends TensorFlowOperation { + + private final DoubleBinaryOperator operator; + + public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) { + super(modelName, node, inputs, port); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + + // Well now we have potentially entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // In broadcasting, the size of each dimension is compared element-wise, + // starting with the trailing dimensions and working forward. A special + // case occurs when the size of one dimension is 1, while the other is not. + // Then the dimension with size 1 is "stretched" to be of compatible size. + // + // An example: + // + // Tensor A: d0[5], d1[1], d2[3], d3[1] + // Tensor B: d1[4], d2[1], d3[2] + // + // In TensorFlow and using the above rules of broadcasting, the resulting + // type is: + // d0[5], d1[4], d2[3], d2[2] + // + // However, in Vespa's tensor logic, the join of the two above tensors would + // result in a tensor of type: + // d0[5], d1[1], d2[1], d3[1] + // + // By reducing the dimensions of size 1 in each tensor before joining, + // we get equal results as in TensorFlow. + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < a.rank(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + long size = aDim.size().orElse(-1L); + + if (i - sizeDifference >= 0) { + TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); + size = Math.max(size, bDim.size().orElse(-1L)); + } + + if (aDim.type() == TensorType.Dimension.Type.indexedBound) { + builder.add(TensorType.Dimension.indexed(aDim.name(), size)); + } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { + builder.add(TensorType.Dimension.indexed(aDim.name())); + } else if (aDim.type() == TensorType.Dimension.Type.mapped) { + builder.add(TensorType.Dimension.mapped(aDim.name())); + } + } + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + + TensorFlowOperation a = largestInput(); + TensorFlowOperation b = smallestInput(); + + List<String> aDimensionsToReduce = new ArrayList<>(); + List<String> bDimensionsToReduce = new ArrayList<>(); + int sizeDifference = a.type().get().rank() - b.type().get().rank(); + for (int i = 0; i < b.type().get().rank(); ++i) { + TensorType.Dimension bDim = b.type().get().dimensions().get(i); + TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); + long bSize = bDim.size().orElse(-1L); + long aSize = aDim.size().orElse(-1L); + if (bSize == 1L && aSize != 1L) { + bDimensionsToReduce.add(bDim.name()); + } + if (aSize == 1L && bSize != 1L) { + aDimensionsToReduce.add(bDim.name()); + } + } + + TensorFunction aReducedFunction = a.function().get(); + if (aDimensionsToReduce.size() > 0) { + aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); + } + TensorFunction bReducedFunction = b.function().get(); + if (bDimensionsToReduce.size() > 0) { + bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); + } + + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < b.rank(); ++i) { + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + + private TensorFlowOperation largestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); + } + + private TensorFlowOperation smallestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java index f54ae83052f..c015f5ecba8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java @@ -1,19 +1,20 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; import java.util.function.DoubleUnaryOperator; -public class Map extends IntermediateOperation { +public class Map extends TensorFlowOperation { private final DoubleUnaryOperator operator; - public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { - super(modelName, nodeName, inputs); + public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) { + super(modelName, node, inputs, port); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java new file mode 100644 index 00000000000..b2b9530a161 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Matmul extends TensorFlowOperation { + + public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); + typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() < 2 || bType.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (aType.type().rank() != bType.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java index 95a77c07590..3eba872c6a0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java @@ -1,10 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -14,20 +13,20 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; -public class Mean extends IntermediateOperation { +public class Mean extends TensorFlowOperation { - private final AttributeMap attributeMap; private List<String> reduceDimensions; - public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; + public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -35,9 +34,9 @@ public class Mean extends IntermediateOperation { if (!allInputTypesPresent(2)) { return null; } - IntermediateOperation reductionIndices = inputs.get(1); + TensorFlowOperation reductionIndices = inputs.get(1); if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + name + ": " + + throw new IllegalArgumentException("Mean in " + node.getName() + ": " + "reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); @@ -55,7 +54,7 @@ public class Mean extends IntermediateOperation { return reducedType(inputType, shouldKeepDimensions()); } - // optimization: if keepDims and one reduce dimension that has size 1: same as identity. + // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override protected TensorFunction lazyGetFunction() { @@ -94,12 +93,12 @@ public class Mean extends IntermediateOperation { } private boolean shouldKeepDimensions() { - Optional<Value> keepDims = attributeMap.get("keep_dims"); - return keepDims.isPresent() && keepDims.get().asBoolean(); + AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims"); + return keepDimsAttr != null && keepDimsAttr.getB(); } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if (!reduceDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java index 9d9eca47b1c..4c95e67e184 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java @@ -1,20 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; -public class Merge extends IntermediateOperation { +public class Merge extends TensorFlowOperation { - public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override protected OrderedTensorType lazyGetType() { - for (IntermediateOperation operation : inputs) { + for (TensorFlowOperation operation : inputs) { if (operation.type().isPresent()) { return operation.type().get(); } @@ -24,7 +25,7 @@ public class Merge extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - for (IntermediateOperation operation : inputs) { + for (TensorFlowOperation operation : inputs) { if (operation.function().isPresent()) { return operation.function().get(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java new file mode 100644 index 00000000000..d558ec89e87 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends TensorFlowOperation { + + public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, Collections.emptyList(), port); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java index 7fc2aae87d1..1619c11427a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java @@ -1,29 +1,28 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; -import java.util.Collections; import java.util.List; -public class Argument extends IntermediateOperation { +public class Placeholder extends TensorFlowOperation { private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - public Argument(String modelName, String nodeName, OrderedTensorType type) { - super(modelName, nodeName, Collections.emptyList()); - this.type = type.rename(vespaName() + "_"); - standardNamingType = OrderedTensorType.standardType(type); + public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); + standardNamingType = OrderedTensorType.fromTensorFlowType(node); } @Override protected OrderedTensorType lazyGetType() { - return type; + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java index 9299ae9be12..65ce7f00e34 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -1,16 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class PlaceholderWithDefault extends IntermediateOperation { +public class PlaceholderWithDefault extends TensorFlowOperation { - public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java index e91c2305f7d..e7d90e5fc1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -18,18 +19,19 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; -public class Reshape extends IntermediateOperation { +public class Reshape extends TensorFlowOperation { - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -37,15 +39,15 @@ public class Reshape extends IntermediateOperation { if (!allInputTypesPresent(2)) { return null; } - IntermediateOperation newShape = inputs.get(1); + TensorFlowOperation newShape = inputs.get(1); if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + name + ": " + + throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + "shape input must be a constant."); } Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -122,7 +124,7 @@ public class Reshape extends IntermediateOperation { operators.add(0, ArithmeticOperator.MULTIPLY); children.add(0, new ConstantNode(new DoubleValue(size))); } - size *= OrderedTensorType.dimensionSize(dimension); + size *= TensorConverter.dimensionSize(dimension); if (i > 0) { operators.add(0, ArithmeticOperator.PLUS); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java index 927a4a368f9..5fdcb5a695f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java @@ -1,23 +1,24 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.function.DoubleBinaryOperator; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize; +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; -public class Select extends IntermediateOperation { +public class Select extends TensorFlowOperation { - public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -38,7 +39,7 @@ public class Select extends IntermediateOperation { if (!allInputFunctionsPresent(3)) { return null; } - IntermediateOperation conditionOperation = inputs().get(0); + TensorFlowOperation conditionOperation = inputs().get(0); TensorFunction a = inputs().get(1).function().get(); TensorFunction b = inputs().get(2).function().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java index da566909adc..af49d2c108b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java @@ -1,19 +1,20 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; -public class Shape extends IntermediateOperation { +public class Shape extends TensorFlowOperation { - public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); + public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); createConstantValue(); } @@ -23,7 +24,7 @@ public class Shape extends IntermediateOperation { return null; } OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder() + return new OrderedTensorType.Builder(node) .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java index c750c47e27e..17ce9e8b7cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java @@ -1,26 +1,26 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -public class Squeeze extends IntermediateOperation { +public class Squeeze extends TensorFlowOperation { - private final AttributeMap attributeMap; private List<String> squeezeDimensions; - public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; + public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -31,21 +31,20 @@ public class Squeeze extends IntermediateOperation { OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); - Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); - if ( ! squeezeDimsAttr.isPresent()) { + AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims"); + if (squeezeDimsAttr == null) { squeezeDimensions = inputType.type().dimensions().stream(). - filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } else { - squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue). + squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). - map(i -> inputType.type().dimensions().get(i)). - filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). + map(i -> inputType.type().dimensions().get(i.intValue())). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } - return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); } @@ -73,7 +72,7 @@ public class Squeeze extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java index 0171d1ea171..de4d8862fd6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java @@ -1,19 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Switch extends IntermediateOperation { +public class Switch extends TensorFlowOperation { - private final int port; - - public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { - super(modelName, nodeName, inputs); - this.port = port; + public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); } @Override @@ -23,7 +21,7 @@ public class Switch extends IntermediateOperation { } Optional<OrderedTensorType> predicate = inputs.get(1).type(); if (predicate.get().type().rank() != 0) { - throw new IllegalArgumentException("Switch in " + name + ": " + + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + "predicate must be a scalar"); } return inputs.get(0).type().orElse(null); @@ -31,13 +29,13 @@ public class Switch extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - IntermediateOperation predicateOperation = inputs().get(1); + TensorFlowOperation predicateOperation = inputs().get(1); if (!predicateOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Switch in " + name + ": " + + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + "predicate must be a constant"); } if (port < 0 || port > 1) { - throw new IllegalArgumentException("Switch in " + name + ": " + + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + "choice should be boolean"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index 43de29cedd5..3687bba8b85 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -1,16 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; - +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Collections; @@ -19,40 +20,43 @@ import java.util.Optional; import java.util.function.Function; /** - * Wraps an imported operation node and produces the respective Vespa tensor - * operation. During import, a graph of these operations are constructed. Then, - * the types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper Vespa - * expressions can be extracted. + * Wraps a TensorFlow node and produces the respective Vespa tensor operation. + * During import, a graph of these operations are constructed. Then, the + * types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper + * Vespa expressions can be extracted. * * @author lesters */ -public abstract class IntermediateOperation { +public abstract class TensorFlowOperation { + + protected final static String MACRO_PREFIX = "tf_macro_"; - private final static String MACRO_PREFIX = "imported_ml_macro_"; + private final String modelName; - protected final String name; - protected final String modelName; - protected final List<IntermediateOperation> inputs; - protected final List<IntermediateOperation> outputs = new ArrayList<>(); + protected final NodeDef node; + protected final int port; + protected final List<TensorFlowOperation> inputs; + protected final List<TensorFlowOperation> outputs = new ArrayList<>(); + protected final List<String> importWarnings = new ArrayList<>(); protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction macro = null; - private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; - private List<IntermediateOperation> controlInputs = Collections.emptyList(); + private List<TensorFlowOperation> controlInputs = Collections.emptyList(); - protected Function<OrderedTensorType, Value> constantValueFunction = null; - - IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { - this.name = name; + TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { this.modelName = modelName; + this.node = node; + this.port = port; this.inputs = Collections.unmodifiableList(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } + protected String modelName() { return modelName; } + protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); @@ -61,6 +65,9 @@ public abstract class IntermediateOperation { if (type == null) { type = lazyGetType(); } + if (type != null) { + type.verifyType(node); + } return Optional.ofNullable(type); } @@ -80,14 +87,14 @@ public abstract class IntermediateOperation { return Optional.ofNullable(function); } - /** Returns original name of this operation node */ - public String name() { return name; } + /** Return TensorFlow node */ + public NodeDef node() { return node; } /** Return unmodifiable list of inputs */ - public List<IntermediateOperation> inputs() { return inputs; } + public List<TensorFlowOperation> inputs() { return inputs; } /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ - public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } + public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a Vespa ranking expression that should be added as a macro */ public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } @@ -102,34 +109,22 @@ public abstract class IntermediateOperation { public boolean isInput() { return false; } /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); } + public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); } /** Sets the constant value */ public void setConstantValue(Value value) { constantValue = value; } /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { - if (constantValue != null) { - return Optional.of(constantValue); - } - if (constantValueFunction != null) { - return Optional.of(constantValueFunction.apply(type)); - } - return Optional.empty(); - } - - /** Set the constant value function */ - public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } + public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } /** Sets the external control inputs */ - public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } + public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; } /** Retrieve the control inputs for this operation */ - public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } + public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } /** Retrieve the valid Vespa name of this node if it is a macro */ public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } @@ -140,48 +135,23 @@ public abstract class IntermediateOperation { /** Set an input warning */ public void warning(String warning) { importWarnings.add(warning); } - boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { + boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { + if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { + return false; + } if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + name + "', got " + inputs.size()); + "for '" + node.getName() + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, IntermediateOperation::type); + return verifyInputs(expected, TensorFlowOperation::type); } boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, IntermediateOperation::function); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - public static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output index part. Indexes are used for nodes with - * multiple outputs. - */ - public static int indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - - /** - * An interface mapping operation attributes to Vespa Values. - * Adapter for differences in ONNX/TensorFlow. - */ - public interface AttributeMap { - Optional<Value> get(String key); - Optional<Value> get(String key, OrderedTensorType type); - Optional<List<Value>> getList(String key); + return verifyInputs(expected, TensorFlowOperation::function); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java new file mode 100644 index 00000000000..b18a8a9b212 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java @@ -0,0 +1,46 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Variable extends TensorFlowOperation { + + public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(modelName, node, inputs, port); + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName() + "_" + super.vespaName(); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java new file mode 100644 index 00000000000..9e53990a9d6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * Tensorflow integration + */ +@ExportPackage +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java index a7926cd2e02..4b68cd40a08 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -1,9 +1,11 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.onnx; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -22,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); @@ -46,7 +48,7 @@ public class OnnxMnistSoftmaxImportTestCase { model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.defaultSignature().outputExpression("add"); + RankingExpression output = model.outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", @@ -66,12 +68,13 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new TensorFlowImporter().importModel("test", path); + SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve"); + TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - ImportedModel model = new OnnxImporter().importModel("test", path); + OnnxModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } @@ -80,7 +83,14 @@ public class OnnxMnistSoftmaxImportTestCase { return expression.evaluate(context).asTensor(); } - private Context contextFrom(ImportedModel result) { + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private Context contextFrom(OnnxModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java index bf9684082f4..0f5eec93feb 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -15,7 +15,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); - ImportedModel.Signature signature = model.get().signature("serving_default"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java index c8c7ec798bb..74b0d11f1d6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; import org.junit.Test; import static org.junit.Assert.assertTrue; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index a63c7346335..50a467ec581 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; @@ -24,7 +24,7 @@ public class DropoutImportTestCase { assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("X")); - ImportedModel.Signature signature = model.get().signature("serving_default"); + TensorFlowModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index bd7644be23b..9f919c452d6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase { // Check signatures assertEquals(1, model.get().signatures().size()); - ImportedModel.Signature signature = model.get().signatures().get("serving_default"); + TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java index b2443082ab1..beec2ab1ead 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 723c5f27914..7ca16939477 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -1,11 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals; public class TestableTensorFlowModel { private SavedModelBundle tensorFlowModel; - private ImportedModel model; + private TensorFlowModel model; // Sizes of the input vector private final int d0Size = 1; @@ -39,7 +39,7 @@ public class TestableTensorFlowModel { model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); } - public ImportedModel get() { return model; } + public TensorFlowModel get() { return model; } public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); @@ -66,7 +66,7 @@ public class TestableTensorFlowModel { return TensorConverter.toVespaTensor(results.get(0)); } - private Context contextFrom(ImportedModel result) { + private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); @@ -81,7 +81,7 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, ImportedModel model, String macroName) { + private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { if (!context.names().contains(macroName)) { RankingExpression e = model.macros().get(macroName); evaluateMacroDependencies(context, model, e.getRoot()); @@ -89,7 +89,7 @@ public class TestableTensorFlowModel { } } - private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { + private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); if (model.macros().containsKey(name)) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java index f94098e6255..051c2c60c95 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java @@ -1,4 +1,4 @@ -package com.yahoo.searchlib.rankingexpression.integration.ml; +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import org.junit.Test; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 3a66eef258d..944755c9db2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -22,37 +22,22 @@ public class ScalarFunctions { public static DoubleBinaryOperator add() { return new Add(); } public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } - public static DoubleBinaryOperator greater() { return new Greater(); } - public static DoubleBinaryOperator less() { return new Less(); } public static DoubleBinaryOperator max() { return new Max(); } public static DoubleBinaryOperator min() { return new Min(); } - public static DoubleBinaryOperator mean() { return new Mean(); } public static DoubleBinaryOperator multiply() { return new Multiply(); } - public static DoubleBinaryOperator pow() { return new Pow(); } public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); } public static DoubleBinaryOperator subtract() { return new Subtract(); } - public static DoubleUnaryOperator abs() { return new Abs(); } public static DoubleUnaryOperator acos() { return new Acos(); } - public static DoubleUnaryOperator asin() { return new Asin(); } - public static DoubleUnaryOperator atan() { return new Atan(); } - public static DoubleUnaryOperator ceil() { return new Ceil(); } - public static DoubleUnaryOperator cos() { return new Cos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } - public static DoubleUnaryOperator log() { return new Log(); } - public static DoubleUnaryOperator neg() { return new Neg(); } - public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } - public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } - public static DoubleUnaryOperator tan() { return new Tan(); } - public static DoubleUnaryOperator tanh() { return new Tanh(); } public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } @@ -74,20 +59,6 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a==b)"; } } - public static class Greater implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; } - @Override - public String toString() { return "f(a,b)(a > b)"; } - } - - public static class Less implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; } - @Override - public String toString() { return "f(a,b)(a < b)"; } - } - public static class Max implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return Math.max(left, right); } @@ -102,13 +73,6 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(min(a, b))"; } } - public static class Mean implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return (left + right) / 2; } - @Override - public String toString() { return "f(a,b)((a + b) / 2)"; } - } - public static class Multiply implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } @@ -116,13 +80,6 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a * b)"; } } - public static class Pow implements DoubleBinaryOperator { - @Override - public double applyAsDouble(double left, double right) { return Math.pow(left, right); } - @Override - public String toString() { return "f(a,b)(pow(a, b))"; } - } - public static class Divide implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left / right; } @@ -147,13 +104,6 @@ public class ScalarFunctions { // Unary operators ------------------------------------------------------------------------------ - public static class Abs implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.abs(operand); } - @Override - public String toString() { return "f(a)(fabs(a))"; } - } - public static class Acos implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.acos(operand); } @@ -161,34 +111,6 @@ public class ScalarFunctions { public String toString() { return "f(a)(acos(a))"; } } - public static class Asin implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.asin(operand); } - @Override - public String toString() { return "f(a)(asin(a))"; } - } - - public static class Atan implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.atan(operand); } - @Override - public String toString() { return "f(a)(atan(a))"; } - } - - public static class Ceil implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.ceil(operand); } - @Override - public String toString() { return "f(a)(ceil(a))"; } - } - - public static class Cos implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.cos(operand); } - @Override - public String toString() { return "f(a)(cos(a))"; } - } - public static class Elu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } @@ -210,26 +132,6 @@ public class ScalarFunctions { public String toString() { return "f(a)(floor(a))"; } } - public static class Log implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.log(operand); } - @Override - public String toString() { return "f(a)(log(a))"; } - } - - public static class Neg implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return -operand; } - @Override - public String toString() { return "f(a)(-a)"; } - } - - public static class Reciprocal implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return 1.0 / operand; } - @Override - public String toString() { return "f(a)(1 / a)"; } - } public static class Relu implements DoubleUnaryOperator { @Override @@ -248,13 +150,6 @@ public class ScalarFunctions { public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); } } - public static class Sin implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.sin(operand); } - @Override - public String toString() { return "f(a)(sin(a))"; } - } - public static class Rsqrt implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } @@ -277,29 +172,15 @@ public class ScalarFunctions { } public static class Square implements DoubleUnaryOperator { + @Override public double applyAsDouble(double operand) { return operand * operand; } - @Override - public String toString() { return "f(a)(a * a)"; } - } - public static class Tan implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.tan(operand); } @Override - public String toString() { return "f(a)(tan(a))"; } - } + public String toString() { return "f(a)(a * a)"; } - public static class Tanh implements DoubleUnaryOperator { - @Override - public double applyAsDouble(double operand) { return Math.tanh(operand); } - @Override - public String toString() { return "f(a)(tanh(a))"; } } - - - // Variable-length operators ----------------------------------------------------------------------------- public static class EqualElements implements Function<List<Long>, Double> { |