aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-06 15:43:18 +0200
committerLester Solbakken <lesters@oath.com>2018-06-06 15:43:18 +0200
commit0bf235c481d24d627c82901a84bef585fe84bbb2 (patch)
tree6cb6d0b192f56f3e8fdb533fb9603d3f927fe3c1
parent389801098797ab37c7bc4ac5a3888ef4d92214e7 (diff)
Refactor ONNX and TF import to use same code base
This reverts commit 681963959794b47102d1a1cf72f215c72b0e2b51.
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java674
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java636
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java677
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java)101
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java242
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java47
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java)9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java)10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java)154
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java216
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java52
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java)53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java)21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java)118
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java)29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java)24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java)33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java234
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java)3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java)2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java326
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java64
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java411
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java97
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java255
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java145
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java)6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java)22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java)14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java)2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java123
64 files changed, 2365 insertions, 3726 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
new file mode 100644
index 00000000000..effa261be3b
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
@@ -0,0 +1,674 @@
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for replacing instances of a pseudofeature for imported ML
+ * ranking models with native Vespa ranking expressions.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ ExpressionNode transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
+ String output = chooseOutput(signature, store.arguments().output());
+ if (signature.skippedOutputs().containsKey(output)) {
+ String message = "Could not import model output '" + output + "'";
+ if (!signature.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + signature.skippedOutputs().get(output);
+ }
+ if (!signature.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", signature.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing signature, or the only signature if none is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
+ if ( ! signatureName.isPresent()) {
+ if (importResult.signatures().size() == 0)
+ throw new IllegalArgumentException("No signatures are available");
+ if (importResult.signatures().size() > 1)
+ throw new IllegalArgumentException("Model has multiple signatures (" +
+ Joiner.on(", ").join(importResult.signatures().keySet()) +
+ "), one must be specified " +
+ "as a second argument to tensorflow()");
+ return importResult.signatures().values().stream().findFirst().get();
+ }
+ else {
+ ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
+ if (signature == null)
+ throw new IllegalArgumentException("Model does not have the specified signature '" +
+ signatureName.get() + "'");
+ return signature;
+ }
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (signature.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
+ if (signature.outputs().size() > 1)
+ throw new IllegalArgumentException(signature + " has multiple outputs (" +
+ Joiner.on(", ").join(signature.outputs().keySet()) +
+ "), one must be specified " +
+ "as a third argument to tensorflow()");
+ return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = signature.outputs().get(outputName.get());
+ if (output == null) {
+ if (signature.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ signature.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ typeMismatchExplanation(constantValue.type(), macroType));
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(ImportedModel.Signature signature) {
+ if (signature.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName + "'. " +
+ typeMismatchExplanation(requiredType, actualType));
+ }
+ }
+
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
+ (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
+ "in query profile types - see the documentation."
+ : "");
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ ModelStore(ApplicationPackage application, FeatureArguments arguments) {
+ this.application = application;
+ this.arguments = arguments;
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File modelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the arguments to the import feature */
+ static abstract class FeatureArguments {
+
+ Path modelPath;
+
+ /** Optional arguments */
+ Optional<String> signature, output;
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 1c41ad8284e..44eeb364603 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -2,58 +2,20 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the onnx(model-path, output)
@@ -63,12 +25,12 @@ import java.util.stream.Collectors;
* @author bratseth
* @author lesters
*/
-public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class OnnxFeatureConverter extends MLImportFeatureConverter {
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, OnnxModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -84,7 +46,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -98,597 +61,24 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFromOnnxModel(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles) {
- OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> onnxImporter.importModel(store.arguments().modelName(),
- store.onnxModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- String output = chooseOutput(model, store.arguments().output());
- if (model.skippedOutputs().containsKey(output)) {
- String message = "Could not import Onnx model output '" + output + "'";
- if (!model.skippedOutputs().get(output).isEmpty()) {
- message += ": " + model.skippedOutputs().get(output);
- }
- if (!model.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", model.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(OnnxModel model, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (model.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
- if (model.outputs().size() > 1)
- throw new IllegalArgumentException("Onnx model has multiple outputs (" +
- Joiner.on(", ").join(model.outputs().keySet()) +
- "), one must be specified " +
- "as a second argument to onnx()");
- return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = model.outputs().get(outputName.get());
- if (output == null) {
- if (model.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- model.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- "The required type of this is " + constantValue.type() +
- ", but the macro returns " + macroType);
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store, RankProfile profile,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(OnnxModel model) {
- if (model.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro produces type " + actualType);
- }
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File onnxModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> output;
-
- public FeatureArguments(Arguments arguments) {
+ static class OnnxFeatureArguments extends FeatureArguments {
+ public OnnxFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
- "the onnx model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
modelPath = Path.fromString(asString(arguments.expressions().get(0)));
output = optionalArgument(1, arguments);
+ signature = Optional.of("default");
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 41da32f64c3..27e1ad51b33 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,59 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -62,12 +22,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -83,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -95,565 +56,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
private ExpressionNode transformFromTensorFlowModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> tensorFlowImporter.importModel(store.arguments().modelName(),
- store.tensorFlowModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import TensorFlow model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
- }
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
- }
-
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- typeMismatchExplanation(constantValue.type(), macroType));
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> tensorFlowImporter.importModel(store.arguments().modelName(),
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
- }
- }
-
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
- return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
- (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
- "in query profile types - see the documentation."
- : "");
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
-
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which (if hasTensorFlowModels is true)
- * contains the source model to use for these arguments
- */
- public File tensorFlowModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
+ static class TensorFlowFeatureArguments extends FeatureArguments {
+ public TensorFlowFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
@@ -661,68 +76,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
signature = optionalArgument(1, arguments);
output = optionalArgument(2, arguments);
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 1c54d12d8b3..d9beab6e2f2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReference() throws ParseException {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
- assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
- }
-
- @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
@@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceSpecifyingOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'add')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- }
-
- @Test
public void testOnnxReferenceMissingMacro() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
@@ -145,7 +129,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -163,8 +147,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
- "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index d288a396732..7228af2b0de 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
"but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
new StoringApplicationPackage(applicationDir));
search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase {
application);
search.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- public static class StoringApplicationPackageFile extends ApplicationFile {
+ static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 721214f9e94..4b49f17f74e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,5 +1,4 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -13,76 +12,61 @@ import java.util.Map;
import java.util.regex.Pattern;
/**
- * The result of importing a TensorFlow model into Vespa.
- * - A set of signatures which are named collections of inputs and outputs.
- * - A set of named constant tensors represented by Variable nodes in TensorFlow.
- * - A list of warning messages.
+ * The result of importing a model (TensorFlow or ONNX) into Vespa.
*
* @author bratseth
*/
-// This object can be built incrementally within this package, but is immutable when observed from outside the package
-public class TensorFlowModel {
+public class ImportedModel {
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+ private static final String defaultSignatureName = "default";
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
/**
- * Creates a TensorFlow model
+ * Creates a new imported model.
*
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
*/
- public TensorFlowModel(String name) {
+ public ImportedModel(String name) {
if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
+ throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
this.name = name;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
- private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
- return signatures.computeIfAbsent(name, Signature::new);
- }
-
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/**
* Returns an immutable map of the small constants of this.
* These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow source.
+ * values given in the TensorFlow or ONNX source.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
/**
* Returns an immutable map of the large constants of this.
- * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
- * and correspond to Variable files stored separately in TensorFlow.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
+ * For TensorFlow this corresponds to Variable files stored separately.
*/
public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
- * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
- * which are not Placeholders or Variables (which instead become respectively arguments and constants).
- * Note that only nodes recursively referenced by a placeholder are added.
+ * Returns an immutable map of the expressions of this - corresponding to graph nodes
+ * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder/input are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -95,9 +79,26 @@ public class TensorFlowModel {
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, Signature::new);
+ }
+
+ /** Convenience method for returning a default signature */
+ Signature defaultSignature() { return signature(defaultSignatureName); }
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
- * and outputs maps to expressions nodes.
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument
+ * ("placeholder") names+types, and outputs maps to expressions nodes.
+ * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
+ * concept of signatures. For now, we handle ONNX models as having a single signature.
*/
public class Signature {
@@ -107,19 +108,14 @@ public class TensorFlowModel {
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
- Signature(String name) {
+ public Signature(String name) {
this.name = name;
}
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
-
public String name() { return name; }
/** Returns the result this is part of */
- TensorFlowModel owner() { return TensorFlowModel.this; }
+ public ImportedModel owner() { return ImportedModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
@@ -127,7 +123,7 @@ public class TensorFlowModel {
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ /** Returns the type of the argument this input references */
public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
@@ -144,12 +140,17 @@ public class TensorFlowModel {
*/
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
- /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ /** Returns the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
@Override
public String toString() { return "signature '" + name + "'"; }
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
new file mode 100644
index 00000000000..a658833b426
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -0,0 +1,242 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
+ * ranking expressions. The general mechanism for import is for the
+ * specific ML platform import implementations to create an
+ * IntermediateGraph. This class offers common code to convert the
+ * IntermediateGraph to Vespa ranking expressions and macros.
+ *
+ * @author lesters
+ */
+public abstract class ModelImporter {
+
+ private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
+
+ /**
+ * The main import function.
+ */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
+ public ImportedModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ /**
+ * Takes an IntermediateGraph and converts it to a ImportedModel containing
+ * the actual Vespa ranking expressions.
+ */
+ static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
+ ImportedModel model = new ImportedModel(graph.name());
+
+ graph.optimize();
+
+ importSignatures(graph, model);
+ importExpressions(graph, model);
+ reportWarnings(graph, model);
+ logVariableTypes(graph);
+
+ return model;
+ }
+
+ private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
+ for (String signatureName : graph.signatures()) {
+ ImportedModel.Signature signature = model.signature(signatureName);
+ for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
+ signature.input(input.getKey(), input.getValue());
+ }
+ for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
+ signature.output(output.getKey(), output.getValue());
+ }
+ }
+ }
+
+ private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Convert intermediate representation to Vespa ranking expressions.
+ */
+ static void importExpressions(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importExpressionInputs(operation, model);
+ importRankingExpression(operation, model);
+ importArgumentExpression(operation, model);
+ importMacroExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.name();
+ if (!model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Imported function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.argument(operation.vespaName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.macro().isPresent()) {
+ TensorFunction function = operation.macro().get();
+ try {
+ model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+
+ /**
+ * Add any import warnings to the signature in the ImportedModel.
+ */
+ private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ reportWarnings(graph.get(outputName), model);
+ }
+ }
+ }
+
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ for (String warning : operation.warnings()) {
+ model.defaultSignature().importWarning(warning);
+ }
+ for (IntermediateOperation input : operation.inputs()) {
+ reportWarnings(input, model);
+ }
+ }
+
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(IntermediateGraph graph) {
+ for (IntermediateOperation operation : graph.operations()) {
+ if ( ! (operation instanceof Constant)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+ log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
new file mode 100644
index 00000000000..d3dd2a1d418
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -0,0 +1,30 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
+import onnx.Onnx;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+
+/**
+ * Converts a ONNX model into a ranking expression and set of constants.
+ *
+ * @author lesters
+ */
+public class OnnxImporter extends ModelImporter {
+
+ @Override
+ public ImportedModel importModel(String modelName, String modelPath) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
new file mode 100644
index 00000000000..ff584559a83
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -0,0 +1,47 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import org.tensorflow.SavedModelBundle;
+
+import java.io.IOException;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorFlowImporter extends ModelImporter {
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a .pbtxt or .pb file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public ImportedModel importModel(String modelName, String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+ return importModel(modelName, model);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ /** Imports a TensorFlow model */
+ ImportedModel importModel(String modelName, SavedModelBundle model) {
+ try {
+ IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
+ return convertIntermediateGraphToModel(graph);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
+ }
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
index c5ac7ace0fc..e1294ec3e01 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
@@ -1,7 +1,8 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -24,7 +25,7 @@ public class VariableConverter {
*/
public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName,
+ return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
bundle),
OrderedTensorType.fromSpec(orderedTypeSpec)));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
index 2524417cee0..38f1d2329e2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import java.util.ArrayDeque;
import java.util.ArrayList;
@@ -47,7 +47,7 @@ public class DimensionRenamer {
/**
* Add a constraint between dimension names.
*/
- public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) {
+ public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
Arc opposite = arc.opposite();
constraints.put(arc, pred);
@@ -175,9 +175,9 @@ public class DimensionRenamer {
private final String from;
private final String to;
- private final OnnxOperation operation;
+ private final IntermediateOperation operation;
- Arc(String from, String to, OnnxOperation operation) {
+ Arc(String from, String to, IntermediateOperation operation) {
this.from = from;
this.to = to;
this.operation = operation;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
new file mode 100644
index 00000000000..39a8b211d09
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Holds an intermediate representation of an imported ONNX or TensorFlow
+ * graph. After this intermediate representation is constructed, it is used to
+ * simplify and optimize the computational graph and then converted into the
+ * final ImportedModel that holds the Vespa ranking expressions for the model.
+ *
+ * @author lesters
+ */
+public class IntermediateGraph {
+
+ private final String modelName;
+ private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, GraphSignature> signatures = new HashMap<>();
+
+ private static class GraphSignature {
+ final Map<String, String> inputs = new HashMap<>();
+ final Map<String, String> outputs = new HashMap<>();
+ }
+
+ public IntermediateGraph(String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String name() {
+ return modelName;
+ }
+
+ public IntermediateOperation put(String key, IntermediateOperation operation) {
+ return index.put(key, operation);
+ }
+
+ public IntermediateOperation get(String key) {
+ return index.get(key);
+ }
+
+ public Set<String> signatures() {
+ return signatures.keySet();
+ }
+
+ public Map<String, String> inputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
+ }
+
+ public Map<String, String> outputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
+ }
+
+ public String defaultSignature() {
+ return "default";
+ }
+
+ public boolean alreadyImported(String key) {
+ return index.containsKey(key);
+ }
+
+ public Collection<IntermediateOperation> operations() {
+ return index.values();
+ }
+
+ public void optimize() {
+ renameDimensions();
+ }
+
+ /**
+ * Find dimension names to avoid excessive renaming while evaluating the model.
+ */
+ private void renameDimensions() {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
index 812e9b8d678..209d73a9f38 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.tensor.TensorType;
-import onnx.Onnx;
+import com.yahoo.tensor.TensorTypeParser;
import java.util.ArrayList;
import java.util.Collections;
@@ -13,9 +13,9 @@ import java.util.stream.Collectors;
/**
* A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. ONNX tensors have an explicit ordering of their dimensions.
+ * names. Imported tensors have an explicit ordering of their dimensions.
* During import, we need to track the Vespa dimension that matches the
- * corresponding ONNX dimension as the ordering can change after
+ * corresponding imported dimension as the ordering can change after
* dimension renaming. That is the purpose of this class.
*
* @author lesters
@@ -25,14 +25,14 @@ public class OrderedTensorType {
private final TensorType type;
private final List<TensorType.Dimension> dimensions;
- private final long[] innerSizesOnnx;
+ private final long[] innerSizesOriginal;
private final long[] innerSizesVespa;
private final int[] dimensionMap;
private OrderedTensorType(List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesOnnx = new long[dimensions.size()];
+ this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
}
@@ -54,10 +54,10 @@ public class OrderedTensorType {
if (numDimensions == 0) {
return null;
}
- innerSizesOnnx[numDimensions - 1] = 1;
+ innerSizesOriginal[numDimensions - 1] = 1;
innerSizesVespa[numDimensions - 1] = 1;
for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1];
+ innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
}
int[] mapping = new int[numDimensions];
@@ -74,11 +74,15 @@ public class OrderedTensorType {
return mapping;
}
+ public int dimensionMap(int originalIndex) {
+ return dimensionMap[originalIndex];
+ }
+
/**
- * When dimension ordering between Vespa and Onnx differs, i.e.
+ * When dimension ordering between Vespa and imported differs, i.e.
* after dimension renaming, use the dimension map to read in values
* so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from Onnx.
+ * Used when importing tensors.
*/
public int toDirectIndex(int index) {
if (dimensions.size() == 0) {
@@ -90,9 +94,9 @@ public class OrderedTensorType {
int directIndex = 0;
long rest = index;
for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesOnnx[i];
+ long address = rest / innerSizesOriginal[i];
directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesOnnx[i];
+ rest %= innerSizesOriginal[i];
}
return directIndex;
}
@@ -116,22 +120,6 @@ public class OrderedTensorType {
return true;
}
- public void verifyType(Onnx.TypeProto typeProto) {
- Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
- }
- for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) {
- int vespaIndex = dimensionMap[onnxIndex];
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
- }
- }
- }
- }
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
@@ -151,18 +139,13 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- Builder builder = new Builder(shape);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
+ public OrderedTensorType rename(String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ Optional<Long> dimSize = dimensions.get(i).size();
+ if (dimSize.isPresent() && dimSize.get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -170,13 +153,13 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
- Builder builder = new Builder();
- for (int i = 0; i < dims.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
- if (dimSize >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -184,13 +167,46 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType standardType(OrderedTensorType type) {
- Builder builder = new Builder();
- for (int i = 0; i < type.dimensions().size(); ++ i) {
- TensorType.Dimension dim = type.dimensions().get(i);
- String dimensionName = "d" + i;
- if (dim.size().isPresent() && dim.size().get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims) {
+ return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -200,45 +216,13 @@ public class OrderedTensorType {
public static class Builder {
- private final Onnx.TensorShapeProto shape;
private final List<TensorType.Dimension> dimensions;
- public Builder(Onnx.TensorShapeProto shape) {
- this.shape = shape;
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
public Builder() {
- this.shape = null;
this.dimensions = new ArrayList<>();
}
public Builder add(TensorType.Dimension vespaDimension) {
- if (shape != null) {
- int index = dimensions.size();
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index);
- long size = onnxDimension.getDimValue();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
- "dimension types");
- }
- }
- }
this.dimensions.add(vespaDimension);
return this;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
new file mode 100644
index 00000000000..3fe92440cae
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -0,0 +1,216 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import onnx.Onnx;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+
+ switch (node.getOpType().toLowerCase()) {
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
+ case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
+ Onnx.GraphProto onnxGraph = model.getGraph();
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importOperations(onnxGraph, intermediateGraph);
+ verifyOutputTypes(onnxGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
+ importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
+ }
+ }
+
+ private static IntermediateOperation importOperation(String name,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ if (intermediateGraph.alreadyImported(name)) {
+ return intermediateGraph.get(name);
+ }
+ IntermediateOperation operation;
+ if (isArgumentTensor(name, onnxGraph)) {
+ Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
+ if (valueInfoProto == null)
+ throw new IllegalArgumentException("Could not find argument tensor: " + name);
+ OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
+ operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
+
+ intermediateGraph.inputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+
+ } else if (isConstantTensor(name, onnxGraph)) {
+ Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ operation = new Constant(intermediateGraph.name(), name, defaultType);
+ operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+
+ } else {
+ Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph);
+
+ if (isOutputNode(name, onnxGraph)) {
+ intermediateGraph.outputs(intermediateGraph.defaultSignature())
+ .put(IntermediateOperation.namePartOf(name), operation.vespaName());
+ }
+ }
+ intermediateGraph.put(operation.vespaName(), operation);
+
+ return operation;
+ }
+
+ private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor == null;
+ }
+
+ private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor != null;
+ }
+
+ private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
+ if (tensorProto.getName().equals(name)) {
+ return tensorProto;
+ }
+ }
+ return null;
+ }
+
+ private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
+ return getOutputNode(name, graph) != null;
+ }
+
+ private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node,
+ Onnx.GraphProto onnxGraph,
+ IntermediateGraph intermediateGraph) {
+ return node.getInputList().stream()
+ .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph))
+ .collect(Collectors.toList());
+ }
+
+ private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
+ for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(onnxNode.getType(), type);
+ }
+ }
+
+ private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
+ boolean hasPortNumber = nodeName.contains(":");
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (hasPortNumber) {
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return node;
+ }
+ }
+ } else if (node.getName().equals(nodeName)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
index 2912db03b5f..18856d4a25f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
import com.google.protobuf.ByteString;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import onnx.Onnx;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
-import java.util.List;
/**
* Converts Onnx tensors into Vespa tensors.
@@ -29,7 +28,6 @@ public class TensorConverter {
return builder.build();
}
- /* todo: support more types */
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
new file mode 100644
index 00000000000..715c55d8323
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
@@ -0,0 +1,52 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+/**
+ * Converts and verifies ONNX tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
+ Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
+ }
+ for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
+ int vespaIndex = type.dimensionMap(onnxIndex);
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ if (onnxDimension.getDimValue() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 1619c11427a..7fc2aae87d1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,28 +1,29 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
+import java.util.Collections;
import java.util.List;
-public class Placeholder extends TensorFlowOperation {
+public class Argument extends IntermediateOperation {
private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
- public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- standardNamingType = OrderedTensorType.fromTensorFlowType(node);
+ public Argument(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
+ this.type = type.rename(vespaName() + "_");
+ standardNamingType = OrderedTensorType.standardType(type);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 4f5d61d75f9..1b8c62fe0e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,38 +1,37 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class ConcatV2 extends TensorFlowOperation {
+public class ConcatV2 extends IntermediateOperation {
private String concatDimensionName;
- public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return null;
}
- TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a constant.");
}
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"concat dimension must be a scalar.");
}
@@ -44,7 +43,7 @@ public class ConcatV2 extends TensorFlowOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"inputs must have save rank.");
}
for (int j = 0; j < aType.rank(); ++j) {
@@ -53,13 +52,13 @@ public class ConcatV2 extends TensorFlowOperation {
if (j == concatDim) {
concatDimSize += dimSizeB;
} else if (dimSizeA != dimSizeB) {
- throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
"input dimension " + j + " differs in input tensors.");
}
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDim) {
@@ -75,7 +74,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
return null;
}
TensorFunction result = inputs.get(0).function().get();
@@ -88,7 +87,7 @@ public class ConcatV2 extends TensorFlowOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
return;
}
OrderedTensorType a = inputs.get(0).type().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
index 718e2a4b3c2..3c0f8569c47 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
@@ -1,36 +1,38 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Const extends TensorFlowOperation {
+public class Const extends IntermediateOperation {
- public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final AttributeMap attributeMap;
+
+ public Const(String modelName,
+ String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ OrderedTensorType type) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.type = type.rename(vespaName() + "_");
setConstantValue(value());
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ return type;
}
@Override
@@ -55,7 +57,7 @@ public class Const extends TensorFlowOperation {
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
@@ -77,24 +79,11 @@ public class Const extends TensorFlowOperation {
}
private Value value() {
- if ( ! node.getAttrMap().containsKey("value")) {
- throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
- "const has missing 'value' attribute");
- }
- AttrValue attrValue = node.getAttrMap().get("value");
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return new BooleanValue(attrValue.getB());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return new DoubleValue(attrValue.getI());
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return new DoubleValue(attrValue.getF());
+ Optional<Value> value = attributeMap.get("value", type);
+ if ( ! value.isPresent()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "const has missing or non-recognized 'value' attribute");
}
- throw new IllegalArgumentException("Requesting value of constant in " +
- node.getName() + " but type is not recognized.");
+ return value.get();
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
index 13043a61a8e..5e4abeaa234 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
@@ -1,38 +1,34 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.Collections;
import java.util.Optional;
-public class Constant extends OnnxOperation {
+public class Constant extends IntermediateOperation {
- final String modelName;
- final Onnx.TensorProto tensorProto;
+ private final String modelName;
- public Constant(String modelName, Onnx.TensorProto tensorProto) {
- super(null, Collections.emptyList());
+ public Constant(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
this.modelName = modelName;
- this.tensorProto = tensorProto;
+ this.type = type.rename(vespaName() + "_");
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + vespaName(tensorProto.getName());
+ return modelName + "_" + vespaName(name);
}
@Override
protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_");
+ return type;
}
@Override
@@ -40,9 +36,14 @@ public class Constant extends OnnxOperation {
return null; // will be added by function() since this is constant.
}
+ /**
+ * Constant values are sent in via the constantValueFunction, as the
+ * dimension names and thus the data layout depends on the dimension
+ * renaming which happens after the conversion to intermediate graph.
+ */
@Override
public Optional<Value> getConstantValue() {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+ return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
index 2d0f4c7042b..742ed8b89ab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -12,18 +12,17 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
-public class ExpandDims extends TensorFlowOperation {
+public class ExpandDims extends IntermediateOperation {
private List<String> expandDimensions;
- public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -32,14 +31,14 @@ public class ExpandDims extends TensorFlowOperation {
return null;
}
- TensorFlowOperation axisOperation = inputs().get(1);
+ IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
"axis argument must be a scalar.");
}
@@ -49,7 +48,7 @@ public class ExpandDims extends TensorFlowOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
index 1408e7e04f0..d29bd4b7a9e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
@@ -1,22 +1,21 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Identity extends TensorFlowOperation {
+public class Identity extends IntermediateOperation {
- public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName() + "_" + super.vespaName();
+ return modelName + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 3687bba8b85..43de29cedd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
@@ -20,43 +19,40 @@ import java.util.Optional;
import java.util.function.Function;
/**
- * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
+ * Wraps an imported operation node and produces the respective Vespa tensor
+ * operation. During import, a graph of these operations are constructed. Then,
+ * the types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper Vespa
+ * expressions can be extracted.
*
* @author lesters
*/
-public abstract class TensorFlowOperation {
-
- protected final static String MACRO_PREFIX = "tf_macro_";
+public abstract class IntermediateOperation {
- private final String modelName;
+ private final static String MACRO_PREFIX = "imported_ml_macro_";
- protected final NodeDef node;
- protected final int port;
- protected final List<TensorFlowOperation> inputs;
- protected final List<TensorFlowOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
+ protected final String name;
+ protected final String modelName;
+ protected final List<IntermediateOperation> inputs;
+ protected final List<IntermediateOperation> outputs = new ArrayList<>();
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction macro = null;
+ private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
- private List<TensorFlowOperation> controlInputs = Collections.emptyList();
+ private List<IntermediateOperation> controlInputs = Collections.emptyList();
- TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ protected Function<OrderedTensorType, Value> constantValueFunction = null;
+
+ IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
+ this.name = name;
this.modelName = modelName;
- this.node = node;
- this.port = port;
this.inputs = Collections.unmodifiableList(inputs);
this.inputs.forEach(i -> i.outputs.add(this));
}
- protected String modelName() { return modelName; }
-
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
@@ -65,9 +61,6 @@ public abstract class TensorFlowOperation {
if (type == null) {
type = lazyGetType();
}
- if (type != null) {
- type.verifyType(node);
- }
return Optional.ofNullable(type);
}
@@ -87,14 +80,14 @@ public abstract class TensorFlowOperation {
return Optional.ofNullable(function);
}
- /** Return TensorFlow node */
- public NodeDef node() { return node; }
+ /** Returns original name of this operation node */
+ public String name() { return name; }
/** Return unmodifiable list of inputs */
- public List<TensorFlowOperation> inputs() { return inputs; }
+ public List<IntermediateOperation> inputs() { return inputs; }
/** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
+ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a Vespa ranking expression that should be added as a macro */
public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
@@ -109,22 +102,34 @@ public abstract class TensorFlowOperation {
public boolean isInput() { return false; }
/** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
+ public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
/** Sets the constant value */
public void setConstantValue(Value value) { constantValue = value; }
/** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+ public Optional<Value> getConstantValue() {
+ if (constantValue != null) {
+ return Optional.of(constantValue);
+ }
+ if (constantValueFunction != null) {
+ return Optional.of(constantValueFunction.apply(type));
+ }
+ return Optional.empty();
+ }
+
+ /** Set the constant value function */
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
/** Sets the external control inputs */
- public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
+ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
/** Retrieve the control inputs for this operation */
- public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+ public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+ public String vespaName() { return vespaName(name); }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the valid Vespa name of this node if it is a macro */
public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; }
@@ -135,23 +140,48 @@ public abstract class TensorFlowOperation {
/** Set an input warning */
public void warning(String warning) { importWarnings.add(warning); }
- boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
- if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
- return false;
- }
+ boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
+ "for '" + name + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}
boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::type);
+ return verifyInputs(expected, IntermediateOperation::type);
}
boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, TensorFlowOperation::function);
+ return verifyInputs(expected, IntermediateOperation::function);
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+ /**
+ * An interface mapping operation attributes to Vespa Values.
+ * Adapter for differences in ONNX/TensorFlow.
+ */
+ public interface AttributeMap {
+ Optional<Value> get(String key);
+ Optional<Value> get(String key, OrderedTensorType type);
+ Optional<List<Value>> getList(String key);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index fe2004a528d..8413ed74118 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,24 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-public class Join extends OnnxOperation {
+public class Join extends IntermediateOperation {
private final DoubleBinaryOperator operator;
- public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) {
- super(node, inputs);
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
@@ -61,8 +59,8 @@ public class Join extends OnnxOperation {
return null;
}
- OnnxOperation a = largestInput();
- OnnxOperation b = smallestInput();
+ IntermediateOperation a = largestInput();
+ IntermediateOperation b = smallestInput();
List<String> aDimensionsToReduce = new ArrayList<>();
List<String> bDimensionsToReduce = new ArrayList<>();
@@ -107,13 +105,13 @@ public class Join extends OnnxOperation {
}
}
- private OnnxOperation largestInput() {
+ private IntermediateOperation largestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
- private OnnxOperation smallestInput() {
+ private IntermediateOperation smallestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
index c015f5ecba8..f54ae83052f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;
-public class Map extends TensorFlowOperation {
+public class Map extends IntermediateOperation {
private final DoubleUnaryOperator operator;
- public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
- super(modelName, node, inputs, port);
+ public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
index 1b388e2ae89..52e223f9518 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
@@ -1,21 +1,18 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-public class MatMul extends OnnxOperation {
+public class MatMul extends IntermediateOperation {
- public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, inputs);
+ public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 3eba872c6a0..95a77c07590 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -13,20 +14,20 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
-public class Mean extends TensorFlowOperation {
+public class Mean extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> reduceDimensions;
- public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -34,9 +35,9 @@ public class Mean extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation reductionIndices = inputs.get(1);
+ IntermediateOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Mean in " + name + ": " +
"reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
@@ -54,7 +55,7 @@ public class Mean extends TensorFlowOperation {
return reducedType(inputType, shouldKeepDimensions());
}
- // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
@@ -93,12 +94,12 @@ public class Mean extends TensorFlowOperation {
}
private boolean shouldKeepDimensions() {
- AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
+ Optional<Value> keepDims = attributeMap.get("keep_dims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if (!reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
index 4c95e67e184..9d9eca47b1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
@@ -1,21 +1,20 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Merge extends TensorFlowOperation {
+public class Merge extends IntermediateOperation {
- public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
protected OrderedTensorType lazyGetType() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.type().isPresent()) {
return operation.type().get();
}
@@ -25,7 +24,7 @@ public class Merge extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- for (TensorFlowOperation operation : inputs) {
+ for (IntermediateOperation operation : inputs) {
if (operation.function().isPresent()) {
return operation.function().get();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
new file mode 100644
index 00000000000..19ba146492c
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
@@ -0,0 +1,26 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends IntermediateOperation {
+
+ public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
index 65ce7f00e34..9299ae9be12 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
@@ -1,17 +1,16 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class PlaceholderWithDefault extends TensorFlowOperation {
+public class PlaceholderWithDefault extends IntermediateOperation {
- public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e7d90e5fc1f..e91c2305f7d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -1,10 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -19,19 +18,18 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Reshape extends TensorFlowOperation {
+public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,15 +37,15 @@ public class Reshape extends TensorFlowOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- TensorFlowOperation newShape = inputs.get(1);
+ IntermediateOperation newShape = inputs.get(1);
if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Reshape in " + name + ": " +
"shape input must be a constant.");
}
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -124,7 +122,7 @@ public class Reshape extends TensorFlowOperation {
operators.add(0, ArithmeticOperator.MULTIPLY);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
- size *= TensorConverter.dimensionSize(dimension);
+ size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
operators.add(0, ArithmeticOperator.PLUS);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
index 5fdcb5a695f..927a4a368f9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
@@ -1,24 +1,23 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
-import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-public class Select extends TensorFlowOperation {
+public class Select extends IntermediateOperation {
- public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
@@ -39,7 +38,7 @@ public class Select extends TensorFlowOperation {
if (!allInputFunctionsPresent(3)) {
return null;
}
- TensorFlowOperation conditionOperation = inputs().get(0);
+ IntermediateOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
index af49d2c108b..da566909adc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
@@ -1,20 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Shape extends TensorFlowOperation {
+public class Shape extends IntermediateOperation {
- public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
createConstantValue();
}
@@ -24,7 +23,7 @@ public class Shape extends TensorFlowOperation {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder(node)
+ return new OrderedTensorType.Builder()
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
index 17ce9e8b7cb..c750c47e27e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
@@ -1,26 +1,26 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
-public class Squeeze extends TensorFlowOperation {
+public class Squeeze extends IntermediateOperation {
+ private final AttributeMap attributeMap;
private List<String> squeezeDimensions;
- public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -31,20 +31,21 @@ public class Squeeze extends TensorFlowOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
- AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
+ Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
+ if ( ! squeezeDimsAttr.isPresent()) {
squeezeDimensions = inputType.type().dimensions().stream().
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
} else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
- map(i -> inputType.type().dimensions().get(i.intValue())).
- filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(i -> inputType.type().dimensions().get(i)).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
}
+
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
@@ -72,7 +73,7 @@ public class Squeeze extends TensorFlowOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
index de4d8862fd6..0171d1ea171 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
@@ -1,17 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Switch extends TensorFlowOperation {
+public class Switch extends IntermediateOperation {
- public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
+ private final int port;
+
+ public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
+ super(modelName, nodeName, inputs);
+ this.port = port;
}
@Override
@@ -21,7 +23,7 @@ public class Switch extends TensorFlowOperation {
}
Optional<OrderedTensorType> predicate = inputs.get(1).type();
if (predicate.get().type().rank() != 0) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a scalar");
}
return inputs.get(0).type().orElse(null);
@@ -29,13 +31,13 @@ public class Switch extends TensorFlowOperation {
@Override
protected TensorFunction lazyGetFunction() {
- TensorFlowOperation predicateOperation = inputs().get(1);
+ IntermediateOperation predicateOperation = inputs().get(1);
if (!predicateOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"predicate must be a constant");
}
if (port < 0 || port > 1) {
- throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ throw new IllegalArgumentException("Switch in " + name + ": " +
"choice should be boolean");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
new file mode 100644
index 00000000000..a815cbc3944
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
@@ -0,0 +1,85 @@
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Converts TensorFlow node attributes to Vespa attribute values.
+ *
+ * @author lesters
+ */
+public class AttributeConverter implements IntermediateOperation.AttributeMap {
+
+ private final Map<String, AttrValue> attributeMap;
+
+ public AttributeConverter(NodeDef node) {
+ attributeMap = node.getAttrMap();
+ }
+
+ public static AttributeConverter convert(NodeDef node) {
+ return new AttributeConverter(node);
+ }
+
+ @Override
+ public Optional<Value> get(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.empty(); // requires type
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return Optional.of(new BooleanValue(attrValue.getB()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return Optional.of(new DoubleValue(attrValue.getI()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return Optional.of(new DoubleValue(attrValue.getF()));
+ }
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<Value> get(String key, OrderedTensorType type) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
+ }
+ }
+ return get(key);
+ }
+
+ @Override
+ public Optional<List<Value>> getList(String key) {
+ if (attributeMap.containsKey(key)) {
+ AttrValue attrValue = attributeMap.get(key);
+ if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
+ AttrValue.ListValue listValue = attrValue.getList();
+ if ( ! listValue.getBList().isEmpty()) {
+ return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getIList().isEmpty()) {
+ return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ if ( ! listValue.getFList().isEmpty()) {
+ return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ }
+ // add the rest
+ }
+ }
+ return Optional.empty();
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
new file mode 100644
index 00000000000..e1b292f9e61
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -0,0 +1,234 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
+public class GraphImporter {
+
+ public static IntermediateOperation mapOperation(NodeDef node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
+ String nodeName = node.getName();
+ String modelName = graph.name();
+ int nodePort = IntermediateOperation.indexPartOf(nodeName);
+ OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
+ AttributeConverter attributes = AttributeConverter.convert(node);
+
+ switch (node.getOp().toLowerCase()) {
+ // array ops
+ case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
+ case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
+ case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "placeholder": return new Argument(modelName, nodeName, nodeType);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
+
+ // control flow
+ case "merge": return new Merge(modelName, nodeName, inputs);
+ case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
+
+ // math ops
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, nodeName, inputs);
+ case "where3": return new Select(modelName, nodeName, inputs);
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+
+ // nn ops
+ case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+
+ // state ops
+ case "variable": return new Constant(modelName, nodeName, nodeType);
+ case "variablev2": return new Constant(modelName, nodeName, nodeType);
+
+ // evaluation no-ops
+ case "stopgradient":return new Identity(modelName, nodeName, inputs);
+ case "noop": return new NoOp(modelName, nodeName, inputs);
+
+ }
+
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
+ op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ return op;
+ }
+
+ public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
+ MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
+
+ IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ importSignatures(tfGraph, intermediateGraph);
+ importOperations(tfGraph, intermediateGraph, bundle);
+ verifyOutputTypes(tfGraph, intermediateGraph);
+
+ return intermediateGraph;
+ }
+
+ private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
+ String signatureName = signatureEntry.getKey();
+ java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ String nodeName = input.getValue().getName();
+ intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
+ String outputName = output.getKey();
+ String nodeName = output.getValue().getName();
+ intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
+ }
+ }
+ }
+
+ private static void importOperations(MetaGraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
+ }
+ }
+ }
+
+ private static IntermediateOperation importOperation(String nodeName,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ if (intermediateGraph.alreadyImported(nodeName)) {
+ return intermediateGraph.get(nodeName);
+ }
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
+ List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
+ IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
+ intermediateGraph.put(nodeName, operation);
+
+ List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
+ }
+
+ if (operation.isConstant()) {
+ operation.setConstantValueFunction(
+ type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
+ }
+
+ return operation;
+ }
+
+ private static List<IntermediateOperation> importOperationInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static List<IntermediateOperation> importControlInputs(NodeDef node,
+ GraphDef tfGraph,
+ IntermediateGraph intermediateGraph,
+ SavedModelBundle bundle) {
+ return node.getInputList().stream()
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
+ .collect(Collectors.toList());
+ }
+
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
+ }
+
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
+ for (NodeDef node : tfGraph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ Session.Runner fetched = bundle.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name +
+ ", but got " + importedTensors.size());
+ return importedTensors.get(0);
+ }
+
+ private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
+ for (String signatureName : intermediateGraph.signatures()) {
+ for (String outputName : intermediateGraph.outputs(signatureName).values()) {
+ IntermediateOperation operation = intermediateGraph.get(outputName);
+ NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
+ OrderedTensorType type = operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
+ TypeConverter.verifyType(node, type);
+ }
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
index 3f55e622fdf..d2d0acfc964 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
new file mode 100644
index 00000000000..67ad1edc312
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
@@ -0,0 +1,72 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.List;
+
+/**
+ * Converts and verifies TensorFlow tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
+public class TypeConverter {
+
+ public static void verifyType(NodeDef node, OrderedTensorType type) {
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
+ int vespaIndex = type.dimensionMap(tensorFlowIndex);
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
index 5cff8b03d40..1530754cc43 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
@@ -3,6 +3,6 @@
* ONNX integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
deleted file mode 100644
index fa1f929cc80..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
+++ /dev/null
@@ -1,326 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a ONNX model into a ranking expression and set of constants.
- *
- * @author lesters
- */
-public class OnnxImporter {
-
- private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
-
- public OnnxModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- public OnnxModel importModel(String modelName, String modelPath) {
- try (FileInputStream inputStream = new FileInputStream(modelPath)) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- return importModel(modelName, model);
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
- }
- }
-
- public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
- return importGraph(modelName, model.getGraph());
- }
-
- private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
- OnnxModel model = new OnnxModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importNodes(graph, model, index);
- verifyOutputTypes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index);
-
- reportWarnings(model, index);
-
- return model;
- }
-
- private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- importNode(valueInfo.getName(), graph, model, index);
- }
- }
-
- private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- if (index.alreadyImported(name)) {
- return index.get(name);
- }
- OnnxOperation operation;
- if (isArgumentTensor(name, graph)) {
- operation = new Argument(getArgumentTensor(name, graph));
- model.input(OnnxOperation.namePartOf(name), operation.vespaName());
- } else if (isConstantTensor(name, graph)) {
- operation = new Constant(model.name(), getConstantTensor(name, graph));
- } else {
- Onnx.NodeProto node = getNodeFromGraph(name, graph);
- List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
- operation = OperationMapper.get(node, inputs);
- if (isOutputNode(name, graph)) {
- model.output(OnnxOperation.namePartOf(name), operation.vespaName());
- }
- }
- index.put(operation.vespaName(), operation);
-
- return operation;
- }
-
- private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor == null;
- }
-
- private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
- }
-
- private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
- if (tensorProto.getName().equals(name)) {
- return tensorProto;
- }
- }
- return null;
- }
-
- private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
- return getOutputNode(name, graph) != null;
- }
-
- private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
- if (nodeName.equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
- Onnx.GraphProto graph,
- OnnxModel model,
- OperationIndex index) {
- return node.getInputList().stream()
- .map(nodeName -> importNode(nodeName, graph, model, index))
- .collect(Collectors.toList());
- }
-
- private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- OnnxOperation operation = index.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
- operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
- .verifyType(onnxNode.getType());
- }
- }
-
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(OnnxModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (String output : model.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- renamer.solve();
- for (String output : model.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
-
- private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(OnnxModel model, OperationIndex index) {
- for (String outputName : model.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model);
- if (!function.isPresent()) {
- model.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- model.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(operation, model);
- }
- importInputExpressions(operation, model);
- importRankingExpression(operation, model);
- importArgumentExpression(operation, model);
-
- return operation.function();
- }
-
- private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
- }
-
- private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Value value = operation.getConstantValue().orElseThrow(() ->
- new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
- "is constant but does not have a value."));
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
-
- Tensor tensor = value.asTensor();
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.function().isPresent()) {
- String name = operation.vespaName();
- if (!model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
-
- if (model.outputs().containsKey(name)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced from the output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
- if (operation.isInput()) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(OnnxModel model, OperationIndex index) {
- for (String output : model.outputs().values()) {
- reportWarnings(model, index.get(output));
- }
- }
-
- private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
- for (String warning : operation.warnings()) {
- model.importWarning(warning);
- }
- for (OnnxOperation input : operation.inputs()) {
- reportWarnings(model, input);
- }
- }
-
- private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
- }
- } else if (node.getName().equals(nodeName)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
- }
-
- private static class OperationIndex {
- private final Map<String, OnnxOperation> index = new HashMap<>();
- public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); }
- public OnnxOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<OnnxOperation> operations() { return index.values(); }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
deleted file mode 100644
index bd53afefc3f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
+++ /dev/null
@@ -1,112 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * The result of importing an ONNX model into Vespa.
- *
- * @author bratseth
- * @author lesters
- */
-public class OnnxModel {
-
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
-
- private final String name;
-
- public OnnxModel(String name) {
- if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
- this.name = name;
- }
-
- /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
- public String name() { return name; }
-
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
- private final Map<String, String> skippedOutputs = new HashMap<>();
- private final List<String> importWarnings = new ArrayList<>();
-
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
-
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
- /**
- * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * to argument (Placeholder) name in the owner of this
- */
- public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
-
- /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
- public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
-
- /** Returns an immutable list of the expression names of this */
- public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
-
- /**
- * Returns an immutable list of the outputs of this which could not be imported,
- * with a string detailing the reason for each
- */
- public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
-
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
- public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
- public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
-
- /** Returns an immutable map of the arguments (inputs) of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
-
- /**
- * Returns an immutable map of the small constants of this.
- */
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
-
- /**
- * Returns an immutable map of the large constants of this.
- */
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
-
- /**
- * Returns an immutable map of the expressions of this - corresponding to ONNX nodes
- * which are not inputs or constants.
- */
- public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
-
- /** Returns an immutable map of macros that are part of this model */
- public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
-
- /** Returns an immutable map of the macros that must be provided by the environment running this model */
- public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
deleted file mode 100644
index 12090145d3a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import onnx.Onnx;
-
-import java.util.List;
-
-public class OperationMapper {
-
- public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- switch (node.getOpType().toLowerCase()) {
- case "add": return new Join(node, inputs, ScalarFunctions.add());
- case "matmul": return new MatMul(node, inputs);
- }
-
- OnnxOperation op = new NoOp(node, inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
- return op;
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
deleted file mode 100644
index a8d8d63daf4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class Argument extends OnnxOperation {
-
- private Onnx.ValueInfoProto valueInfo;
- private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
-
- public Argument(Onnx.ValueInfoProto valueInfoProto) {
- super(null, Collections.emptyList());
- valueInfo = valueInfoProto;
- standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType());
- }
-
- @Override
- public String vespaName() {
- return vespaName(valueInfo.getName());
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
- List<String> renameFrom = standardNamingType.dimensionNames();
- List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
- }
- return output;
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isInput() {
- return true;
- }
-
- @Override
- public boolean isConstant() {
- return false;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
deleted file mode 100644
index b1136a0ce0a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends OnnxOperation {
-
- public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- super(node, Collections.emptyList()); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
deleted file mode 100644
index 30f7b4f4711..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.functions.TensorFunction;
-import onnx.Onnx;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.Function;
-
-/**
- * Wraps an ONNX node and produces the respective Vespa tensor operation.
- * During import, a graph of these operations are constructed. Then, the
- * types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper
- * Vespa expressions can be extracted.
- *
- * @author lesters
- */
-public abstract class OnnxOperation {
-
- protected final Onnx.NodeProto node; // can be null for onnx inputs and constants
- protected final List<OnnxOperation> inputs;
- protected final List<OnnxOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
-
- protected OrderedTensorType type;
- protected TensorFunction function;
- protected Value constantValue = null;
-
- OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
- this.node = node;
- this.inputs = Collections.unmodifiableList(inputs);
- this.inputs.forEach(i -> i.outputs.add(this));
- }
-
- protected abstract OrderedTensorType lazyGetType();
- protected abstract TensorFunction lazyGetFunction();
-
- /** Returns the Vespa tensor type of this operation if it exists */
- public Optional<OrderedTensorType> type() {
- if (type == null) {
- type = lazyGetType();
- }
- return Optional.ofNullable(type);
- }
-
- /** Returns the Vespa tensor function implementing all operations from this node with inputs */
- public Optional<TensorFunction> function() {
- if (function == null) {
- if (isConstant()) {
- ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
- function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
- } else {
- function = lazyGetFunction();
- }
- }
- return Optional.ofNullable(function);
- }
-
- /** Return Onnx node */
- public Onnx.NodeProto node() { return node; }
-
- /** Return unmodifiable list of inputs */
- public List<OnnxOperation> inputs() { return inputs; }
-
- /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); }
-
- /** Add dimension name constraints for this operation */
- public void addDimensionNameConstraints(DimensionRenamer renamer) { }
-
- /** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
-
- /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
- public boolean isInput() { return false; }
-
- /** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); }
-
- /** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
-
- /** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(node.getName()); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
-
- /** Retrieve the list of warnings produced during its lifetime */
- public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Set an input warning */
- public void warning(String warning) { importWarnings.add(warning); }
-
- boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
- if (inputs.size() != expected) {
- throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
- }
- return inputs.stream().map(func).allMatch(Optional::isPresent);
- }
-
- boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::type);
- }
-
- boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, OnnxOperation::function);
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- public static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output index part. Indexes are used for nodes with
- * multiple outputs.
- */
- public static int indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
deleted file mode 100644
index e3c72830095..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ /dev/null
@@ -1,411 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-import java.util.stream.Collectors;
-
-/**
- * Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorFlowImporter {
-
- private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
-
- /**
- * Imports a saved TensorFlow model from a directory.
- * The model should be saved as a .pbtxt or .pb file.
- * The name of the model is taken as the db/pbtxt file name (not including the file ending).
- *
- * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
- * @param modelDir the directory containing the TensorFlow model files to import
- */
- public TensorFlowModel importModel(String modelName, String modelDir) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
-
- return importModel(modelName, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- public TensorFlowModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- /** Imports a TensorFlow model */
- public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
- try {
- return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
- /**
- * Imports the TensorFlow graph by first importing the tensor types, then
- * finding a suitable set of dimensions names for each
- * placeholder/constant/variable, then importing the expressions.
- */
- private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
- TensorFlowModel model = new TensorFlowModel(modelName);
- OperationIndex index = new OperationIndex();
-
- importSignatures(graph, model);
- importNodes(graph, model, index);
- findDimensionNames(model, index);
- importExpressions(model, index, bundle);
-
- reportWarnings(model, index);
- logVariableTypes(index);
-
- return model;
- }
-
- private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
- for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- TensorFlowModel.Signature signature = model.signature(signatureName);
-
- Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- signature.input(inputName, namePartOf(input.getValue().getName()));
- }
-
- Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- signature.output(outputName, namePartOf(output.getValue().getName()));
- }
- }
- }
-
- private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String inputName : signature.inputs().values()) {
- if (inputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- if (outputName.equals(operation.node().getName())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- importNode(model.name(), outputName, graph.getGraphDef(), index);
- }
- }
- }
-
- private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
- if (index.alreadyImported(nodeName)) {
- return index.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
- List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
- TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
- index.put(nodeName, operation);
-
- List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- return operation;
- }
-
- private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importNode(modelName, nodeName, graph, index))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- /** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
- DimensionRenamer renamer = new DimensionRenamer();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- }
- renamer.solve();
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
- }
-
- private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
- private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
- if (!function.isPresent()) {
- signature.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(model, operation, bundle);
- }
-
- importInputExpressions(operation, model, bundle);
- importRankingExpression(model, operation);
- importInputExpression(model, operation);
- importMacroExpression(model, operation);
-
- return operation.function();
- }
-
- private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model,
- SavedModelBundle bundle) {
- operation.inputs().forEach(input -> importExpression(input, model, bundle));
- }
-
- private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.macro().isPresent()) {
- TensorFunction function = operation.macro().get();
- try {
- model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
-
- private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation,
- SavedModelBundle bundle) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Tensor tensor;
- if (operation.getConstantValue().isPresent()) {
- Value value = operation.getConstantValue().get();
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
- tensor = value.asTensor();
- } else {
- // Here we use the type from the operation, which will have correct dimension names after name resolving
- tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
- operation.type().get());
- operation.setConstantValue(new TensorValue(tensor));
- }
-
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.function().isPresent()) {
- String name = operation.node().getName();
- if (!model.expressions().containsKey(operation.node().getName())) {
- TensorFunction function = operation.function().get();
-
- // Make sure output adheres to standard naming convention
- if (isSignatureOutput(model, operation)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced in a signature output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
- if (operation.isInput() && isSignatureInput(model, operation)) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
- model.argument(operation.node().getName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
- for (TensorFlowModel.Signature signature : model.signatures().values()) {
- for (String output : signature.outputs().values()) {
- reportWarnings(index.get(output), signature);
- }
- }
- }
-
- /**
- * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
- * This allows users to learn the exact types (including dimension order after renaming) of the Variables
- * such that these can be converted and fed to a parent document independently of the rest of the model
- * for fast model weight updates.
- */
- private static void logVariableTypes(OperationIndex index) {
- for (TensorFlowOperation operation : index.operations()) {
- if ( ! (operation instanceof Variable)) continue;
- if ( ! operation.type().isPresent()) continue; // will not happen
-
- log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
- " of type " + operation.type().get());
- }
- }
-
- private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
- for (String warning : operation.warnings()) {
- signature.importWarning(warning);
- }
- for (TensorFlowOperation input : operation.inputs()) {
- reportWarnings(input, signature);
- }
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
- for (NodeDef node : graph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- private static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output port part. Indexes are used for nodes with
- * multiple outputs.
- */
- private static int portPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
- private static class OperationIndex {
-
- private final Map<String, TensorFlowOperation> index = new HashMap<>();
- public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
- public TensorFlowOperation get(String key) { return index.get(key); }
- public boolean alreadyImported(String key) { return index.containsKey(key); }
- public Collection<TensorFlowOperation> operations() { return index.values(); }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
deleted file mode 100644
index c1665d066a4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
+++ /dev/null
@@ -1,210 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-
-/**
- * A constraint satisfier to find suitable dimension names to reduce the
- * amount of necessary renaming during evaluation of an imported model.
- *
- * @author lesters
- */
-public class DimensionRenamer {
-
- private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
-
- private int iterations = 0;
-
- public DimensionRenamer() {
- this("d");
- }
-
- public DimensionRenamer(String dimensionPrefix) {
- this.dimensionPrefix = dimensionPrefix;
- }
-
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
-
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
- Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
- }
-
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
- }
-
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- public void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
-
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
- }
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
- }
-
- public void solve() {
- solve(100000);
- }
-
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
- }
- }
- }
-
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
- }
- return true;
- }
-
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
- }
- }
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
- }
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
- }
-
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
- }
-
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
- }
-
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
- }
-
- private static class Arc {
-
- private final String from;
- private final String to;
- private final TensorFlowOperation operation;
-
- Arc(String from, String to, TensorFlowOperation operation) {
- this.from = from;
- this.to = to;
- this.operation = operation;
- }
-
- Arc opposite() {
- return new Arc(to, from, operation);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(from, to);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
- return false;
- }
- Arc other = (Arc) obj;
- return Objects.equals(from, other.from) && Objects.equals(to, other.to);
- }
-
- @Override
- public String toString() {
- return String.format("%s -> %s", from, to);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
deleted file mode 100644
index b665413a6b2..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-/**
- * Maps from TensorFlow operations to Vespa operations.
- *
- * @author bratseth
- * @author lesters
- */
-public class OperationMapper {
-
- public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, node, inputs, port);
- case "const": return new Const(modelName, node, inputs, port);
- case "expanddims": return new ExpandDims(modelName, node, inputs, port);
- case "identity": return new Identity(modelName, node, inputs, port);
- case "placeholder": return new Placeholder(modelName, node, inputs, port);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port);
- case "reshape": return new Reshape(modelName, node, inputs, port);
- case "shape": return new Shape(modelName, node, inputs, port);
- case "squeeze": return new Squeeze(modelName, node, inputs, port);
-
- // control flow
- case "merge": return new Merge(modelName, node, inputs, port);
- case "switch": return new Switch(modelName, node, inputs, port);
-
- // math ops
- case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos());
- case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
- case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor());
- case "matmul": return new Matmul(modelName, node, inputs, port);
- case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max());
- case "mean": return new Mean(modelName, node, inputs, port);
- case "reducemean": return new Mean(modelName, node, inputs, port);
- case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
- case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, node, inputs, port);
- case "where3": return new Select(modelName, node, inputs, port);
- case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
- case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu());
- case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu());
- case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu());
-
- // state ops
- case "variable": return new Variable(modelName, node, inputs, port);
- case "variablev2": return new Variable(modelName, node, inputs, port);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, node, inputs, port);
- case "noop": return new NoOp(modelName, node, inputs, port);
- }
-
- TensorFlowOperation op = new NoOp(modelName, node, inputs, port);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
-}
-
-
-
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
deleted file mode 100644
index 03a65333192..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ /dev/null
@@ -1,255 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorTypeParser;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. TensorFlow tensors have an explicit ordering of their dimensions.
- * During import, we need to track the Vespa dimension that matches the
- * corresponding TensorFlow dimension as the ordering can change after
- * dimension renaming. That is the purpose of this class.
- *
- * @author lesters
- */
-public class OrderedTensorType {
-
- private final TensorType type;
- private final List<TensorType.Dimension> dimensions;
-
- private final long[] innerSizesTensorFlow;
- private final long[] innerSizesVespa;
- private final int[] dimensionMap;
-
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
- this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesTensorFlow = new long[dimensions.size()];
- this.innerSizesVespa = new long[dimensions.size()];
- this.dimensionMap = createDimensionMap();
- }
-
- public TensorType type() {
- return this.type;
- }
-
- public int rank() { return dimensions.size(); }
-
- public List<TensorType.Dimension> dimensions() {
- return dimensions;
- }
-
- public List<String> dimensionNames() {
- return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
- }
-
- private int[] createDimensionMap() {
- int numDimensions = dimensions.size();
- if (numDimensions == 0) {
- return null;
- }
- innerSizesTensorFlow[numDimensions - 1] = 1;
- innerSizesVespa[numDimensions - 1] = 1;
- for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
- innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
- }
- int[] mapping = new int[numDimensions];
- for (int i = 0; i < numDimensions; ++i) {
- TensorType.Dimension dim1 = dimensions().get(i);
- for (int j = 0; j < numDimensions; ++j) {
- TensorType.Dimension dim2 = type.dimensions().get(j);
- if (dim1.equals(dim2)) {
- mapping[i] = j;
- break;
- }
- }
- }
- return mapping;
- }
-
- /**
- * When dimension ordering between Vespa and TensorFlow differs, i.e.
- * after dimension renaming, use the dimension map to read in values
- * so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors from TensorFlow.
- */
- public int toDirectIndex(int index) {
- if (dimensions.size() == 0) {
- return 0;
- }
- if (dimensionMap == null) {
- throw new IllegalArgumentException("Dimension map is not available");
- }
- int directIndex = 0;
- long rest = index;
- for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesTensorFlow[i];
- directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesTensorFlow[i];
- }
- return directIndex;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof OrderedTensorType)) {
- return false;
- }
- OrderedTensorType other = (OrderedTensorType) obj;
- if (dimensions.size() != dimensions.size()) {
- return false;
- }
- List<TensorType.Dimension> thisDimensions = this.dimensions();
- List<TensorType.Dimension> otherDimensions = other.dimensions();
- for (int i = 0; i < thisDimensions.size(); ++i) {
- if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
- return false;
- }
- }
- return true;
- }
-
- public void verifyType(NodeDef node) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
- int vespaIndex = dimensionMap[tensorFlowIndex];
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- public OrderedTensorType rename(DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
- for (TensorType.Dimension dimension : dimensions) {
- String oldName = dimension.name();
- Optional<String> newName = renamer.dimensionNameOf(oldName);
- if (!newName.isPresent())
- return this; // presumably, already renamed
- TensorType.Dimension.Type dimensionType = dimension.type();
- if (dimensionType == TensorType.Dimension.Type.indexedBound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
- } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
- } else if (dimensionType == TensorType.Dimension.Type.mapped) {
- renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
- }
- }
- return new OrderedTensorType(renamedDimensions);
- }
-
- /**
- * Returns a string representation of this: A standard tensor type string where dimensions
- * are listed in the order of this rather than in the natural order of their names.
- */
- @Override
- public String toString() {
- return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
-
- /**
- * Creates an instance from the string representation of this: A standard tensor type string
- * where dimensions are listed in the order of this rather than the natural order of their names.
- */
- public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- Builder builder = new Builder(node);
- TensorShapeProto shape = tensorFlowShape(node);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static class Builder {
-
- private final TensorShapeProto shape;
- private final List<TensorType.Dimension> dimensions;
-
- public Builder(NodeDef node) {
- this.shape = tensorFlowShape(node);
- this.dimensions = new ArrayList<>(shape.getDimCount());
- }
-
- public Builder add(TensorType.Dimension vespaDimension) {
- int index = dimensions.size();
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
- long size = tensorFlowDimension.getSize();
- if (size >= 0) {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- if (!vespaDimension.size().isPresent()) {
- throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
- }
- if (vespaDimension.size().get() != size) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " +
- vespaDimension.size().get());
- }
- } else {
- if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
- throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
- }
- }
- this.dimensions.add(vespaDimension);
- return this;
- }
-
- public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
deleted file mode 100644
index 6cbfe0dfb05..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
+++ /dev/null
@@ -1,145 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-
-public class Join extends TensorFlowOperation {
-
- private final DoubleBinaryOperator operator;
-
- public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
- super(modelName, node, inputs, port);
- this.operator = operator;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
-
- // Well now we have potentially entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // In broadcasting, the size of each dimension is compared element-wise,
- // starting with the trailing dimensions and working forward. A special
- // case occurs when the size of one dimension is 1, while the other is not.
- // Then the dimension with size 1 is "stretched" to be of compatible size.
- //
- // An example:
- //
- // Tensor A: d0[5], d1[1], d2[3], d3[1]
- // Tensor B: d1[4], d2[1], d3[2]
- //
- // In TensorFlow and using the above rules of broadcasting, the resulting
- // type is:
- // d0[5], d1[4], d2[3], d2[2]
- //
- // However, in Vespa's tensor logic, the join of the two above tensors would
- // result in a tensor of type:
- // d0[5], d1[1], d2[1], d3[1]
- //
- // By reducing the dimensions of size 1 in each tensor before joining,
- // we get equal results as in TensorFlow.
-
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < a.rank(); ++i) {
- TensorType.Dimension aDim = a.dimensions().get(i);
- long size = aDim.size().orElse(-1L);
-
- if (i - sizeDifference >= 0) {
- TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
- size = Math.max(size, bDim.size().orElse(-1L));
- }
-
- if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
- builder.add(TensorType.Dimension.indexed(aDim.name(), size));
- } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
- builder.add(TensorType.Dimension.indexed(aDim.name()));
- } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
- builder.add(TensorType.Dimension.mapped(aDim.name()));
- }
- }
- return builder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
-
- TensorFlowOperation a = largestInput();
- TensorFlowOperation b = smallestInput();
-
- List<String> aDimensionsToReduce = new ArrayList<>();
- List<String> bDimensionsToReduce = new ArrayList<>();
- int sizeDifference = a.type().get().rank() - b.type().get().rank();
- for (int i = 0; i < b.type().get().rank(); ++i) {
- TensorType.Dimension bDim = b.type().get().dimensions().get(i);
- TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
- long bSize = bDim.size().orElse(-1L);
- long aSize = aDim.size().orElse(-1L);
- if (bSize == 1L && aSize != 1L) {
- bDimensionsToReduce.add(bDim.name());
- }
- if (aSize == 1L && bSize != 1L) {
- aDimensionsToReduce.add(bDim.name());
- }
- }
-
- TensorFunction aReducedFunction = a.function().get();
- if (aDimensionsToReduce.size() > 0) {
- aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
- }
- TensorFunction bReducedFunction = b.function().get();
- if (bDimensionsToReduce.size() > 0) {
- bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
- }
-
- return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < b.rank(); ++i) {
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
- }
- }
-
- private TensorFlowOperation largestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
- private TensorFlowOperation smallestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
deleted file mode 100644
index b2b9530a161..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Optional;
-
-public class Matmul extends TensorFlowOperation {
-
- public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
-
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
-
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
deleted file mode 100644
index d558ec89e87..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends TensorFlowOperation {
-
- public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, Collections.emptyList(), port); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
deleted file mode 100644
index b18a8a9b212..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-
-public class Variable extends TensorFlowOperation {
-
- public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
- super(modelName, node, inputs, port);
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName() + "_" + super.vespaName();
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null; // will be added by function() since this is constant.
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
deleted file mode 100644
index 9e53990a9d6..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Tensorflow integration
- */
-@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 0f5eec93feb..bf9684082f4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -15,7 +15,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
index 74b0d11f1d6..c8c7ec798bb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 50a467ec581..a63c7346335 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,7 @@ public class DropoutImportTestCase {
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
model.get().requiredMacros().get("X"));
- TensorFlowModel.Signature signature = model.get().signature("serving_default");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
index 9f919c452d6..bd7644be23b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase {
// Check signatures
assertEquals(1, model.get().signatures().size());
- TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
+ ImportedModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index 4b68cd40a08..a7926cd2e02 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,11 +1,9 @@
-package com.yahoo.searchlib.rankingexpression.integration.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -24,7 +22,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -48,7 +46,7 @@ public class OnnxMnistSoftmaxImportTestCase {
model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.outputExpression("add");
+ RankingExpression output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
@@ -68,13 +66,12 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
- TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
+ ImportedModel model = new TensorFlowImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel("test", path);
+ ImportedModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
@@ -83,14 +80,7 @@ public class OnnxMnistSoftmaxImportTestCase {
return expression.evaluate(context).asTensor();
}
- private Context contextFrom(TensorFlowModel result) {
- MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- return context;
- }
-
- private Context contextFrom(OnnxModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
index beec2ab1ead..b2443082ab1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 7ca16939477..723c5f27914 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -1,11 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class TestableTensorFlowModel {
private SavedModelBundle tensorFlowModel;
- private TensorFlowModel model;
+ private ImportedModel model;
// Sizes of the input vector
private final int d0Size = 1;
@@ -39,7 +39,7 @@ public class TestableTensorFlowModel {
model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
- public TensorFlowModel get() { return model; }
+ public ImportedModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
@@ -66,7 +66,7 @@ public class TestableTensorFlowModel {
return TensorConverter.toVespaTensor(results.get(0));
}
- private Context contextFrom(TensorFlowModel result) {
+ private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
@@ -81,7 +81,7 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
+ private void evaluateMacro(Context context, ImportedModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
@@ -89,7 +89,7 @@ public class TestableTensorFlowModel {
}
}
- private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
+ private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.macros().containsKey(name)) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
index 051c2c60c95..f94098e6255 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
@@ -1,4 +1,4 @@
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.ml;
import org.junit.Test;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 944755c9db2..3a66eef258d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -22,22 +22,37 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
+ public static DoubleBinaryOperator greater() { return new Greater(); }
+ public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
+ public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
+ public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
+ public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator asin() { return new Asin(); }
+ public static DoubleUnaryOperator atan() { return new Atan(); }
+ public static DoubleUnaryOperator ceil() { return new Ceil(); }
+ public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
+ public static DoubleUnaryOperator log() { return new Log(); }
+ public static DoubleUnaryOperator neg() { return new Neg(); }
+ public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleUnaryOperator tan() { return new Tan(); }
+ public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
@@ -59,6 +74,20 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
+ public static class Greater implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a > b)"; }
+ }
+
+ public static class Less implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a < b)"; }
+ }
+
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@@ -73,6 +102,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(min(a, b))"; }
}
+ public static class Mean implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return (left + right) / 2; }
+ @Override
+ public String toString() { return "f(a,b)((a + b) / 2)"; }
+ }
+
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@@ -80,6 +116,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a * b)"; }
}
+ public static class Pow implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
+ @Override
+ public String toString() { return "f(a,b)(pow(a, b))"; }
+ }
+
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@@ -104,6 +147,13 @@ public class ScalarFunctions {
// Unary operators ------------------------------------------------------------------------------
+ public static class Abs implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.abs(operand); }
+ @Override
+ public String toString() { return "f(a)(fabs(a))"; }
+ }
+
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@@ -111,6 +161,34 @@ public class ScalarFunctions {
public String toString() { return "f(a)(acos(a))"; }
}
+ public static class Asin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.asin(operand); }
+ @Override
+ public String toString() { return "f(a)(asin(a))"; }
+ }
+
+ public static class Atan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.atan(operand); }
+ @Override
+ public String toString() { return "f(a)(atan(a))"; }
+ }
+
+ public static class Ceil implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.ceil(operand); }
+ @Override
+ public String toString() { return "f(a)(ceil(a))"; }
+ }
+
+ public static class Cos implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.cos(operand); }
+ @Override
+ public String toString() { return "f(a)(cos(a))"; }
+ }
+
public static class Elu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
@@ -132,6 +210,26 @@ public class ScalarFunctions {
public String toString() { return "f(a)(floor(a))"; }
}
+ public static class Log implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.log(operand); }
+ @Override
+ public String toString() { return "f(a)(log(a))"; }
+ }
+
+ public static class Neg implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return -operand; }
+ @Override
+ public String toString() { return "f(a)(-a)"; }
+ }
+
+ public static class Reciprocal implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return 1.0 / operand; }
+ @Override
+ public String toString() { return "f(a)(1 / a)"; }
+ }
public static class Relu implements DoubleUnaryOperator {
@Override
@@ -150,6 +248,13 @@ public class ScalarFunctions {
public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); }
}
+ public static class Sin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.sin(operand); }
+ @Override
+ public String toString() { return "f(a)(sin(a))"; }
+ }
+
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@@ -172,15 +277,29 @@ public class ScalarFunctions {
}
public static class Square implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return operand * operand; }
-
@Override
public String toString() { return "f(a)(a * a)"; }
+ }
+ public static class Tan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tan(operand); }
+ @Override
+ public String toString() { return "f(a)(tan(a))"; }
}
+ public static class Tanh implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tanh(operand); }
+ @Override
+ public String toString() { return "f(a)(tanh(a))"; }
+ }
+
+
+
+
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function<List<Long>, Double> {