summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-06 15:43:18 +0200
committerLester Solbakken <lesters@oath.com>2018-06-06 15:43:18 +0200
commit0bf235c481d24d627c82901a84bef585fe84bbb2 (patch)
tree6cb6d0b192f56f3e8fdb533fb9603d3f927fe3c1 /searchlib
parent389801098797ab37c7bc4ac5a3888ef4d92214e7 (diff)
Refactor ONNX and TF import to use same code base
This reverts commit 681963959794b47102d1a1cf72f215c72b0e2b51.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java)101
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java242
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java47
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java)9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java)10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java)154
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java216
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java52
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java)53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java)21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java)118
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java)29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java)24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java)33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java234
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java)3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java)2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java326
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java64
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java411
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java97
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java255
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java145
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java)6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java)22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java)14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java)2
58 files changed, 1526 insertions, 2407 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 721214f9e94..4b49f17f74e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,5 +1,4 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -13,76 +12,61 @@ import java.util.Map;
import java.util.regex.Pattern;
/**
- * The result of importing a TensorFlow model into Vespa.
- * - A set of signatures which are named collections of inputs and outputs.
- * - A set of named constant tensors represented by Variable nodes in TensorFlow.
- * - A list of warning messages.
+ * The result of importing a model (TensorFlow or ONNX) into Vespa.
*
* @author bratseth
*/
-// This object can be built incrementally within this package, but is immutable when observed from outside the package
-public class TensorFlowModel {
+public class ImportedModel {
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+ private static final String defaultSignatureName = "default";
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
/**
- * Creates a TensorFlow model
+ * Creates a new imported model.
*
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
*/
- public TensorFlowModel(String name) {
+ public ImportedModel(String name) {
if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
+ throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
this.name = name;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
- private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
- return signatures.computeIfAbsent(name, Signature::new);
- }
-
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/**
* Returns an immutable map of the small constants of this.
* These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow source.
+ * values given in the TensorFlow or ONNX source.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
/**
* Returns an immutable map of the large constants of this.
- * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
- * and correspond to Variable files stored separately in TensorFlow.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
+ * For TensorFlow this corresponds to Variable files stored separately.
*/
public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
- * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
- * which are not Placeholders or Variables (which instead become respectively arguments and constants).
- * Note that only nodes recursively referenced by a placeholder are added.
+ * Returns an immutable map of the expressions of this - corresponding to graph nodes
+ * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder/input are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -95,9 +79,26 @@ public class TensorFlowModel {
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, Signature::new);
+ }
+
+ /** Convenience method for returning a default signature */
+ Signature defaultSignature() { return signature(defaultSignatureName); }
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
- * and outputs maps to expressions nodes.
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument
+ * ("placeholder") names+types, and outputs maps to expressions nodes.
+ * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
+ * concept of signatures. For now, we handle ONNX models as having a single signature.
*/
public class Signature {
@@ -107,19 +108,14 @@ public class TensorFlowModel {
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
- Signature(String name) {
+ public Signature(String name) {
this.name = name;
}
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
-
public String name() { return name; }
/** Returns the result this is part of */
- TensorFlowModel owner() { return TensorFlowModel.this; }
+ public ImportedModel owner() { return ImportedModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
@@ -127,7 +123,7 @@ public class TensorFlowModel {
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ /** Returns the type of the argument this input references */
public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
@@ -144,12 +140,17 @@ public class TensorFlowModel {
*/
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
- /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ /** Returns the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
@Override
public String toString() { return "signature '" + name + "'"; }
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
new file mode 100644
index 00000000000..a658833b426
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -0,0 +1,242 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
+ * ranking expressions. The general mechanism for import is for the
+ * specific ML platform import implementations to create an
+ * IntermediateGraph. This class offers common code to convert the
+ * IntermediateGraph to Vespa ranking expressions and macros.
+ *
+ * @author lesters
+ */
+public abstract class ModelImporter {
+
+ private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
+
+ /**
+ * The main import function.
+ */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
+ public ImportedModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ /**
+ * Takes an IntermediateGraph and converts it to a ImportedModel containing
+ * the actual Vespa ranking expressions.
+ */
+ static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
+ ImportedModel model = new ImportedModel(graph.name());
+
+ graph.optimize();
+
+ importSignatures(graph, model);
+ importExpressions(graph, model);
+ reportWarnings(graph, model);
+ logVariableTypes(graph);
+
+ return model;
+ }
+
+ private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
+ for (String signatureName : graph.signatures()) {
+ ImportedModel.Signature signature = model.signature(signatureName);
+ for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
+ signature.input(input.getKey(), input.getValue());
+ }
+ for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
+ signature.output(output.getKey(), output.getValue());
+ }
+ }
+ }
+
+ private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Convert intermediate representation to Vespa ranking expressions.
+ */
+ static void importExpressions(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importExpressionInputs(operation, model);
+ importRankingExpression(operation, model);
+ importArgumentExpression(operation, model);
+ importMacroExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.name();
+ if (!model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Imported function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.argument(operation.vespaName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.macro().isPresent()) {
+ TensorFunction function = operation.macro().get();
+ try {
+ model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+
+ /**
+ * Add any import warnings to the signature in the ImportedModel.
+ */
+ private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ reportWarnings(graph.get(outputName), model);
+ }
+ }
+ }
+
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ for (String warning : operation.warnings()) {
+ model.defaultSignature().importWarning(warning);
+ }
+ for (IntermediateOperation input : operation.inputs()) {
+ reportWarnings(input, model);
+ }
+ }
+
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(IntermediateGraph graph) {
+ for (IntermediateOperation operation : graph.operations()) {
+ if ( ! (operation instanceof Constant)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+ log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
new file mode 100644
index 00000000000..d3dd2a1d418
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -0,0 +1,30 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
+import onnx.Onnx;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+
+/**
+ * Converts a ONNX model into a ranking expression and set of constants.
+ *
+ * @author lesters
+ */
+public class OnnxImporter extends ModelImporter {
+
+ @Override
+ public ImportedModel importModel(String modelName, String modelPath) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
new file mode 100644
index 00000000000..ff584559a83
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -0,0 +1,47 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import org.tensorflow.SavedModelBundle;
+
+import java.io.IOException;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorFlowImporter extends ModelImporter {
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a .pbtxt or .pb file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportedModel importModel(String modelName, String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importModel(modelName, model);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ /** Imports a TensorFlow model */
+ ImportedModel importModel(String modelName, SavedModelBundle model) {
+ try {
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
+ }
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
index c5ac7ace0fc..e1294ec3e01 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
@@ -1,7 +1,8 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -24,7 +25,7 @@ public class VariableConverter {
*/
public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName,
+ return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
bundle),
OrderedTensorType.fromSpec(orderedTypeSpec)));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
index 2524417cee0..38f1d2329e2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import java.util.ArrayDeque;
import java.util.ArrayList;
@@ -47,7 +47,7 @@ public class DimensionRenamer {
/**
* Add a constraint between dimension names.
*/
- public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) {
+ public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
Arc opposite = arc.opposite();
constraints.put(arc, pred);
@@ -175,9 +175,9 @@ public class DimensionRenamer {
private final String from;
private final String to;
- private final OnnxOperation operation;
+ private final IntermediateOperation operation;
- Arc(String from, String to, OnnxOperation operation) {
+ Arc(String from, String to, IntermediateOperation operation) {
this.from = from;
this.to = to;
this.operation = operation;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
new file mode 100644
index 00000000000..39a8b211d09
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Holds an intermediate representation of an imported ONNX or TensorFlow
+ * graph. After this intermediate representation is constructed, it is used to
+ * simplify and optimize the computational graph and then converted into the
+ * final ImportedModel that holds the Vespa ranking expressions for the model.
+ *
+ * @author lesters
+ */
+public class IntermediateGraph {
+
+ private final String modelName;
+ private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, GraphSignature> signatures = new HashMap<>();
+
+ private static class GraphSignature {
+ final Map<String, String> inputs = new HashMap<>();
+ final Map<String, String> outputs = new HashMap<>();
+ }
+
+ public IntermediateGraph(String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String name() {
+ return modelName;
+ }
+
+ public IntermediateOperation put(String key, IntermediateOperation operation) {
+ return index.put(key, operation);
+ }
+
+ public IntermediateOperation get(String key) {
+ return index.get(key);
+ }
+
+ public Set<String> signatures() {
+ return signatures.keySet();
+ }
+
+ public Map<String, String> inputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
+ }
+
+ public Map<String, String> outputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
+ }
+
+ public String defaultSignature() {
+ return "default";
+ }
+
+ public boolean alreadyImported(String key) {
+ return index.containsKey(key);
+ }
+
+ public Collection<IntermediateOperation> operations() {
+ return index.values();
+ }
+
+ public void optimize() {
+ renameDimensions();
+ }
+
+ /**
+ * Find dimension names to avoid excessive renaming while evaluating the model.
+ */
+ private void renameDimensions() {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
index 812e9b8d678..209d73a9f38 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.tensor.TensorType;
-import onnx.Onnx;
+import com.yahoo.tensor.TensorTypeParser;
import java.util.ArrayList;
import java.util.Collections;
@@ -13,9 +13,9 @@ import java.util.stream.Collectors;
/**
* A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. ONNX tensors have an explicit ordering of their dimensions.
+ * names. Imported tensors have an explicit ordering of their dimensions.
* During import, we need to track the Vespa dimension that matches the
- * corresponding ONNX dimension as the ordering can change after
+ * corresponding imported dimension as the ordering can change after
* dimension renaming. That is the purpose of this class.
*
* @author lesters
@@ -25,14 +25,14 @@ public class OrderedTensorType {
private final TensorType type;
private final List<TensorType.Dimension> dimensions;
- private final long[] innerSizesOnnx;
+ private final long[] innerSizesOriginal;
private final long[] innerSizesVespa;
private final int[] dimensionMap;
private OrderedTensorType(List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesOnnx = new long[dimensions.size()];
+ this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
}
@@ -54,10 +54,10 @@ public class OrderedTensorType {
if (numDimensions == 0) {
return null;
}
- innerSizesOnnx[numDimensions - 1] = 1;
+ innerSizesOriginal[numDimensions - 1] = 1;
innerSizesVespa[numDimensions - 1] = 1;
for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1];
+ innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
}
int[] mapping = new int[numDimensions];
@@ -74,11 +74,15 @@ public class OrderedTensorType {
return mapping;
}
+ public int dimensionMap(int originalIndex) {
+ return dimensionMap[originalIndex];
+ }
+
/**
- * When dimension ordering between Vespa and Onnx differs, i.e.
+ * When dimension ordering between Vespa and imported differs, i.e.
* after dimension renaming, use the dimension map to read in values
* so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from Onnx.
+ * Used when importing tensors.
*/
public int toDirectIndex(int index) {
if (dimensions.size() == 0) {
@@ -90,9 +94,9 @@ public class OrderedTensorType {
int directIndex = 0;
long rest = index;
for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesOnnx[i];
+ long address = rest / innerSizesOriginal[i];
directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesOnnx[i];
+ rest %= innerSizesOriginal[i];
}
return directIndex;
}
@@ -116,22 +120,6 @@ public class OrderedTensorType {
return true;
}
- public void verifyType(Onnx.TypeProto typeProto) {
- Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
- }
- for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) {
- int vespaIndex = dimensionMap[onnxIndex];
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
- }
- }
- }
- }
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
@@ -151,18 +139,13 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- Builder builder = new Builder(shape);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
+ public OrderedTensorType rename(String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ Optional<Long> dimSize = dimensions.get(i).size();
+ if (dimSize.isPresent() && dimSize.get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -170,13 +153,13 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
- Builder builder = new Builder();
- for (int i = 0; i < dims.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
- if (dimSize >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -184,13 +167,46 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType standardType(OrderedTensorType type) {
- Builder builder = new Builder();
- for (int i = 0; i < type.dimensions().size(); ++ i) {
- TensorType.Dimension dim = type.dimensions().get(i);
- String dimensionName = "d" + i;
- if (dim.size().isPresent() && dim.size().get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims) {
+ return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -200,45 +216,13 @@ public class OrderedTensorType {
public static class Builder {
- private final Onnx.TensorShapeProto shape;
private final List<TensorType.Dimension> dimensions;
- public Builder(Onnx.TensorShapeProto shape) {
- this.shape = shape;
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
public Builder() {
- this.shape = null;
this.dimensions = new ArrayList<>();
}
public Builder add(TensorType.Dimension vespaDimension) {
- if (shape != null) {
- int index = dimensions.size();
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index);
- long size = onnxDimension.getDimValue();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- }
- }
this.dimensions.add(vespaDimension);
return this;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
new file mode 100644
index 00000000000..3fe92440cae
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -0,0 +1,216 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import onnx.Onnx;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+
+ switch (node.getOpType().toLowerCase()) {
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
+ case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
+ Onnx.GraphProto onnxGraph = model.getGraph();
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importOperations(onnxGraph, intermediateGraph);
+ verifyOutputTypes(onnxGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
+ importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
+ }
+ }
+
+ private static IntermediateOperation importOperation(String name,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ if (intermediateGraph.alreadyImported(name)) {
+ return intermediateGraph.get(name);
+ }
+ IntermediateOperation operation;
+ if (isArgumentTensor(name, onnxGraph)) {
+ Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
+ if (valueInfoProto == null)
+ throw new IllegalArgumentException("Could not find argument tensor: " + name);
+ OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
+ operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
+
+ intermediateGraph.inputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+
+ } else if (isConstantTensor(name, onnxGraph)) {
+ Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ operation = new Constant(intermediateGraph.name(), name, defaultType);
+ operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+
+ } else {
+ Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph);
+
+ if (isOutputNode(name, onnxGraph)) {
+ intermediateGraph.outputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ }
+ }
+ intermediateGraph.put(operation.vespaName(), operation);
+
+ return operation;
+ }
+
+ private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor == null;
+ }
+
+ private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor != null;
+ }
+
+ private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
+ if (tensorProto.getName().equals(name)) {
+ return tensorProto;
+ }
+ }
+ return null;
+ }
+
+ private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
+ return getOutputNode(name, graph) != null;
+ }
+
+ private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ return node.getInputList().stream()
+ .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph))
+ .collect(Collectors.toList());
+ }
+
+ private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(onnxNode.getType(), type);
+ }
+ }
+
+ private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
+ boolean hasPortNumber = nodeName.contains(":");
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (hasPortNumber) {
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return node;
+ }
+ }
+ } else if (node.getName().equals(nodeName)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
index 2912db03b5f..18856d4a25f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
import com.google.protobuf.ByteString;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import onnx.Onnx;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
-import java.util.List;
/**
* Converts Onnx tensors into Vespa tensors.
@@ -29,7 +28,6 @@ public class TensorConverter {
return builder.build();
}
- /* todo: support more types */
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
new file mode 100644
index 00000000000..715c55d8323
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
@@ -0,0 +1,52 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+/**
+ * Converts and verifies ONNX tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
+ Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
+ }
+ for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
+ int vespaIndex = type.dimensionMap(onnxIndex);
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ if (onnxDimension.getDimValue() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 1619c11427a..7fc2aae87d1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,28 +1,29 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
+import java.util.Collections;
import java.util.List;
-public class Placeholder extends TensorFlowOperation {
+public class Argument extends IntermediateOperation {
private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
- public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- standardNamingType = OrderedTensorType.fromTensorFlowType(node);
+ public Argument(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
+ this.type = type.rename(vespaName() + "_");
+ standardNamingType = OrderedTensorType.standardType(type);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 4f5d61d75f9..1b8c62fe0e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,38 +1,37 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class ConcatV2 extends TensorFlowOperation {
+public class ConcatV2 extends IntermediateOperation {
private String concatDimensionName;
- public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return null;
}
- TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a constant.");
}
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a scalar.");
}
@@ -44,7 +43,7 @@ public class ConcatV2 extends TensorFlowOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"inputs must have save rank.");
}
for (int j = 0; j < aType.rank(); ++j) {
@@ -53,13 +52,13 @@ public class ConcatV2 extends TensorFlowOperation {
if (j == concatDim) {
concatDimSize += dimSizeB;
} else if (dimSizeA != dimSizeB) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"input dimension " + j + " differs in input tensors.");
}
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDim) {
@@ -75,7 +74,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
return null;
}
TensorFunction result = inputs.get(0).function().get();
@@ -88,7 +87,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return;
}
OrderedTensorType a = inputs.get(0).type().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
index 718e2a4b3c2..3c0f8569c47 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
@@ -1,36 +1,38 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Const extends TensorFlowOperation {
+public class Const extends IntermediateOperation {
- public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final AttributeMap attributeMap;
+
+ public Const(String modelName,
+ String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ OrderedTensorType type) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.type = type.rename(vespaName() + "_");
setConstantValue(value());
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
@@ -55,7 +57,7 @@ public class Const extends TensorFlowOperation {
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
@@ -77,24 +79,11 @@ public class Const extends TensorFlowOperation {
}
private Value value() {
- if ( ! node.getAttrMap().containsKey("value")) {
- throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
- "const has missing 'value' attribute");
- }
- AttrValue attrValue = node.getAttrMap().get("value");
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return new BooleanValue(attrValue.getB());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return new DoubleValue(attrValue.getI());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return new DoubleValue(attrValue.getF());
+ Optional<Value> value = attributeMap.get("value", type);
+ if ( ! value.isPresent()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "const has missing or non-recognized 'value' attribute");
}
- throw new IllegalArgumentException("Requesting value of constant in " +
- node.getName() + " but type is not recognized.");
+ return value.get();
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
index 13043a61a8e..5e4abeaa234 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
@@ -1,38 +1,34 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.Collections;
import java.util.Optional;
-public class Constant extends OnnxOperation {
+public class Constant extends IntermediateOperation {
- final String modelName;
- final Onnx.TensorProto tensorProto;
+ private final String modelName;
- public Constant(String modelName, Onnx.TensorProto tensorProto) {
- super(null, Collections.emptyList());
+ public Constant(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
this.modelName = modelName;
- this.tensorProto = tensorProto;
+ this.type = type.rename(vespaName() + "_");
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + vespaName(tensorProto.getName());
+ return modelName + "_" + vespaName(name);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_");
+ return type;
}
@Override
@@ -40,9 +36,14 @@ public class Constant extends OnnxOperation {
return null; // will be added by function() since this is constant.
}
+ /**
+ * Constant values are sent in via the constantValueFunction, as the
+ * dimension names and thus the data layout depends on the dimension
+ * renaming which happens after the conversion to intermediate graph.
+ */
@Override
public Optional<Value> getConstantValue() {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+ return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
index 2d0f4c7042b..742ed8b89ab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -12,18 +12,17 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
-public class ExpandDims extends TensorFlowOperation {
+public class ExpandDims extends IntermediateOperation {
private List<String> expandDimensions;
- public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -32,14 +31,14 @@ public class ExpandDims extends TensorFlowOperation {
return null;
}
- TensorFlowOperation axisOperation = inputs().get(1);
+ IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis argument must be a scalar.");
}
@@ -49,7 +48,7 @@ public class ExpandDims extends TensorFlowOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
index 1408e7e04f0..d29bd4b7a9e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
@@ -1,22 +1,21 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Identity extends TensorFlowOperation {
+public class Identity extends IntermediateOperation {
- public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 3687bba8b85..43de29cedd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
@@ -20,43 +19,40 @@ import java.util.Optional;
import java.util.function.Function;
/**
- * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
+ * Wraps an imported operation node and produces the respective Vespa tensor
+ * operation. During import, a graph of these operations are constructed. Then,
+ * the types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper Vespa
+ * expressions can be extracted.
*
* @author lesters
*/
-public abstract class TensorFlowOperation {
-
- protected final static String MACRO_PREFIX = "tf_macro_";
+public abstract class IntermediateOperation {
- private final String modelName;
+ private final static String MACRO_PREFIX = "imported_ml_macro_";
- protected final NodeDef node;
- protected final int port;
- protected final List<TensorFlowOperation> inputs;
- protected final List<TensorFlowOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
+ protected final String name;
+ protected final String modelName;
+ protected final List<IntermediateOperation> inputs;
+ protected final List<IntermediateOperation> outputs = new ArrayList<>();
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction macro = null;
+ private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
- private List<TensorFlowOperation> controlInputs = Collections.emptyList();
+ private List<IntermediateOperation> controlInputs = Collections.emptyList();
- TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ protected Function<OrderedTensorType, Value> constantValueFunction = null;
+
+ IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
+ this.name = name;
this.modelName = modelName;
- this.node = node;
- this.port = port;
this.inputs = Collections.unmodifiableList(inputs);
this.inputs.forEach(i -> i.outputs.add(this));
}
- protected String modelName() { return modelName; }
-
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
@@ -65,9 +61,6 @@ public abstract class TensorFlowOperation {
if (type == null) {
type = lazyGetType();
}
- if (type != null) {
- type.verifyType(node);
- }
return Optional.ofNullable(type);
}
@@ -87,14 +80,14 @@ public abstract class TensorFlowOperation {
return Optional.ofNullable(function);
}
- /** Return TensorFlow node */
- public NodeDef node() { return node; }
+ /** Returns original name of this operation node */
+ public String name() { return name; }
/** Return unmodifiable list of inputs */
- public List<TensorFlowOperation> inputs() { return inputs; }
+ public List<IntermediateOperation> inputs() { return inputs; }
/** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
+ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a Vespa ranking expression that should be added as a macro */
public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
@@ -109,22 +102,34 @@ public abstract class TensorFlowOperation {
public boolean isInput() { return false; }
/** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
+ public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
/** Sets the constant value */
public void setConstantValue(Value value) { constantValue = value; }
/** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+ public Optional<Value> getConstantValue() {
+ if (constantValue != null) {
+ return Optional.of(constantValue);
+ }
+ if (constantValueFunction != null) {
+ return Optional.of(constantValueFunction.apply(type));
+ }
+ return Optional.empty();
+ }
+
+ /** Set the constant value function */
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
/** Sets the external control inputs */
- public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
+ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
/** Retrieve the control inputs for this operation */
- public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+ public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+ public String vespaName() { return vespaName(name); }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the valid Vespa name of this node if it is a macro */
public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; }
@@ -135,23 +140,48 @@ public abstract class TensorFlowOperation {
/** Set an input warning */
public void warning(String warning) { importWarnings.add(warning); }
- boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
- if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
- return false;
- }
+ boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
+ "for '" + name + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}
boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::type);
+ return verifyInputs(expected, IntermediateOperation::type);
}
boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::function);
+ return verifyInputs(expected, IntermediateOperation::function);
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+ /**
+ * An interface mapping operation attributes to Vespa Values.
+ * Adapter for differences in ONNX/TensorFlow.
+ */
+ public interface AttributeMap {
+ Optional<Value> get(String key);
+ Optional<Value> get(String key, OrderedTensorType type);
+ Optional<List<Value>> getList(String key);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index fe2004a528d..8413ed74118 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,24 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-public class Join extends OnnxOperation {
+public class Join extends IntermediateOperation {
private final DoubleBinaryOperator operator;
- public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) {
- super(node, inputs);
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
@@ -61,8 +59,8 @@ public class Join extends OnnxOperation {
return null;
}
- OnnxOperation a = largestInput();
- OnnxOperation b = smallestInput();
+ IntermediateOperation a = largestInput();
+ IntermediateOperation b = smallestInput();
List<String> aDimensionsToReduce = new ArrayList<>();
List<String> bDimensionsToReduce = new ArrayList<>();
@@ -107,13 +105,13 @@ public class Join extends OnnxOperation {
}
}
- private OnnxOperation largestInput() {
+ private IntermediateOperation largestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
- private OnnxOperation smallestInput() {
+ private IntermediateOperation smallestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
index c015f5ecba8..f54ae83052f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;
-public class Map extends TensorFlowOperation {
+public class Map extends IntermediateOperation {
private final DoubleUnaryOperator operator;
- public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
- super(modelName, node, inputs, port);
+ public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
index 1b388e2ae89..52e223f9518 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
@@ -1,21 +1,18 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-public class MatMul extends OnnxOperation {
+public class MatMul extends IntermediateOperation {
- public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, inputs);
+ public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 3eba872c6a0..95a77c07590 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -13,20 +14,20 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
-public class Mean extends TensorFlowOperation {
+public class Mean extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> reduceDimensions;
- public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -34,9 +35,9 @@ public class Mean extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation reductionIndices = inputs.get(1);
+ IntermediateOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Mean in " + name + ": " +
"reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
@@ -54,7 +55,7 @@ public class Mean extends TensorFlowOperation {
return reducedType(inputType, shouldKeepDimensions());
}
- // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
@@ -93,12 +94,12 @@ public class Mean extends TensorFlowOperation {
}
private boolean shouldKeepDimensions() {
- AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
+ Optional<Value> keepDims = attributeMap.get("keep_dims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if (!reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
index 4c95e67e184..9d9eca47b1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
@@ -1,21 +1,20 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Merge extends TensorFlowOperation {
+public class Merge extends IntermediateOperation {
- public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.type().isPresent()) {
return operation.type().get();
}
@@ -25,7 +24,7 @@ public class Merge extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.function().isPresent()) {
return operation.function().get();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
new file mode 100644
index 00000000000..19ba146492c
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
@@ -0,0 +1,26 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends IntermediateOperation {
+
+ public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
index 65ce7f00e34..9299ae9be12 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class PlaceholderWithDefault extends TensorFlowOperation {
+public class PlaceholderWithDefault extends IntermediateOperation {
- public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e7d90e5fc1f..e91c2305f7d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -1,10 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -19,19 +18,18 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Reshape extends TensorFlowOperation {
+public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,15 +37,15 @@ public class Reshape extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation newShape = inputs.get(1);
+ IntermediateOperation newShape = inputs.get(1);
if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Reshape in " + name + ": " +
"shape input must be a constant.");
}
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -124,7 +122,7 @@ public class Reshape extends TensorFlowOperation {
operators.add(0, ArithmeticOperator.MULTIPLY);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
- size *= TensorConverter.dimensionSize(dimension);
+ size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
operators.add(0, ArithmeticOperator.PLUS);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
index 5fdcb5a695f..927a4a368f9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
@@ -1,24 +1,23 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Select extends TensorFlowOperation {
+public class Select extends IntermediateOperation {
- public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,7 +38,7 @@ public class Select extends TensorFlowOperation {
if (!allInputFunctionsPresent(3)) {
return null;
}
- TensorFlowOperation conditionOperation = inputs().get(0);
+ IntermediateOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
index af49d2c108b..da566909adc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Shape extends TensorFlowOperation {
+public class Shape extends IntermediateOperation {
- public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
createConstantValue();
}
@@ -24,7 +23,7 @@ public class Shape extends TensorFlowOperation {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder(node)
+ return new OrderedTensorType.Builder()
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
index 17ce9e8b7cb..c750c47e27e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
@@ -1,26 +1,26 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
-public class Squeeze extends TensorFlowOperation {
+public class Squeeze extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> squeezeDimensions;
- public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -31,20 +31,21 @@ public class Squeeze extends TensorFlowOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
- AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
+ Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
+ if ( ! squeezeDimsAttr.isPresent()) {
squeezeDimensions = inputType.type().dimensions().stream().
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
} else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
- map(i -> inputType.type().dimensions().get(i.intValue())).
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(i -> inputType.type().dimensions().get(i)).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
}
+
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
@@ -72,7 +73,7 @@ public class Squeeze extends TensorFlowOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
index de4d8862fd6..0171d1ea171 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
@@ -1,17 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Switch extends TensorFlowOperation {
+public class Switch extends IntermediateOperation {
- public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final int port;
+
+ public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
+ super(modelName, nodeName, inputs);
+ this.port = port;
}
@Override
@@ -21,7 +23,7 @@ public class Switch extends TensorFlowOperation {
}
Optional<OrderedTensorType> predicate = inputs.get(1).type();
if (predicate.get().type().rank() != 0) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a scalar");
}
return inputs.get(0).type().orElse(null);
@@ -29,13 +31,13 @@ public class Switch extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- TensorFlowOperation predicateOperation = inputs().get(1);
+ IntermediateOperation predicateOperation = inputs().get(1);
if (!predicateOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a constant");
}
if (port < 0 || port > 1) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"choice should be boolean");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
new file mode 100644
index 00000000000..a815cbc3944
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
@@ -0,0 +1,85 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Converts TensorFlow node attributes to Vespa attribute values.
+ *
+ * @author lesters
+ */
+public class AttributeConverter implements IntermediateOperation.AttributeMap {
+
+ private final Map<String, AttrValue> attributeMap;
+
+ public AttributeConverter(NodeDef node) {
+ attributeMap = node.getAttrMap();
+ }
+
+ public static AttributeConverter convert(NodeDef node) {
+ return new AttributeConverter(node);
+ }
+
+ @Override
+ public Optional<Value> get(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.empty(); // requires type
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return Optional.of(new BooleanValue(attrValue.getB()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return Optional.of(new DoubleValue(attrValue.getI()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return Optional.of(new DoubleValue(attrValue.getF()));
+ }
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<Value> get(String key, OrderedTensorType type) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
+ }
+ }
+ return get(key);
+ }
+
+ @Override
+ public Optional<List<Value>> getList(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
+ AttrValue.ListValue listValue = attrValue.getList();
+ if ( ! listValue.getBList().isEmpty()) {
+ return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getIList().isEmpty()) {
+ return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getFList().isEmpty()) {
+ return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ // add the rest
+ }
+ }
+ return Optional.empty();
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
new file mode 100644
index 00000000000..e1b292f9e61
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -0,0 +1,234 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(NodeDef node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+ int nodePort = IntermediateOperation.indexPartOf(nodeName);
+ OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
+ AttributeConverter attributes = AttributeConverter.convert(node);
+
+ switch (node.getOp().toLowerCase()) {
+ // array ops
+ case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
+ case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
+ case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "placeholder": return new Argument(modelName, nodeName, nodeType);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
+
+ // control flow
+ case "merge": return new Merge(modelName, nodeName, inputs);
+ case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
+
+ // math ops
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, nodeName, inputs);
+ case "where3": return new Select(modelName, nodeName, inputs);
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+
+ // nn ops
+ case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+
+ // state ops
+ case "variable": return new Constant(modelName, nodeName, nodeType);
+ case "variablev2": return new Constant(modelName, nodeName, nodeType);
+
+ // evaluation no-ops
+ case "stopgradient":return new Identity(modelName, nodeName, inputs);
+ case "noop": return new NoOp(modelName, nodeName, inputs);
+
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
+ MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importSignatures(tfGraph, intermediateGraph);
+ importOperations(tfGraph, intermediateGraph, bundle);
+ verifyOutputTypes(tfGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
+ String signatureName = signatureEntry.getKey();
+ java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ String nodeName = input.getValue().getName();
+ intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
+ String outputName = output.getKey();
+ String nodeName = output.getValue().getName();
+ intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ }
+ }
+
+ private static void importOperations(MetaGraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
+ }
+ }
+ }
+
+ private static IntermediateOperation importOperation(String nodeName,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ if (intermediateGraph.alreadyImported(nodeName)) {
+ return intermediateGraph.get(nodeName);
+ }
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
+ IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
+ intermediateGraph.put(nodeName, operation);
+
+ List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
+ }
+
+ if (operation.isConstant()) {
+ operation.setConstantValueFunction(
+ type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
+ }
+
+ return operation;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static List<IntermediateOperation> importControlInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
+ }
+
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
+ for (NodeDef node : tfGraph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ Session.Runner fetched = bundle.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name +
+ ", but got " + importedTensors.size());
+ return importedTensors.get(0);
+ }
+
+ private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(node, type);
+ }
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
index 3f55e622fdf..d2d0acfc964 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
new file mode 100644
index 00000000000..67ad1edc312
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
@@ -0,0 +1,72 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.List;
+
+/**
+ * Converts and verifies TensorFlow tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(NodeDef node, OrderedTensorType type) {
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
+ int vespaIndex = type.dimensionMap(tensorFlowIndex);
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
index 5cff8b03d40..1530754cc43 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
@@ -3,6 +3,6 @@
* ONNX integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
deleted file mode 100644
index fa1f929cc80..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
+++ /dev/null
@@ -1,326 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a ONNX model into a ranking expression and set of constants.
- *
- * @author lesters
- */
-public class OnnxImporter {
-
- private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
-
- public OnnxModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- public OnnxModel importModel(String modelName, String modelPath) {
- try (FileInputStream inputStream = new FileInputStream(modelPath)) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- return importModel(modelName, model);
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
- }
- }
-
- public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
- return importGraph(modelName, model.getGraph());
- }
-
- private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
- OnnxModel model = new OnnxModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importNodes(graph, model, index);
- verifyOutputTypes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index);
-
- reportWarnings(model, index);
-
- return model;
- }
-
- private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- importNode(valueInfo.getName(), graph, model, index);
- }
- }
-
- private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- if (index.alreadyImported(name)) {
- return index.get(name);
- }
- OnnxOperation operation;
- if (isArgumentTensor(name, graph)) {
- operation = new Argument(getArgumentTensor(name, graph));
- model.input(OnnxOperation.namePartOf(name), operation.vespaName());
- } else if (isConstantTensor(name, graph)) {
- operation = new Constant(model.name(), getConstantTensor(name, graph));
- } else {
- Onnx.NodeProto node = getNodeFromGraph(name, graph);
- List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
- operation = OperationMapper.get(node, inputs);
- if (isOutputNode(name, graph)) {
- model.output(OnnxOperation.namePartOf(name), operation.vespaName());
- }
- }
- index.put(operation.vespaName(), operation);
-
- return operation;
- }
-
- private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor == null;
- }
-
- private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
- }
-
- private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
- if (tensorProto.getName().equals(name)) {
- return tensorProto;
- }
- }
- return null;
- }
-
- private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
- return getOutputNode(name, graph) != null;
- }
-
- private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
- if (nodeName.equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
- Onnx.GraphProto graph,
- OnnxModel model,
- OperationIndex index) {
- return node.getInputList().stream()
- .map(nodeName -> importNode(nodeName, graph, model, index))
- .collect(Collectors.toList());
- }
-
- private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- OnnxOperation operation = index.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
- operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
- .verifyType(onnxNode.getType());
- }
- }
-
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(OnnxModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (String output : model.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- renamer.solve();
- for (String output : model.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
-
- private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model);
- if (!function.isPresent()) {
- model.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- model.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(operation, model);
- }
- importInputExpressions(operation, model);
- importRankingExpression(operation, model);
- importArgumentExpression(operation, model);
-
- return operation.function();
- }
-
- private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
- }
-
- private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Value value = operation.getConstantValue().orElseThrow(() ->
- new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
- "is constant but does not have a value."));
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
-
- Tensor tensor = value.asTensor();
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.function().isPresent()) {
- String name = operation.vespaName();
- if (!model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
-
- if (model.outputs().containsKey(name)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced from the output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.isInput()) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(OnnxModel model, OperationIndex index) {
- for (String output : model.outputs().values()) {
- reportWarnings(model, index.get(output));
- }
- }
-
- private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
- for (String warning : operation.warnings()) {
- model.importWarning(warning);
- }
- for (OnnxOperation input : operation.inputs()) {
- reportWarnings(model, input);
- }
- }
-
- private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
- }
- } else if (node.getName().equals(nodeName)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
- }
-
- private static class OperationIndex {
- private final Map<String, OnnxOperation> index = new HashMap<>();
- public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); }
- public OnnxOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<OnnxOperation> operations() { return index.values(); }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
deleted file mode 100644
index bd53afefc3f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * The result of importing an ONNX model into Vespa.
- *
- * @author bratseth
- * @author lesters
- */
-public class OnnxModel {
-
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
-
- private final String name;
-
- public OnnxModel(String name) {
- if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
- this.name = name;
- }
-
- /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
- public String name() { return name; }
-
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
- private final Map<String, String> skippedOutputs = new HashMap<>();
- private final List<String> importWarnings = new ArrayList<>();
-
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /**
- * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * to argument (Placeholder) name in the owner of this
- */
- public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
-
- /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
- public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
-
- /** Returns an immutable list of the expression names of this */
- public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
-
- /**
- * Returns an immutable list of the outputs of this which could not be imported,
- * with a string detailing the reason for each
- */
- public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
-
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
- public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
- public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
-
- /** Returns an immutable map of the arguments (inputs) of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
-
- /**
- * Returns an immutable map of the small constants of this.
- */
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
-
- /**
- * Returns an immutable map of the large constants of this.
- */
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
-
- /**
- * Returns an immutable map of the expressions of this - corresponding to ONNX nodes
- * which are not inputs or constants.
- */
- public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
-
- /** Returns an immutable map of macros that are part of this model */
- public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
-
- /** Returns an immutable map of the macros that must be provided by the environment running this model */
- public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
deleted file mode 100644
index 12090145d3a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import onnx.Onnx;
-
-import java.util.List;
-
-public class OperationMapper {
-
- public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- switch (node.getOpType().toLowerCase()) {
- case "add": return new Join(node, inputs, ScalarFunctions.add());
- case "matmul": return new MatMul(node, inputs);
- }
-
- OnnxOperation op = new NoOp(node, inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
- return op;
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
deleted file mode 100644
index a8d8d63daf4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class Argument extends OnnxOperation {
-
- private Onnx.ValueInfoProto valueInfo;
- private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
-
- public Argument(Onnx.ValueInfoProto valueInfoProto) {
- super(null, Collections.emptyList());
- valueInfo = valueInfoProto;
- standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType());
- }
-
- @Override
- public String vespaName() {
- return vespaName(valueInfo.getName());
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
- List<String> renameFrom = standardNamingType.dimensionNames();
- List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
- }
- return output;
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isInput() {
- return true;
- }
-
- @Override
- public boolean isConstant() {
- return false;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
deleted file mode 100644
index b1136a0ce0a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends OnnxOperation {
-
- public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, Collections.emptyList()); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
deleted file mode 100644
index 30f7b4f4711..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.Function;
-
-/**
- * Wraps an ONNX node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
- *
- * @author lesters
- */
-public abstract class OnnxOperation {
-
- protected final Onnx.NodeProto node; // can be null for onnx inputs and constants
- protected final List<OnnxOperation> inputs;
- protected final List<OnnxOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
-
- protected OrderedTensorType type;
- protected TensorFunction function;
- protected Value constantValue = null;
-
- OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- this.node = node;
- this.inputs = Collections.unmodifiableList(inputs);
- this.inputs.forEach(i -> i.outputs.add(this));
- }
-
- protected abstract OrderedTensorType lazyGetType();
- protected abstract TensorFunction lazyGetFunction();
-
- /** Returns the Vespa tensor type of this operation if it exists */
- public Optional<OrderedTensorType> type() {
- if (type == null) {
- type = lazyGetType();
- }
- return Optional.ofNullable(type);
- }
-
- /** Returns the Vespa tensor function implementing all operations from this node with inputs */
- public Optional<TensorFunction> function() {
- if (function == null) {
- if (isConstant()) {
- ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
- function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
- } else {
- function = lazyGetFunction();
- }
- }
- return Optional.ofNullable(function);
- }
-
- /** Return Onnx node */
- public Onnx.NodeProto node() { return node; }
-
- /** Return unmodifiable list of inputs */
- public List<OnnxOperation> inputs() { return inputs; }
-
- /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); }
-
- /** Add dimension name constraints for this operation */
- public void addDimensionNameConstraints(DimensionRenamer renamer) { }
-
- /** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
-
- /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
- public boolean isInput() { return false; }
-
- /** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); }
-
- /** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
-
- /** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(node.getName()); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
-
- /** Retrieve the list of warnings produced during its lifetime */
- public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Set an input warning */
- public void warning(String warning) { importWarnings.add(warning); }
-
- boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
- if (inputs.size() != expected) {
- throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
- }
- return inputs.stream().map(func).allMatch(Optional::isPresent);
- }
-
- boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::type);
- }
-
- boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::function);
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- public static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output index part. Indexes are used for nodes with
- * multiple outputs.
- */
- public static int indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
deleted file mode 100644
index e3c72830095..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ /dev/null
@@ -1,411 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorFlowImporter {
-
- private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
-
- /**
- * Imports a saved TensorFlow model from a directory.
- * The model should be saved as a .pbtxt or .pb file.
- * The name of the model is taken as the db/pbtxt file name (not including the file ending).
- *
- * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
- * @param modelDir the directory containing the TensorFlow model files to import
- */
- public TensorFlowModel importModel(String modelName, String modelDir) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
-
- return importModel(modelName, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- public TensorFlowModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- /** Imports a TensorFlow model */
- public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
- try {
- return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
- /**
- * Imports the TensorFlow graph by first importing the tensor types, then
- * finding a suitable set of dimensions names for each
- * placeholder/constant/variable, then importing the expressions.
- */
- private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
- TensorFlowModel model = new TensorFlowModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importSignatures(graph, model);
- importNodes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index, bundle);
-
- reportWarnings(model, index);
- logVariableTypes(index);
-
- return model;
- }
-
- private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
- for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- TensorFlowModel.Signature signature = model.signature(signatureName);
-
- Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- signature.input(inputName, namePartOf(input.getValue().getName()));
- }
-
- Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- signature.output(outputName, namePartOf(output.getValue().getName()));
- }
- }
- }
-
- private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String inputName : signature.inputs().values()) {
- if (inputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- if (outputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- importNode(model.name(), outputName, graph.getGraphDef(), index);
- }
- }
- }
-
- private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
- if (index.alreadyImported(nodeName)) {
- return index.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
- List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
- TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
- index.put(nodeName, operation);
-
- List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- return operation;
- }
-
- private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- }
- renamer.solve();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
- }
-
- private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
- if (!function.isPresent()) {
- signature.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(model, operation, bundle);
- }
-
- importInputExpressions(operation, model, bundle);
- importRankingExpression(model, operation);
- importInputExpression(model, operation);
- importMacroExpression(model, operation);
-
- return operation.function();
- }
-
- private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model,
- SavedModelBundle bundle) {
- operation.inputs().forEach(input -> importExpression(input, model, bundle));
- }
-
- private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.macro().isPresent()) {
- TensorFunction function = operation.macro().get();
- try {
- model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
-
- private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation,
- SavedModelBundle bundle) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Tensor tensor;
- if (operation.getConstantValue().isPresent()) {
- Value value = operation.getConstantValue().get();
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
- tensor = value.asTensor();
- } else {
- // Here we use the type from the operation, which will have correct dimension names after name resolving
- tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
- operation.type().get());
- operation.setConstantValue(new TensorValue(tensor));
- }
-
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.function().isPresent()) {
- String name = operation.node().getName();
- if (!model.expressions().containsKey(operation.node().getName())) {
- TensorFunction function = operation.function().get();
-
- // Make sure output adheres to standard naming convention
- if (isSignatureOutput(model, operation)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced in a signature output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.isInput() && isSignatureInput(model, operation)) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
- model.argument(operation.node().getName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- reportWarnings(index.get(output), signature);
- }
- }
- }
-
- /**
- * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
- * This allows users to learn the exact types (including dimension order after renaming) of the Variables
- * such that these can be converted and fed to a parent document independently of the rest of the model
- * for fast model weight updates.
- */
- private static void logVariableTypes(OperationIndex index) {
- for (TensorFlowOperation operation : index.operations()) {
- if ( ! (operation instanceof Variable)) continue;
- if ( ! operation.type().isPresent()) continue; // will not happen
-
- log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
- " of type " + operation.type().get());
- }
- }
-
- private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
- for (String warning : operation.warnings()) {
- signature.importWarning(warning);
- }
- for (TensorFlowOperation input : operation.inputs()) {
- reportWarnings(input, signature);
- }
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
- for (NodeDef node : graph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- private static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output port part. Indexes are used for nodes with
- * multiple outputs.
- */
- private static int portPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
- private static class OperationIndex {
-
- private final Map<String, TensorFlowOperation> index = new HashMap<>();
- public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
- public TensorFlowOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<TensorFlowOperation> operations() { return index.values(); }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
deleted file mode 100644
index c1665d066a4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
+++ /dev/null
@@ -1,210 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-
-/**
- * A constraint satisfier to find suitable dimension names to reduce the
- * amount of necessary renaming during evaluation of an imported model.
- *
- * @author lesters
- */
-public class DimensionRenamer {
-
- private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
-
- private int iterations = 0;
-
- public DimensionRenamer() {
- this("d");
- }
-
- public DimensionRenamer(String dimensionPrefix) {
- this.dimensionPrefix = dimensionPrefix;
- }
-
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
-
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
- Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
- }
-
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
- }
-
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- public void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
-
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
- }
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
- }
-
- public void solve() {
- solve(100000);
- }
-
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
- }
- }
- }
-
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
- }
- return true;
- }
-
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
- }
- }
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
- }
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
- }
-
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
- }
-
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
- }
-
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
- }
-
- private static class Arc {
-
- private final String from;
- private final String to;
- private final TensorFlowOperation operation;
-
- Arc(String from, String to, TensorFlowOperation operation) {
- this.from = from;
- this.to = to;
- this.operation = operation;
- }
-
- Arc opposite() {
- return new Arc(to, from, operation);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(from, to);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
- return false;
- }
- Arc other = (Arc) obj;
- return Objects.equals(from, other.from) && Objects.equals(to, other.to);
- }
-
- @Override
- public String toString() {
- return String.format("%s -> %s", from, to);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
deleted file mode 100644
index b665413a6b2..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-/**
- * Maps from TensorFlow operations to Vespa operations.
- *
- * @author bratseth
- * @author lesters
- */
-public class OperationMapper {
-
- public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, node, inputs, port);
- case "const": return new Const(modelName, node, inputs, port);
- case "expanddims": return new ExpandDims(modelName, node, inputs, port);
- case "identity": return new Identity(modelName, node, inputs, port);
- case "placeholder": return new Placeholder(modelName, node, inputs, port);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port);
- case "reshape": return new Reshape(modelName, node, inputs, port);
- case "shape": return new Shape(modelName, node, inputs, port);
- case "squeeze": return new Squeeze(modelName, node, inputs, port);
-
- // control flow
- case "merge": return new Merge(modelName, node, inputs, port);
- case "switch": return new Switch(modelName, node, inputs, port);
-
- // math ops
- case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos());
- case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor());
- case "matmul": return new Matmul(modelName, node, inputs, port);
- case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max());
- case "mean": return new Mean(modelName, node, inputs, port);
- case "reducemean": return new Mean(modelName, node, inputs, port);
- case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, node, inputs, port);
- case "where3": return new Select(modelName, node, inputs, port);
- case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu());
- case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu());
- case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu());
-
- // state ops
- case "variable": return new Variable(modelName, node, inputs, port);
- case "variablev2": return new Variable(modelName, node, inputs, port);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, node, inputs, port);
- case "noop": return new NoOp(modelName, node, inputs, port);
- }
-
- TensorFlowOperation op = new NoOp(modelName, node, inputs, port);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
-}
-
-
-
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
deleted file mode 100644
index 03a65333192..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ /dev/null
@@ -1,255 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorTypeParser;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. TensorFlow tensors have an explicit ordering of their dimensions.
- * During import, we need to track the Vespa dimension that matches the
- * corresponding TensorFlow dimension as the ordering can change after
- * dimension renaming. That is the purpose of this class.
- *
- * @author lesters
- */
-public class OrderedTensorType {
-
- private final TensorType type;
- private final List<TensorType.Dimension> dimensions;
-
- private final long[] innerSizesTensorFlow;
- private final long[] innerSizesVespa;
- private final int[] dimensionMap;
-
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
- this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesTensorFlow = new long[dimensions.size()];
- this.innerSizesVespa = new long[dimensions.size()];
- this.dimensionMap = createDimensionMap();
- }
-
- public TensorType type() {
- return this.type;
- }
-
- public int rank() { return dimensions.size(); }
-
- public List<TensorType.Dimension> dimensions() {
- return dimensions;
- }
-
- public List<String> dimensionNames() {
- return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
- }
-
- private int[] createDimensionMap() {
- int numDimensions = dimensions.size();
- if (numDimensions == 0) {
- return null;
- }
- innerSizesTensorFlow[numDimensions - 1] = 1;
- innerSizesVespa[numDimensions - 1] = 1;
- for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
- innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
- }
- int[] mapping = new int[numDimensions];
- for (int i = 0; i < numDimensions; ++i) {
- TensorType.Dimension dim1 = dimensions().get(i);
- for (int j = 0; j < numDimensions; ++j) {
- TensorType.Dimension dim2 = type.dimensions().get(j);
- if (dim1.equals(dim2)) {
- mapping[i] = j;
- break;
- }
- }
- }
- return mapping;
- }
-
- /**
- * When dimension ordering between Vespa and TensorFlow differs, i.e.
- * after dimension renaming, use the dimension map to read in values
- * so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from TensorFlow.
- */
- public int toDirectIndex(int index) {
- if (dimensions.size() == 0) {
- return 0;
- }
- if (dimensionMap == null) {
- throw new IllegalArgumentException("Dimension map is not available");
- }
- int directIndex = 0;
- long rest = index;
- for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesTensorFlow[i];
- directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesTensorFlow[i];
- }
- return directIndex;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof OrderedTensorType)) {
- return false;
- }
- OrderedTensorType other = (OrderedTensorType) obj;
- if (dimensions.size() != dimensions.size()) {
- return false;
- }
- List<TensorType.Dimension> thisDimensions = this.dimensions();
- List<TensorType.Dimension> otherDimensions = other.dimensions();
- for (int i = 0; i < thisDimensions.size(); ++i) {
- if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
- return false;
- }
- }
- return true;
- }
-
- public void verifyType(NodeDef node) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
- int vespaIndex = dimensionMap[tensorFlowIndex];
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- public OrderedTensorType rename(DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
- for (TensorType.Dimension dimension : dimensions) {
- String oldName = dimension.name();
- Optional<String> newName = renamer.dimensionNameOf(oldName);
- if (!newName.isPresent())
- return this; // presumably, already renamed
- TensorType.Dimension.Type dimensionType = dimension.type();
- if (dimensionType == TensorType.Dimension.Type.indexedBound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
- } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
- } else if (dimensionType == TensorType.Dimension.Type.mapped) {
- renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
- }
- }
- return new OrderedTensorType(renamedDimensions);
- }
-
- /**
- * Returns a string representation of this: A standard tensor type string where dimensions
- * are listed in the order of this rather than in the natural order of their names.
- */
- @Override
- public String toString() {
- return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
-
- /**
- * Creates an instance from the string representation of this: A standard tensor type string
- * where dimensions are listed in the order of this rather than the natural order of their names.
- */
- public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- Builder builder = new Builder(node);
- TensorShapeProto shape = tensorFlowShape(node);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static class Builder {
-
- private final TensorShapeProto shape;
- private final List<TensorType.Dimension> dimensions;
-
- public Builder(NodeDef node) {
- this.shape = tensorFlowShape(node);
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
- public Builder add(TensorType.Dimension vespaDimension) {
- int index = dimensions.size();
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
- long size = tensorFlowDimension.getSize();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- }
- this.dimensions.add(vespaDimension);
- return this;
- }
-
- public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
deleted file mode 100644
index 6cbfe0dfb05..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-
-public class Join extends TensorFlowOperation {
-
- private final DoubleBinaryOperator operator;
-
- public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
- super(modelName, node, inputs, port);
- this.operator = operator;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
-
- // Well now we have potentially entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // In broadcasting, the size of each dimension is compared element-wise,
- // starting with the trailing dimensions and working forward. A special
- // case occurs when the size of one dimension is 1, while the other is not.
- // Then the dimension with size 1 is "stretched" to be of compatible size.
- //
- // An example:
- //
- // Tensor A: d0[5], d1[1], d2[3], d3[1]
- // Tensor B: d1[4], d2[1], d3[2]
- //
- // In TensorFlow and using the above rules of broadcasting, the resulting
- // type is:
- // d0[5], d1[4], d2[3], d2[2]
- //
- // However, in Vespa's tensor logic, the join of the two above tensors would
- // result in a tensor of type:
- // d0[5], d1[1], d2[1], d3[1]
- //
- // By reducing the dimensions of size 1 in each tensor before joining,
- // we get equal results as in TensorFlow.
-
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < a.rank(); ++i) {
- TensorType.Dimension aDim = a.dimensions().get(i);
- long size = aDim.size().orElse(-1L);
-
- if (i - sizeDifference >= 0) {
- TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
- size = Math.max(size, bDim.size().orElse(-1L));
- }
-
- if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
- builder.add(TensorType.Dimension.indexed(aDim.name(), size));
- } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
- builder.add(TensorType.Dimension.indexed(aDim.name()));
- } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
- builder.add(TensorType.Dimension.mapped(aDim.name()));
- }
- }
- return builder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
-
- TensorFlowOperation a = largestInput();
- TensorFlowOperation b = smallestInput();
-
- List<String> aDimensionsToReduce = new ArrayList<>();
- List<String> bDimensionsToReduce = new ArrayList<>();
- int sizeDifference = a.type().get().rank() - b.type().get().rank();
- for (int i = 0; i < b.type().get().rank(); ++i) {
- TensorType.Dimension bDim = b.type().get().dimensions().get(i);
- TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
- long bSize = bDim.size().orElse(-1L);
- long aSize = aDim.size().orElse(-1L);
- if (bSize == 1L && aSize != 1L) {
- bDimensionsToReduce.add(bDim.name());
- }
- if (aSize == 1L && bSize != 1L) {
- aDimensionsToReduce.add(bDim.name());
- }
- }
-
- TensorFunction aReducedFunction = a.function().get();
- if (aDimensionsToReduce.size() > 0) {
- aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
- }
- TensorFunction bReducedFunction = b.function().get();
- if (bDimensionsToReduce.size() > 0) {
- bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
- }
-
- return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < b.rank(); ++i) {
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
- }
- }
-
- private TensorFlowOperation largestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
- private TensorFlowOperation smallestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
deleted file mode 100644
index b2b9530a161..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Optional;
-
-public class Matmul extends TensorFlowOperation {
-
- public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
-
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
-
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
deleted file mode 100644
index d558ec89e87..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends TensorFlowOperation {
-
- public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, Collections.emptyList(), port); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
deleted file mode 100644
index b18a8a9b212..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-public class Variable extends TensorFlowOperation {
-
- public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName() + "_" + super.vespaName();
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null; // will be added by function() since this is constant.
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
deleted file mode 100644
index 9e53990a9d6..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Tensorflow integration
- */
-@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 0f5eec93feb..bf9684082f4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -15,7 +15,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
index 74b0d11f1d6..c8c7ec798bb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 50a467ec581..a63c7346335 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,7 @@ public class DropoutImportTestCase {
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
model.get().requiredMacros().get("X"));
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
index 9f919c452d6..bd7644be23b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase {
// Check signatures
assertEquals(1, model.get().signatures().size());
- TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
+ ImportedModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index 4b68cd40a08..a7926cd2e02 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,11 +1,9 @@
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -24,7 +22,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -48,7 +46,7 @@ public class OnnxMnistSoftmaxImportTestCase {
model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.outputExpression("add");
+ RankingExpression output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
@@ -68,13 +66,12 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
- TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
+ ImportedModel model = new TensorFlowImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel("test", path);
+ ImportedModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
@@ -83,14 +80,7 @@ public class OnnxMnistSoftmaxImportTestCase {
return expression.evaluate(context).asTensor();
}
- private Context contextFrom(TensorFlowModel result) {
- MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- return context;
- }
-
- private Context contextFrom(OnnxModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
index beec2ab1ead..b2443082ab1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 7ca16939477..723c5f27914 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -1,11 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class TestableTensorFlowModel {
private SavedModelBundle tensorFlowModel;
- private TensorFlowModel model;
+ private ImportedModel model;
// Sizes of the input vector
private final int d0Size = 1;
@@ -39,7 +39,7 @@ public class TestableTensorFlowModel {
model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
- public TensorFlowModel get() { return model; }
+ public ImportedModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
@@ -66,7 +66,7 @@ public class TestableTensorFlowModel {
return TensorConverter.toVespaTensor(results.get(0));
}
- private Context contextFrom(TensorFlowModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
@@ -81,7 +81,7 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
+ private void evaluateMacro(Context context, ImportedModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
@@ -89,7 +89,7 @@ public class TestableTensorFlowModel {
}
}
- private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
+ private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.macros().containsKey(name)) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
index 051c2c60c95..f94098e6255 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
@@ -1,4 +1,4 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import org.junit.Test;