aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-09-10 23:41:55 +0200
committerGitHub <noreply@github.com>2021-09-10 23:41:55 +0200
commitb439a3506cf9e93b8e572c14457fb1e952182ae7 (patch)
tree0c96859b9eb9b49670cea04456968fb21031cf2e /model-integration
parent32dd2f430a08c9c310055a843f29676bba8bd184 (diff)
Revert "Balder/wire executor to ml model importing"
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java99
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java53
3 files changed, 115 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 0152669ef78..cf92cbc1e89 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.TensorType;
import java.io.File;
import java.io.IOException;
+import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
new file mode 100644
index 00000000000..fc576df0f09
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
@@ -0,0 +1,99 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.path.Path;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+// TODO: Remove this class after November 2018
+public class ImportedModels {
+
+ /** All imported models, indexed by their names */
+ private final Map<String, ImportedModel> importedModels;
+
+ /** Create a null imported models */
+ public ImportedModels() {
+ importedModels = Collections.emptyMap();
+ }
+
+ public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
+ Map<String, ImportedModel> models = new HashMap<>();
+
+ // Find all subdirectories recursively which contains a model we can read
+ importRecursively(modelsDirectory, models, importers);
+ importedModels = Collections.unmodifiableMap(models);
+ }
+
+ /**
+ * Returns the model at the given location in the application package.
+ *
+ * @param modelPath the path to this model (file or directory, depending on model type)
+ * under the application package, both from the root or relative to the
+ * models directory works
+ * @return the model at this path or null if none
+ */
+ public ImportedModel get(File modelPath) {
+ return importedModels.get(toName(modelPath));
+ }
+
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedModel> all() {
+ return importedModels.values();
+ }
+
+ private static void importRecursively(File dir,
+ Map<String, ImportedModel> models,
+ Collection<ModelImporter> importers) {
+ if ( ! dir.isDirectory()) return;
+
+ Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
+ Optional<ModelImporter> importer = findImporterOf(child, importers);
+ if (importer.isPresent()) {
+ String name = toName(child);
+ ImportedModel existing = models.get(name);
+ if (existing != null)
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
+ " both resolve to the model name '" + name + "'");
+ models.put(name, importer.get().importModel(name, child));
+ }
+ else {
+ importRecursively(child, models, importers);
+ }
+ });
+ }
+
+ private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
+ return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
+ }
+
+ private static String toName(File modelFile) {
+ Path modelPath = Path.fromString(modelFile.toString());
+ if (modelFile.isFile())
+ modelPath = stripFileEnding(modelPath);
+ String localPath = concatenateAfterModelsDirectory(modelPath);
+ return localPath.replace('.', '_');
+ }
+
+ private static Path stripFileEnding(Path path) {
+ int dotIndex = path.last().lastIndexOf(".");
+ if (dotIndex <= 0) return path;
+ return path.withLast(path.last().substring(0, dotIndex));
+ }
+
+ private static String concatenateAfterModelsDirectory(Path path) {
+ boolean afterModels = false;
+ StringBuilder result = new StringBuilder();
+ for (String element : path.elements()) {
+ if (afterModels) result.append(element).append("_");
+ if (element.equals("models")) afterModels = true;
+ }
+ return result.substring(0, result.length()-1);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
index 8ea3b00d423..4039de85e31 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
@@ -10,10 +10,6 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
/**
* All models imported from the models/ directory in the application package.
@@ -28,29 +24,18 @@ public class ImportedMlModels {
private final Map<String, ImportedMlModel> importedModels;
/** Models that were not imported due to some error */
- private final Map<String, String> skippedModels = new ConcurrentHashMap<>();
+ private final Map<String, String> skippedModels = new HashMap<>();
/** Create a null imported models */
public ImportedMlModels() {
importedModels = Collections.emptyMap();
}
- public ImportedMlModels(File modelsDirectory, ExecutorService executor, Collection<MlModelImporter> importers) {
- Map<String, Future<ImportedMlModel>> futureModels = new HashMap<>();
+ public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) {
+ Map<String, ImportedMlModel> models = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, executor, futureModels, importers, skippedModels);
- Map<String, ImportedMlModel> models = new HashMap<>();
- futureModels.forEach((name, future) -> {
- try {
- ImportedMlModel model = future.get();
- if (model != null) {
- models.put(name, model);
- }
- } catch (InterruptedException | ExecutionException e) {
- skippedModels.put(name, e.getMessage());
- }
- });
+ importRecursively(modelsDirectory, models, importers, skippedModels);
importedModels = Collections.unmodifiableMap(models);
}
@@ -76,8 +61,7 @@ public class ImportedMlModels {
}
private static void importRecursively(File dir,
- ExecutorService executor,
- Map<String, Future<ImportedMlModel>> models,
+ Map<String, ImportedMlModel> models,
Collection<MlModelImporter> importers,
Map<String, String> skippedModels) {
if ( ! dir.isDirectory()) return;
@@ -86,26 +70,19 @@ public class ImportedMlModels {
Optional<MlModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
- Future<ImportedMlModel> existing = models.get(name);
- if (existing != null) {
- try {
- throw new IllegalArgumentException("The models in " + child + " and " + existing.get().source() +
- " both resolve to the model name '" + name + "'");
- } catch (InterruptedException | ExecutionException e) {}
+ ImportedMlModel existing = models.get(name);
+ if (existing != null)
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
+ " both resolve to the model name '" + name + "'");
+ try {
+ ImportedMlModel importedModel = importer.get().importModel(name, child);
+ models.put(name, importedModel);
+ } catch (RuntimeException e) {
+ skippedModels.put(name, e.getMessage());
}
-
- Future<ImportedMlModel> future = executor.submit(() -> {
- try {
- return importer.get().importModel(name, child);
- } catch (RuntimeException e) {
- skippedModels.put(name, e.getMessage());
- }
- return null;
- });
- models.put(name, future);
}
else {
- importRecursively(child, executor, models, importers, skippedModels);
+ importRecursively(child, models, importers, skippedModels);
}
});
}