summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHÃ¥kon Hallingstad <hakon@oath.com>2018-06-06 14:25:23 +0200
committerGitHub <noreply@github.com>2018-06-06 14:25:23 +0200
commit62ae46a58d9501ad60431634f374b3cfa2856a48 (patch)
tree12ea280192a44f26b9718018c7cfb39b0c4c4735
parent240176d60c44507f4e6733c7512620e80554c8de (diff)
parent681963959794b47102d1a1cf72f215c72b0e2b51 (diff)
Merge pull request #6106 from vespa-engine/revert-6046-lesters/refactor-onnx-tensorflow-import
Revert "Refactor ONNX and TF import to use same code base"
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java674
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java636
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java677
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java242
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java47
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java216
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java52
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java234
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java326
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java)10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java)154
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java64
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java)2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java411
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java)101
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java)9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java97
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java255
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java)3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java)31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java)53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java)21
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java145
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java)29
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java)15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java)11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java)24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java)19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java)13
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java)33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java)118
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java)22
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java)6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java)4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java)14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java (renamed from searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java)2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java123
64 files changed, 3726 insertions, 2365 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
deleted file mode 100644
index effa261be3b..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
+++ /dev/null
@@ -1,674 +0,0 @@
-package com.yahoo.searchdefinition.expressiontransforms;
-
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
-import com.yahoo.path.Path;
-import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.Arguments;
-import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
-import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
-
-/**
- * Base class for replacing instances of a pseudofeature for imported ML
- * ranking models with native Vespa ranking expressions.
- *
- * @author bratseth
- * @author lesters
- */
-abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
-
- ExpressionNode transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
- }
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
- }
-
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- typeMismatchExplanation(constantValue.type(), macroType));
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(ImportedModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers input '" + macroName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
- }
- }
-
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
- return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
- (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
- "in query profile types - see the documentation."
- : "");
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- ModelStore(ApplicationPackage application, FeatureArguments arguments) {
- this.application = application;
- this.arguments = arguments;
- }
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File modelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the arguments to the import feature */
- static abstract class FeatureArguments {
-
- Path modelPath;
-
- /** Optional arguments */
- Optional<String> signature, output;
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
- }
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 44eeb364603..1c41ad8284e 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -2,20 +2,58 @@
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Replaces instances of the onnx(model-path, output)
@@ -25,12 +63,12 @@ import java.util.Optional;
* @author bratseth
* @author lesters
*/
-public class OnnxFeatureConverter extends MLImportFeatureConverter {
+public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, OnnxModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -46,8 +84,7 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter {
if ( ! feature.getName().equals("onnx")) return feature;
try {
- FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -61,24 +98,597 @@ public class OnnxFeatureConverter extends MLImportFeatureConverter {
private ExpressionNode transformFromOnnxModel(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles) {
- ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> onnxImporter.importModel(store.arguments().modelName(),
- store.modelDir()));
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ store.onnxModelDir()));
+
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ String output = chooseOutput(model, store.arguments().output());
+ if (model.skippedOutputs().containsKey(output)) {
+ String message = "Could not import Onnx model output '" + output + "'";
+ if (!model.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + model.skippedOutputs().get(output);
+ }
+ if (!model.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", model.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(OnnxModel model, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (model.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
+ if (model.outputs().size() > 1)
+ throw new IllegalArgumentException("Onnx model has multiple outputs (" +
+ Joiner.on(", ").join(model.outputs().keySet()) +
+ "), one must be specified " +
+ "as a second argument to onnx()");
+ return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = model.outputs().get(outputName.get());
+ if (output == null) {
+ if (model.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ model.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
}
- static class OnnxFeatureArguments extends FeatureArguments {
- public OnnxFeatureArguments(Arguments arguments) {
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ "The required type of this is " + constantValue.type() +
+ ", but the macro returns " + macroType);
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(OnnxModel model) {
+ if (model.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro produces type " + actualType);
+ }
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ private static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ public ModelStore(ApplicationPackage application, Arguments arguments) {
+ this.application = application;
+ this.arguments = new FeatureArguments(arguments);
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File onnxModelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ public void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ public RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ public void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ public List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ public Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
+ private static class FeatureArguments {
+
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> output;
+
+ public FeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ "the onnx model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
modelPath = Path.fromString(asString(arguments.expressions().get(0)));
output = optionalArgument(1, arguments);
- signature = Optional.of("default");
}
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 27e1ad51b33..41da32f64c3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,19 +1,59 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -22,12 +62,12 @@ import java.util.Map;
*
* @author bratseth
*/
-public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
+public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -43,8 +83,7 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -56,19 +95,565 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
}
private ExpressionNode transformFromTensorFlowModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> tensorFlowImporter.importModel(store.arguments().modelName(),
- store.modelDir()));
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> tensorFlowImporter.importModel(store.arguments().modelName(),
+ store.tensorFlowModelDir()));
+
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ Signature signature = chooseSignature(model, store.arguments().signature());
+ String output = chooseOutput(signature, store.arguments().output());
+ if (signature.skippedOutputs().containsKey(output)) {
+ String message = "Could not import TensorFlow model output '" + output + "'";
+ if (!signature.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + signature.skippedOutputs().get(output);
+ }
+ if (!signature.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", signature.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing signature, or the only signature if none is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
+ if ( ! signatureName.isPresent()) {
+ if (importResult.signatures().size() == 0)
+ throw new IllegalArgumentException("No signatures are available");
+ if (importResult.signatures().size() > 1)
+ throw new IllegalArgumentException("Model has multiple signatures (" +
+ Joiner.on(", ").join(importResult.signatures().keySet()) +
+ "), one must be specified " +
+ "as a second argument to tensorflow()");
+ return importResult.signatures().values().stream().findFirst().get();
+ }
+ else {
+ Signature signature = importResult.signatures().get(signatureName.get());
+ if (signature == null)
+ throw new IllegalArgumentException("Model does not have the specified signature '" +
+ signatureName.get() + "'");
+ return signature;
+ }
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(Signature signature, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (signature.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
+ if (signature.outputs().size() > 1)
+ throw new IllegalArgumentException(signature + " has multiple outputs (" +
+ Joiner.on(", ").join(signature.outputs().keySet()) +
+ "), one must be specified " +
+ "as a third argument to tensorflow()");
+ return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = signature.outputs().get(outputName.get());
+ if (output == null) {
+ if (signature.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ signature.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ typeMismatchExplanation(constantValue.type(), macroType));
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
+ if (signature.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
}
- static class TensorFlowFeatureArguments extends FeatureArguments {
- public TensorFlowFeatureArguments(Arguments arguments) {
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers placeholder '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers placeholder '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " +
+ typeMismatchExplanation(requiredType, actualType));
+ }
+ }
+
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
+ (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
+ "in query profile types - see the documentation."
+ : "");
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ private static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ public ModelStore(ApplicationPackage application, Arguments arguments) {
+ this.application = application;
+ this.arguments = new FeatureArguments(arguments);
+ }
+
+
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which (if hasTensorFlowModels is true)
+ * contains the source model to use for these arguments
+ */
+ public File tensorFlowModelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ public void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ public RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ public void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ public List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ public Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
+ private static class FeatureArguments {
+
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> signature, output;
+
+ public FeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
@@ -76,6 +661,68 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
signature = optionalArgument(1, arguments);
output = optionalArgument(2, arguments);
}
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index d9beab6e2f2..1c54d12d8b3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -37,6 +37,15 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
+ public void testOnnxReference() throws ParseException {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
@@ -113,6 +122,13 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
+ public void testOnnxReferenceSpecifyingOutput() {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'add')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
public void testOnnxReferenceMissingMacro() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
@@ -129,7 +145,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -147,8 +163,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
- "but this macro returns tensor(d0[2],d5[10])",
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
+ "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 7228af2b0de..d288a396732 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
"but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
new StoringApplicationPackage(applicationDir));
search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase {
application);
search.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- static class StoringApplicationPackageFile extends ApplicationFile {
+ public static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
deleted file mode 100644
index a658833b426..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ /dev/null
@@ -1,242 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-
-import java.io.File;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-
-/**
- * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
- * ranking expressions. The general mechanism for import is for the
- * specific ML platform import implementations to create an
- * IntermediateGraph. This class offers common code to convert the
- * IntermediateGraph to Vespa ranking expressions and macros.
- *
- * @author lesters
- */
-public abstract class ModelImporter {
-
- private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
-
- /**
- * The main import function.
- */
- public abstract ImportedModel importModel(String modelName, String modelPath);
-
- public ImportedModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
- }
-
- /**
- * Takes an IntermediateGraph and converts it to a ImportedModel containing
- * the actual Vespa ranking expressions.
- */
- static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
- ImportedModel model = new ImportedModel(graph.name());
-
- graph.optimize();
-
- importSignatures(graph, model);
- importExpressions(graph, model);
- reportWarnings(graph, model);
- logVariableTypes(graph);
-
- return model;
- }
-
- private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
- for (String signatureName : graph.signatures()) {
- ImportedModel.Signature signature = model.signature(signatureName);
- for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
- signature.input(input.getKey(), input.getValue());
- }
- for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
- signature.output(output.getKey(), output.getValue());
- }
- }
- }
-
- private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String inputName : signature.inputs().values()) {
- if (inputName.equals(operation.name())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- if (outputName.equals(operation.name())) {
- return true;
- }
- }
- }
- return false;
- }
-
- /**
- * Convert intermediate representation to Vespa ranking expressions.
- */
- static void importExpressions(IntermediateGraph graph, ImportedModel model) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
- if (!function.isPresent()) {
- signature.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(operation, model);
- }
- importExpressionInputs(operation, model);
- importRankingExpression(operation, model);
- importArgumentExpression(operation, model);
- importMacroExpression(operation, model);
-
- return operation.function();
- }
-
- private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
- }
-
- private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Value value = operation.getConstantValue().orElseThrow(() ->
- new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
- "is constant but does not have a value."));
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
-
- Tensor tensor = value.asTensor();
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.function().isPresent()) {
- String name = operation.name();
- if (!model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
-
- if (isSignatureOutput(model, operation)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced from the output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Imported function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.isInput()) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.macro().isPresent()) {
- TensorFunction function = operation.macro().get();
- try {
- model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
-
- /**
- * Add any import warnings to the signature in the ImportedModel.
- */
- private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
- }
- }
- }
-
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
- for (String warning : operation.warnings()) {
- model.defaultSignature().importWarning(warning);
- }
- for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
- }
- }
-
- /**
- * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
- * This allows users to learn the exact types (including dimension order after renaming) of the Variables
- * such that these can be converted and fed to a parent document independently of the rest of the model
- * for fast model weight updates.
- */
- private static void logVariableTypes(IntermediateGraph graph) {
- for (IntermediateOperation operation : graph.operations()) {
- if ( ! (operation instanceof Constant)) continue;
- if ( ! operation.type().isPresent()) continue; // will not happen
- log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
- " of type " + operation.type().get());
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
deleted file mode 100644
index d3dd2a1d418..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
-import onnx.Onnx;
-
-import java.io.FileInputStream;
-import java.io.IOException;
-
-/**
- * Converts a ONNX model into a ranking expression and set of constants.
- *
- * @author lesters
- */
-public class OnnxImporter extends ModelImporter {
-
- @Override
- public ImportedModel importModel(String modelName, String modelPath) {
- try (FileInputStream inputStream = new FileInputStream(modelPath)) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph);
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
deleted file mode 100644
index ff584559a83..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
-import org.tensorflow.SavedModelBundle;
-
-import java.io.IOException;
-
-/**
- * Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorFlowImporter extends ModelImporter {
-
- /**
- * Imports a saved TensorFlow model from a directory.
- * The model should be saved as a .pbtxt or .pb file.
- * The name of the model is taken as the db/pbtxt file name (not including the file ending).
- *
- * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
- * @param modelDir the directory containing the TensorFlow model files to import
- */
- public ImportedModel importModel(String modelName, String modelDir) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(modelName, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- /** Imports a TensorFlow model */
- ImportedModel importModel(String modelName, SavedModelBundle model) {
- try {
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
deleted file mode 100644
index 39a8b211d09..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Holds an intermediate representation of an imported ONNX or TensorFlow
- * graph. After this intermediate representation is constructed, it is used to
- * simplify and optimize the computational graph and then converted into the
- * final ImportedModel that holds the Vespa ranking expressions for the model.
- *
- * @author lesters
- */
-public class IntermediateGraph {
-
- private final String modelName;
- private final Map<String, IntermediateOperation> index = new HashMap<>();
- private final Map<String, GraphSignature> signatures = new HashMap<>();
-
- private static class GraphSignature {
- final Map<String, String> inputs = new HashMap<>();
- final Map<String, String> outputs = new HashMap<>();
- }
-
- public IntermediateGraph(String modelName) {
- this.modelName = modelName;
- }
-
- public String name() {
- return modelName;
- }
-
- public IntermediateOperation put(String key, IntermediateOperation operation) {
- return index.put(key, operation);
- }
-
- public IntermediateOperation get(String key) {
- return index.get(key);
- }
-
- public Set<String> signatures() {
- return signatures.keySet();
- }
-
- public Map<String, String> inputs(String signature) {
- return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
- }
-
- public Map<String, String> outputs(String signature) {
- return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
- }
-
- public String defaultSignature() {
- return "default";
- }
-
- public boolean alreadyImported(String key) {
- return index.containsKey(key);
- }
-
- public Collection<IntermediateOperation> operations() {
- return index.values();
- }
-
- public void optimize() {
- renameDimensions();
- }
-
- /**
- * Find dimension names to avoid excessive renaming while evaluating the model.
- */
- private void renameDimensions() {
- DimensionRenamer renamer = new DimensionRenamer();
- for (String signature : signatures()) {
- for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- }
- renamer.solve();
- for (String signature : signatures()) {
- for (String output : outputs(signature).values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
- }
-
- private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
deleted file mode 100644
index 3fe92440cae..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ /dev/null
@@ -1,216 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import onnx.Onnx;
-
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
- * for generating Vespa ranking expressions.
- *
- * @author lesters
- */
-public class GraphImporter {
-
- public static IntermediateOperation mapOperation(Onnx.NodeProto node,
- List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
- String nodeName = node.getName();
- String modelName = graph.name();
-
- switch (node.getOpType().toLowerCase()) {
- case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
- case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
- case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
- case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
- case "concat": return new ConcatV2(modelName, nodeName, inputs);
- case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
- case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
- case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
- case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
- case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
- case "identity": return new Identity(modelName, nodeName, inputs);
- case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
- case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
- case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
- case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
- case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
- case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "shape": return new Shape(modelName, nodeName, inputs);
- case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
- case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
- case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
- case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
- }
-
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
- return op;
- }
-
- public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
- Onnx.GraphProto onnxGraph = model.getGraph();
-
- IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
- importOperations(onnxGraph, intermediateGraph);
- verifyOutputTypes(onnxGraph, intermediateGraph);
-
- return intermediateGraph;
- }
-
- private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
- for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
- importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
- }
- }
-
- private static IntermediateOperation importOperation(String name,
- Onnx.GraphProto onnxGraph,
- IntermediateGraph intermediateGraph) {
- if (intermediateGraph.alreadyImported(name)) {
- return intermediateGraph.get(name);
- }
- IntermediateOperation operation;
- if (isArgumentTensor(name, onnxGraph)) {
- Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
- if (valueInfoProto == null)
- throw new IllegalArgumentException("Could not find argument tensor: " + name);
- OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
- operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
-
- intermediateGraph.inputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
-
- } else if (isConstantTensor(name, onnxGraph)) {
- Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
- operation = new Constant(intermediateGraph.name(), name, defaultType);
- operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
-
- } else {
- Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
- List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
-
- if (isOutputNode(name, onnxGraph)) {
- intermediateGraph.outputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
- }
- }
- intermediateGraph.put(operation.vespaName(), operation);
-
- return operation;
- }
-
- private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor == null;
- }
-
- private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
- }
-
- private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
- if (tensorProto.getName().equals(name)) {
- return tensorProto;
- }
- }
- return null;
- }
-
- private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
- return getOutputNode(name, graph) != null;
- }
-
- private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
- if (nodeName.equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node,
- Onnx.GraphProto onnxGraph,
- IntermediateGraph intermediateGraph) {
- return node.getInputList().stream()
- .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph))
- .collect(Collectors.toList());
- }
-
- private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
- for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
- OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
- TypeConverter.verifyType(onnxNode.getType(), type);
- }
- }
-
- private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
- }
- } else if (node.getName().equals(nodeName)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
deleted file mode 100644
index 715c55d8323..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import onnx.Onnx;
-
-/**
- * Converts and verifies ONNX tensor types into Vespa tensor types.
- *
- * @author lesters
- */
-public class TypeConverter {
-
- public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
- Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
- }
- for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
- int vespaIndex = type.dimensionMap(onnxIndex);
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
- }
- }
- }
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
deleted file mode 100644
index 19ba146492c..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends IntermediateOperation {
-
- public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
deleted file mode 100644
index a815cbc3944..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
+++ /dev/null
@@ -1,85 +0,0 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * Converts TensorFlow node attributes to Vespa attribute values.
- *
- * @author lesters
- */
-public class AttributeConverter implements IntermediateOperation.AttributeMap {
-
- private final Map<String, AttrValue> attributeMap;
-
- public AttributeConverter(NodeDef node) {
- attributeMap = node.getAttrMap();
- }
-
- public static AttributeConverter convert(NodeDef node) {
- return new AttributeConverter(node);
- }
-
- @Override
- public Optional<Value> get(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.empty(); // requires type
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return Optional.of(new BooleanValue(attrValue.getB()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return Optional.of(new DoubleValue(attrValue.getI()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return Optional.of(new DoubleValue(attrValue.getF()));
- }
- }
- return Optional.empty();
- }
-
- @Override
- public Optional<Value> get(String key, OrderedTensorType type) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
- }
- }
- return get(key);
- }
-
- @Override
- public Optional<List<Value>> getList(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
- AttrValue.ListValue listValue = attrValue.getList();
- if ( ! listValue.getBList().isEmpty()) {
- return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getIList().isEmpty()) {
- return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getFList().isEmpty()) {
- return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- // add the rest
- }
- }
- return Optional.empty();
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
deleted file mode 100644
index e1b292f9e61..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ /dev/null
@@ -1,234 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
- * for generating Vespa ranking expressions.
- *
- * @author lesters
- */
-public class GraphImporter {
-
- public static IntermediateOperation mapOperation(NodeDef node,
- List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
- String nodeName = node.getName();
- String modelName = graph.name();
- int nodePort = IntermediateOperation.indexPartOf(nodeName);
- OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
- AttributeConverter attributes = AttributeConverter.convert(node);
-
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
- case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
- case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
- case "identity": return new Identity(modelName, nodeName, inputs);
- case "placeholder": return new Argument(modelName, nodeName, nodeType);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "shape": return new Shape(modelName, nodeName, inputs);
- case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
-
- // control flow
- case "merge": return new Merge(modelName, nodeName, inputs);
- case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
-
- // math ops
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
- case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
- case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "mean": return new Mean(modelName, nodeName, inputs, attributes);
- case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
- case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, nodeName, inputs);
- case "where3": return new Select(modelName, nodeName, inputs);
- case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
-
- // state ops
- case "variable": return new Constant(modelName, nodeName, nodeType);
- case "variablev2": return new Constant(modelName, nodeName, nodeType);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, nodeName, inputs);
- case "noop": return new NoOp(modelName, nodeName, inputs);
-
- }
-
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
- public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
- MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
-
- IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
- importSignatures(tfGraph, intermediateGraph);
- importOperations(tfGraph, intermediateGraph, bundle);
- verifyOutputTypes(tfGraph, intermediateGraph);
-
- return intermediateGraph;
- }
-
- private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- String nodeName = input.getValue().getName();
- intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
- }
- java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- String nodeName = output.getValue().getName();
- intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
- }
- }
- }
-
- private static void importOperations(MetaGraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
- }
- }
- }
-
- private static IntermediateOperation importOperation(String nodeName,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- if (intermediateGraph.alreadyImported(nodeName)) {
- return intermediateGraph.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
- List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
- IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
- intermediateGraph.put(nodeName, operation);
-
- List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- if (operation.isConstant()) {
- operation.setConstantValueFunction(
- type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
- }
-
- return operation;
- }
-
- private static List<IntermediateOperation> importOperationInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static List<IntermediateOperation> importControlInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
- for (NodeDef node : tfGraph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
- OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
- TypeConverter.verifyType(node, type);
- }
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
deleted file mode 100644
index 67ad1edc312..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.util.List;
-
-/**
- * Converts and verifies TensorFlow tensor types into Vespa tensor types.
- *
- * @author lesters
- */
-public class TypeConverter {
-
- public static void verifyType(NodeDef node, OrderedTensorType type) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
- int vespaIndex = type.dimensionMap(tensorFlowIndex);
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- TensorShapeProto shape = tensorFlowShape(node);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
new file mode 100644
index 00000000000..fa1f929cc80
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
@@ -0,0 +1,326 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+import onnx.Onnx;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a ONNX model into a ranking expression and set of constants.
+ *
+ * @author lesters
+ */
+public class OnnxImporter {
+
+ private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
+
+ public OnnxModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ public OnnxModel importModel(String modelName, String modelPath) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ return importModel(modelName, model);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+ public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
+ return importGraph(modelName, model.getGraph());
+ }
+
+ private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
+ OnnxModel model = new OnnxModel(modelName);
+ OperationIndex index = new OperationIndex();
+
+ importNodes(graph, model, index);
+ verifyOutputTypes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index);
+
+ reportWarnings(model, index);
+
+ return model;
+ }
+
+ private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ importNode(valueInfo.getName(), graph, model, index);
+ }
+ }
+
+ private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ if (index.alreadyImported(name)) {
+ return index.get(name);
+ }
+ OnnxOperation operation;
+ if (isArgumentTensor(name, graph)) {
+ operation = new Argument(getArgumentTensor(name, graph));
+ model.input(OnnxOperation.namePartOf(name), operation.vespaName());
+ } else if (isConstantTensor(name, graph)) {
+ operation = new Constant(model.name(), getConstantTensor(name, graph));
+ } else {
+ Onnx.NodeProto node = getNodeFromGraph(name, graph);
+ List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
+ operation = OperationMapper.get(node, inputs);
+ if (isOutputNode(name, graph)) {
+ model.output(OnnxOperation.namePartOf(name), operation.vespaName());
+ }
+ }
+ index.put(operation.vespaName(), operation);
+
+ return operation;
+ }
+
+ private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor == null;
+ }
+
+ private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor != null;
+ }
+
+ private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
+ if (tensorProto.getName().equals(name)) {
+ return tensorProto;
+ }
+ }
+ return null;
+ }
+
+ private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
+ return getOutputNode(name, graph) != null;
+ }
+
+ private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
+ Onnx.GraphProto graph,
+ OnnxModel model,
+ OperationIndex index) {
+ return node.getInputList().stream()
+ .map(nodeName -> importNode(nodeName, graph, model, index))
+ .collect(Collectors.toList());
+ }
+
+ private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ OnnxOperation operation = index.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
+ operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
+ .verifyType(onnxNode.getType());
+ }
+ }
+
+
+ /** Find dimension names to avoid excessive renaming while evaluating the model. */
+ private static void findDimensionNames(OnnxModel model, OperationIndex index) {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (String output : model.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ renamer.solve();
+ for (String output : model.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+
+ private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+ private static void importExpressions(OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model);
+ if (!function.isPresent()) {
+ model.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ model.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importInputExpressions(operation, model);
+ importRankingExpression(operation, model);
+ importArgumentExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.vespaName();
+ if (!model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (model.outputs().containsKey(name)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.argument(operation.vespaName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void reportWarnings(OnnxModel model, OperationIndex index) {
+ for (String output : model.outputs().values()) {
+ reportWarnings(model, index.get(output));
+ }
+ }
+
+ private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
+ for (String warning : operation.warnings()) {
+ model.importWarning(warning);
+ }
+ for (OnnxOperation input : operation.inputs()) {
+ reportWarnings(model, input);
+ }
+ }
+
+ private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
+ boolean hasPortNumber = nodeName.contains(":");
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (hasPortNumber) {
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return node;
+ }
+ }
+ } else if (node.getName().equals(nodeName)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ }
+
+ private static class OperationIndex {
+ private final Map<String, OnnxOperation> index = new HashMap<>();
+ public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); }
+ public OnnxOperation get(String key) { return index.get(key); }
+ public boolean alreadyImported(String key) { return index.containsKey(key); }
+ public Collection<OnnxOperation> operations() { return index.values(); }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
new file mode 100644
index 00000000000..bd53afefc3f
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
@@ -0,0 +1,112 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+/**
+ * The result of importing an ONNX model into Vespa.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OnnxModel {
+
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+
+ private final String name;
+
+ public OnnxModel(String name) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
+ this.name = name;
+ }
+
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
+ private final List<String> importWarnings = new ArrayList<>();
+
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
+ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
+
+ /** Returns an immutable map of the arguments (inputs) of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /**
+ * Returns an immutable map of the small constants of this.
+ */
+ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+
+ /**
+ * Returns an immutable map of the large constants of this.
+ */
+ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
+
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to ONNX nodes
+ * which are not inputs or constants.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+
+ /** Returns an immutable map of macros that are part of this model */
+ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
+
+ /** Returns an immutable map of the macros that must be provided by the environment running this model */
+ public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
index 38f1d2329e2..2524417cee0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/DimensionRenamer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
import java.util.ArrayDeque;
import java.util.ArrayList;
@@ -47,7 +47,7 @@ public class DimensionRenamer {
/**
* Add a constraint between dimension names.
*/
- public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
+ public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) {
Arc arc = new Arc(from, to, operation);
Arc opposite = arc.opposite();
constraints.put(arc, pred);
@@ -175,9 +175,9 @@ public class DimensionRenamer {
private final String from;
private final String to;
- private final IntermediateOperation operation;
+ private final OnnxOperation operation;
- Arc(String from, String to, IntermediateOperation operation) {
+ Arc(String from, String to, OnnxOperation operation) {
this.from = from;
this.to = to;
this.operation = operation;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
new file mode 100644
index 00000000000..12090145d3a
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
@@ -0,0 +1,26 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import onnx.Onnx;
+
+import java.util.List;
+
+public class OperationMapper {
+
+ public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ switch (node.getOpType().toLowerCase()) {
+ case "add": return new Join(node, inputs, ScalarFunctions.add());
+ case "matmul": return new MatMul(node, inputs);
+ }
+
+ OnnxOperation op = new NoOp(node, inputs);
+ op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ return op;
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
index 209d73a9f38..812e9b8d678 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorTypeParser;
+import onnx.Onnx;
import java.util.ArrayList;
import java.util.Collections;
@@ -13,9 +13,9 @@ import java.util.stream.Collectors;
/**
* A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. Imported tensors have an explicit ordering of their dimensions.
+ * names. ONNX tensors have an explicit ordering of their dimensions.
* During import, we need to track the Vespa dimension that matches the
- * corresponding imported dimension as the ordering can change after
+ * corresponding ONNX dimension as the ordering can change after
* dimension renaming. That is the purpose of this class.
*
* @author lesters
@@ -25,14 +25,14 @@ public class OrderedTensorType {
private final TensorType type;
private final List<TensorType.Dimension> dimensions;
- private final long[] innerSizesOriginal;
+ private final long[] innerSizesOnnx;
private final long[] innerSizesVespa;
private final int[] dimensionMap;
private OrderedTensorType(List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesOriginal = new long[dimensions.size()];
+ this.innerSizesOnnx = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
}
@@ -54,10 +54,10 @@ public class OrderedTensorType {
if (numDimensions == 0) {
return null;
}
- innerSizesOriginal[numDimensions - 1] = 1;
+ innerSizesOnnx[numDimensions - 1] = 1;
innerSizesVespa[numDimensions - 1] = 1;
for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
+ innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1];
innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
}
int[] mapping = new int[numDimensions];
@@ -74,15 +74,11 @@ public class OrderedTensorType {
return mapping;
}
- public int dimensionMap(int originalIndex) {
- return dimensionMap[originalIndex];
- }
-
/**
- * When dimension ordering between Vespa and imported differs, i.e.
+ * When dimension ordering between Vespa and Onnx differs, i.e.
* after dimension renaming, use the dimension map to read in values
* so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors.
+ * Used when importing tensors from Onnx.
*/
public int toDirectIndex(int index) {
if (dimensions.size() == 0) {
@@ -94,9 +90,9 @@ public class OrderedTensorType {
int directIndex = 0;
long rest = index;
for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesOriginal[i];
+ long address = rest / innerSizesOnnx[i];
directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesOriginal[i];
+ rest %= innerSizesOnnx[i];
}
return directIndex;
}
@@ -120,6 +116,22 @@ public class OrderedTensorType {
return true;
}
+ public void verifyType(Onnx.TypeProto typeProto) {
+ Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
+ }
+ for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) {
+ int vespaIndex = dimensionMap[onnxIndex];
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
+ TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
+ if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
+ }
+ }
+ }
+ }
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
@@ -139,13 +151,18 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
- public OrderedTensorType rename(String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dimensions.size(); ++ i) {
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ Builder builder = new Builder(shape);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Optional<Long> dimSize = dimensions.get(i).size();
- if (dimSize.isPresent() && dimSize.get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ if (onnxDimension.getDimValue() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -153,13 +170,13 @@ public class OrderedTensorType {
return builder.build();
}
- public static OrderedTensorType standardType(OrderedTensorType type) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < type.dimensions().size(); ++ i) {
- TensorType.Dimension dim = type.dimensions().get(i);
- String dimensionName = "d" + i;
- if (dim.size().isPresent() && dim.size().get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
+ Builder builder = new Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -167,46 +184,13 @@ public class OrderedTensorType {
return builder.build();
}
- public static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- public static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- /**
- * Returns a string representation of this: A standard tensor type string where dimensions
- * are listed in the order of this rather than in the natural order of their names.
- */
- @Override
- public String toString() {
- return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
-
- /**
- * Creates an instance from the string representation of this: A standard tensor type string
- * where dimensions are listed in the order of this rather than the natural order of their names.
- */
- public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
- }
-
- public static OrderedTensorType fromDimensionList(List<Long> dims) {
- return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dims.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
- if (dimSize >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ Builder builder = new Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
@@ -216,13 +200,45 @@ public class OrderedTensorType {
public static class Builder {
+ private final Onnx.TensorShapeProto shape;
private final List<TensorType.Dimension> dimensions;
+ public Builder(Onnx.TensorShapeProto shape) {
+ this.shape = shape;
+ this.dimensions = new ArrayList<>(shape.getDimCount());
+ }
+
public Builder() {
+ this.shape = null;
this.dimensions = new ArrayList<>();
}
public Builder add(TensorType.Dimension vespaDimension) {
+ if (shape != null) {
+ int index = dimensions.size();
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index);
+ long size = onnxDimension.getDimValue();
+ if (size >= 0) {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension types");
+ }
+ if (!vespaDimension.size().isPresent()) {
+ throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
+ "not have a size");
+ }
+ if (vespaDimension.size().get() != size) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension sizes. TensorFlow: " + size + " Vespa: " +
+ vespaDimension.size().get());
+ }
+ } else {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension types");
+ }
+ }
+ }
this.dimensions.add(vespaDimension);
return this;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
index 18856d4a25f..2912db03b5f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
@@ -1,16 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
import com.google.protobuf.ByteString;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import onnx.Onnx;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
+import java.util.List;
/**
* Converts Onnx tensors into Vespa tensors.
@@ -28,6 +29,7 @@ public class TensorConverter {
return builder.build();
}
+ /* todo: support more types */
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
new file mode 100644
index 00000000000..a8d8d63daf4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
@@ -0,0 +1,64 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.List;
+
+public class Argument extends OnnxOperation {
+
+ private Onnx.ValueInfoProto valueInfo;
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
+
+ public Argument(Onnx.ValueInfoProto valueInfoProto) {
+ super(null, Collections.emptyList());
+ valueInfo = valueInfoProto;
+ standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType());
+ }
+
+ @Override
+ public String vespaName() {
+ return vespaName(valueInfo.getName());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ if (!standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isInput() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return false;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
index 5e4abeaa234..13043a61a8e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
@@ -1,34 +1,38 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
import java.util.Collections;
import java.util.Optional;
-public class Constant extends IntermediateOperation {
+public class Constant extends OnnxOperation {
- private final String modelName;
+ final String modelName;
+ final Onnx.TensorProto tensorProto;
- public Constant(String modelName, String nodeName, OrderedTensorType type) {
- super(modelName, nodeName, Collections.emptyList());
+ public Constant(String modelName, Onnx.TensorProto tensorProto) {
+ super(null, Collections.emptyList());
this.modelName = modelName;
- this.type = type.rename(vespaName() + "_");
+ this.tensorProto = tensorProto;
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + vespaName(name);
+ return modelName + "_" + vespaName(tensorProto.getName());
}
@Override
protected OrderedTensorType lazyGetType() {
- return type;
+ return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_");
}
@Override
@@ -36,14 +40,9 @@ public class Constant extends IntermediateOperation {
return null; // will be added by function() since this is constant.
}
- /**
- * Constant values are sent in via the constantValueFunction, as the
- * dimension names and thus the data layout depends on the dimension
- * renaming which happens after the conversion to intermediate graph.
- */
@Override
public Optional<Value> getConstantValue() {
- return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
index 8413ed74118..fe2004a528d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
@@ -1,22 +1,24 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-public class Join extends IntermediateOperation {
+public class Join extends OnnxOperation {
private final DoubleBinaryOperator operator;
- public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
- super(modelName, nodeName, inputs);
+ public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) {
+ super(node, inputs);
this.operator = operator;
}
@@ -59,8 +61,8 @@ public class Join extends IntermediateOperation {
return null;
}
- IntermediateOperation a = largestInput();
- IntermediateOperation b = smallestInput();
+ OnnxOperation a = largestInput();
+ OnnxOperation b = smallestInput();
List<String> aDimensionsToReduce = new ArrayList<>();
List<String> bDimensionsToReduce = new ArrayList<>();
@@ -105,13 +107,13 @@ public class Join extends IntermediateOperation {
}
}
- private IntermediateOperation largestInput() {
+ private OnnxOperation largestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
- private IntermediateOperation smallestInput() {
+ private OnnxOperation smallestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
index 52e223f9518..1b388e2ae89 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
@@ -1,18 +1,21 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+import java.util.Collections;
import java.util.List;
import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
-public class MatMul extends IntermediateOperation {
+public class MatMul extends OnnxOperation {
- public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ super(node, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
new file mode 100644
index 00000000000..b1136a0ce0a
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends OnnxOperation {
+
+ public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ super(node, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
new file mode 100644
index 00000000000..30f7b4f4711
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
@@ -0,0 +1,139 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Wraps an ONNX node and produces the respective Vespa tensor operation.
+ * During import, a graph of these operations are constructed. Then, the
+ * types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper
+ * Vespa expressions can be extracted.
+ *
+ * @author lesters
+ */
+public abstract class OnnxOperation {
+
+ protected final Onnx.NodeProto node; // can be null for onnx inputs and constants
+ protected final List<OnnxOperation> inputs;
+ protected final List<OnnxOperation> outputs = new ArrayList<>();
+ protected final List<String> importWarnings = new ArrayList<>();
+
+ protected OrderedTensorType type;
+ protected TensorFunction function;
+ protected Value constantValue = null;
+
+ OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ this.node = node;
+ this.inputs = Collections.unmodifiableList(inputs);
+ this.inputs.forEach(i -> i.outputs.add(this));
+ }
+
+ protected abstract OrderedTensorType lazyGetType();
+ protected abstract TensorFunction lazyGetFunction();
+
+ /** Returns the Vespa tensor type of this operation if it exists */
+ public Optional<OrderedTensorType> type() {
+ if (type == null) {
+ type = lazyGetType();
+ }
+ return Optional.ofNullable(type);
+ }
+
+ /** Returns the Vespa tensor function implementing all operations from this node with inputs */
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ if (isConstant()) {
+ ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
+ function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else {
+ function = lazyGetFunction();
+ }
+ }
+ return Optional.ofNullable(function);
+ }
+
+ /** Return Onnx node */
+ public Onnx.NodeProto node() { return node; }
+
+ /** Return unmodifiable list of inputs */
+ public List<OnnxOperation> inputs() { return inputs; }
+
+ /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
+ public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); }
+
+ /** Add dimension name constraints for this operation */
+ public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+
+ /** Performs dimension rename for this operation */
+ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+
+ /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
+ public boolean isInput() { return false; }
+
+ /** Return true if this node is constant */
+ public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); }
+
+ /** Gets the constant value if it exists */
+ public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+
+ /** Retrieve the valid Vespa name of this node */
+ public String vespaName() { return vespaName(node.getName()); }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
+
+ /** Retrieve the list of warnings produced during its lifetime */
+ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Set an input warning */
+ public void warning(String warning) { importWarnings.add(warning); }
+
+ boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
+ if (inputs.size() != expected) {
+ throw new IllegalArgumentException("Expected " + expected + " inputs " +
+ "for '" + node.getName() + "', got " + inputs.size());
+ }
+ return inputs.stream().map(func).allMatch(Optional::isPresent);
+ }
+
+ boolean allInputTypesPresent(int expected) {
+ return verifyInputs(expected, OnnxOperation::type);
+ }
+
+ boolean allInputFunctionsPresent(int expected) {
+ return verifyInputs(expected, OnnxOperation::function);
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java
index 1530754cc43..5cff8b03d40 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/package-info.java
@@ -3,6 +3,6 @@
* ONNX integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
new file mode 100644
index 00000000000..e3c72830095
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -0,0 +1,411 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.framework.GraphDef;
+import org.tensorflow.framework.MetaGraphDef;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.SignatureDef;
+import org.tensorflow.framework.TensorInfo;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a saved TensorFlow model into a ranking expression and set of constants.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorFlowImporter {
+
+ private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
+
+ /**
+ * Imports a saved TensorFlow model from a directory.
+ * The model should be saved as a .pbtxt or .pb file.
+ * The name of the model is taken as the db/pbtxt file name (not including the file ending).
+ *
+ * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
+ * @param modelDir the directory containing the TensorFlow model files to import
+ */
+ public TensorFlowModel importModel(String modelName, String modelDir) {
+ try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
+
+ return importModel(modelName, model);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ public TensorFlowModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ /** Imports a TensorFlow model */
+ public TensorFlowModel importModel(String modelName, SavedModelBundle model) {
+ try {
+ return importGraph(modelName, MetaGraphDef.parseFrom(model.metaGraphDef()), model);
+ }
+ catch (IOException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
+ }
+ }
+
+ /**
+ * Imports the TensorFlow graph by first importing the tensor types, then
+ * finding a suitable set of dimensions names for each
+ * placeholder/constant/variable, then importing the expressions.
+ */
+ private static TensorFlowModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
+ TensorFlowModel model = new TensorFlowModel(modelName);
+ OperationIndex index = new OperationIndex();
+
+ importSignatures(graph, model);
+ importNodes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index, bundle);
+
+ reportWarnings(model, index);
+ logVariableTypes(index);
+
+ return model;
+ }
+
+ private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
+ for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
+ String signatureName = signatureEntry.getKey();
+ TensorFlowModel.Signature signature = model.signature(signatureName);
+
+ Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ signature.input(inputName, namePartOf(input.getValue().getName()));
+ }
+
+ Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
+ String outputName = output.getKey();
+ signature.output(outputName, namePartOf(output.getValue().getName()));
+ }
+ }
+ }
+
+ private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ importNode(model.name(), outputName, graph.getGraphDef(), index);
+ }
+ }
+ }
+
+ private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
+ if (index.alreadyImported(nodeName)) {
+ return index.get(nodeName);
+ }
+ NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
+ List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
+ TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
+ index.put(nodeName, operation);
+
+ List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
+ }
+
+ return operation;
+ }
+
+ private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
+ .collect(Collectors.toList());
+ }
+
+ private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(nodeName -> isControlDependency(nodeName))
+ .map(nodeName -> importNode(modelName, nodeName, graph, index))
+ .collect(Collectors.toList());
+ }
+
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
+ }
+
+ /** Find dimension names to avoid excessive renaming while evaluating the model. */
+ private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+ }
+
+ private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+ private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(model, operation, bundle);
+ }
+
+ importInputExpressions(operation, model, bundle);
+ importRankingExpression(model, operation);
+ importInputExpression(model, operation);
+ importMacroExpression(model, operation);
+
+ return operation.function();
+ }
+
+ private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model,
+ SavedModelBundle bundle) {
+ operation.inputs().forEach(input -> importExpression(input, model, bundle));
+ }
+
+ private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.macro().isPresent()) {
+ TensorFunction function = operation.macro().get();
+ try {
+ model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation,
+ SavedModelBundle bundle) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Tensor tensor;
+ if (operation.getConstantValue().isPresent()) {
+ Value value = operation.getConstantValue().get();
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+ tensor = value.asTensor();
+ } else {
+ // Here we use the type from the operation, which will have correct dimension names after name resolving
+ tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
+ operation.type().get());
+ operation.setConstantValue(new TensorValue(tensor));
+ }
+
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ Session.Runner fetched = bundle.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name +
+ ", but got " + importedTensors.size());
+ return importedTensors.get(0);
+ }
+
+ private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.function().isPresent()) {
+ String name = operation.node().getName();
+ if (!model.expressions().containsKey(operation.node().getName())) {
+ TensorFunction function = operation.function().get();
+
+ // Make sure output adheres to standard naming convention
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced in a signature output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.isInput() && isSignatureInput(model, operation)) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
+ model.argument(operation.node().getName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ reportWarnings(index.get(output), signature);
+ }
+ }
+ }
+
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(OperationIndex index) {
+ for (TensorFlowOperation operation : index.operations()) {
+ if ( ! (operation instanceof Variable)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+
+ log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
+ private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
+ for (String warning : operation.warnings()) {
+ signature.importWarning(warning);
+ }
+ for (TensorFlowOperation input : operation.inputs()) {
+ reportWarnings(input, signature);
+ }
+ }
+
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
+ for (NodeDef node : graph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output port part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ private static int portPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+ private static class OperationIndex {
+
+ private final Map<String, TensorFlowOperation> index = new HashMap<>();
+ public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
+ public TensorFlowOperation get(String key) { return index.get(key); }
+ public boolean alreadyImported(String key) { return index.containsKey(key); }
+ public Collection<TensorFlowOperation> operations() { return index.values(); }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 4b49f17f74e..721214f9e94 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -1,4 +1,5 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -12,61 +13,76 @@ import java.util.Map;
import java.util.regex.Pattern;
/**
- * The result of importing a model (TensorFlow or ONNX) into Vespa.
+ * The result of importing a TensorFlow model into Vespa.
+ * - A set of signatures which are named collections of inputs and outputs.
+ * - A set of named constant tensors represented by Variable nodes in TensorFlow.
+ * - A list of warning messages.
*
* @author bratseth
*/
-public class ImportedModel {
-
- private static final String defaultSignatureName = "default";
+// This object can be built incrementally within this package, but is immutable when observed from outside the package
+public class TensorFlowModel {
private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
- private final String name;
- private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> macros = new HashMap<>();
- private final Map<String, TensorType> requiredMacros = new HashMap<>();
+ private final String name;
/**
- * Creates a new imported model.
+ * Creates a TensorFlow model
*
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
*/
- public ImportedModel(String name) {
+ public TensorFlowModel(String name) {
if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" +
- name + "'");
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
this.name = name;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, Signature::new);
+ }
+
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/**
* Returns an immutable map of the small constants of this.
* These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow or ONNX source.
+ * values given in the TensorFlow source.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
/**
* Returns an immutable map of the large constants of this.
- * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
- * For TensorFlow this corresponds to Variable files stored separately.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration,
+ * and correspond to Variable files stored separately in TensorFlow.
*/
public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
/**
- * Returns an immutable map of the expressions of this - corresponding to graph nodes
- * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
- * Note that only nodes recursively referenced by a placeholder/input are added.
+ * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes
+ * which are not Placeholders or Variables (which instead become respectively arguments and constants).
+ * Note that only nodes recursively referenced by a placeholder are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -79,26 +95,9 @@ public class ImportedModel {
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
- /** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
- return signatures.computeIfAbsent(name, Signature::new);
- }
-
- /** Convenience method for returning a default signature */
- Signature defaultSignature() { return signature(defaultSignatureName); }
-
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void macro(String name, RankingExpression expression) { macros.put(name, expression); }
- void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
-
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument
- * ("placeholder") names+types, and outputs maps to expressions nodes.
- * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
- * concept of signatures. For now, we handle ONNX models as having a single signature.
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
+ * and outputs maps to expressions nodes.
*/
public class Signature {
@@ -108,14 +107,19 @@ public class ImportedModel {
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
- public Signature(String name) {
+ Signature(String name) {
this.name = name;
}
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+
public String name() { return name; }
/** Returns the result this is part of */
- public ImportedModel owner() { return ImportedModel.this; }
+ TensorFlowModel owner() { return TensorFlowModel.this; }
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
@@ -123,7 +127,7 @@ public class ImportedModel {
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns the type of the argument this input references */
+ /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
@@ -140,17 +144,12 @@ public class ImportedModel {
*/
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
- /** Returns the expression this output references */
+ /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
@Override
public String toString() { return "signature '" + name + "'"; }
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
-
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
index e1294ec3e01..c5ac7ace0fc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
@@ -1,8 +1,7 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
@@ -25,7 +24,7 @@ public class VariableConverter {
*/
public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
+ return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName,
bundle),
OrderedTensorType.fromSpec(orderedTypeSpec)));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
new file mode 100644
index 00000000000..c1665d066a4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
@@ -0,0 +1,210 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A constraint satisfier to find suitable dimension names to reduce the
+ * amount of necessary renaming during evaluation of an imported model.
+ *
+ * @author lesters
+ */
+public class DimensionRenamer {
+
+ private final String dimensionPrefix;
+ private final Map<String, List<Integer>> variables = new HashMap<>();
+ private final Map<Arc, Constraint> constraints = new HashMap<>();
+ private final Map<String, Integer> renames = new HashMap<>();
+
+ private int iterations = 0;
+
+ public DimensionRenamer() {
+ this("d");
+ }
+
+ public DimensionRenamer(String dimensionPrefix) {
+ this.dimensionPrefix = dimensionPrefix;
+ }
+
+ /**
+ * Add a dimension name variable.
+ */
+ public void addDimension(String name) {
+ variables.computeIfAbsent(name, d -> new ArrayList<>());
+ }
+
+ /**
+ * Add a constraint between dimension names.
+ */
+ public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
+ Arc arc = new Arc(from, to, operation);
+ Arc opposite = arc.opposite();
+ constraints.put(arc, pred);
+ constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ }
+
+ /**
+ * Retrieve resulting name of dimension after solving for constraints.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if (!renames.containsKey(name)) {
+ return Optional.empty();
+ }
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ }
+
+ /**
+ * Perform iterative arc consistency until we have found a solution. After
+ * an initial iteration, the variables (dimensions) will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * not typically result in a guaranteed ordering. If that is needed, the
+ * algorithm below needs to be adapted with a backtracking (tree) search
+ * to find solutions.
+ */
+ public void solve(int maxIterations) {
+ initialize();
+
+ // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if (!ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ }
+ values.sort(Integer::compare);
+ variables.put(dimension, Collections.singletonList(values.get(0)));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
+ maxIterations + " iterations");
+ }
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+
+ public void solve() {
+ solve(100000);
+ }
+
+ private void initialize() {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while (!workList.isEmpty()) {
+ Arc arc = workList.pop();
+ iterations += 1;
+ if (revise(arc)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(Arc arc) {
+ boolean revised = false;
+ for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).test(from, to)) {
+ satisfied = true;
+ }
+ }
+ if (!satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public interface Constraint {
+ boolean test(Integer x, Integer y);
+ }
+
+ public static boolean equals(Integer x, Integer y) {
+ return Objects.equals(x, y);
+ }
+
+ public static boolean lesserThan(Integer x, Integer y) {
+ return x < y;
+ }
+
+ public static boolean greaterThan(Integer x, Integer y) {
+ return x > y;
+ }
+
+ private static class Arc {
+
+ private final String from;
+ private final String to;
+ private final TensorFlowOperation operation;
+
+ Arc(String from, String to, TensorFlowOperation operation) {
+ this.from = from;
+ this.to = to;
+ this.operation = operation;
+ }
+
+ Arc opposite() {
+ return new Arc(to, from, operation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof Arc)) {
+ return false;
+ }
+ Arc other = (Arc) obj;
+ return Objects.equals(from, other.from) && Objects.equals(to, other.to);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%s -> %s", from, to);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
new file mode 100644
index 00000000000..b665413a6b2
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -0,0 +1,97 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+/**
+ * Maps from TensorFlow operations to Vespa operations.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OperationMapper {
+
+ public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ switch (node.getOp().toLowerCase()) {
+ // array ops
+ case "concatv2": return new ConcatV2(modelName, node, inputs, port);
+ case "const": return new Const(modelName, node, inputs, port);
+ case "expanddims": return new ExpandDims(modelName, node, inputs, port);
+ case "identity": return new Identity(modelName, node, inputs, port);
+ case "placeholder": return new Placeholder(modelName, node, inputs, port);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, node, inputs, port);
+ case "reshape": return new Reshape(modelName, node, inputs, port);
+ case "shape": return new Shape(modelName, node, inputs, port);
+ case "squeeze": return new Squeeze(modelName, node, inputs, port);
+
+ // control flow
+ case "merge": return new Merge(modelName, node, inputs, port);
+ case "switch": return new Switch(modelName, node, inputs, port);
+
+ // math ops
+ case "add": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
+ case "acos": return new Map(modelName, node, inputs, port, ScalarFunctions.acos());
+ case "div": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, node, inputs, port, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, node, inputs, port, ScalarFunctions.floor());
+ case "matmul": return new Matmul(modelName, node, inputs, port);
+ case "maximum": return new Join(modelName, node, inputs, port, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, node, inputs, port);
+ case "reducemean": return new Mean(modelName, node, inputs, port);
+ case "mul": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, node, inputs, port, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, node, inputs, port, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, node, inputs, port);
+ case "where3": return new Select(modelName, node, inputs, port);
+ case "sigmoid": return new Map(modelName, node, inputs, port, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, node, inputs, port, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, node, inputs, port, ScalarFunctions.subtract());
+
+ // nn ops
+ case "biasadd": return new Join(modelName, node, inputs, port, ScalarFunctions.add());
+ case "elu": return new Map(modelName, node, inputs, port, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, node, inputs, port, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, node, inputs, port, ScalarFunctions.selu());
+
+ // state ops
+ case "variable": return new Variable(modelName, node, inputs, port);
+ case "variablev2": return new Variable(modelName, node, inputs, port);
+
+ // evaluation no-ops
+ case "stopgradient":return new Identity(modelName, node, inputs, port);
+ case "noop": return new NoOp(modelName, node, inputs, port);
+ }
+
+ TensorFlowOperation op = new NoOp(modelName, node, inputs, port);
+ op.warning("Operation '" + node.getOp() + "' is currently not implemented");
+ return op;
+ }
+
+}
+
+
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
new file mode 100644
index 00000000000..03a65333192
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -0,0 +1,255 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorTypeParser;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A Vespa tensor type is ordered by the lexicographical ordering of dimension
+ * names. TensorFlow tensors have an explicit ordering of their dimensions.
+ * During import, we need to track the Vespa dimension that matches the
+ * corresponding TensorFlow dimension as the ordering can change after
+ * dimension renaming. That is the purpose of this class.
+ *
+ * @author lesters
+ */
+public class OrderedTensorType {
+
+ private final TensorType type;
+ private final List<TensorType.Dimension> dimensions;
+
+ private final long[] innerSizesTensorFlow;
+ private final long[] innerSizesVespa;
+ private final int[] dimensionMap;
+
+ private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ this.dimensions = Collections.unmodifiableList(dimensions);
+ this.type = new TensorType.Builder(dimensions).build();
+ this.innerSizesTensorFlow = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ public TensorType type() {
+ return this.type;
+ }
+
+ public int rank() { return dimensions.size(); }
+
+ public List<TensorType.Dimension> dimensions() {
+ return dimensions;
+ }
+
+ public List<String> dimensionNames() {
+ return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
+ private int[] createDimensionMap() {
+ int numDimensions = dimensions.size();
+ if (numDimensions == 0) {
+ return null;
+ }
+ innerSizesTensorFlow[numDimensions - 1] = 1;
+ innerSizesVespa[numDimensions - 1] = 1;
+ for (int i = numDimensions - 1; --i >= 0; ) {
+ innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
+ innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
+ }
+ int[] mapping = new int[numDimensions];
+ for (int i = 0; i < numDimensions; ++i) {
+ TensorType.Dimension dim1 = dimensions().get(i);
+ for (int j = 0; j < numDimensions; ++j) {
+ TensorType.Dimension dim2 = type.dimensions().get(j);
+ if (dim1.equals(dim2)) {
+ mapping[i] = j;
+ break;
+ }
+ }
+ }
+ return mapping;
+ }
+
+ /**
+ * When dimension ordering between Vespa and TensorFlow differs, i.e.
+ * after dimension renaming, use the dimension map to read in values
+ * so that they are correctly laid out in memory for Vespa.
+ * Used when importing tensors from TensorFlow.
+ */
+ public int toDirectIndex(int index) {
+ if (dimensions.size() == 0) {
+ return 0;
+ }
+ if (dimensionMap == null) {
+ throw new IllegalArgumentException("Dimension map is not available");
+ }
+ int directIndex = 0;
+ long rest = index;
+ for (int i = 0; i < dimensions.size(); ++i) {
+ long address = rest / innerSizesTensorFlow[i];
+ directIndex += innerSizesVespa[dimensionMap[i]] * address;
+ rest %= innerSizesTensorFlow[i];
+ }
+ return directIndex;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof OrderedTensorType)) {
+ return false;
+ }
+ OrderedTensorType other = (OrderedTensorType) obj;
+ if (dimensions.size() != dimensions.size()) {
+ return false;
+ }
+ List<TensorType.Dimension> thisDimensions = this.dimensions();
+ List<TensorType.Dimension> otherDimensions = other.dimensions();
+ for (int i = 0; i < thisDimensions.size(); ++i) {
+ if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public void verifyType(NodeDef node) {
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
+ int vespaIndex = dimensionMap[tensorFlowIndex];
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public OrderedTensorType rename(DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ for (TensorType.Dimension dimension : dimensions) {
+ String oldName = dimension.name();
+ Optional<String> newName = renamer.dimensionNameOf(oldName);
+ if (!newName.isPresent())
+ return this; // presumably, already renamed
+ TensorType.Dimension.Type dimensionType = dimension.type();
+ if (dimensionType == TensorType.Dimension.Type.indexedBound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
+ } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
+ } else if (dimensionType == TensorType.Dimension.Type.mapped) {
+ renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
+ }
+ }
+ return new OrderedTensorType(renamedDimensions);
+ }
+
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ Builder builder = new Builder(node);
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static class Builder {
+
+ private final TensorShapeProto shape;
+ private final List<TensorType.Dimension> dimensions;
+
+ public Builder(NodeDef node) {
+ this.shape = tensorFlowShape(node);
+ this.dimensions = new ArrayList<>(shape.getDimCount());
+ }
+
+ public Builder add(TensorType.Dimension vespaDimension) {
+ int index = dimensions.size();
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
+ long size = tensorFlowDimension.getSize();
+ if (size >= 0) {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ if (!vespaDimension.size().isPresent()) {
+ throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
+ "not have a size");
+ }
+ if (vespaDimension.size().get() != size) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension sizes. TensorFlow: " + size + " Vespa: " +
+ vespaDimension.size().get());
+ }
+ } else {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ }
+ this.dimensions.add(vespaDimension);
+ return this;
+ }
+
+ public OrderedTensorType build() {
+ return new OrderedTensorType(dimensions);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
index d2d0acfc964..3f55e622fdf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
index 1b8c62fe0e9..4f5d61d75f9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
@@ -1,37 +1,38 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class ConcatV2 extends IntermediateOperation {
+public class ConcatV2 extends TensorFlowOperation {
private String concatDimensionName;
- public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
return null;
}
- IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
"concat dimension must be a constant.");
}
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
"concat dimension must be a scalar.");
}
@@ -43,7 +44,7 @@ public class ConcatV2 extends IntermediateOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
"inputs must have save rank.");
}
for (int j = 0; j < aType.rank(); ++j) {
@@ -52,13 +53,13 @@ public class ConcatV2 extends IntermediateOperation {
if (j == concatDim) {
concatDimSize += dimSizeB;
} else if (dimSizeA != dimSizeB) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
"input dimension " + j + " differs in input tensors.");
}
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDim) {
@@ -74,7 +75,7 @@ public class ConcatV2 extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) {
return null;
}
TensorFunction result = inputs.get(0).function().get();
@@ -87,7 +88,7 @@ public class ConcatV2 extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
return;
}
OrderedTensorType a = inputs.get(0).type().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
index 3c0f8569c47..718e2a4b3c2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -1,38 +1,36 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Const extends IntermediateOperation {
+public class Const extends TensorFlowOperation {
- private final AttributeMap attributeMap;
-
- public Const(String modelName,
- String nodeName,
- List<IntermediateOperation> inputs,
- AttributeMap attributeMap,
- OrderedTensorType type) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
- this.type = type.rename(vespaName() + "_");
+ public Const(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
setConstantValue(value());
}
@Override
protected OrderedTensorType lazyGetType() {
- return type;
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
}
@Override
@@ -57,7 +55,7 @@ public class Const extends IntermediateOperation {
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + super.vespaName();
+ return modelName() + "_" + super.vespaName();
}
@Override
@@ -79,11 +77,24 @@ public class Const extends IntermediateOperation {
}
private Value value() {
- Optional<Value> value = attributeMap.get("value", type);
- if ( ! value.isPresent()) {
- throw new IllegalArgumentException("Node '" + name + "' of type " +
- "const has missing or non-recognized 'value' attribute");
+ if ( ! node.getAttrMap().containsKey("value")) {
+ throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
+ "const has missing 'value' attribute");
+ }
+ AttrValue attrValue = node.getAttrMap().get("value");
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return new BooleanValue(attrValue.getB());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return new DoubleValue(attrValue.getI());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return new DoubleValue(attrValue.getF());
}
- return value.get();
+ throw new IllegalArgumentException("Requesting value of constant in " +
+ node.getName() + " but type is not recognized.");
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
index 742ed8b89ab..2d0f4c7042b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -12,17 +12,18 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
-public class ExpandDims extends IntermediateOperation {
+public class ExpandDims extends TensorFlowOperation {
private List<String> expandDimensions;
- public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public ExpandDims(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -31,14 +32,14 @@ public class ExpandDims extends IntermediateOperation {
return null;
}
- IntermediateOperation axisOperation = inputs().get(1);
+ TensorFlowOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
"axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
"axis argument must be a scalar.");
}
@@ -48,7 +49,7 @@ public class ExpandDims extends IntermediateOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
index d29bd4b7a9e..1408e7e04f0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
@@ -1,21 +1,22 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Identity extends IntermediateOperation {
+public class Identity extends TensorFlowOperation {
- public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public Identity(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
- return modelName + "_" + super.vespaName();
+ return modelName() + "_" + super.vespaName();
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
new file mode 100644
index 00000000000..6cbfe0dfb05
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
@@ -0,0 +1,145 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends TensorFlowOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
+ super(modelName, node, inputs, port);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+
+ // Well now we have potentially entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // In broadcasting, the size of each dimension is compared element-wise,
+ // starting with the trailing dimensions and working forward. A special
+ // case occurs when the size of one dimension is 1, while the other is not.
+ // Then the dimension with size 1 is "stretched" to be of compatible size.
+ //
+ // An example:
+ //
+ // Tensor A: d0[5], d1[1], d2[3], d3[1]
+ // Tensor B: d1[4], d2[1], d3[2]
+ //
+ // In TensorFlow and using the above rules of broadcasting, the resulting
+ // type is:
+ // d0[5], d1[4], d2[3], d2[2]
+ //
+ // However, in Vespa's tensor logic, the join of the two above tensors would
+ // result in a tensor of type:
+ // d0[5], d1[1], d2[1], d3[1]
+ //
+ // By reducing the dimensions of size 1 in each tensor before joining,
+ // we get equal results as in TensorFlow.
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < a.rank(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ long size = aDim.size().orElse(-1L);
+
+ if (i - sizeDifference >= 0) {
+ TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
+ size = Math.max(size, bDim.size().orElse(-1L));
+ }
+
+ if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name(), size));
+ } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name()));
+ } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
+ builder.add(TensorType.Dimension.mapped(aDim.name()));
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ TensorFlowOperation a = largestInput();
+ TensorFlowOperation b = smallestInput();
+
+ List<String> aDimensionsToReduce = new ArrayList<>();
+ List<String> bDimensionsToReduce = new ArrayList<>();
+ int sizeDifference = a.type().get().rank() - b.type().get().rank();
+ for (int i = 0; i < b.type().get().rank(); ++i) {
+ TensorType.Dimension bDim = b.type().get().dimensions().get(i);
+ TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
+ long bSize = bDim.size().orElse(-1L);
+ long aSize = aDim.size().orElse(-1L);
+ if (bSize == 1L && aSize != 1L) {
+ bDimensionsToReduce.add(bDim.name());
+ }
+ if (aSize == 1L && bSize != 1L) {
+ aDimensionsToReduce.add(bDim.name());
+ }
+ }
+
+ TensorFunction aReducedFunction = a.function().get();
+ if (aDimensionsToReduce.size() > 0) {
+ aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ }
+ TensorFunction bReducedFunction = b.function().get();
+ if (bDimensionsToReduce.size() > 0) {
+ bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ }
+
+ return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ private TensorFlowOperation largestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+ private TensorFlowOperation smallestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
index f54ae83052f..c015f5ecba8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
@@ -1,19 +1,20 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;
-public class Map extends IntermediateOperation {
+public class Map extends TensorFlowOperation {
private final DoubleUnaryOperator operator;
- public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
- super(modelName, nodeName, inputs);
+ public Map(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
+ super(modelName, node, inputs, port);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
new file mode 100644
index 00000000000..b2b9530a161
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
@@ -0,0 +1,74 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Matmul extends TensorFlowOperation {
+
+ public Matmul(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
+ typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() < 2 || bType.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (aType.type().rank() != bType.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+ return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
index 95a77c07590..3eba872c6a0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
@@ -1,10 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
@@ -14,20 +13,20 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
-public class Mean extends IntermediateOperation {
+public class Mean extends TensorFlowOperation {
- private final AttributeMap attributeMap;
private List<String> reduceDimensions;
- public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
+ public Mean(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -35,9 +34,9 @@ public class Mean extends IntermediateOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- IntermediateOperation reductionIndices = inputs.get(1);
+ TensorFlowOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + name + ": " +
+ throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
"reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
@@ -55,7 +54,7 @@ public class Mean extends IntermediateOperation {
return reducedType(inputType, shouldKeepDimensions());
}
- // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+ // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
@@ -94,12 +93,12 @@ public class Mean extends IntermediateOperation {
}
private boolean shouldKeepDimensions() {
- Optional<Value> keepDims = attributeMap.get("keep_dims");
- return keepDims.isPresent() && keepDims.get().asBoolean();
+ AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
+ return keepDimsAttr != null && keepDimsAttr.getB();
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if (!reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
index 9d9eca47b1c..4c95e67e184 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
@@ -1,20 +1,21 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Merge extends IntermediateOperation {
+public class Merge extends TensorFlowOperation {
- public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public Merge(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
protected OrderedTensorType lazyGetType() {
- for (IntermediateOperation operation : inputs) {
+ for (TensorFlowOperation operation : inputs) {
if (operation.type().isPresent()) {
return operation.type().get();
}
@@ -24,7 +25,7 @@ public class Merge extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- for (IntermediateOperation operation : inputs) {
+ for (TensorFlowOperation operation : inputs) {
if (operation.function().isPresent()) {
return operation.function().get();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
new file mode 100644
index 00000000000..d558ec89e87
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends TensorFlowOperation {
+
+ public NoOp(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, Collections.emptyList(), port); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
index 7fc2aae87d1..1619c11427a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
@@ -1,29 +1,28 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
-import java.util.Collections;
import java.util.List;
-public class Argument extends IntermediateOperation {
+public class Placeholder extends TensorFlowOperation {
private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
- public Argument(String modelName, String nodeName, OrderedTensorType type) {
- super(modelName, nodeName, Collections.emptyList());
- this.type = type.rename(vespaName() + "_");
- standardNamingType = OrderedTensorType.standardType(type);
+ public Placeholder(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
+ standardNamingType = OrderedTensorType.fromTensorFlowType(node);
}
@Override
protected OrderedTensorType lazyGetType() {
- return type;
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
index 9299ae9be12..65ce7f00e34 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
@@ -1,16 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class PlaceholderWithDefault extends IntermediateOperation {
+public class PlaceholderWithDefault extends TensorFlowOperation {
- public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public PlaceholderWithDefault(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
index e91c2305f7d..e7d90e5fc1f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
@@ -1,9 +1,10 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -18,18 +19,19 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
-public class Reshape extends IntermediateOperation {
+public class Reshape extends TensorFlowOperation {
- public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public Reshape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -37,15 +39,15 @@ public class Reshape extends IntermediateOperation {
if (!allInputTypesPresent(2)) {
return null;
}
- IntermediateOperation newShape = inputs.get(1);
+ TensorFlowOperation newShape = inputs.get(1);
if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + name + ": " +
+ throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
"shape input must be a constant.");
}
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -122,7 +124,7 @@ public class Reshape extends IntermediateOperation {
operators.add(0, ArithmeticOperator.MULTIPLY);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
- size *= OrderedTensorType.dimensionSize(dimension);
+ size *= TensorConverter.dimensionSize(dimension);
if (i > 0) {
operators.add(0, ArithmeticOperator.PLUS);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
index 927a4a368f9..5fdcb5a695f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
@@ -1,23 +1,24 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.dimensionSize;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
-public class Select extends IntermediateOperation {
+public class Select extends TensorFlowOperation {
- public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public Select(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -38,7 +39,7 @@ public class Select extends IntermediateOperation {
if (!allInputFunctionsPresent(3)) {
return null;
}
- IntermediateOperation conditionOperation = inputs().get(0);
+ TensorFlowOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
index da566909adc..af49d2c108b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
@@ -1,19 +1,20 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
-public class Shape extends IntermediateOperation {
+public class Shape extends TensorFlowOperation {
- public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
+ public Shape(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
createConstantValue();
}
@@ -23,7 +24,7 @@ public class Shape extends IntermediateOperation {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder()
+ return new OrderedTensorType.Builder(node)
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
index c750c47e27e..17ce9e8b7cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
@@ -1,26 +1,26 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
-public class Squeeze extends IntermediateOperation {
+public class Squeeze extends TensorFlowOperation {
- private final AttributeMap attributeMap;
private List<String> squeezeDimensions;
- public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
+ public Squeeze(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -31,21 +31,20 @@ public class Squeeze extends IntermediateOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
- Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
- if ( ! squeezeDimsAttr.isPresent()) {
+ AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
+ if (squeezeDimsAttr == null) {
squeezeDimensions = inputType.type().dimensions().stream().
- filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
} else {
- squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
+ squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
- map(i -> inputType.type().dimensions().get(i)).
- filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
+ map(i -> inputType.type().dimensions().get(i.intValue())).
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
collect(Collectors.toList());
}
-
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
@@ -73,7 +72,7 @@ public class Squeeze extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
index 0171d1ea171..de4d8862fd6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
@@ -1,19 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.List;
import java.util.Optional;
-public class Switch extends IntermediateOperation {
+public class Switch extends TensorFlowOperation {
- private final int port;
-
- public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
- super(modelName, nodeName, inputs);
- this.port = port;
+ public Switch(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
}
@Override
@@ -23,7 +21,7 @@ public class Switch extends IntermediateOperation {
}
Optional<OrderedTensorType> predicate = inputs.get(1).type();
if (predicate.get().type().rank() != 0) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
"predicate must be a scalar");
}
return inputs.get(0).type().orElse(null);
@@ -31,13 +29,13 @@ public class Switch extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- IntermediateOperation predicateOperation = inputs().get(1);
+ TensorFlowOperation predicateOperation = inputs().get(1);
if (!predicateOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
"predicate must be a constant");
}
if (port < 0 || port > 1) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
"choice should be boolean");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index 43de29cedd5..3687bba8b85 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -1,16 +1,17 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
@@ -19,40 +20,43 @@ import java.util.Optional;
import java.util.function.Function;
/**
- * Wraps an imported operation node and produces the respective Vespa tensor
- * operation. During import, a graph of these operations are constructed. Then,
- * the types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper Vespa
- * expressions can be extracted.
+ * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
+ * During import, a graph of these operations are constructed. Then, the
+ * types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper
+ * Vespa expressions can be extracted.
*
* @author lesters
*/
-public abstract class IntermediateOperation {
+public abstract class TensorFlowOperation {
+
+ protected final static String MACRO_PREFIX = "tf_macro_";
- private final static String MACRO_PREFIX = "imported_ml_macro_";
+ private final String modelName;
- protected final String name;
- protected final String modelName;
- protected final List<IntermediateOperation> inputs;
- protected final List<IntermediateOperation> outputs = new ArrayList<>();
+ protected final NodeDef node;
+ protected final int port;
+ protected final List<TensorFlowOperation> inputs;
+ protected final List<TensorFlowOperation> outputs = new ArrayList<>();
+ protected final List<String> importWarnings = new ArrayList<>();
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction macro = null;
- private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
- private List<IntermediateOperation> controlInputs = Collections.emptyList();
+ private List<TensorFlowOperation> controlInputs = Collections.emptyList();
- protected Function<OrderedTensorType, Value> constantValueFunction = null;
-
- IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
- this.name = name;
+ TensorFlowOperation(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
this.modelName = modelName;
+ this.node = node;
+ this.port = port;
this.inputs = Collections.unmodifiableList(inputs);
this.inputs.forEach(i -> i.outputs.add(this));
}
+ protected String modelName() { return modelName; }
+
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
@@ -61,6 +65,9 @@ public abstract class IntermediateOperation {
if (type == null) {
type = lazyGetType();
}
+ if (type != null) {
+ type.verifyType(node);
+ }
return Optional.ofNullable(type);
}
@@ -80,14 +87,14 @@ public abstract class IntermediateOperation {
return Optional.ofNullable(function);
}
- /** Returns original name of this operation node */
- public String name() { return name; }
+ /** Return TensorFlow node */
+ public NodeDef node() { return node; }
/** Return unmodifiable list of inputs */
- public List<IntermediateOperation> inputs() { return inputs; }
+ public List<TensorFlowOperation> inputs() { return inputs; }
/** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
- public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
+ public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a Vespa ranking expression that should be added as a macro */
public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
@@ -102,34 +109,22 @@ public abstract class IntermediateOperation {
public boolean isInput() { return false; }
/** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
+ public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
/** Sets the constant value */
public void setConstantValue(Value value) { constantValue = value; }
/** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() {
- if (constantValue != null) {
- return Optional.of(constantValue);
- }
- if (constantValueFunction != null) {
- return Optional.of(constantValueFunction.apply(type));
- }
- return Optional.empty();
- }
-
- /** Set the constant value function */
- public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
+ public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
/** Sets the external control inputs */
- public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
+ public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
/** Retrieve the control inputs for this operation */
- public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+ public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
+ public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
/** Retrieve the valid Vespa name of this node if it is a macro */
public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; }
@@ -140,48 +135,23 @@ public abstract class IntermediateOperation {
/** Set an input warning */
public void warning(String warning) { importWarnings.add(warning); }
- boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
+ boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
+ if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
+ return false;
+ }
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + name + "', got " + inputs.size());
+ "for '" + node.getName() + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}
boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, IntermediateOperation::type);
+ return verifyInputs(expected, TensorFlowOperation::type);
}
boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, IntermediateOperation::function);
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- public static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output index part. Indexes are used for nodes with
- * multiple outputs.
- */
- public static int indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
- /**
- * An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in ONNX/TensorFlow.
- */
- public interface AttributeMap {
- Optional<Value> get(String key);
- Optional<Value> get(String key, OrderedTensorType type);
- Optional<List<Value>> getList(String key);
+ return verifyInputs(expected, TensorFlowOperation::function);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
new file mode 100644
index 00000000000..b18a8a9b212
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
@@ -0,0 +1,46 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Variable extends TensorFlowOperation {
+
+ public Variable(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName() + "_" + super.vespaName();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, super.vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
new file mode 100644
index 00000000000..9e53990a9d6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/package-info.java
@@ -0,0 +1,8 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+/**
+ * Tensorflow integration
+ */
+@ExportPackage
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
index a7926cd2e02..4b68cd40a08 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -1,9 +1,11 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -22,7 +24,7 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- ImportedModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
+ OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
@@ -46,7 +48,7 @@ public class OnnxMnistSoftmaxImportTestCase {
model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.defaultSignature().outputExpression("add");
+ RankingExpression output = model.outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
@@ -66,12 +68,13 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new TensorFlowImporter().importModel("test", path);
+ SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
+ TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new OnnxImporter().importModel("test", path);
+ OnnxModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}
@@ -80,7 +83,14 @@ public class OnnxMnistSoftmaxImportTestCase {
return expression.evaluate(context).asTensor();
}
- private Context contextFrom(ImportedModel result) {
+ private Context contextFrom(TensorFlowModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Context contextFrom(OnnxModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
index bf9684082f4..0f5eec93feb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/BatchNormImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -15,7 +15,7 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
- ImportedModel.Signature signature = model.get().signature("serving_default");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
index c8c7ec798bb..74b0d11f1d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import org.junit.Test;
import static org.junit.Assert.assertTrue;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index a63c7346335..50a467ec581 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,7 @@ public class DropoutImportTestCase {
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
model.get().requiredMacros().get("X"));
- ImportedModel.Signature signature = model.get().signature("serving_default");
+ TensorFlowModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(tf_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), tf_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index bd7644be23b..9f919c452d6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -1,5 +1,5 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
@@ -45,7 +45,7 @@ public class MnistSoftmaxImportTestCase {
// Check signatures
assertEquals(1, model.get().signatures().size());
- ImportedModel.Signature signature = model.get().signatures().get("serving_default");
+ TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
index b2443082ab1..beec2ab1ead 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
@@ -1,6 +1,6 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 723c5f27914..7ca16939477 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -1,11 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -28,7 +28,7 @@ import static org.junit.Assert.assertEquals;
public class TestableTensorFlowModel {
private SavedModelBundle tensorFlowModel;
- private ImportedModel model;
+ private TensorFlowModel model;
// Sizes of the input vector
private final int d0Size = 1;
@@ -39,7 +39,7 @@ public class TestableTensorFlowModel {
model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
}
- public ImportedModel get() { return model; }
+ public TensorFlowModel get() { return model; }
public void assertEqualResult(String inputName, String operationName) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
@@ -66,7 +66,7 @@ public class TestableTensorFlowModel {
return TensorConverter.toVespaTensor(results.get(0));
}
- private Context contextFrom(ImportedModel result) {
+ private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
@@ -81,7 +81,7 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, ImportedModel model, String macroName) {
+ private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
@@ -89,7 +89,7 @@ public class TestableTensorFlowModel {
}
}
- private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.macros().containsKey(name)) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
index f94098e6255..051c2c60c95 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
@@ -1,4 +1,4 @@
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import org.junit.Test;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 3a66eef258d..944755c9db2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -22,37 +22,22 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleBinaryOperator greater() { return new Greater(); }
- public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
- public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
- public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
- public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
- public static DoubleUnaryOperator asin() { return new Asin(); }
- public static DoubleUnaryOperator atan() { return new Atan(); }
- public static DoubleUnaryOperator ceil() { return new Ceil(); }
- public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
- public static DoubleUnaryOperator log() { return new Log(); }
- public static DoubleUnaryOperator neg() { return new Neg(); }
- public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
- public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static DoubleUnaryOperator tan() { return new Tan(); }
- public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
@@ -74,20 +59,6 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Greater implements DoubleBinaryOperator {
- @Override
- public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
- @Override
- public String toString() { return "f(a,b)(a > b)"; }
- }
-
- public static class Less implements DoubleBinaryOperator {
- @Override
- public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
- @Override
- public String toString() { return "f(a,b)(a < b)"; }
- }
-
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@@ -102,13 +73,6 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(min(a, b))"; }
}
- public static class Mean implements DoubleBinaryOperator {
- @Override
- public double applyAsDouble(double left, double right) { return (left + right) / 2; }
- @Override
- public String toString() { return "f(a,b)((a + b) / 2)"; }
- }
-
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@@ -116,13 +80,6 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a * b)"; }
}
- public static class Pow implements DoubleBinaryOperator {
- @Override
- public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
- @Override
- public String toString() { return "f(a,b)(pow(a, b))"; }
- }
-
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@@ -147,13 +104,6 @@ public class ScalarFunctions {
// Unary operators ------------------------------------------------------------------------------
- public static class Abs implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.abs(operand); }
- @Override
- public String toString() { return "f(a)(fabs(a))"; }
- }
-
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@@ -161,34 +111,6 @@ public class ScalarFunctions {
public String toString() { return "f(a)(acos(a))"; }
}
- public static class Asin implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.asin(operand); }
- @Override
- public String toString() { return "f(a)(asin(a))"; }
- }
-
- public static class Atan implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.atan(operand); }
- @Override
- public String toString() { return "f(a)(atan(a))"; }
- }
-
- public static class Ceil implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.ceil(operand); }
- @Override
- public String toString() { return "f(a)(ceil(a))"; }
- }
-
- public static class Cos implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.cos(operand); }
- @Override
- public String toString() { return "f(a)(cos(a))"; }
- }
-
public static class Elu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
@@ -210,26 +132,6 @@ public class ScalarFunctions {
public String toString() { return "f(a)(floor(a))"; }
}
- public static class Log implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.log(operand); }
- @Override
- public String toString() { return "f(a)(log(a))"; }
- }
-
- public static class Neg implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return -operand; }
- @Override
- public String toString() { return "f(a)(-a)"; }
- }
-
- public static class Reciprocal implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return 1.0 / operand; }
- @Override
- public String toString() { return "f(a)(1 / a)"; }
- }
public static class Relu implements DoubleUnaryOperator {
@Override
@@ -248,13 +150,6 @@ public class ScalarFunctions {
public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); }
}
- public static class Sin implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.sin(operand); }
- @Override
- public String toString() { return "f(a)(sin(a))"; }
- }
-
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@@ -277,29 +172,15 @@ public class ScalarFunctions {
}
public static class Square implements DoubleUnaryOperator {
+
@Override
public double applyAsDouble(double operand) { return operand * operand; }
- @Override
- public String toString() { return "f(a)(a * a)"; }
- }
- public static class Tan implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.tan(operand); }
@Override
- public String toString() { return "f(a)(tan(a))"; }
- }
+ public String toString() { return "f(a)(a * a)"; }
- public static class Tanh implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.tanh(operand); }
- @Override
- public String toString() { return "f(a)(tanh(a))"; }
}
-
-
-
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function<List<Long>, Double> {