summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-22 11:44:52 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-22 11:44:52 +0200
commit7392f9fdbee5f0a52ac9c056376b659b32500c60 (patch)
tree96449c35ccbdf0a4e8f6e93b3e57350be78e5efd
parent4f59bae2e4a90a3064311bda6ef1158f48182250 (diff)
Automatically figure out the right importer
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java14
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java15
6 files changed, 40 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
index 6f3fb0e1768..7cf0a5d8b76 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
@@ -1,9 +1,13 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
import java.io.File;
import java.util.HashMap;
@@ -16,13 +20,12 @@ import java.util.Map;
*/
class ImportedModels {
- private final ModelImporter modelImporter;
-
/** The cache of already imported models */
private final Map<String, ImportedModel> importedModels = new HashMap<>();
- ImportedModels(ModelImporter modelImporter) {
- this.modelImporter = modelImporter;
+ private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter());
+
+ ImportedModels() {
}
/**
@@ -34,7 +37,8 @@ class ImportedModels {
*/
public ImportedModel get(File modelPath) {
String modelName = toName(modelPath);
- return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelPath));
+ ModelImporter importer = importers.stream().filter(item -> item.canImport(modelPath.toString())).findFirst().get();
+ return importedModels.computeIfAbsent(modelName, __ -> importer.importModel(modelName, modelPath));
}
private static String toName(File modelPath) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index b6cc5df22f6..0b68a67acff 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -24,7 +24,7 @@ import java.util.Map;
*/
public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedOnnxModels = new ImportedModels(new OnnxImporter());
+ private final ImportedModels importedOnnxModels = new ImportedModels();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 774b166c45a..4f15fb5a291 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -22,7 +22,7 @@ import java.util.Map;
*/
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedTensorFlowModels = new ImportedModels(new TensorFlowImporter());
+ private final ImportedModels importedTensorFlowModels = new ImportedModels();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index a658833b426..41817eb3e62 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -32,13 +32,14 @@ public abstract class ModelImporter {
private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
- /**
- * The main import function.
- */
+ /** Returns whether the file or directory at the given path is of the tyope which can be imported by this */
+ public abstract boolean canImport(String modelPath);
+
+ /** Imports the given model */
public abstract ImportedModel importModel(String modelName, String modelPath);
- public ImportedModel importModel(String modelName, File modelDir) {
- return importModel(modelName, modelDir.toString());
+ public final ImportedModel importModel(String modelName, File modelPath) {
+ return importModel(modelName, modelPath.toString());
}
/**
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
index d3dd2a1d418..eafd18a6f83 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.Intermediat
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
import onnx.Onnx;
+import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
@@ -17,6 +18,14 @@ import java.io.IOException;
public class OnnxImporter extends ModelImporter {
@Override
+ public boolean canImport(String modelPath) {
+ File modelFile = new File(modelPath);
+ if ( ! modelFile.isFile()) return false;
+
+ return modelFile.toString().endsWith(".onnx");
+ }
+
+ @Override
public ImportedModel importModel(String modelName, String modelPath) {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
index 303ba228fa6..afd01b3d7da 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.Intermediat
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter;
import org.tensorflow.SavedModelBundle;
+import java.io.File;
import java.io.IOException;
/**
@@ -15,10 +16,22 @@ import java.io.IOException;
*/
public class TensorFlowImporter extends ModelImporter {
+ @Override
+ public boolean canImport(String modelPath) {
+ File modelDir = new File(modelPath);
+ if ( ! modelDir.isDirectory()) return false;
+
+ // No other model types are stored in protobuf files thus far
+ for (File file : modelDir.listFiles()) {
+ if (file.toString().endsWith(".pbtxt")) return true;
+ if (file.toString().endsWith(".pb")) return true;
+ }
+ return false;
+ }
+
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
- * The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
* @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_]
* @param modelDir the directory containing the TensorFlow model files to import