summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-09-04 23:51:48 +0200
committerGitHub <noreply@github.com>2018-09-04 23:51:48 +0200
commit24753b7ccd7a51ab1a8c850e09e7a234c4df7453 (patch)
treecf2638f5ea8781d4115a530ab9ce31633c9fe5fc /searchlib
parent9cbbb4c62926c7e4ed4e1855bb807b95c3c666ec (diff)
parent68995e136c5ae87df33e5fbf261291d19ccd1929 (diff)
Merge pull request #6778 from vespa-engine/bratseth/check-for-model-name-collisions
Check for model name collisions
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java2
8 files changed, 33 insertions, 19 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index f7fe91cb56f..ac5eefcc5b2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -23,6 +23,7 @@ public class ImportedModel {
private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
private final String name;
+ private final String source;
private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
@@ -36,16 +37,21 @@ public class ImportedModel {
* Creates a new imported model.
*
* @param name the name of this mode, containing only characters in [A-Za-z0-9_]
+ * @param source the source path (directory or file) of this model
*/
- public ImportedModel(String name) {
+ public ImportedModel(String name, String source) {
if ( ! nameRegexp.matcher(name).matches())
throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
this.name = name;
+ this.source = source;
}
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
+ /** Returns the source path (directiry or file) of this model */
+ public String source() { return source; }
+
/** Returns an immutable map of the arguments ("Placeholders") of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
index 92cb8c3f360..40d1ca8030a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -6,7 +6,10 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import java.io.File;
+import java.util.Arrays;
import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Optional;
/**
@@ -30,25 +33,30 @@ public class ImportedModels {
}
public ImportedModels(File modelsDirectory) {
- ImmutableMap.Builder<String, ImportedModel> builder = new ImmutableMap.Builder<>();
+ Map<String, ImportedModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, builder);
- importedModels = builder.build();
+ importRecursively(modelsDirectory, models);
+ importedModels = ImmutableMap.copyOf(models);
}
- private static void importRecursively(File dir, ImmutableMap.Builder<String, ImportedModel> builder) {
+ private static void importRecursively(File dir, Map<String, ImportedModel> models) {
if ( ! dir.isDirectory()) return;
- for (File child : dir.listFiles()) {
+
+ Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
Optional<ModelImporter> importer = findImporterOf(child);
if (importer.isPresent()) {
String name = toName(child);
- builder.put(name, importer.get().importModel(name, child));
+ ImportedModel existing = models.get(name);
+ if (existing != null)
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
+ " both resolve to the model name '" + name + "'");
+ models.put(name, importer.get().importModel(name, child));
}
else {
- importRecursively(child, builder);
+ importRecursively(child, models);
}
- }
+ });
}
private static Optional<ModelImporter> findImporterOf(File path) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 13718935cef..9833e52cb61 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -46,8 +46,8 @@ public abstract class ModelImporter {
* Takes an IntermediateGraph and converts it to a ImportedModel containing
* the actual Vespa ranking expressions.
*/
- static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
- ImportedModel model = new ImportedModel(graph.name());
+ static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
+ ImportedModel model = new ImportedModel(graph.name(), modelSource);
graph.optimize();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
index 187e2f2e29d..917b0d6a389 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -31,7 +31,7 @@ public class OnnxImporter extends ModelImporter {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph);
+ return convertIntermediateGraphToModel(graph, modelPath);
} catch (IOException e) {
throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
index afd01b3d7da..7c18e04bae7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -39,7 +39,7 @@ public class TensorFlowImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelDir) {
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(modelName, model);
+ return importModel(modelName, modelDir, model);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
@@ -47,10 +47,10 @@ public class TensorFlowImporter extends ModelImporter {
}
/** Imports a TensorFlow model */
- ImportedModel importModel(String modelName, SavedModelBundle model) {
+ ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph);
+ return convertIntermediateGraphToModel(graph, modelDir);
}
catch (IOException e) {
throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
index e08214579db..725f319a839 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XGBoostImporter.java
@@ -27,7 +27,7 @@ public class XGBoostImporter extends ModelImporter {
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try {
- ImportedModel model = new ImportedModel(modelName);
+ ImportedModel model = new ImportedModel(modelName, modelPath);
XGBoostParser parser = new XGBoostParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
index 39a8b211d09..eee92862e7f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/IntermediateGraph.java
@@ -10,8 +10,8 @@ import java.util.Map;
import java.util.Set;
/**
- * Holds an intermediate representation of an imported ONNX or TensorFlow
- * graph. After this intermediate representation is constructed, it is used to
+ * Holds an intermediate representation of an imported model graph.
+ * After this intermediate representation is constructed, it is used to
* simplify and optimize the computational graph and then converted into the
* final ImportedModel that holds the Vespa ranking expressions for the model.
*
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 723c5f27914..273eafad0d9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -36,7 +36,7 @@ public class TestableTensorFlowModel {
public TestableTensorFlowModel(String modelName, String modelDir) {
tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
- model = new TensorFlowImporter().importModel(modelName, tensorFlowModel);
+ model = new TensorFlowImporter().importModel(modelName, modelDir, tensorFlowModel);
}
public ImportedModel get() { return model; }