diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-06 15:43:18 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-06 15:43:18 +0200 |
commit | 0bf235c481d24d627c82901a84bef585fe84bbb2 (patch) | |
tree | 6cb6d0b192f56f3e8fdb533fb9603d3f927fe3c1 /searchlib | |
parent | 389801098797ab37c7bc4ac5a3888ef4d92214e7 (diff) |
Refactor ONNX and TF import to use same code base
This reverts commit 681963959794b47102d1a1cf72f215c72b0e2b51.
Diffstat (limited to 'searchlib')
58 files changed, 1526 insertions, 2407 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 721214f9e94..4b49f17f74e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,5 +1,4 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -13,76 +12,61 @@ import java.util.Map; import java.util.regex.Pattern; /** - * The result of importing a TensorFlow model into Vespa. - * - A set of signatures which are named collections of inputs and outputs. - * - A set of named constant tensors represented by Variable nodes in TensorFlow. - * - A list of warning messages. + * The result of importing a model (TensorFlow or ONNX) into Vespa. * * @author bratseth */ -// This object can be built incrementally within this package, but is immutable when observed from outside the package -public class TensorFlowModel { +public class ImportedModel { - private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + private static final String defaultSignatureName = "default"; + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; + private final Map<String, Signature> signatures = new HashMap<>(); + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> macros = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); + /** - * Creates a TensorFlow model + * Creates a new imported model. * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] */ - public TensorFlowModel(String name) { + public ImportedModel(String name) { if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + - name + "'"); + throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + + name + "'"); this.name = name; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } - private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); - - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - - /** Returns the given signature. If it does not already exist it is added to this. */ - Signature signature(String name) { - return signatures.computeIfAbsent(name, Signature::new); - } - /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** * Returns an immutable map of the small constants of this. * These should have sizes up to a few kb at most, and correspond to constant - * values given in the TensorFlow source. + * values given in the TensorFlow or ONNX source. */ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } /** * Returns an immutable map of the large constants of this. - * These can have sizes in gigabytes and must be distributed to nodes separately from configuration, - * and correspond to Variable files stored separately in TensorFlow. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. + * For TensorFlow this corresponds to Variable files stored separately. */ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } /** - * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes - * which are not Placeholders or Variables (which instead become respectively arguments and constants). - * Note that only nodes recursively referenced by a placeholder are added. + * Returns an immutable map of the expressions of this - corresponding to graph nodes + * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants). + * Note that only nodes recursively referenced by a placeholder/input are added. */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } @@ -95,9 +79,26 @@ public class TensorFlowModel { /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, Signature::new); + } + + /** Convenience method for returning a default signature */ + Signature defaultSignature() { return signature(defaultSignatureName); } + + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void macro(String name, RankingExpression expression) { macros.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + /** - * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, - * and outputs maps to expressions nodes. + * A signature is a set of named inputs and outputs, where the inputs maps to argument + * ("placeholder") names+types, and outputs maps to expressions nodes. + * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit + * concept of signatures. For now, we handle ONNX models as having a single signature. */ public class Signature { @@ -107,19 +108,14 @@ public class TensorFlowModel { private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); - Signature(String name) { + public Signature(String name) { this.name = name; } - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - public String name() { return name; } /** Returns the result this is part of */ - TensorFlowModel owner() { return TensorFlowModel.this; } + public ImportedModel owner() { return ImportedModel.this; } /** * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name @@ -127,7 +123,7 @@ public class TensorFlowModel { */ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + /** Returns the type of the argument this input references */ public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ @@ -144,12 +140,17 @@ public class TensorFlowModel { */ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ + /** Returns the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } @Override public String toString() { return "signature '" + name + "'"; } + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java new file mode 100644 index 00000000000..a658833b426 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -0,0 +1,242 @@ +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * Base class for importing ML models (ONNX/TensorFlow) as native Vespa + * ranking expressions. The general mechanism for import is for the + * specific ML platform import implementations to create an + * IntermediateGraph. This class offers common code to convert the + * IntermediateGraph to Vespa ranking expressions and macros. + * + * @author lesters + */ +public abstract class ModelImporter { + + private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); + + /** + * The main import function. + */ + public abstract ImportedModel importModel(String modelName, String modelPath); + + public ImportedModel importModel(String modelName, File modelDir) { + return importModel(modelName, modelDir.toString()); + } + + /** + * Takes an IntermediateGraph and converts it to a ImportedModel containing + * the actual Vespa ranking expressions. + */ + static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { + ImportedModel model = new ImportedModel(graph.name()); + + graph.optimize(); + + importSignatures(graph, model); + importExpressions(graph, model); + reportWarnings(graph, model); + logVariableTypes(graph); + + return model; + } + + private static void importSignatures(IntermediateGraph graph, ImportedModel model) { + for (String signatureName : graph.signatures()) { + ImportedModel.Signature signature = model.signature(signatureName); + for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) { + signature.input(input.getKey(), input.getValue()); + } + for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { + signature.output(output.getKey(), output.getValue()); + } + } + } + + private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String inputName : signature.inputs().values()) { + if (inputName.equals(operation.name())) { + return true; + } + } + } + return false; + } + + private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.name())) { + return true; + } + } + } + return false; + } + + /** + * Convert intermediate representation to Vespa ranking expressions. + */ + static void importExpressions(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(graph.get(outputName), model); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } + } + + private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(operation, model); + } + importExpressionInputs(operation, model); + importRankingExpression(operation, model); + importArgumentExpression(operation, model); + importMacroExpression(operation, model); + + return operation.function(); + } + + private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { + operation.inputs().forEach(input -> importExpression(input, model)); + } + + private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Value value = operation.getConstantValue().orElseThrow(() -> + new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + + "is constant but does not have a value.")); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + + Tensor tensor = value.asTensor(); + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.function().isPresent()) { + String name = operation.name(); + if (!model.expressions().containsKey(name)) { + TensorFunction function = operation.function().get(); + + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced from the output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Imported function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.isInput()) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); + model.argument(operation.vespaName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + } + } + + private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.macro().isPresent()) { + TensorFunction function = operation.macro().get(); + try { + model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + + /** + * Add any import warnings to the signature in the ImportedModel. + */ + private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + reportWarnings(graph.get(outputName), model); + } + } + } + + private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + for (String warning : operation.warnings()) { + model.defaultSignature().importWarning(warning); + } + for (IntermediateOperation input : operation.inputs()) { + reportWarnings(input, model); + } + } + + /** + * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ + private static void logVariableTypes(IntermediateGraph graph) { + for (IntermediateOperation operation : graph.operations()) { + if ( ! (operation instanceof Constant)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java new file mode 100644 index 00000000000..d3dd2a1d418 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; +import onnx.Onnx; + +import java.io.FileInputStream; +import java.io.IOException; + +/** + * Converts a ONNX model into a ranking expression and set of constants. + * + * @author lesters + */ +public class OnnxImporter extends ModelImporter { + + @Override + public ImportedModel importModel(String modelName, String modelPath) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + IntermediateGraph graph = GraphImporter.importGraph(modelName, model); + return convertIntermediateGraphToModel(graph); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java new file mode 100644 index 00000000000..ff584559a83 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -0,0 +1,47 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; +import org.tensorflow.SavedModelBundle; + +import java.io.IOException; + +/** + * Converts a saved TensorFlow model into a ranking expression and set of constants. + * + * @author bratseth + * @author lesters + */ +public class TensorFlowImporter extends ModelImporter { + + /** + * Imports a saved TensorFlow model from a directory. + * The model should be saved as a .pbtxt or .pb file. + * The name of the model is taken as the db/pbtxt file name (not including the file ending). + * + * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] + * @param modelDir the directory containing the TensorFlow model files to import + */ + public ImportedModel importModel(String modelName, String modelDir) { + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importModel(modelName, model); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } + } + + /** Imports a TensorFlow model */ + ImportedModel importModel(String modelName, SavedModelBundle model) { + try { + IntermediateGraph graph = GraphImporter.importGraph(modelName, model); + return convertIntermediateGraphToModel(graph); + } + catch (IOException e) { + throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); + } + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java index c5ac7ace0fc..e1294ec3e01 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java @@ -1,7 +1,8 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; @@ -24,7 +25,7 @@ public class VariableConverter { */ public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { - return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName, + return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, bundle), OrderedTensorType.fromSpec(orderedTypeSpec))); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java index 2524417cee0..38f1d2329e2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import java.util.ArrayDeque; import java.util.ArrayList; @@ -47,7 +47,7 @@ public class DimensionRenamer { /** * Add a constraint between dimension names. */ - public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) { + public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); Arc opposite = arc.opposite(); constraints.put(arc, pred); @@ -175,9 +175,9 @@ public class DimensionRenamer { private final String from; private final String to; - private final OnnxOperation operation; + private final IntermediateOperation operation; - Arc(String from, String to, OnnxOperation operation) { + Arc(String from, String to, IntermediateOperation operation) { this.from = from; this.to = to; this.operation = operation; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java new file mode 100644 index 00000000000..39a8b211d09 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java @@ -0,0 +1,107 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Holds an intermediate representation of an imported ONNX or TensorFlow + * graph. After this intermediate representation is constructed, it is used to + * simplify and optimize the computational graph and then converted into the + * final ImportedModel that holds the Vespa ranking expressions for the model. + * + * @author lesters + */ +public class IntermediateGraph { + + private final String modelName; + private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, GraphSignature> signatures = new HashMap<>(); + + private static class GraphSignature { + final Map<String, String> inputs = new HashMap<>(); + final Map<String, String> outputs = new HashMap<>(); + } + + public IntermediateGraph(String modelName) { + this.modelName = modelName; + } + + public String name() { + return modelName; + } + + public IntermediateOperation put(String key, IntermediateOperation operation) { + return index.put(key, operation); + } + + public IntermediateOperation get(String key) { + return index.get(key); + } + + public Set<String> signatures() { + return signatures.keySet(); + } + + public Map<String, String> inputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs; + } + + public Map<String, String> outputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs; + } + + public String defaultSignature() { + return "default"; + } + + public boolean alreadyImported(String key) { + return index.containsKey(key); + } + + public Collection<IntermediateOperation> operations() { + return index.values(); + } + + public void optimize() { + renameDimensions(); + } + + /** + * Find dimension names to avoid excessive renaming while evaluating the model. + */ + private void renameDimensions() { + DimensionRenamer renamer = new DimensionRenamer(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + renameDimensions(index.get(output), renamer); + } + } + } + + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java index 812e9b8d678..209d73a9f38 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer; import com.yahoo.tensor.TensorType; -import onnx.Onnx; +import com.yahoo.tensor.TensorTypeParser; import java.util.ArrayList; import java.util.Collections; @@ -13,9 +13,9 @@ import java.util.stream.Collectors; /** * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. ONNX tensors have an explicit ordering of their dimensions. + * names. Imported tensors have an explicit ordering of their dimensions. * During import, we need to track the Vespa dimension that matches the - * corresponding ONNX dimension as the ordering can change after + * corresponding imported dimension as the ordering can change after * dimension renaming. That is the purpose of this class. * * @author lesters @@ -25,14 +25,14 @@ public class OrderedTensorType { private final TensorType type; private final List<TensorType.Dimension> dimensions; - private final long[] innerSizesOnnx; + private final long[] innerSizesOriginal; private final long[] innerSizesVespa; private final int[] dimensionMap; private OrderedTensorType(List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesOnnx = new long[dimensions.size()]; + this.innerSizesOriginal = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); } @@ -54,10 +54,10 @@ public class OrderedTensorType { if (numDimensions == 0) { return null; } - innerSizesOnnx[numDimensions - 1] = 1; + innerSizesOriginal[numDimensions - 1] = 1; innerSizesVespa[numDimensions - 1] = 1; for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1]; + innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1]; innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; } int[] mapping = new int[numDimensions]; @@ -74,11 +74,15 @@ public class OrderedTensorType { return mapping; } + public int dimensionMap(int originalIndex) { + return dimensionMap[originalIndex]; + } + /** - * When dimension ordering between Vespa and Onnx differs, i.e. + * When dimension ordering between Vespa and imported differs, i.e. * after dimension renaming, use the dimension map to read in values * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors from Onnx. + * Used when importing tensors. */ public int toDirectIndex(int index) { if (dimensions.size() == 0) { @@ -90,9 +94,9 @@ public class OrderedTensorType { int directIndex = 0; long rest = index; for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesOnnx[i]; + long address = rest / innerSizesOriginal[i]; directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesOnnx[i]; + rest %= innerSizesOriginal[i]; } return directIndex; } @@ -116,22 +120,6 @@ public class OrderedTensorType { return true; } - public void verifyType(Onnx.TypeProto typeProto) { - Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); - } - for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) { - int vespaIndex = dimensionMap[onnxIndex]; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); - TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); - if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions"); - } - } - } - } public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); for (TensorType.Dimension dimension : dimensions) { @@ -151,18 +139,13 @@ public class OrderedTensorType { return new OrderedTensorType(renamedDimensions); } - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { - return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - Builder builder = new Builder(shape); - for (int i = 0; i < shape.getDimCount(); ++ i) { + public OrderedTensorType rename(String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - if (onnxDimension.getDimValue() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + Optional<Long> dimSize = dimensions.get(i).size(); + if (dimSize.isPresent() && dimSize.get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -170,13 +153,13 @@ public class OrderedTensorType { return builder.build(); } - public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) { - Builder builder = new Builder(); - for (int i = 0; i < dims.size(); ++ i) { - String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); - if (dimSize >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); + public static OrderedTensorType standardType(OrderedTensorType type) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < type.dimensions().size(); ++ i) { + TensorType.Dimension dim = type.dimensions().get(i); + String dimensionName = "d" + i; + if (dim.size().isPresent() && dim.size().get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -184,13 +167,46 @@ public class OrderedTensorType { return builder.build(); } - public static OrderedTensorType standardType(OrderedTensorType type) { - Builder builder = new Builder(); - for (int i = 0; i < type.dimensions().size(); ++ i) { - TensorType.Dimension dim = type.dimensions().get(i); - String dimensionName = "d" + i; - if (dim.size().isPresent() && dim.size().get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); + public static Long tensorSize(TensorType type) { + Long size = 1L; + for (TensorType.Dimension dimension : type.dimensions()) { + size *= dimensionSize(dimension); + } + return size; + } + + public static Long dimensionSize(TensorType.Dimension dim) { + return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + } + + /** + * Returns a string representation of this: A standard tensor type string where dimensions + * are listed in the order of this rather than in the natural order of their names. + */ + @Override + public String toString() { + return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; + } + + /** + * Creates an instance from the string representation of this: A standard tensor type string + * where dimensions are listed in the order of this rather than the natural order of their names. + */ + public static OrderedTensorType fromSpec(String typeSpec) { + return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + } + + public static OrderedTensorType fromDimensionList(List<Long> dims) { + return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dims.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Long dimSize = dims.get(i); + if (dimSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { builder.add(TensorType.Dimension.indexed(dimensionName)); } @@ -200,45 +216,13 @@ public class OrderedTensorType { public static class Builder { - private final Onnx.TensorShapeProto shape; private final List<TensorType.Dimension> dimensions; - public Builder(Onnx.TensorShapeProto shape) { - this.shape = shape; - this.dimensions = new ArrayList<>(shape.getDimCount()); - } - public Builder() { - this.shape = null; this.dimensions = new ArrayList<>(); } public Builder add(TensorType.Dimension vespaDimension) { - if (shape != null) { - int index = dimensions.size(); - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index); - long size = onnxDimension.getDimValue(); - if (size >= 0) { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension types"); - } - if (!vespaDimension.size().isPresent()) { - throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + - "not have a size"); - } - if (vespaDimension.size().get() != size) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension sizes. TensorFlow: " + size + " Vespa: " + - vespaDimension.size().get()); - } - } else { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { - throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + - "dimension types"); - } - } - } this.dimensions.add(vespaDimension); return this; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java new file mode 100644 index 00000000000..3fe92440cae --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -0,0 +1,216 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; +import com.yahoo.tensor.functions.ScalarFunctions; +import onnx.Onnx; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ +public class GraphImporter { + + public static IntermediateOperation mapOperation(Onnx.NodeProto node, + List<IntermediateOperation> inputs, + IntermediateGraph graph) { + String nodeName = node.getName(); + String modelName = graph.name(); + + switch (node.getOpType().toLowerCase()) { + case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); + case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); + case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); + case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); + case "identity": return new Identity(modelName, nodeName, inputs); + case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); + case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); + case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); + case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "shape": return new Shape(modelName, nodeName, inputs); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); + case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + } + + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); + op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + return op; + } + + public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) { + Onnx.GraphProto onnxGraph = model.getGraph(); + + IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + importOperations(onnxGraph, intermediateGraph); + verifyOutputTypes(onnxGraph, intermediateGraph); + + return intermediateGraph; + } + + private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { + for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) { + importOperation(valueInfo.getName(), onnxGraph, intermediateGraph); + } + } + + private static IntermediateOperation importOperation(String name, + Onnx.GraphProto onnxGraph, + IntermediateGraph intermediateGraph) { + if (intermediateGraph.alreadyImported(name)) { + return intermediateGraph.get(name); + } + IntermediateOperation operation; + if (isArgumentTensor(name, onnxGraph)) { + Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); + if (valueInfoProto == null) + throw new IllegalArgumentException("Could not find argument tensor: " + name); + OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); + operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); + + intermediateGraph.inputs(intermediateGraph.defaultSignature()) + .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + + } else if (isConstantTensor(name, onnxGraph)) { + Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); + OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + operation = new Constant(intermediateGraph.name(), name, defaultType); + operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); + + } else { + Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph); + + if (isOutputNode(name, onnxGraph)) { + intermediateGraph.outputs(intermediateGraph.defaultSignature()) + .put(IntermediateOperation.namePartOf(name), operation.vespaName()); + } + } + intermediateGraph.put(operation.vespaName(), operation); + + return operation; + } + + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor == null; + } + + private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor != null; + } + + private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + } + return null; + } + + private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { + for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { + if (tensorProto.getName().equals(name)) { + return tensorProto; + } + } + return null; + } + + private static boolean isOutputNode(String name, Onnx.GraphProto graph) { + return getOutputNode(name, graph) != null; + } + + private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + String nodeName = IntermediateOperation.namePartOf(valueInfo.getName()); + if (nodeName.equals(name)) { + return valueInfo; + } + } + return null; + } + + private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node, + Onnx.GraphProto onnxGraph, + IntermediateGraph intermediateGraph) { + return node.getInputList().stream() + .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph)) + .collect(Collectors.toList()); + } + + private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { + for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) { + IntermediateOperation operation = intermediateGraph.get(outputName); + Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph); + OrderedTensorType type = operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); + TypeConverter.verifyType(onnxNode.getType(), type); + } + } + + private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { + boolean hasPortNumber = nodeName.contains(":"); + for (Onnx.NodeProto node : graph.getNodeList()) { + if (hasPortNumber) { + for (String outputName : node.getOutputList()) { + if (outputName.equals(nodeName)) { + return node; + } + } + } else if (node.getName().equals(nodeName)) { + return node; + } + } + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java index 2912db03b5f..18856d4a25f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; import com.google.protobuf.ByteString; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import onnx.Onnx; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; -import java.util.List; /** * Converts Onnx tensors into Vespa tensors. @@ -29,7 +28,6 @@ public class TensorConverter { return builder.build(); } - /* todo: support more types */ private static Values readValuesOf(Onnx.TensorProto tensorProto) { if (tensorProto.hasRawData()) { switch (tensorProto.getDataType()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java new file mode 100644 index 00000000000..715c55d8323 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java @@ -0,0 +1,52 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +/** + * Converts and verifies ONNX tensor types into Vespa tensor types. + * + * @author lesters + */ +public class TypeConverter { + + public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { + Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); + } + for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) { + int vespaIndex = type.dimensionMap(onnxIndex); + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions"); + } + } + } + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + if (onnxDimension.getDimValue() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java index 1619c11427a..7fc2aae87d1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java @@ -1,28 +1,29 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; +import java.util.Collections; import java.util.List; -public class Placeholder extends TensorFlowOperation { +public class Argument extends IntermediateOperation { private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - standardNamingType = OrderedTensorType.fromTensorFlowType(node); + public Argument(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); + this.type = type.rename(vespaName() + "_"); + standardNamingType = OrderedTensorType.standardType(type); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + return type; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java index 4f5d61d75f9..1b8c62fe0e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java @@ -1,38 +1,37 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class ConcatV2 extends TensorFlowOperation { +public class ConcatV2 extends IntermediateOperation { private String concatDimensionName; - public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { return null; } - TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input + IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "concat dimension must be a constant."); } Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "concat dimension must be a scalar."); } @@ -44,7 +43,7 @@ public class ConcatV2 extends TensorFlowOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "inputs must have save rank."); } for (int j = 0; j < aType.rank(); ++j) { @@ -53,13 +52,13 @@ public class ConcatV2 extends TensorFlowOperation { if (j == concatDim) { concatDimSize += dimSizeB; } else if (dimSizeA != dimSizeB) { - throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " + + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + "input dimension " + j + " differs in input tensors."); } } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDim) { @@ -75,7 +74,7 @@ public class ConcatV2 extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { return null; } TensorFunction result = inputs.get(0).function().get(); @@ -88,7 +87,7 @@ public class ConcatV2 extends TensorFlowOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { return; } OrderedTensorType a = inputs.get(0).type().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java index 718e2a4b3c2..3c0f8569c47 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java @@ -1,36 +1,38 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Const extends TensorFlowOperation { +public class Const extends IntermediateOperation { - public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + private final AttributeMap attributeMap; + + public Const(String modelName, + String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + OrderedTensorType type) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.type = type.rename(vespaName() + "_"); setConstantValue(value()); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + return type; } @Override @@ -55,7 +57,7 @@ public class Const extends TensorFlowOperation { /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName() + "_" + super.vespaName(); + return modelName + "_" + super.vespaName(); } @Override @@ -77,24 +79,11 @@ public class Const extends TensorFlowOperation { } private Value value() { - if ( ! node.getAttrMap().containsKey("value")) { - throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + - "const has missing 'value' attribute"); - } - AttrValue attrValue = node.getAttrMap().get("value"); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.B) { - return new BooleanValue(attrValue.getB()); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.I) { - return new DoubleValue(attrValue.getI()); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.F) { - return new DoubleValue(attrValue.getF()); + Optional<Value> value = attributeMap.get("value", type); + if ( ! value.isPresent()) { + throw new IllegalArgumentException("Node '" + name + "' of type " + + "const has missing or non-recognized 'value' attribute"); } - throw new IllegalArgumentException("Requesting value of constant in " + - node.getName() + " but type is not recognized."); + return value.get(); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java index 13043a61a8e..5e4abeaa234 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java @@ -1,38 +1,34 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; import java.util.Collections; import java.util.Optional; -public class Constant extends OnnxOperation { +public class Constant extends IntermediateOperation { - final String modelName; - final Onnx.TensorProto tensorProto; + private final String modelName; - public Constant(String modelName, Onnx.TensorProto tensorProto) { - super(null, Collections.emptyList()); + public Constant(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); this.modelName = modelName; - this.tensorProto = tensorProto; + this.type = type.rename(vespaName() + "_"); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName + "_" + vespaName(tensorProto.getName()); + return modelName + "_" + vespaName(name); } @Override protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_"); + return type; } @Override @@ -40,9 +36,14 @@ public class Constant extends OnnxOperation { return null; // will be added by function() since this is constant. } + /** + * Constant values are sent in via the constantValueFunction, as the + * dimension names and thus the data layout depends on the dimension + * renaming which happens after the conversion to intermediate graph. + */ @Override public Optional<Value> getConstantValue() { - return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); + return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type)); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java index 2d0f4c7042b..742ed8b89ab 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -12,18 +12,17 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; -public class ExpandDims extends TensorFlowOperation { +public class ExpandDims extends IntermediateOperation { private List<String> expandDimensions; - public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -32,14 +31,14 @@ public class ExpandDims extends TensorFlowOperation { return null; } - TensorFlowOperation axisOperation = inputs().get(1); + IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + throw new IllegalArgumentException("ExpandDims in " + name + ": " + "axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + throw new IllegalArgumentException("ExpandDims in " + name + ": " + "axis argument must be a scalar."); } @@ -49,7 +48,7 @@ public class ExpandDims extends TensorFlowOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java index 1408e7e04f0..d29bd4b7a9e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java @@ -1,22 +1,21 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Identity extends TensorFlowOperation { +public class Identity extends IntermediateOperation { - public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ @Override public String vespaName() { - return modelName() + "_" + super.vespaName(); + return modelName + "_" + super.vespaName(); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 3687bba8b85..43de29cedd5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; -import com.yahoo.searchlib.rankingexpression.RankingExpression; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; + import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Collections; @@ -20,43 +19,40 @@ import java.util.Optional; import java.util.function.Function; /** - * Wraps a TensorFlow node and produces the respective Vespa tensor operation. - * During import, a graph of these operations are constructed. Then, the - * types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper - * Vespa expressions can be extracted. + * Wraps an imported operation node and produces the respective Vespa tensor + * operation. During import, a graph of these operations are constructed. Then, + * the types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper Vespa + * expressions can be extracted. * * @author lesters */ -public abstract class TensorFlowOperation { - - protected final static String MACRO_PREFIX = "tf_macro_"; +public abstract class IntermediateOperation { - private final String modelName; + private final static String MACRO_PREFIX = "imported_ml_macro_"; - protected final NodeDef node; - protected final int port; - protected final List<TensorFlowOperation> inputs; - protected final List<TensorFlowOperation> outputs = new ArrayList<>(); - protected final List<String> importWarnings = new ArrayList<>(); + protected final String name; + protected final String modelName; + protected final List<IntermediateOperation> inputs; + protected final List<IntermediateOperation> outputs = new ArrayList<>(); protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction macro = null; + private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; - private List<TensorFlowOperation> controlInputs = Collections.emptyList(); + private List<IntermediateOperation> controlInputs = Collections.emptyList(); - TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { + protected Function<OrderedTensorType, Value> constantValueFunction = null; + + IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { + this.name = name; this.modelName = modelName; - this.node = node; - this.port = port; this.inputs = Collections.unmodifiableList(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } - protected String modelName() { return modelName; } - protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); @@ -65,9 +61,6 @@ public abstract class TensorFlowOperation { if (type == null) { type = lazyGetType(); } - if (type != null) { - type.verifyType(node); - } return Optional.ofNullable(type); } @@ -87,14 +80,14 @@ public abstract class TensorFlowOperation { return Optional.ofNullable(function); } - /** Return TensorFlow node */ - public NodeDef node() { return node; } + /** Returns original name of this operation node */ + public String name() { return name; } /** Return unmodifiable list of inputs */ - public List<TensorFlowOperation> inputs() { return inputs; } + public List<IntermediateOperation> inputs() { return inputs; } /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ - public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } + public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a Vespa ranking expression that should be added as a macro */ public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } @@ -109,22 +102,34 @@ public abstract class TensorFlowOperation { public boolean isInput() { return false; } /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); } + public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); } /** Sets the constant value */ public void setConstantValue(Value value) { constantValue = value; } /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + public Optional<Value> getConstantValue() { + if (constantValue != null) { + return Optional.of(constantValue); + } + if (constantValueFunction != null) { + return Optional.of(constantValueFunction.apply(type)); + } + return Optional.empty(); + } + + /** Set the constant value function */ + public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } /** Sets the external control inputs */ - public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; } + public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } /** Retrieve the control inputs for this operation */ - public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } + public String vespaName() { return vespaName(name); } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } /** Retrieve the valid Vespa name of this node if it is a macro */ public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } @@ -135,23 +140,48 @@ public abstract class TensorFlowOperation { /** Set an input warning */ public void warning(String warning) { importWarnings.add(warning); } - boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { - if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { - return false; - } + boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + node.getName() + "', got " + inputs.size()); + "for '" + name + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, TensorFlowOperation::type); + return verifyInputs(expected, IntermediateOperation::type); } boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, TensorFlowOperation::function); + return verifyInputs(expected, IntermediateOperation::function); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + /** + * An interface mapping operation attributes to Vespa Values. + * Adapter for differences in ONNX/TensorFlow. + */ + public interface AttributeMap { + Optional<Value> get(String key); + Optional<Value> get(String key, OrderedTensorType type); + Optional<List<Value>> getList(String key); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java index fe2004a528d..8413ed74118 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java @@ -1,24 +1,22 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.function.DoubleBinaryOperator; -public class Join extends OnnxOperation { +public class Join extends IntermediateOperation { private final DoubleBinaryOperator operator; - public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) { - super(node, inputs); + public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } @@ -61,8 +59,8 @@ public class Join extends OnnxOperation { return null; } - OnnxOperation a = largestInput(); - OnnxOperation b = smallestInput(); + IntermediateOperation a = largestInput(); + IntermediateOperation b = smallestInput(); List<String> aDimensionsToReduce = new ArrayList<>(); List<String> bDimensionsToReduce = new ArrayList<>(); @@ -107,13 +105,13 @@ public class Join extends OnnxOperation { } } - private OnnxOperation largestInput() { + private IntermediateOperation largestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); } - private OnnxOperation smallestInput() { + private IntermediateOperation smallestInput() { OrderedTensorType a = inputs.get(0).type().get(); OrderedTensorType b = inputs.get(1).type().get(); return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java index c015f5ecba8..f54ae83052f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java @@ -1,20 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; import java.util.function.DoubleUnaryOperator; -public class Map extends TensorFlowOperation { +public class Map extends IntermediateOperation { private final DoubleUnaryOperator operator; - public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) { - super(modelName, node, inputs, port); + public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java index 1b388e2ae89..52e223f9518 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java @@ -1,21 +1,18 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; -import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.function.DoubleBinaryOperator; -public class MatMul extends OnnxOperation { +public class MatMul extends IntermediateOperation { - public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) { - super(node, inputs); + public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java index 3eba872c6a0..95a77c07590 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java @@ -1,9 +1,10 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; @@ -13,20 +14,20 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; -public class Mean extends TensorFlowOperation { +public class Mean extends IntermediateOperation { + private final AttributeMap attributeMap; private List<String> reduceDimensions; - public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -34,9 +35,9 @@ public class Mean extends TensorFlowOperation { if (!allInputTypesPresent(2)) { return null; } - TensorFlowOperation reductionIndices = inputs.get(1); + IntermediateOperation reductionIndices = inputs.get(1); if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + node.getName() + ": " + + throw new IllegalArgumentException("Mean in " + name + ": " + "reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); @@ -54,7 +55,7 @@ public class Mean extends TensorFlowOperation { return reducedType(inputType, shouldKeepDimensions()); } - // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. + // optimization: if keepDims and one reduce dimension that has size 1: same as identity. @Override protected TensorFunction lazyGetFunction() { @@ -93,12 +94,12 @@ public class Mean extends TensorFlowOperation { } private boolean shouldKeepDimensions() { - AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); + Optional<Value> keepDims = attributeMap.get("keep_dims"); + return keepDims.isPresent() && keepDims.get().asBoolean(); } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if (!reduceDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java index 4c95e67e184..9d9eca47b1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java @@ -1,21 +1,20 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Merge extends TensorFlowOperation { +public class Merge extends IntermediateOperation { - public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override protected OrderedTensorType lazyGetType() { - for (TensorFlowOperation operation : inputs) { + for (IntermediateOperation operation : inputs) { if (operation.type().isPresent()) { return operation.type().get(); } @@ -25,7 +24,7 @@ public class Merge extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - for (TensorFlowOperation operation : inputs) { + for (IntermediateOperation operation : inputs) { if (operation.function().isPresent()) { return operation.function().get(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java new file mode 100644 index 00000000000..19ba146492c --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java @@ -0,0 +1,26 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends IntermediateOperation { + + public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java index 65ce7f00e34..9299ae9be12 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java @@ -1,17 +1,16 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class PlaceholderWithDefault extends TensorFlowOperation { +public class PlaceholderWithDefault extends IntermediateOperation { - public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java index e7d90e5fc1f..e91c2305f7d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java @@ -1,10 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -19,19 +18,18 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; -public class Reshape extends TensorFlowOperation { +public class Reshape extends IntermediateOperation { - public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -39,15 +37,15 @@ public class Reshape extends TensorFlowOperation { if (!allInputTypesPresent(2)) { return null; } - TensorFlowOperation newShape = inputs.get(1); + IntermediateOperation newShape = inputs.get(1); if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + + throw new IllegalArgumentException("Reshape in " + name + ": " + "shape input must be a constant."); } Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -124,7 +122,7 @@ public class Reshape extends TensorFlowOperation { operators.add(0, ArithmeticOperator.MULTIPLY); children.add(0, new ConstantNode(new DoubleValue(size))); } - size *= TensorConverter.dimensionSize(dimension); + size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { operators.add(0, ArithmeticOperator.PLUS); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java index 5fdcb5a695f..927a4a368f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java @@ -1,24 +1,23 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.function.DoubleBinaryOperator; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize; -import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize; +import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; -public class Select extends TensorFlowOperation { +public class Select extends IntermediateOperation { - public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override @@ -39,7 +38,7 @@ public class Select extends TensorFlowOperation { if (!allInputFunctionsPresent(3)) { return null; } - TensorFlowOperation conditionOperation = inputs().get(0); + IntermediateOperation conditionOperation = inputs().get(0); TensorFunction a = inputs().get(1).function().get(); TensorFunction b = inputs().get(2).function().get(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java index af49d2c108b..da566909adc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java @@ -1,20 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; -public class Shape extends TensorFlowOperation { +public class Shape extends IntermediateOperation { - public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); createConstantValue(); } @@ -24,7 +23,7 @@ public class Shape extends TensorFlowOperation { return null; } OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder(node) + return new OrderedTensorType.Builder() .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java index 17ce9e8b7cb..c750c47e27e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java @@ -1,26 +1,26 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; -public class Squeeze extends TensorFlowOperation { +public class Squeeze extends IntermediateOperation { + private final AttributeMap attributeMap; private List<String> squeezeDimensions; - public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -31,20 +31,21 @@ public class Squeeze extends TensorFlowOperation { OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); - AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { + Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); + if ( ! squeezeDimsAttr.isPresent()) { squeezeDimensions = inputType.type().dimensions().stream(). - filter(dim -> TensorConverter.dimensionSize(dim) == 1). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue). map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). - map(i -> inputType.type().dimensions().get(i.intValue())). - filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(i -> inputType.type().dimensions().get(i)). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). map(TensorType.Dimension::name). collect(Collectors.toList()); } + return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); } @@ -72,7 +73,7 @@ public class Squeeze extends TensorFlowOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java index de4d8862fd6..0171d1ea171 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java @@ -1,17 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; import java.util.List; import java.util.Optional; -public class Switch extends TensorFlowOperation { +public class Switch extends IntermediateOperation { - public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); + private final int port; + + public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { + super(modelName, nodeName, inputs); + this.port = port; } @Override @@ -21,7 +23,7 @@ public class Switch extends TensorFlowOperation { } Optional<OrderedTensorType> predicate = inputs.get(1).type(); if (predicate.get().type().rank() != 0) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "predicate must be a scalar"); } return inputs.get(0).type().orElse(null); @@ -29,13 +31,13 @@ public class Switch extends TensorFlowOperation { @Override protected TensorFunction lazyGetFunction() { - TensorFlowOperation predicateOperation = inputs().get(1); + IntermediateOperation predicateOperation = inputs().get(1); if (!predicateOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "predicate must be a constant"); } if (port < 0 || port > 1) { - throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + throw new IllegalArgumentException("Switch in " + name + ": " + "choice should be boolean"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java new file mode 100644 index 00000000000..a815cbc3944 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java @@ -0,0 +1,85 @@ +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Converts TensorFlow node attributes to Vespa attribute values. + * + * @author lesters + */ +public class AttributeConverter implements IntermediateOperation.AttributeMap { + + private final Map<String, AttrValue> attributeMap; + + public AttributeConverter(NodeDef node) { + attributeMap = node.getAttrMap(); + } + + public static AttributeConverter convert(NodeDef node) { + return new AttributeConverter(node); + } + + @Override + public Optional<Value> get(String key) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return Optional.empty(); // requires type + } + if (attrValue.getValueCase() == AttrValue.ValueCase.B) { + return Optional.of(new BooleanValue(attrValue.getB())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.I) { + return Optional.of(new DoubleValue(attrValue.getI())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.F) { + return Optional.of(new DoubleValue(attrValue.getF())); + } + } + return Optional.empty(); + } + + @Override + public Optional<Value> get(String key, OrderedTensorType type) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type()))); + } + } + return get(key); + } + + @Override + public Optional<List<Value>> getList(String key) { + if (attributeMap.containsKey(key)) { + AttrValue attrValue = attributeMap.get(key); + if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) { + AttrValue.ListValue listValue = attrValue.getList(); + if ( ! listValue.getBList().isEmpty()) { + return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList())); + } + if ( ! listValue.getIList().isEmpty()) { + return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList())); + } + if ( ! listValue.getFList().isEmpty()) { + return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList())); + } + // add the rest + } + } + return Optional.empty(); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java new file mode 100644 index 00000000000..e1b292f9e61 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java @@ -0,0 +1,234 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch; +import com.yahoo.tensor.functions.ScalarFunctions; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.MetaGraphDef; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.SignatureDef; +import org.tensorflow.framework.TensorInfo; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ +public class GraphImporter { + + public static IntermediateOperation mapOperation(NodeDef node, + List<IntermediateOperation> inputs, + IntermediateGraph graph) { + String nodeName = node.getName(); + String modelName = graph.name(); + int nodePort = IntermediateOperation.indexPartOf(nodeName); + OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node); + AttributeConverter attributes = AttributeConverter.convert(node); + + switch (node.getOp().toLowerCase()) { + // array ops + case "concatv2": return new ConcatV2(modelName, nodeName, inputs); + case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); + case "expanddims": return new ExpandDims(modelName, nodeName, inputs); + case "identity": return new Identity(modelName, nodeName, inputs); + case "placeholder": return new Argument(modelName, nodeName, nodeType); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "shape": return new Shape(modelName, nodeName, inputs); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); + + // control flow + case "merge": return new Merge(modelName, nodeName, inputs); + case "switch": return new Switch(modelName, nodeName, inputs, nodePort); + + // math ops + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "mean": return new Mean(modelName, nodeName, inputs, attributes); + case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, nodeName, inputs); + case "where3": return new Select(modelName, nodeName, inputs); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + + // nn ops + case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + + // state ops + case "variable": return new Constant(modelName, nodeName, nodeType); + case "variablev2": return new Constant(modelName, nodeName, nodeType); + + // evaluation no-ops + case "stopgradient":return new Identity(modelName, nodeName, inputs); + case "noop": return new NoOp(modelName, nodeName, inputs); + + } + + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); + op.warning("Operation '" + node.getOp() + "' is currently not implemented"); + return op; + } + + public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { + MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); + + IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + importSignatures(tfGraph, intermediateGraph); + importOperations(tfGraph, intermediateGraph, bundle); + verifyOutputTypes(tfGraph, intermediateGraph); + + return intermediateGraph; + } + + private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { + for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) { + String signatureName = signatureEntry.getKey(); + java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); + for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { + String inputName = input.getKey(); + String nodeName = input.getValue().getName(); + intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName)); + } + java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); + for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { + String outputName = output.getKey(); + String nodeName = output.getValue().getName(); + intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName)); + } + } + } + + private static void importOperations(MetaGraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + for (String signatureName : intermediateGraph.signatures()) { + for (String outputName : intermediateGraph.outputs(signatureName).values()) { + importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle); + } + } + } + + private static IntermediateOperation importOperation(String nodeName, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + if (intermediateGraph.alreadyImported(nodeName)) { + return intermediateGraph.get(nodeName); + } + NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph); + List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle); + IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph); + intermediateGraph.put(nodeName, operation); + + List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle); + if (controlInputs.size() > 0) { + operation.setControlInputs(controlInputs); + } + + if (operation.isConstant()) { + operation.setConstantValueFunction( + type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type))); + } + + return operation; + } + + private static List<IntermediateOperation> importOperationInputs(NodeDef node, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + return node.getInputList().stream() + .filter(name -> ! isControlDependency(name)) + .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) + .collect(Collectors.toList()); + } + + private static List<IntermediateOperation> importControlInputs(NodeDef node, + GraphDef tfGraph, + IntermediateGraph intermediateGraph, + SavedModelBundle bundle) { + return node.getInputList().stream() + .filter(nodeName -> isControlDependency(nodeName)) + .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) + .collect(Collectors.toList()); + } + + private static boolean isControlDependency(String name) { + return name.startsWith("^"); + } + + private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) { + for (NodeDef node : tfGraph.getNodeList()) { + if (node.getName().equals(name)) { + return node; + } + } + throw new IllegalArgumentException("Could not find node '" + name + "'"); + } + + public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + Session.Runner fetched = bundle.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + + ", but got " + importedTensors.size()); + return importedTensors.get(0); + } + + private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { + for (String signatureName : intermediateGraph.signatures()) { + for (String outputName : intermediateGraph.outputs(signatureName).values()) { + IntermediateOperation operation = intermediateGraph.get(outputName); + NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef()); + OrderedTensorType type = operation.type().orElseThrow( + () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); + TypeConverter.verifyType(node, type); + } + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java index 3f55e622fdf..d2d0acfc964 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java new file mode 100644 index 00000000000..67ad1edc312 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorShapeProto; + +import java.util.List; + +/** + * Converts and verifies TensorFlow tensor types into Vespa tensor types. + * + * @author lesters + */ +public class TypeConverter { + + public static void verifyType(NodeDef node, OrderedTensorType type) { + TensorShapeProto shape = tensorFlowShape(node); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + + "does not match Vespa shape"); + } + for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { + int vespaIndex = type.dimensionMap(tensorFlowIndex); + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + + "does not match Vespa dimensions"); + } + } + } + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + } + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + } + List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); + return shapeList.get(0); // support multiple outputs? + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node) { + return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + TensorShapeProto shape = tensorFlowShape(node); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java index 5cff8b03d40..1530754cc43 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java @@ -3,6 +3,6 @@ * ONNX integration */ @ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.onnx; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java deleted file mode 100644 index fa1f929cc80..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; -import onnx.Onnx; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; -import java.util.stream.Collectors; - -/** - * Converts a ONNX model into a ranking expression and set of constants. - * - * @author lesters - */ -public class OnnxImporter { - - private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); - - public OnnxModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); - } - - public OnnxModel importModel(String modelName, String modelPath) { - try (FileInputStream inputStream = new FileInputStream(modelPath)) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - return importModel(modelName, model); - } catch (IOException e) { - throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); - } - } - - public OnnxModel importModel(String modelName, Onnx.ModelProto model) { - return importGraph(modelName, model.getGraph()); - } - - private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) { - OnnxModel model = new OnnxModel(modelName); - OperationIndex index = new OperationIndex(); - - importNodes(graph, model, index); - verifyOutputTypes(graph, model, index); - findDimensionNames(model, index); - importExpressions(model, index); - - reportWarnings(model, index); - - return model; - } - - private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - importNode(valueInfo.getName(), graph, model, index); - } - } - - private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - if (index.alreadyImported(name)) { - return index.get(name); - } - OnnxOperation operation; - if (isArgumentTensor(name, graph)) { - operation = new Argument(getArgumentTensor(name, graph)); - model.input(OnnxOperation.namePartOf(name), operation.vespaName()); - } else if (isConstantTensor(name, graph)) { - operation = new Constant(model.name(), getConstantTensor(name, graph)); - } else { - Onnx.NodeProto node = getNodeFromGraph(name, graph); - List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index); - operation = OperationMapper.get(node, inputs); - if (isOutputNode(name, graph)) { - model.output(OnnxOperation.namePartOf(name), operation.vespaName()); - } - } - index.put(operation.vespaName(), operation); - - return operation; - } - - private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor == null; - } - - private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { - Onnx.ValueInfoProto value = getArgumentTensor(name, graph); - Onnx.TensorProto tensor = getConstantTensor(name, graph); - return value != null && tensor != null; - } - - private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - } - return null; - } - - private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { - for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { - if (tensorProto.getName().equals(name)) { - return tensorProto; - } - } - return null; - } - - private static boolean isOutputNode(String name, Onnx.GraphProto graph) { - return getOutputNode(name, graph) != null; - } - - private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { - for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { - if (valueInfo.getName().equals(name)) { - return valueInfo; - } - String nodeName = OnnxOperation.namePartOf(valueInfo.getName()); - if (nodeName.equals(name)) { - return valueInfo; - } - } - return null; - } - - private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, - Onnx.GraphProto graph, - OnnxModel model, - OperationIndex index) { - return node.getInputList().stream() - .map(nodeName -> importNode(nodeName, graph, model, index)) - .collect(Collectors.toList()); - } - - private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) { - for (String outputName : model.outputs().values()) { - OnnxOperation operation = index.get(outputName); - Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph); - operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")) - .verifyType(onnxNode.getType()); - } - } - - - /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(OnnxModel model, OperationIndex index) { - DimensionRenamer renamer = new DimensionRenamer(); - for (String output : model.outputs().values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - renamer.solve(); - for (String output : model.outputs().values()) { - renameDimensions(index.get(output), renamer); - } - } - - private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - - private static void importExpressions(OnnxModel model, OperationIndex index) { - for (String outputName : model.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(index.get(outputName), model); - if (!function.isPresent()) { - model.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - model.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - - private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(operation, model); - } - importInputExpressions(operation, model); - importRankingExpression(operation, model); - importArgumentExpression(operation, model); - - return operation.function(); - } - - private static void importInputExpressions(OnnxOperation operation, OnnxModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); - } - - private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Value value = operation.getConstantValue().orElseThrow(() -> - new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + - "is constant but does not have a value.")); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - - Tensor tensor = value.asTensor(); - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - private static void importRankingExpression(OnnxOperation operation, OnnxModel model) { - if (operation.function().isPresent()) { - String name = operation.vespaName(); - if (!model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); - - if (model.outputs().containsKey(name)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced from the output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) { - if (operation.isInput()) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.argument(operation.vespaName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void reportWarnings(OnnxModel model, OperationIndex index) { - for (String output : model.outputs().values()) { - reportWarnings(model, index.get(output)); - } - } - - private static void reportWarnings(OnnxModel model, OnnxOperation operation) { - for (String warning : operation.warnings()) { - model.importWarning(warning); - } - for (OnnxOperation input : operation.inputs()) { - reportWarnings(model, input); - } - } - - private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { - boolean hasPortNumber = nodeName.contains(":"); - for (Onnx.NodeProto node : graph.getNodeList()) { - if (hasPortNumber) { - for (String outputName : node.getOutputList()) { - if (outputName.equals(nodeName)) { - return node; - } - } - } else if (node.getName().equals(nodeName)) { - return node; - } - } - throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); - } - - private static class OperationIndex { - private final Map<String, OnnxOperation> index = new HashMap<>(); - public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); } - public OnnxOperation get(String key) { return index.get(key); } - public boolean alreadyImported(String key) { return index.containsKey(key); } - public Collection<OnnxOperation> operations() { return index.values(); } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java deleted file mode 100644 index bd53afefc3f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; - -/** - * The result of importing an ONNX model into Vespa. - * - * @author bratseth - * @author lesters - */ -public class OnnxModel { - - private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); - - private final String name; - - public OnnxModel(String name) { - if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" + - name + "'"); - this.name = name; - } - - /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ - public String name() { return name; } - - private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, String> outputs = new HashMap<>(); - private final Map<String, String> skippedOutputs = new HashMap<>(); - private final List<String> importWarnings = new ArrayList<>(); - - private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); - - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } - - /** - * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name - * to argument (Placeholder) name in the owner of this - */ - public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - - /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */ - public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); } - - /** Returns an immutable list of the expression names of this */ - public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } - - /** - * Returns an immutable list of the outputs of this which could not be imported, - * with a string detailing the reason for each - */ - public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - - /** - * Returns an immutable list of possibly non-fatal warnings encountered during import. - */ - public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - - /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */ - public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); } - - /** Returns an immutable map of the arguments (inputs) of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } - - /** - * Returns an immutable map of the small constants of this. - */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } - - /** - * Returns an immutable map of the large constants of this. - */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } - - /** - * Returns an immutable map of the expressions of this - corresponding to ONNX nodes - * which are not inputs or constants. - */ - public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - - /** Returns an immutable map of macros that are part of this model */ - public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } - - /** Returns an immutable map of the macros that must be provided by the environment running this model */ - public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java deleted file mode 100644 index 12090145d3a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; -import com.yahoo.tensor.functions.ScalarFunctions; -import onnx.Onnx; - -import java.util.List; - -public class OperationMapper { - - public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) { - switch (node.getOpType().toLowerCase()) { - case "add": return new Join(node, inputs, ScalarFunctions.add()); - case "matmul": return new MatMul(node, inputs); - } - - OnnxOperation op = new NoOp(node, inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); - return op; - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java deleted file mode 100644 index a8d8d63daf4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.Collections; -import java.util.List; - -public class Argument extends OnnxOperation { - - private Onnx.ValueInfoProto valueInfo; - private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - - public Argument(Onnx.ValueInfoProto valueInfoProto) { - super(null, Collections.emptyList()); - valueInfo = valueInfoProto; - standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType()); - } - - @Override - public String vespaName() { - return vespaName(valueInfo.getName()); - } - - @Override - protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_"); - } - - @Override - protected TensorFunction lazyGetFunction() { - TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { - List<String> renameFrom = standardNamingType.dimensionNames(); - List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); - } - return output; - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isInput() { - return true; - } - - @Override - public boolean isConstant() { - return false; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java deleted file mode 100644 index b1136a0ce0a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends OnnxOperation { - - public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) { - super(node, Collections.emptyList()); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java deleted file mode 100644 index 30f7b4f4711..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.functions.TensorFunction; -import onnx.Onnx; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.function.Function; - -/** - * Wraps an ONNX node and produces the respective Vespa tensor operation. - * During import, a graph of these operations are constructed. Then, the - * types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper - * Vespa expressions can be extracted. - * - * @author lesters - */ -public abstract class OnnxOperation { - - protected final Onnx.NodeProto node; // can be null for onnx inputs and constants - protected final List<OnnxOperation> inputs; - protected final List<OnnxOperation> outputs = new ArrayList<>(); - protected final List<String> importWarnings = new ArrayList<>(); - - protected OrderedTensorType type; - protected TensorFunction function; - protected Value constantValue = null; - - OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) { - this.node = node; - this.inputs = Collections.unmodifiableList(inputs); - this.inputs.forEach(i -> i.outputs.add(this)); - } - - protected abstract OrderedTensorType lazyGetType(); - protected abstract TensorFunction lazyGetFunction(); - - /** Returns the Vespa tensor type of this operation if it exists */ - public Optional<OrderedTensorType> type() { - if (type == null) { - type = lazyGetType(); - } - return Optional.ofNullable(type); - } - - /** Returns the Vespa tensor function implementing all operations from this node with inputs */ - public Optional<TensorFunction> function() { - if (function == null) { - if (isConstant()) { - ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); - function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); - } else { - function = lazyGetFunction(); - } - } - return Optional.ofNullable(function); - } - - /** Return Onnx node */ - public Onnx.NodeProto node() { return node; } - - /** Return unmodifiable list of inputs */ - public List<OnnxOperation> inputs() { return inputs; } - - /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ - public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); } - - /** Add dimension name constraints for this operation */ - public void addDimensionNameConstraints(DimensionRenamer renamer) { } - - /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } - - /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ - public boolean isInput() { return false; } - - /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); } - - /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } - - /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(node.getName()); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } - - /** Retrieve the list of warnings produced during its lifetime */ - public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } - - /** Set an input warning */ - public void warning(String warning) { importWarnings.add(warning); } - - boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) { - if (inputs.size() != expected) { - throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + node.getName() + "', got " + inputs.size()); - } - return inputs.stream().map(func).allMatch(Optional::isPresent); - } - - boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, OnnxOperation::type); - } - - boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, OnnxOperation::function); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - public static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output index part. Indexes are used for nodes with - * multiple outputs. - */ - public static int indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java deleted file mode 100644 index e3c72830095..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; - -import java.io.File; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; -import java.util.stream.Collectors; - -/** - * Converts a saved TensorFlow model into a ranking expression and set of constants. - * - * @author bratseth - * @author lesters - */ -public class TensorFlowImporter { - - private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - - /** - * Imports a saved TensorFlow model from a directory. - * The model should be saved as a .pbtxt or .pb file. - * The name of the model is taken as the db/pbtxt file name (not including the file ending). - * - * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] - * @param modelDir the directory containing the TensorFlow model files to import - */ - public TensorFlowModel importModel(String modelName, String modelDir) { - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - - return importModel(modelName, model); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); - } - } - - public TensorFlowModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); - } - - /** Imports a TensorFlow model */ - public TensorFlowModel importModel(String modelName, SavedModelBundle model) { - try { - return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model); - } - catch (IOException e) { - throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); - } - } - - /** - * Imports the TensorFlow graph by first importing the tensor types, then - * finding a suitable set of dimensions names for each - * placeholder/constant/variable, then importing the expressions. - */ - private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) { - TensorFlowModel model = new TensorFlowModel(modelName); - OperationIndex index = new OperationIndex(); - - importSignatures(graph, model); - importNodes(graph, model, index); - findDimensionNames(model, index); - importExpressions(model, index, bundle); - - reportWarnings(model, index); - logVariableTypes(index); - - return model; - } - - private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) { - for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - String signatureName = signatureEntry.getKey(); - TensorFlowModel.Signature signature = model.signature(signatureName); - - Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); - for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { - String inputName = input.getKey(); - signature.input(inputName, namePartOf(input.getValue().getName())); - } - - Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); - for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { - String outputName = output.getKey(); - signature.output(outputName, namePartOf(output.getValue().getName())); - } - } - } - - private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String inputName : signature.inputs().values()) { - if (inputName.equals(operation.node().getName())) { - return true; - } - } - } - return false; - } - - private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - if (outputName.equals(operation.node().getName())) { - return true; - } - } - } - return false; - } - - private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - importNode(model.name(), outputName, graph.getGraphDef(), index); - } - } - } - - private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) { - if (index.alreadyImported(nodeName)) { - return index.get(nodeName); - } - NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph); - List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index); - TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName)); - index.put(nodeName, operation); - - List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index); - if (controlInputs.size() > 0) { - operation.setControlInputs(controlInputs); - } - - return operation; - } - - private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { - return node.getInputList().stream() - .filter(name -> ! isControlDependency(name)) - .map(nodeName -> importNode(modelName, nodeName, graph, index)) - .collect(Collectors.toList()); - } - - private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { - return node.getInputList().stream() - .filter(nodeName -> isControlDependency(nodeName)) - .map(nodeName -> importNode(modelName, nodeName, graph, index)) - .collect(Collectors.toList()); - } - - private static boolean isControlDependency(String name) { - return name.startsWith("^"); - } - - /** Find dimension names to avoid excessive renaming while evaluating the model. */ - private static void findDimensionNames(TensorFlowModel model, OperationIndex index) { - DimensionRenamer renamer = new DimensionRenamer(); - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - } - renamer.solve(); - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - renameDimensions(index.get(output), renamer); - } - } - } - - private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - - private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); - if (!function.isPresent()) { - signature.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - } - - private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(model, operation, bundle); - } - - importInputExpressions(operation, model, bundle); - importRankingExpression(model, operation); - importInputExpression(model, operation); - importMacroExpression(model, operation); - - return operation.function(); - } - - private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, - SavedModelBundle bundle) { - operation.inputs().forEach(input -> importExpression(input, model, bundle)); - } - - private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.macro().isPresent()) { - TensorFunction function = operation.macro().get(); - try { - model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - - private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, - SavedModelBundle bundle) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Tensor tensor; - if (operation.getConstantValue().isPresent()) { - Value value = operation.getConstantValue().get(); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - tensor = value.asTensor(); - } else { - // Here we use the type from the operation, which will have correct dimension names after name resolving - tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), - operation.type().get()); - operation.setConstantValue(new TensorValue(tensor)); - } - - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { - Session.Runner fetched = bundle.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + - ", but got " + importedTensors.size()); - return importedTensors.get(0); - } - - private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.function().isPresent()) { - String name = operation.node().getName(); - if (!model.expressions().containsKey(operation.node().getName())) { - TensorFunction function = operation.function().get(); - - // Make sure output adheres to standard naming convention - if (isSignatureOutput(model, operation)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced in a signature output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) { - if (operation.isInput() && isSignatureInput(model, operation)) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); - model.argument(operation.node().getName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void reportWarnings(TensorFlowModel model, OperationIndex index) { - for (TensorFlowModel.Signature signature : model.signatures().values()) { - for (String output : signature.outputs().values()) { - reportWarnings(index.get(output), signature); - } - } - } - - /** - * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. - * This allows users to learn the exact types (including dimension order after renaming) of the Variables - * such that these can be converted and fed to a parent document independently of the rest of the model - * for fast model weight updates. - */ - private static void logVariableTypes(OperationIndex index) { - for (TensorFlowOperation operation : index.operations()) { - if ( ! (operation instanceof Variable)) continue; - if ( ! operation.type().isPresent()) continue; // will not happen - - log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() + - " of type " + operation.type().get()); - } - } - - private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { - for (String warning : operation.warnings()) { - signature.importWarning(warning); - } - for (TensorFlowOperation input : operation.inputs()) { - reportWarnings(input, signature); - } - } - - private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { - for (NodeDef node : graph.getNodeList()) { - if (node.getName().equals(name)) { - return node; - } - } - throw new IllegalArgumentException("Could not find node '" + name + "'"); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - private static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output port part. Indexes are used for nodes with - * multiple outputs. - */ - private static int portPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - - private static class OperationIndex { - - private final Map<String, TensorFlowOperation> index = new HashMap<>(); - public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } - public TensorFlowOperation get(String key) { return index.get(key); } - public boolean alreadyImported(String key) { return index.containsKey(key); } - public Collection<TensorFlowOperation> operations() { return index.values(); } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java deleted file mode 100644 index c1665d066a4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; - -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -/** - * A constraint satisfier to find suitable dimension names to reduce the - * amount of necessary renaming during evaluation of an imported model. - * - * @author lesters - */ -public class DimensionRenamer { - - private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - - private int iterations = 0; - - public DimensionRenamer() { - this("d"); - } - - public DimensionRenamer(String dimensionPrefix) { - this.dimensionPrefix = dimensionPrefix; - } - - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } - - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) { - Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric - } - - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); - } - - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - public void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts - - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } - } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. - } - - public void solve() { - solve(100000); - } - - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order - } - } - } - - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } - } - return true; - } - - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; - } - } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } - } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); - } - - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); - } - - public static boolean lesserThan(Integer x, Integer y) { - return x < y; - } - - public static boolean greaterThan(Integer x, Integer y) { - return x > y; - } - - private static class Arc { - - private final String from; - private final String to; - private final TensorFlowOperation operation; - - Arc(String from, String to, TensorFlowOperation operation) { - this.from = from; - this.to = to; - this.operation = operation; - } - - Arc opposite() { - return new Arc(to, from, operation); - } - - @Override - public int hashCode() { - return Objects.hash(from, to); - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { - return false; - } - Arc other = (Arc) obj; - return Objects.equals(from, other.from) && Objects.equals(to, other.to); - } - - @Override - public String toString() { - return String.format("%s -> %s", from, to); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java deleted file mode 100644 index b665413a6b2..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; -import com.yahoo.tensor.functions.ScalarFunctions; -import org.tensorflow.framework.NodeDef; - -import java.util.List; - -/** - * Maps from TensorFlow operations to Vespa operations. - * - * @author bratseth - * @author lesters - */ -public class OperationMapper { - - public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - switch (node.getOp().toLowerCase()) { - // array ops - case "concatv2": return new ConcatV2(modelName, node, inputs, port); - case "const": return new Const(modelName, node, inputs, port); - case "expanddims": return new ExpandDims(modelName, node, inputs, port); - case "identity": return new Identity(modelName, node, inputs, port); - case "placeholder": return new Placeholder(modelName, node, inputs, port); - case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port); - case "reshape": return new Reshape(modelName, node, inputs, port); - case "shape": return new Shape(modelName, node, inputs, port); - case "squeeze": return new Squeeze(modelName, node, inputs, port); - - // control flow - case "merge": return new Merge(modelName, node, inputs, port); - case "switch": return new Switch(modelName, node, inputs, port); - - // math ops - case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos()); - case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); - case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide()); - case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor()); - case "matmul": return new Matmul(modelName, node, inputs, port); - case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max()); - case "mean": return new Mean(modelName, node, inputs, port); - case "reducemean": return new Mean(modelName, node, inputs, port); - case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); - case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply()); - case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt()); - case "select": return new Select(modelName, node, inputs, port); - case "where3": return new Select(modelName, node, inputs, port); - case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference()); - case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); - case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract()); - - // nn ops - case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add()); - case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu()); - case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu()); - case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu()); - - // state ops - case "variable": return new Variable(modelName, node, inputs, port); - case "variablev2": return new Variable(modelName, node, inputs, port); - - // evaluation no-ops - case "stopgradient":return new Identity(modelName, node, inputs, port); - case "noop": return new NoOp(modelName, node, inputs, port); - } - - TensorFlowOperation op = new NoOp(modelName, node, inputs, port); - op.warning("Operation '" + node.getOp() + "' is currently not implemented"); - return op; - } - -} - - - diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java deleted file mode 100644 index 03a65333192..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.TensorTypeParser; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorShapeProto; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. TensorFlow tensors have an explicit ordering of their dimensions. - * During import, we need to track the Vespa dimension that matches the - * corresponding TensorFlow dimension as the ordering can change after - * dimension renaming. That is the purpose of this class. - * - * @author lesters - */ -public class OrderedTensorType { - - private final TensorType type; - private final List<TensorType.Dimension> dimensions; - - private final long[] innerSizesTensorFlow; - private final long[] innerSizesVespa; - private final int[] dimensionMap; - - private OrderedTensorType(List<TensorType.Dimension> dimensions) { - this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesTensorFlow = new long[dimensions.size()]; - this.innerSizesVespa = new long[dimensions.size()]; - this.dimensionMap = createDimensionMap(); - } - - public TensorType type() { - return this.type; - } - - public int rank() { return dimensions.size(); } - - public List<TensorType.Dimension> dimensions() { - return dimensions; - } - - public List<String> dimensionNames() { - return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); - } - - private int[] createDimensionMap() { - int numDimensions = dimensions.size(); - if (numDimensions == 0) { - return null; - } - innerSizesTensorFlow[numDimensions - 1] = 1; - innerSizesVespa[numDimensions - 1] = 1; - for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1]; - innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; - } - int[] mapping = new int[numDimensions]; - for (int i = 0; i < numDimensions; ++i) { - TensorType.Dimension dim1 = dimensions().get(i); - for (int j = 0; j < numDimensions; ++j) { - TensorType.Dimension dim2 = type.dimensions().get(j); - if (dim1.equals(dim2)) { - mapping[i] = j; - break; - } - } - } - return mapping; - } - - /** - * When dimension ordering between Vespa and TensorFlow differs, i.e. - * after dimension renaming, use the dimension map to read in values - * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors from TensorFlow. - */ - public int toDirectIndex(int index) { - if (dimensions.size() == 0) { - return 0; - } - if (dimensionMap == null) { - throw new IllegalArgumentException("Dimension map is not available"); - } - int directIndex = 0; - long rest = index; - for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesTensorFlow[i]; - directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesTensorFlow[i]; - } - return directIndex; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof OrderedTensorType)) { - return false; - } - OrderedTensorType other = (OrderedTensorType) obj; - if (dimensions.size() != dimensions.size()) { - return false; - } - List<TensorType.Dimension> thisDimensions = this.dimensions(); - List<TensorType.Dimension> otherDimensions = other.dimensions(); - for (int i = 0; i < thisDimensions.size(); ++i) { - if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { - return false; - } - } - return true; - } - - public void verifyType(NodeDef node) { - TensorShapeProto shape = tensorFlowShape(node); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); - } - for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) { - int vespaIndex = dimensionMap[tensorFlowIndex]; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); - TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); - if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); - } - } - } - } - - private static TensorShapeProto tensorFlowShape(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? - } - - public OrderedTensorType rename(DimensionRenamer renamer) { - List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); - for (TensorType.Dimension dimension : dimensions) { - String oldName = dimension.name(); - Optional<String> newName = renamer.dimensionNameOf(oldName); - if (!newName.isPresent()) - return this; // presumably, already renamed - TensorType.Dimension.Type dimensionType = dimension.type(); - if (dimensionType == TensorType.Dimension.Type.indexedBound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); - } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); - } else if (dimensionType == TensorType.Dimension.Type.mapped) { - renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); - } - } - return new OrderedTensorType(renamedDimensions); - } - - /** - * Returns a string representation of this: A standard tensor type string where dimensions - * are listed in the order of this rather than in the natural order of their names. - */ - @Override - public String toString() { - return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; - } - - /** - * Creates an instance from the string representation of this: A standard tensor type string - * where dimensions are listed in the order of this rather than the natural order of their names. - */ - public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node) { - return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - Builder builder = new Builder(node); - TensorShapeProto shape = tensorFlowShape(node); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); - if (tensorFlowDimension.getSize() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - public static class Builder { - - private final TensorShapeProto shape; - private final List<TensorType.Dimension> dimensions; - - public Builder(NodeDef node) { - this.shape = tensorFlowShape(node); - this.dimensions = new ArrayList<>(shape.getDimCount()); - } - - public Builder add(TensorType.Dimension vespaDimension) { - int index = dimensions.size(); - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index); - long size = tensorFlowDimension.getSize(); - if (size >= 0) { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); - } - if (!vespaDimension.size().isPresent()) { - throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + - "not have a size"); - } - if (vespaDimension.size().get() != size) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension sizes. TensorFlow: " + size + " Vespa: " + - vespaDimension.size().get()); - } - } else { - if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { - throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); - } - } - this.dimensions.add(vespaDimension); - return this; - } - - public OrderedTensorType build() { - return new OrderedTensorType(dimensions); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java deleted file mode 100644 index 6cbfe0dfb05..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.function.DoubleBinaryOperator; - -public class Join extends TensorFlowOperation { - - private final DoubleBinaryOperator operator; - - public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) { - super(modelName, node, inputs, port); - this.operator = operator; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - - // Well now we have potentially entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // In broadcasting, the size of each dimension is compared element-wise, - // starting with the trailing dimensions and working forward. A special - // case occurs when the size of one dimension is 1, while the other is not. - // Then the dimension with size 1 is "stretched" to be of compatible size. - // - // An example: - // - // Tensor A: d0[5], d1[1], d2[3], d3[1] - // Tensor B: d1[4], d2[1], d3[2] - // - // In TensorFlow and using the above rules of broadcasting, the resulting - // type is: - // d0[5], d1[4], d2[3], d2[2] - // - // However, in Vespa's tensor logic, the join of the two above tensors would - // result in a tensor of type: - // d0[5], d1[1], d2[1], d3[1] - // - // By reducing the dimensions of size 1 in each tensor before joining, - // we get equal results as in TensorFlow. - - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < a.rank(); ++i) { - TensorType.Dimension aDim = a.dimensions().get(i); - long size = aDim.size().orElse(-1L); - - if (i - sizeDifference >= 0) { - TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); - size = Math.max(size, bDim.size().orElse(-1L)); - } - - if (aDim.type() == TensorType.Dimension.Type.indexedBound) { - builder.add(TensorType.Dimension.indexed(aDim.name(), size)); - } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { - builder.add(TensorType.Dimension.indexed(aDim.name())); - } else if (aDim.type() == TensorType.Dimension.Type.mapped) { - builder.add(TensorType.Dimension.mapped(aDim.name())); - } - } - return builder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } - - TensorFlowOperation a = largestInput(); - TensorFlowOperation b = smallestInput(); - - List<String> aDimensionsToReduce = new ArrayList<>(); - List<String> bDimensionsToReduce = new ArrayList<>(); - int sizeDifference = a.type().get().rank() - b.type().get().rank(); - for (int i = 0; i < b.type().get().rank(); ++i) { - TensorType.Dimension bDim = b.type().get().dimensions().get(i); - TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); - long bSize = bDim.size().orElse(-1L); - long aSize = aDim.size().orElse(-1L); - if (bSize == 1L && aSize != 1L) { - bDimensionsToReduce.add(bDim.name()); - } - if (aSize == 1L && bSize != 1L) { - aDimensionsToReduce.add(bDim.name()); - } - } - - TensorFunction aReducedFunction = a.function().get(); - if (aDimensionsToReduce.size() > 0) { - aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); - } - TensorFunction bReducedFunction = b.function().get(); - if (bDimensionsToReduce.size() > 0) { - bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); - } - - return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < b.rank(); ++i) { - String bDim = b.dimensions().get(i).name(); - String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); - } - } - - private TensorFlowOperation largestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); - } - - private TensorFlowOperation smallestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java deleted file mode 100644 index b2b9530a161..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.List; -import java.util.Optional; - -public class Matmul extends TensorFlowOperation { - - public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); - return typeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); - - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); - - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); - - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java deleted file mode 100644 index d558ec89e87..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends TensorFlowOperation { - - public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, Collections.emptyList(), port); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java deleted file mode 100644 index b18a8a9b212..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; - -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.framework.NodeDef; - -import java.util.List; - -public class Variable extends TensorFlowOperation { - - public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) { - super(modelName, node, inputs, port); - } - - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName() + "_" + super.vespaName(); - } - - @Override - protected OrderedTensorType lazyGetType() { - return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_"); - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; // will be added by function() since this is constant. - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java deleted file mode 100644 index 9e53990a9d6..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Tensorflow integration - */ -@ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 0f5eec93feb..bf9684082f4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -15,7 +15,7 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); - TensorFlowModel.Signature signature = model.get().signature("serving_default"); + ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java index 74b0d11f1d6..c8c7ec798bb 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import org.junit.Test; import static org.junit.Assert.assertTrue; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 50a467ec581..a63c7346335 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; @@ -24,7 +24,7 @@ public class DropoutImportTestCase { assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("X")); - TensorFlowModel.Signature signature = model.get().signature("serving_default"); + ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java index 9f919c452d6..bd7644be23b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; @@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase { // Check signatures assertEquals(1, model.get().signatures().size()); - TensorFlowModel.Signature signature = model.get().signatures().get("serving_default"); + ImportedModel.Signature signature = model.get().signatures().get("serving_default"); assertNotNull(signature); // ... signature inputs diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index 4b68cd40a08..a7926cd2e02 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,11 +1,9 @@ -package com.yahoo.searchlib.rankingexpression.integration.onnx; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -24,7 +22,7 @@ public class OnnxMnistSoftmaxImportTestCase { @Test public void testMnistSoftmaxImport() throws IOException { - OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); + ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"); // Check constants assertEquals(2, model.largeConstants().size()); @@ -48,7 +46,7 @@ public class OnnxMnistSoftmaxImportTestCase { model.requiredMacros().get("Placeholder")); // Check outputs - RankingExpression output = model.outputExpression("add"); + RankingExpression output = model.defaultSignature().outputExpression("add"); assertNotNull(output); assertEquals("add", output.getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", @@ -68,13 +66,12 @@ public class OnnxMnistSoftmaxImportTestCase { } private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { - SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve"); - TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel); + ImportedModel model = new TensorFlowImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { - OnnxModel model = new OnnxImporter().importModel("test", path); + ImportedModel model = new OnnxImporter().importModel("test", path); return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); } @@ -83,14 +80,7 @@ public class OnnxMnistSoftmaxImportTestCase { return expression.evaluate(context).asTensor(); } - private Context contextFrom(TensorFlowModel result) { - MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - return context; - } - - private Context contextFrom(OnnxModel result) { + private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java index beec2ab1ead..b2443082ab1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java @@ -1,6 +1,6 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 7ca16939477..723c5f27914 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -1,11 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals; public class TestableTensorFlowModel { private SavedModelBundle tensorFlowModel; - private TensorFlowModel model; + private ImportedModel model; // Sizes of the input vector private final int d0Size = 1; @@ -39,7 +39,7 @@ public class TestableTensorFlowModel { model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); } - public TensorFlowModel get() { return model; } + public ImportedModel get() { return model; } public void assertEqualResult(String inputName, String operationName) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); @@ -66,7 +66,7 @@ public class TestableTensorFlowModel { return TensorConverter.toVespaTensor(results.get(0)); } - private Context contextFrom(TensorFlowModel result) { + private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); @@ -81,7 +81,7 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { + private void evaluateMacro(Context context, ImportedModel model, String macroName) { if (!context.names().contains(macroName)) { RankingExpression e = model.macros().get(macroName); evaluateMacroDependencies(context, model, e.getRoot()); @@ -89,7 +89,7 @@ public class TestableTensorFlowModel { } } - private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { + private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); if (model.macros().containsKey(name)) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java index 051c2c60c95..f94098e6255 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java @@ -1,4 +1,4 @@ -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +package com.yahoo.searchlib.rankingexpression.integration.ml; import org.junit.Test; |