summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-21 19:21:22 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-21 19:21:22 +0100
commit99ca9b2907ff637fc6e4e0a61860923ac1c9dee5 (patch)
treed5a5e408d56e9165cd716e9531ab9bcec6a29e4a /searchlib/src/main/java/com
parent61cae2609740b51c180b2f507b5e4d0eb399fedc (diff)
Separate model integration into a separate module
This allows us to access model importers (such as TensorFlow) in config models without loading one instance per config model instance, which is not possible with TensorFlow because it depends on JNI code.
Diffstat (limited to 'searchlib/src/main/java/com')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java39
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java61
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java61
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java42
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java)26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java)22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java)6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java216
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java52
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java)4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java86
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java234
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java225
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java77
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java77
23 files changed, 33 insertions, 1364 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
deleted file mode 100644
index e6bb5f40b3f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.IOException;
-
-/**
- * Converts a ONNX model into a ranking expression and set of constants.
- *
- * @author lesters
- */
-public class OnnxImporter extends ModelImporter {
-
- @Override
- public boolean canImport(String modelPath) {
- File modelFile = new File(modelPath);
- if ( ! modelFile.isFile()) return false;
-
- return modelFile.toString().endsWith(".onnx");
- }
-
- @Override
- public ImportedModel importModel(String modelName, String modelPath) {
- try (FileInputStream inputStream = new FileInputStream(modelPath)) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelPath);
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
deleted file mode 100644
index 7c18e04bae7..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
-import org.tensorflow.SavedModelBundle;
-
-import java.io.File;
-import java.io.IOException;
-
-/**
- * Converts a saved TensorFlow model into a ranking expression and set of constants.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorFlowImporter extends ModelImporter {
-
- @Override
- public boolean canImport(String modelPath) {
- File modelDir = new File(modelPath);
- if ( ! modelDir.isDirectory()) return false;
-
- // No other model types are stored in protobuf files thus far
- for (File file : modelDir.listFiles()) {
- if (file.toString().endsWith(".pbtxt")) return true;
- if (file.toString().endsWith(".pb")) return true;
- }
- return false;
- }
-
- /**
- * Imports a saved TensorFlow model from a directory.
- * The model should be saved as a .pbtxt or .pb file.
- *
- * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
- * @param modelDir the directory containing the TensorFlow model files to import
- */
- @Override
- public ImportedModel importModel(String modelName, String modelDir) {
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(modelName, modelDir, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- /** Imports a TensorFlow model */
- ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
- try {
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelDir);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
deleted file mode 100644
index 25bac27f315..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/VariableConverter.java
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
-import com.yahoo.tensor.serialization.JsonFormat;
-import com.yahoo.yolean.Exceptions;
-import org.tensorflow.SavedModelBundle;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * Converts TensorFlow Variables to the Vespa document format.
- * Intended to be used from the command line to convert trained tensors to document form.
- *
- * @author bratseth
- */
-public class VariableConverter {
-
- /**
- * Reads the tensor with the given TensorFlow name at the given model location,
- * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type.
- * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
- * tensor dimensions are implicitly ordered.
- */
- public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
- try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
- bundle),
- OrderedTensorType.fromSpec(orderedTypeSpec)));
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- public static void main(String[] args) {
- if ( args.length != 3) {
- System.out.println("Converts a TensorFlow variable into Vespa tensor document field value JSON:");
- System.out.println("A JSON map containing a 'cells' array, see");
- System.out.println("http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor)");
- System.out.println("");
- System.out.println("Arguments: modelDirectory tensorFlowVariableName orderedTypeSpec");
- System.out.println(" - modelDirectory: The directory of the TensorFlow SavedModel");
- System.out.println(" - tensorFlowVariableName: The name of the TensorFlow variable to convert");
- System.out.println(" - orderedTypeSpec: The tensor type, e.g tensor(b[],a[10]), where dimensions are ");
- System.out.println(" ordered as given in the deployment log message starting by ");
- System.out.println(" 'Importing TensorFlow variable'");
- return;
- }
-
- try {
- System.out.println(new String(importVariable(args[0], args[1], args[2]), StandardCharsets.UTF_8));
- }
- catch (Exception e) {
- System.err.println("Import failed: " + Exceptions.toMessageString(e));
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
deleted file mode 100644
index 725f319a839..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-
-import java.io.File;
-import java.io.IOException;
-
-/**
- * Converts a saved XGBoost model into a ranking expression.
- *
- * @author grace-lam
- * @author bratseth
- */
-public class XGBoostImporter extends ModelImporter {
-
- @Override
- public boolean canImport(String modelPath) {
- File modelFile = new File(modelPath);
- if ( ! modelFile.isFile()) return false;
-
- return modelFile.toString().endsWith(".json"); // No other models ends by json yet
- }
-
- @Override
- public ImportedModel importModel(String modelName, String modelPath) {
- try {
- ImportedModel model = new ImportedModel(modelName, modelPath);
- XGBoostParser parser = new XGBoostParser(modelPath);
- RankingExpression expression = new RankingExpression(parser.toRankingExpression());
- model.expression(modelName, expression);
- return model;
- } catch (IOException e) {
- throw new IllegalArgumentException("Could not import XGBoost model from '" + modelPath + "'", e);
- } catch (ParseException e) {
- throw new IllegalArgumentException("Could not parse ranking expression resulting from '" + modelPath + "'", e);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
index 59ec66b7209..854a5202916 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModel.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
@@ -10,13 +9,11 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.regex.Pattern;
@@ -64,8 +61,7 @@ public class ImportedModel {
/**
* Returns an immutable map of the small constants of this.
- * These should have sizes up to a few kb at most, and correspond to constant
- * values given in the TensorFlow or ONNX source.
+ * These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
@@ -93,18 +89,18 @@ public class ImportedModel {
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
/** Returns the given signature. If it does not already exist it is added to this. */
- Signature signature(String name) {
+ public Signature signature(String name) {
return signatures.computeIfAbsent(name, Signature::new);
}
/** Convenience method for returning a default signature */
- Signature defaultSignature() { return signature(defaultSignatureName); }
+ public Signature defaultSignature() { return signature(defaultSignatureName); }
- void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
- void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- void function(String name, RankingExpression expression) { functions.put(name, expression); }
+ public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
+ public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ public void function(String name, RankingExpression expression) { functions.put(name, expression); }
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
@@ -116,9 +112,9 @@ public class ImportedModel {
List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(),
signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ .withName(signatureEntry.getKey() + "" + outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
index 40d1ca8030a..55f1eef741c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ImportedModels.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
-import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
@@ -24,27 +23,26 @@ public class ImportedModels {
/** All imported models, indexed by their names */
private final ImmutableMap<String, ImportedModel> importedModels;
- private static final ImmutableList<ModelImporter> importers =
- ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), new XGBoostImporter());
-
/** Create a null imported models */
public ImportedModels() {
importedModels = ImmutableMap.of();
}
- public ImportedModels(File modelsDirectory) {
+ public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
Map<String, ImportedModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models);
+ importRecursively(modelsDirectory, models, importers);
importedModels = ImmutableMap.copyOf(models);
}
- private static void importRecursively(File dir, Map<String, ImportedModel> models) {
+ private static void importRecursively(File dir,
+ Map<String, ImportedModel> models,
+ Collection<ModelImporter> importers) {
if ( ! dir.isDirectory()) return;
Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child);
+ Optional<ModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
ImportedModel existing = models.get(name);
@@ -54,12 +52,12 @@ public class ImportedModels {
models.put(name, importer.get().importModel(name, child));
}
else {
- importRecursively(child, models);
+ importRecursively(child, models, importers);
}
});
}
- private static Optional<ModelImporter> findImporterOf(File path) {
+ private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
}
@@ -93,7 +91,7 @@ public class ImportedModels {
}
private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf(".");
+ int dotIndex = path.last().lastIndexOf("");
if (dotIndex <= 0) return path;
return path.withLast(path.last().substring(0, dotIndex));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
index 481b7f9397a..1b6494e8ce8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/ModelImporter.java
@@ -1,11 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -47,7 +45,7 @@ public abstract class ModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
+ public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
graph.optimize();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
deleted file mode 100644
index 3fe92440cae..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ /dev/null
@@ -1,216 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import onnx.Onnx;
-
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
- * for generating Vespa ranking expressions.
- *
- * @author lesters
- */
-public class GraphImporter {
-
- public static IntermediateOperation mapOperation(Onnx.NodeProto node,
- List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
- String nodeName = node.getName();
- String modelName = graph.name();
-
- switch (node.getOpType().toLowerCase()) {
- case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
- case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
- case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
- case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
- case "concat": return new ConcatV2(modelName, nodeName, inputs);
- case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
- case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
- case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
- case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
- case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
- case "identity": return new Identity(modelName, nodeName, inputs);
- case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
- case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
- case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
- case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
- case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
- case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "shape": return new Shape(modelName, nodeName, inputs);
- case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
- case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
- case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
- case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
- }
-
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
- return op;
- }
-
- public static IntermediateGraph importGraph(String modelName, Onnx.ModelProto model) {
- Onnx.GraphProto onnxGraph = model.getGraph();
-
- IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
- importOperations(onnxGraph, intermediateGraph);
- verifyOutputTypes(onnxGraph, intermediateGraph);
-
- return intermediateGraph;
- }
-
- private static void importOperations(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
- for (Onnx.ValueInfoProto valueInfo : onnxGraph.getOutputList()) {
- importOperation(valueInfo.getName(), onnxGraph, intermediateGraph);
- }
- }
-
- private static IntermediateOperation importOperation(String name,
- Onnx.GraphProto onnxGraph,
- IntermediateGraph intermediateGraph) {
- if (intermediateGraph.alreadyImported(name)) {
- return intermediateGraph.get(name);
- }
- IntermediateOperation operation;
- if (isArgumentTensor(name, onnxGraph)) {
- Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
- if (valueInfoProto == null)
- throw new IllegalArgumentException("Could not find argument tensor: " + name);
- OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
- operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
-
- intermediateGraph.inputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
-
- } else if (isConstantTensor(name, onnxGraph)) {
- Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
- operation = new Constant(intermediateGraph.name(), name, defaultType);
- operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
-
- } else {
- Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
- List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
-
- if (isOutputNode(name, onnxGraph)) {
- intermediateGraph.outputs(intermediateGraph.defaultSignature())
- .put(IntermediateOperation.namePartOf(name), operation.vespaName());
- }
- }
- intermediateGraph.put(operation.vespaName(), operation);
-
- return operation;
- }
-
- private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor == null;
- }
-
- private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
- Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
- Onnx.TensorProto tensor = getConstantTensor(name, graph);
- return value != null && tensor != null;
- }
-
- private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
- for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
- if (tensorProto.getName().equals(name)) {
- return tensorProto;
- }
- }
- return null;
- }
-
- private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
- return getOutputNode(name, graph) != null;
- }
-
- private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
- for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- if (valueInfo.getName().equals(name)) {
- return valueInfo;
- }
- String nodeName = IntermediateOperation.namePartOf(valueInfo.getName());
- if (nodeName.equals(name)) {
- return valueInfo;
- }
- }
- return null;
- }
-
- private static List<IntermediateOperation> importOperationInputs(Onnx.NodeProto node,
- Onnx.GraphProto onnxGraph,
- IntermediateGraph intermediateGraph) {
- return node.getInputList().stream()
- .map(nodeName -> importOperation(nodeName, onnxGraph, intermediateGraph))
- .collect(Collectors.toList());
- }
-
- private static void verifyOutputTypes(Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) {
- for (String outputName : intermediateGraph.outputs(intermediateGraph.defaultSignature()).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, onnxGraph);
- OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
- TypeConverter.verifyType(onnxNode.getType(), type);
- }
- }
-
- private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
- boolean hasPortNumber = nodeName.contains(":");
- for (Onnx.NodeProto node : graph.getNodeList()) {
- if (hasPortNumber) {
- for (String outputName : node.getOutputList()) {
- if (outputName.equals(nodeName)) {
- return node;
- }
- }
- } else if (node.getName().equals(nodeName)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
deleted file mode 100644
index 18856d4a25f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TensorConverter.java
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
-
-import com.google.protobuf.ByteString;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import onnx.Onnx;
-
-import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
-import java.nio.FloatBuffer;
-
-/**
- * Converts Onnx tensors into Vespa tensors.
- *
- * @author lesters
- */
-public class TensorConverter {
-
- public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
- Values values = readValuesOf(tensorProto);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
- for (int i = 0; i < values.size(); i++) {
- builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
- }
- return builder.build();
- }
-
- private static Values readValuesOf(Onnx.TensorProto tensorProto) {
- if (tensorProto.hasRawData()) {
- switch (tensorProto.getDataType()) {
- case FLOAT: return new RawFloatValues(tensorProto);
- }
- } else {
- switch (tensorProto.getDataType()) {
- case FLOAT: return new FloatValues(tensorProto);
- }
- }
- throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tensorProto.getDataType() + " to a Vespa tensor");
- }
-
- /** Allows reading values from buffers of various numeric types as bytes */
- private static abstract class Values {
- abstract double get(int i);
- abstract int size();
- }
-
- private static abstract class RawValues extends Values {
- ByteBuffer bytes(Onnx.TensorProto tensorProto) {
- ByteString byteString = tensorProto.getRawData();
- return byteString.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
- }
- }
-
- private static class RawFloatValues extends RawValues {
- private final FloatBuffer values;
- private final int size;
- RawFloatValues(Onnx.TensorProto tensorProto) {
- values = bytes(tensorProto).asFloatBuffer();
- size = values.remaining();
- }
- @Override double get(int i) { return values.get(i); }
- @Override int size() { return size; }
- }
-
- private static class FloatValues extends Values {
- private final Onnx.TensorProto tensorProto;
- FloatValues(Onnx.TensorProto tensorProto) {
- this.tensorProto = tensorProto;
- }
- @Override double get(int i) { return tensorProto.getFloatData(i); }
- @Override int size() { return tensorProto.getFloatDataCount(); }
- }
-
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
deleted file mode 100644
index 715c55d8323..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import onnx.Onnx;
-
-/**
- * Converts and verifies ONNX tensor types into Vespa tensor types.
- *
- * @author lesters
- */
-public class TypeConverter {
-
- public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
- Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
- }
- for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
- int vespaIndex = type.dimensionMap(onnxIndex);
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
- }
- }
- }
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- if (onnxDimension.getDimValue() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 7fc2aae87d1..ab6a80a193a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 1b8c62fe0e9..1f479cd2e0b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 0eff8e8bc08..aab50f422be 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -180,7 +180,7 @@ public abstract class IntermediateOperation {
/**
* An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in ONNX/TensorFlow.
+ * Adapter for differences in different model types.
*/
public interface AttributeMap {
Optional<Value> get(String key);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index 8413ed74118..2d401469e40 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 95a77c07590..6ce9abf2ec9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -3,8 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e91c2305f7d..ff87412396d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -24,8 +24,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-
public class Reshape extends IntermediateOperation {
public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation {
int size = cell.getValue().intValue();
if (size < 0) {
size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- tensorSize(inputType.type()).intValue();
+ OrderedTensorType.tensorSize(inputType.type()).intValue();
}
outputTypeBuilder.add(TensorType.Dimension.indexed(
String.format("%s_%d", vespaName(), dimensionIndex), size));
@@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
index 1530754cc43..bb55ed768a6 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/package-info.java
@@ -1,8 +1,8 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
/**
- * ONNX integration
+ * Model integration
*/
@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.ml;
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer;
import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
deleted file mode 100644
index 89b75e8e3e2..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * Converts TensorFlow node attributes to Vespa attribute values.
- *
- * @author lesters
- */
-public class AttributeConverter implements IntermediateOperation.AttributeMap {
-
- private final Map<String, AttrValue> attributeMap;
-
- public AttributeConverter(NodeDef node) {
- attributeMap = node.getAttrMap();
- }
-
- public static AttributeConverter convert(NodeDef node) {
- return new AttributeConverter(node);
- }
-
- @Override
- public Optional<Value> get(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.empty(); // requires type
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return Optional.of(new BooleanValue(attrValue.getB()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return Optional.of(new DoubleValue(attrValue.getI()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return Optional.of(new DoubleValue(attrValue.getF()));
- }
- }
- return Optional.empty();
- }
-
- @Override
- public Optional<Value> get(String key, OrderedTensorType type) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type.type())));
- }
- }
- return get(key);
- }
-
- @Override
- public Optional<List<Value>> getList(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
- AttrValue.ListValue listValue = attrValue.getList();
- if ( ! listValue.getBList().isEmpty()) {
- return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getIList().isEmpty()) {
- return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getFList().isEmpty()) {
- return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- // add the rest
- }
- }
- return Optional.empty();
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
deleted file mode 100644
index e1b292f9e61..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ /dev/null
@@ -1,234 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Switch;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
- * for generating Vespa ranking expressions.
- *
- * @author lesters
- */
-public class GraphImporter {
-
- public static IntermediateOperation mapOperation(NodeDef node,
- List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
- String nodeName = node.getName();
- String modelName = graph.name();
- int nodePort = IntermediateOperation.indexPartOf(nodeName);
- OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
- AttributeConverter attributes = AttributeConverter.convert(node);
-
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
- case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
- case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
- case "identity": return new Identity(modelName, nodeName, inputs);
- case "placeholder": return new Argument(modelName, nodeName, nodeType);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "shape": return new Shape(modelName, nodeName, inputs);
- case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
-
- // control flow
- case "merge": return new Merge(modelName, nodeName, inputs);
- case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
-
- // math ops
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
- case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
- case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "mean": return new Mean(modelName, nodeName, inputs, attributes);
- case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
- case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, nodeName, inputs);
- case "where3": return new Select(modelName, nodeName, inputs);
- case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
-
- // state ops
- case "variable": return new Constant(modelName, nodeName, nodeType);
- case "variablev2": return new Constant(modelName, nodeName, nodeType);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, nodeName, inputs);
- case "noop": return new NoOp(modelName, nodeName, inputs);
-
- }
-
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
- public static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
- MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
-
- IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
- importSignatures(tfGraph, intermediateGraph);
- importOperations(tfGraph, intermediateGraph, bundle);
- verifyOutputTypes(tfGraph, intermediateGraph);
-
- return intermediateGraph;
- }
-
- private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- String nodeName = input.getValue().getName();
- intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
- }
- java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- String nodeName = output.getValue().getName();
- intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
- }
- }
- }
-
- private static void importOperations(MetaGraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
- }
- }
- }
-
- private static IntermediateOperation importOperation(String nodeName,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- if (intermediateGraph.alreadyImported(nodeName)) {
- return intermediateGraph.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
- List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
- IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
- intermediateGraph.put(nodeName, operation);
-
- List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- if (operation.isConstant()) {
- operation.setConstantValueFunction(
- type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
- }
-
- return operation;
- }
-
- private static List<IntermediateOperation> importOperationInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static List<IntermediateOperation> importControlInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
- for (NodeDef node : tfGraph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- public static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
- OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
- TypeConverter.verifyType(node, type);
- }
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
deleted file mode 100644
index d2d0acfc964..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TensorConverter.java
+++ /dev/null
@@ -1,225 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.TensorProto;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
-import java.nio.FloatBuffer;
-import java.nio.IntBuffer;
-import java.nio.LongBuffer;
-
-
-/**
- * Converts TensorFlow tensors into Vespa tensors.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorConverter {
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
- return toVespaTensor(tfTensor, "d");
- }
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cellByDirectIndex(i, values.get(i));
- return builder.build();
- }
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
- for (int i = 0; i < values.size(); i++) {
- builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
- }
- return builder.build();
- }
-
- public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- Values values = readValuesOf(tensorProto);
- for (int i = 0; i < values.size(); ++i) {
- builder.cellByDirectIndex(i, values.get(i));
- }
- return builder.build();
- }
-
- private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder();
- int dimensionIndex = 0;
- for (long dimensionSize : shape) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
- public static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- public static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
- switch (tfTensor.dataType()) {
- case DOUBLE: return new DoubleValues(tfTensor);
- case FLOAT: return new FloatValues(tfTensor);
- case BOOL: return new BoolValues(tfTensor);
- case UINT8: return new IntValues(tfTensor);
- case INT32: return new IntValues(tfTensor);
- case INT64: return new LongValues(tfTensor);
- }
- throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
- }
-
- private static Values readValuesOf(TensorProto tensorProto) {
- switch (tensorProto.getDtype()) {
- case DT_BOOL:
- return new ProtoBoolValues(tensorProto);
- case DT_HALF:
- return new ProtoHalfValues(tensorProto);
- case DT_INT16:
- case DT_INT32:
- return new ProtoIntValues(tensorProto);
- case DT_INT64:
- return new ProtoInt64Values(tensorProto);
- case DT_FLOAT:
- return new ProtoFloatValues(tensorProto);
- case DT_DOUBLE:
- return new ProtoDoubleValues(tensorProto);
- }
- throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
- }
-
- /** Allows reading values from buffers of various numeric types as bytes */
- private static abstract class Values {
- abstract double get(int i);
- abstract int size();
- }
-
- private static abstract class TensorFlowValues extends Values {
- private final int size;
- TensorFlowValues(int size) {
- this.size = size;
- }
- @Override int size() { return this.size; }
- }
-
- private static class DoubleValues extends TensorFlowValues {
- private final DoubleBuffer values;
- DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = DoubleBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class FloatValues extends TensorFlowValues {
- private final FloatBuffer values;
- FloatValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = FloatBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class BoolValues extends TensorFlowValues {
- private final ByteBuffer values;
- BoolValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = ByteBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class IntValues extends TensorFlowValues {
- private final IntBuffer values;
- IntValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = IntBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class LongValues extends TensorFlowValues {
- private final LongBuffer values;
- LongValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = LongBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static abstract class ProtoValues extends Values {
- protected final TensorProto tensorProto;
- protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
- }
-
- private static class ProtoBoolValues extends ProtoValues {
- ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
- @Override int size() { return tensorProto.getBoolValCount(); }
- }
-
- private static class ProtoHalfValues extends ProtoValues {
- ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getHalfVal(i); }
- @Override int size() { return tensorProto.getHalfValCount(); }
- }
-
- private static class ProtoIntValues extends ProtoValues {
- ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getIntVal(i); }
- @Override int size() { return tensorProto.getIntValCount(); }
- }
-
- private static class ProtoInt64Values extends ProtoValues {
- ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getInt64Val(i); }
- @Override int size() { return tensorProto.getInt64ValCount(); }
- }
-
- private static class ProtoFloatValues extends ProtoValues {
- ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getFloatVal(i); }
- @Override int size() { return tensorProto.getFloatValCount(); }
- }
-
- private static class ProtoDoubleValues extends ProtoValues {
- ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getDoubleVal(i); }
- @Override int size() { return tensorProto.getDoubleValCount(); }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
deleted file mode 100644
index 67ad1edc312..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.util.List;
-
-/**
- * Converts and verifies TensorFlow tensor types into Vespa tensor types.
- *
- * @author lesters
- */
-public class TypeConverter {
-
- public static void verifyType(NodeDef node, OrderedTensorType type) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
- int vespaIndex = type.dimensionMap(tensorFlowIndex);
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- TensorShapeProto shape = tensorFlowShape(node);
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
deleted file mode 100644
index fef8bfec81d..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.JsonNode;
-import com.fasterxml.jackson.databind.ObjectMapper;
-
-/**
- * @author grace-lam
- */
-public class XGBoostParser {
-
- private List<XGBoostTree> xgboostTrees;
-
- /**
- * Constructor stores parsed JSON trees.
- *
- * @param filePath XGBoost JSON output file.
- * @throws JsonProcessingException Fails JSON parsing.
- * @throws IOException Fails file reading.
- */
- public XGBoostParser(String filePath) throws JsonProcessingException, IOException {
- this.xgboostTrees = new ArrayList<>();
- ObjectMapper mapper = new ObjectMapper();
- JsonNode forestNode = mapper.readTree(new File(filePath));
- for (JsonNode treeNode : forestNode) {
- this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
- }
- }
-
- /**
- * Converts parsed JSON trees to Vespa ranking expressions.
- *
- * @return Vespa ranking expressions.
- */
- public String toRankingExpression() {
- StringBuilder ret = new StringBuilder();
- for (int i = 0; i < xgboostTrees.size(); i++) {
- ret.append(treeToRankExp(xgboostTrees.get(i)));
- if (i != xgboostTrees.size() - 1) {
- ret.append(" + \n");
- }
- }
- return ret.toString();
- }
-
- /**
- * Recursive helper function for toRankingExpression().
- *
- * @param node XGBoost tree node to convert.
- * @return Vespa ranking expression for input node.
- */
- public String treeToRankExp(XGBoostTree node) {
- if (node.isLeaf()) {
- return Double.toString(node.getLeaf());
- } else {
- assert node.getChildren().size() == 2;
- String trueExp;
- String falseExp;
- if (node.getYes() == node.getChildren().get(0).getNodeid()) {
- trueExp = treeToRankExp(node.getChildren().get(0));
- falseExp = treeToRankExp(node.getChildren().get(1));
- } else {
- trueExp = treeToRankExp(node.getChildren().get(1));
- falseExp = treeToRankExp(node.getChildren().get(0));
- }
- return "if (" + node.getSplit() + " < " + Double.toString(node.getSplit_condition()) + ", " + trueExp + ", "
- + falseExp + ")";
- }
- }
-
-} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
deleted file mode 100644
index 6bbc9abe8ae..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
-
-import java.util.List;
-
-/**
- * Outlines the JSON representation used for parsing the XGBoost output file.
- *
- * @author grace-lam
- */
-public class XGBoostTree {
-
- // ID of current node.
- private int nodeid;
- // Depth of current node w.r.t. the tree's root.
- private int depth;
- // Feature name used for split.
- private String split;
- // Feature value threshold to split on.
- private double split_condition;
- // Next node if feature value < split_condition.
- private int yes;
- // Next node if feature value >= split_condition.
- private int no;
- // Next node if feature value is missing.
- private int missing;
- // Response value for leaf node.
- private double leaf;
- // List of child nodes.
- private List<XGBoostTree> children;
-
- public int getNodeid() {
- return nodeid;
- }
-
- public int getDepth() {
- return depth;
- }
-
- public String getSplit() {
- return split;
- }
-
- public double getSplit_condition() {
- return split_condition;
- }
-
- public int getYes() {
- return yes;
- }
-
- public int getNo() {
- return no;
- }
-
- public int getMissing() {
- return missing;
- }
-
- public double getLeaf() {
- return leaf;
- }
-
- public List<XGBoostTree> getChildren() {
- return children;
- }
-
- /**
- * Check if current node is a leaf node.
- *
- * @return True if leaf, false otherwise.
- */
- public boolean isLeaf() {
- return children == null;
- }
-
-} \ No newline at end of file