diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2021-09-09 22:30:33 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2021-09-09 22:30:33 +0200 |
commit | 63dd94db0e2566317cecff9dc67d796432702f55 (patch) | |
tree | 253040e15034982975b024c1230fc8b1fa6919a6 /model-integration | |
parent | a10158e45f95c7f286f3e926c1ed767adbb2c2e4 (diff) |
Wire in and use an executor for ml model importing too.
Diffstat (limited to 'model-integration')
2 files changed, 38 insertions, 114 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java deleted file mode 100644 index fc576df0f09..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer; - -import com.yahoo.path.Path; - -import java.io.File; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -// TODO: Remove this class after November 2018 -public class ImportedModels { - - /** All imported models, indexed by their names */ - private final Map<String, ImportedModel> importedModels; - - /** Create a null imported models */ - public ImportedModels() { - importedModels = Collections.emptyMap(); - } - - public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { - Map<String, ImportedModel> models = new HashMap<>(); - - // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers); - importedModels = Collections.unmodifiableMap(models); - } - - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - /** Returns an immutable collection of all the imported models */ - public Collection<ImportedModel> all() { - return importedModels.values(); - } - - private static void importRecursively(File dir, - Map<String, ImportedModel> models, - Collection<ModelImporter> importers) { - if ( ! dir.isDirectory()) return; - - Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional<ModelImporter> importer = findImporterOf(child, importers); - if (importer.isPresent()) { - String name = toName(child); - ImportedModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - models.put(name, importer.get().importModel(name, child)); - } - else { - importRecursively(child, models, importers); - } - }); - } - - private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { - return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); - } - - private static String toName(File modelFile) { - Path modelPath = Path.fromString(modelFile.toString()); - if (modelFile.isFile()) - modelPath = stripFileEnding(modelPath); - String localPath = concatenateAfterModelsDirectory(modelPath); - return localPath.replace('.', '_'); - } - - private static Path stripFileEnding(Path path) { - int dotIndex = path.last().lastIndexOf("."); - if (dotIndex <= 0) return path; - return path.withLast(path.last().substring(0, dotIndex)); - } - - private static String concatenateAfterModelsDirectory(Path path) { - boolean afterModels = false; - StringBuilder result = new StringBuilder(); - for (String element : path.elements()) { - if (afterModels) result.append(element).append("_"); - if (element.equals("models")) afterModels = true; - } - return result.substring(0, result.length()-1); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java index 4039de85e31..8ea3b00d423 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java @@ -10,6 +10,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; /** * All models imported from the models/ directory in the application package. @@ -24,18 +28,29 @@ public class ImportedMlModels { private final Map<String, ImportedMlModel> importedModels; /** Models that were not imported due to some error */ - private final Map<String, String> skippedModels = new HashMap<>(); + private final Map<String, String> skippedModels = new ConcurrentHashMap<>(); /** Create a null imported models */ public ImportedMlModels() { importedModels = Collections.emptyMap(); } - public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) { - Map<String, ImportedMlModel> models = new HashMap<>(); + public ImportedMlModels(File modelsDirectory, ExecutorService executor, Collection<MlModelImporter> importers) { + Map<String, Future<ImportedMlModel>> futureModels = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers, skippedModels); + importRecursively(modelsDirectory, executor, futureModels, importers, skippedModels); + Map<String, ImportedMlModel> models = new HashMap<>(); + futureModels.forEach((name, future) -> { + try { + ImportedMlModel model = future.get(); + if (model != null) { + models.put(name, model); + } + } catch (InterruptedException | ExecutionException e) { + skippedModels.put(name, e.getMessage()); + } + }); importedModels = Collections.unmodifiableMap(models); } @@ -61,7 +76,8 @@ public class ImportedMlModels { } private static void importRecursively(File dir, - Map<String, ImportedMlModel> models, + ExecutorService executor, + Map<String, Future<ImportedMlModel>> models, Collection<MlModelImporter> importers, Map<String, String> skippedModels) { if ( ! dir.isDirectory()) return; @@ -70,19 +86,26 @@ public class ImportedMlModels { Optional<MlModelImporter> importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); - ImportedMlModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - try { - ImportedMlModel importedModel = importer.get().importModel(name, child); - models.put(name, importedModel); - } catch (RuntimeException e) { - skippedModels.put(name, e.getMessage()); + Future<ImportedMlModel> existing = models.get(name); + if (existing != null) { + try { + throw new IllegalArgumentException("The models in " + child + " and " + existing.get().source() + + " both resolve to the model name '" + name + "'"); + } catch (InterruptedException | ExecutionException e) {} } + + Future<ImportedMlModel> future = executor.submit(() -> { + try { + return importer.get().importModel(name, child); + } catch (RuntimeException e) { + skippedModels.put(name, e.getMessage()); + } + return null; + }); + models.put(name, future); } else { - importRecursively(child, models, importers, skippedModels); + importRecursively(child, executor, models, importers, skippedModels); } }); } |