diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-09-04 23:51:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-04 23:51:48 +0200 |
commit | 24753b7ccd7a51ab1a8c850e09e7a234c4df7453 (patch) | |
tree | cf2638f5ea8781d4115a530ab9ce31633c9fe5fc /searchlib | |
parent | 9cbbb4c62926c7e4ed4e1855bb807b95c3c666ec (diff) | |
parent | 68995e136c5ae87df33e5fbf261291d19ccd1929 (diff) |
Merge pull request #6778 from vespa-engine/bratseth/check-for-model-name-collisions
Check for model name collisions
Diffstat (limited to 'searchlib')
8 files changed, 33 insertions, 19 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index f7fe91cb56f..ac5eefcc5b2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -23,6 +23,7 @@ public class ImportedModel { private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); private final String name; + private final String source; private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); @@ -36,16 +37,21 @@ public class ImportedModel { * Creates a new imported model. * * @param name the name of this mode, containing only characters in [A-Za-z0-9_] + * @param source the source path (directory or file) of this model */ - public ImportedModel(String name) { + public ImportedModel(String name, String source) { if ( ! nameRegexp.matcher(name).matches()) throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); this.name = name; + this.source = source; } /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ public String name() { return name; } + /** Returns the source path (directiry or file) of this model */ + public String source() { return source; } + /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 92cb8c3f360..40d1ca8030a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,7 +6,10 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; +import java.util.Map; import java.util.Optional; /** @@ -30,25 +33,30 @@ public class ImportedModels { } public ImportedModels(File modelsDirectory) { - ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>(); + Map<String, ImportedModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, builder); - importedModels = builder.build(); + importRecursively(modelsDirectory, models); + importedModels = ImmutableMap.copyOf(models); } - private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) { + private static void importRecursively(File dir, Map<String, ImportedModel> models) { if ( ! dir.isDirectory()) return; - for (File child : dir.listFiles()) { + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { Optional<ModelImporter> importer = findImporterOf(child); if (importer.isPresent()) { String name = toName(child); - builder.put(name, importer.get().importModel(name, child)); + ImportedModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); } else { - importRecursively(child, builder); + importRecursively(child, models); } - } + }); } private static Optional<ModelImporter> findImporterOf(File path) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 13718935cef..9833e52cb61 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -46,8 +46,8 @@ public abstract class ModelImporter { * Takes an IntermediateGraph and converts it to a ImportedModel containing * the actual Vespa ranking expressions. */ - static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { - ImportedModel model = new ImportedModel(graph.name()); + static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + ImportedModel model = new ImportedModel(graph.name(), modelSource); graph.optimize(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index 187e2f2e29d..917b0d6a389 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -31,7 +31,7 @@ public class OnnxImporter extends ModelImporter { try (FileInputStream inputStream = new FileInputStream(modelPath)) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelPath); } catch (IOException e) { throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java index afd01b3d7da..7c18e04bae7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -39,7 +39,7 @@ public class TensorFlowImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelDir) { try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(modelName, model); + return importModel(modelName, modelDir, model); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); @@ -47,10 +47,10 @@ public class TensorFlowImporter extends ModelImporter { } /** Imports a TensorFlow model */ - ImportedModel importModel(String modelName, SavedModelBundle model) { + ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph); + return convertIntermediateGraphToModel(graph, modelDir); } catch (IOException e) { throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java index e08214579db..725f319a839 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java @@ -27,7 +27,7 @@ public class XGBoostImporter extends ModelImporter { @Override public ImportedModel importModel(String modelName, String modelPath) { try { - ImportedModel model = new ImportedModel(modelName); + ImportedModel model = new ImportedModel(modelName, modelPath); XGBoostParser parser = new XGBoostParser(modelPath); RankingExpression expression = new RankingExpression(parser.toRankingExpression()); model.expression(modelName, expression); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java index 39a8b211d09..eee92862e7f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java @@ -10,8 +10,8 @@ import java.util.Map; import java.util.Set; /** - * Holds an intermediate representation of an imported ONNX or TensorFlow - * graph. After this intermediate representation is constructed, it is used to + * Holds an intermediate representation of an imported model graph. + * After this intermediate representation is constructed, it is used to * simplify and optimize the computational graph and then converted into the * final ImportedModel that holds the Vespa ranking expressions for the model. * diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 723c5f27914..273eafad0d9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -36,7 +36,7 @@ public class TestableTensorFlowModel { public TestableTensorFlowModel(String modelName, String modelDir) { tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); - model = new TensorFlowImporter().importModel(modelName, tensorFlowModel); + model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel); } public ImportedModel get() { return model; } |