diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-22 11:44:52 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-22 11:44:52 +0200 |
commit | 7392f9fdbee5f0a52ac9c056376b659b32500c60 (patch) | |
tree | 96449c35ccbdf0a4e8f6e93b3e57350be78e5efd | |
parent | 4f59bae2e4a90a3064311bda6ef1158f48182250 (diff) |
Automatically figure out the right importer
6 files changed, 40 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java index 6f3fb0e1768..7cf0a5d8b76 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java @@ -1,9 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; +import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; +import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import java.io.File; import java.util.HashMap; @@ -16,13 +20,12 @@ import java.util.Map; */ class ImportedModels { - private final ModelImporter modelImporter; - /** The cache of already imported models */ private final Map<String, ImportedModel> importedModels = new HashMap<>(); - ImportedModels(ModelImporter modelImporter) { - this.modelImporter = modelImporter; + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter()); + + ImportedModels() { } /** @@ -34,7 +37,8 @@ class ImportedModels { */ public ImportedModel get(File modelPath) { String modelName = toName(modelPath); - return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelPath)); + ModelImporter importer = importers.stream().filter(item -> item.canImport(modelPath.toString())).findFirst().get(); + return importedModels.computeIfAbsent(modelName, __ -> importer.importModel(modelName, modelPath)); } private static String toName(File modelPath) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index b6cc5df22f6..0b68a67acff 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -24,7 +24,7 @@ import java.util.Map; */ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedOnnxModels = new ImportedModels(new OnnxImporter()); + private final ImportedModels importedOnnxModels = new ImportedModels(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedOnnxModels = new HashMap<>(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 774b166c45a..4f15fb5a291 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -22,7 +22,7 @@ import java.util.Map; */ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedTensorFlowModels = new ImportedModels(new TensorFlowImporter()); + private final ImportedModels importedTensorFlowModels = new ImportedModels(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index a658833b426..41817eb3e62 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -32,13 +32,14 @@ public abstract class ModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); - /** - * The main import function. - */ + /** Returns whether the file or directory at the given path is of the tyope which can be imported by this */ + public abstract boolean canImport(String modelPath); + + /** Imports the given model */ public abstract ImportedModel importModel(String modelName, String modelPath); - public ImportedModel importModel(String modelName, File modelDir) { - return importModel(modelName, modelDir.toString()); + public final ImportedModel importModel(String modelName, File modelPath) { + return importModel(modelName, modelPath.toString()); } /** diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index d3dd2a1d418..eafd18a6f83 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.Intermediat import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter; import onnx.Onnx; +import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -17,6 +18,14 @@ import java.io.IOException; public class OnnxImporter extends ModelImporter { @Override + public boolean canImport(String modelPath) { + File modelFile = new File(modelPath); + if ( ! modelFile.isFile()) return false; + + return modelFile.toString().endsWith(".onnx"); + } + + @Override public ImportedModel importModel(String modelName, String modelPath) { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java index 303ba228fa6..afd01b3d7da 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.Intermediat import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.GraphImporter; import org.tensorflow.SavedModelBundle; +import java.io.File; import java.io.IOException; /** @@ -15,10 +16,22 @@ import java.io.IOException; */ public class TensorFlowImporter extends ModelImporter { + @Override + public boolean canImport(String modelPath) { + File modelDir = new File(modelPath); + if ( ! modelDir.isDirectory()) return false; + + // No other model types are stored in protobuf files thus far + for (File file : modelDir.listFiles()) { + if (file.toString().endsWith(".pbtxt")) return true; + if (file.toString().endsWith(".pb")) return true; + } + return false; + } + /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. - * The name of the model is taken as the db/pbtxt file name (not including the file ending). * * @param modelName the name of the model to import, consisting of characters in [A-Za-z0-9_] * @param modelDir the directory containing the TensorFlow model files to import |