diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-09-10 23:41:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-10 23:41:55 +0200 |
commit | b439a3506cf9e93b8e572c14457fb1e952182ae7 (patch) | |
tree | 0c96859b9eb9b49670cea04456968fb21031cf2e /model-integration | |
parent | 32dd2f430a08c9c310055a843f29676bba8bd184 (diff) |
Revert "Balder/wire executor to ml model importing"
Diffstat (limited to 'model-integration')
3 files changed, 115 insertions, 38 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0152669ef78..cf92cbc1e89 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.TensorType; import java.io.File; import java.io.IOException; +import java.io.StringReader; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java new file mode 100644 index 00000000000..fc576df0f09 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java @@ -0,0 +1,99 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.yahoo.path.Path; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +// TODO: Remove this class after November 2018 +public class ImportedModels { + + /** All imported models, indexed by their names */ + private final Map<String, ImportedModel> importedModels; + + /** Create a null imported models */ + public ImportedModels() { + importedModels = Collections.emptyMap(); + } + + public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { + Map<String, ImportedModel> models = new HashMap<>(); + + // Find all subdirectories recursively which contains a model we can read + importRecursively(modelsDirectory, models, importers); + importedModels = Collections.unmodifiableMap(models); + } + + /** + * Returns the model at the given location in the application package. + * + * @param modelPath the path to this model (file or directory, depending on model type) + * under the application package, both from the root or relative to the + * models directory works + * @return the model at this path or null if none + */ + public ImportedModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedModel> all() { + return importedModels.values(); + } + + private static void importRecursively(File dir, + Map<String, ImportedModel> models, + Collection<ModelImporter> importers) { + if ( ! dir.isDirectory()) return; + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { + Optional<ModelImporter> importer = findImporterOf(child, importers); + if (importer.isPresent()) { + String name = toName(child); + ImportedModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); + } + else { + importRecursively(child, models, importers); + } + }); + } + + private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { + return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); + } + + private static String toName(File modelFile) { + Path modelPath = Path.fromString(modelFile.toString()); + if (modelFile.isFile()) + modelPath = stripFileEnding(modelPath); + String localPath = concatenateAfterModelsDirectory(modelPath); + return localPath.replace('.', '_'); + } + + private static Path stripFileEnding(Path path) { + int dotIndex = path.last().lastIndexOf("."); + if (dotIndex <= 0) return path; + return path.withLast(path.last().substring(0, dotIndex)); + } + + private static String concatenateAfterModelsDirectory(Path path) { + boolean afterModels = false; + StringBuilder result = new StringBuilder(); + for (String element : path.elements()) { + if (afterModels) result.append(element).append("_"); + if (element.equals("models")) afterModels = true; + } + return result.substring(0, result.length()-1); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java index 8ea3b00d423..4039de85e31 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java @@ -10,10 +10,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; /** * All models imported from the models/ directory in the application package. @@ -28,29 +24,18 @@ public class ImportedMlModels { private final Map<String, ImportedMlModel> importedModels; /** Models that were not imported due to some error */ - private final Map<String, String> skippedModels = new ConcurrentHashMap<>(); + private final Map<String, String> skippedModels = new HashMap<>(); /** Create a null imported models */ public ImportedMlModels() { importedModels = Collections.emptyMap(); } - public ImportedMlModels(File modelsDirectory, ExecutorService executor, Collection<MlModelImporter> importers) { - Map<String, Future<ImportedMlModel>> futureModels = new HashMap<>(); + public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) { + Map<String, ImportedMlModel> models = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, executor, futureModels, importers, skippedModels); - Map<String, ImportedMlModel> models = new HashMap<>(); - futureModels.forEach((name, future) -> { - try { - ImportedMlModel model = future.get(); - if (model != null) { - models.put(name, model); - } - } catch (InterruptedException | ExecutionException e) { - skippedModels.put(name, e.getMessage()); - } - }); + importRecursively(modelsDirectory, models, importers, skippedModels); importedModels = Collections.unmodifiableMap(models); } @@ -76,8 +61,7 @@ public class ImportedMlModels { } private static void importRecursively(File dir, - ExecutorService executor, - Map<String, Future<ImportedMlModel>> models, + Map<String, ImportedMlModel> models, Collection<MlModelImporter> importers, Map<String, String> skippedModels) { if ( ! dir.isDirectory()) return; @@ -86,26 +70,19 @@ public class ImportedMlModels { Optional<MlModelImporter> importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); - Future<ImportedMlModel> existing = models.get(name); - if (existing != null) { - try { - throw new IllegalArgumentException("The models in " + child + " and " + existing.get().source() + - " both resolve to the model name '" + name + "'"); - } catch (InterruptedException | ExecutionException e) {} + ImportedMlModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + try { + ImportedMlModel importedModel = importer.get().importModel(name, child); + models.put(name, importedModel); + } catch (RuntimeException e) { + skippedModels.put(name, e.getMessage()); } - - Future<ImportedMlModel> future = executor.submit(() -> { - try { - return importer.get().importModel(name, child); - } catch (RuntimeException e) { - skippedModels.put(name, e.getMessage()); - } - return null; - }); - models.put(name, future); } else { - importRecursively(child, executor, models, importers, skippedModels); + importRecursively(child, models, importers, skippedModels); } }); } |