From e46ee256cdd3f996c9d5c034317f672383ff2fa1 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Sun, 12 Sep 2021 17:07:25 +0200 Subject: Revert "Revert "Balder/wire executor to ml model importing"" --- .../com/yahoo/config/model/deploy/DeployState.java | 7 +- .../java/com/yahoo/vespa/model/VespaModel.java | 42 ++++++--- .../processing/RankProfileSearchFixture.java | 5 +- .../rankingexpression/importer/ImportedModel.java | 1 - .../rankingexpression/importer/ImportedModels.java | 99 ---------------------- .../importer/configmodelview/ImportedMlModels.java | 53 ++++++++---- 6 files changed, 75 insertions(+), 132 deletions(-) delete mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 249ca71117a..13769be9ec1 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -145,7 +145,7 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger); + this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger, executor); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml) .orElse(ValidationOverrides.empty); @@ -211,9 +211,10 @@ public class DeployState implements ConfigDefinitionStore { private static ImportedMlModels importMlModels(ApplicationPackage applicationPackage, Collection modelImporters, - DeployLogger deployLogger) { + DeployLogger deployLogger, + ExecutorService executor) { File importFrom = applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR); - ImportedMlModels importedModels = new ImportedMlModels(importFrom, modelImporters); + ImportedMlModels importedModels = new ImportedMlModels(importFrom, executor, modelImporters); for (var entry : importedModels.getSkippedModels().entrySet()) { deployLogger.logApplicationPackage(Level.WARNING, "Skipping import of model " + entry.getKey() + " as an exception " + "occurred during import. Error: " + entry.getValue()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 3f1cf130aff..c6c2fea5900 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.model; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; -import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.collections.Pair; import com.yahoo.component.Version; import com.yahoo.config.ConfigInstance; @@ -78,6 +77,8 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -184,8 +185,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri VespaModelBuilder builder = new VespaDomBuilder(); root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this); - createGlobalRankProfiles(deployState.getDeployLogger(), deployState.getImportedModels(), - deployState.rankProfileRegistry(), deployState.getQueryProfiles()); + createGlobalRankProfiles(deployState); rankProfileList = new RankProfileList(null, // null search -> global rankingConstants, largeRankExpressions, @@ -291,18 +291,24 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. */ - private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels, - RankProfileRegistry rankProfileRegistry, - QueryProfiles queryProfiles) { - if ( ! importedModels.all().isEmpty()) { // models/ directory is available - for (ImportedMlModel model : importedModels.all()) { + private void createGlobalRankProfiles(DeployState deployState) { + var importedModels = deployState.getImportedModels().all(); + DeployLogger deployLogger = deployState.getDeployLogger(); + RankProfileRegistry rankProfileRegistry = deployState.rankProfileRegistry(); + QueryProfiles queryProfiles = deployState.getQueryProfiles(); + List > futureModels = new ArrayList<>(); + if ( ! importedModels.isEmpty()) { // models/ directory is available + for (ImportedMlModel model : importedModels) { // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. OnnxModels onnxModels = onnxModelInfoFromSource(model); RankProfile profile = new RankProfile(model.name(), this, deployLogger, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), - model.name(), profile, queryProfiles.getRegistry(), model); - convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); + futureModels.add(deployState.getExecutor().submit(() -> { + ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), + model.name(), profile, queryProfiles.getRegistry(), model); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); + return convertedModel; + })); } } else { // generated and stored model information may be available instead @@ -314,8 +320,18 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri OnnxModels onnxModels = onnxModelInfoFromStore(modelName); RankProfile profile = new RankProfile(modelName, this, deployLogger, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); - ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); + futureModels.add(deployState.getExecutor().submit(() -> { + ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); + return convertedModel; + })); + } + } + for (var futureConvertedModel : futureModels) { + try { + futureConvertedModel.get(); + } catch (ExecutionException |InterruptedException e) { + throw new RuntimeException(e); } } new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 010b33597f3..9c363ea0628 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -24,6 +24,8 @@ import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import static org.junit.Assert.assertEquals; @@ -43,6 +45,7 @@ class RankProfileSearchFixture { private final QueryProfileRegistry queryProfileRegistry; private final Search search; private final Map compiledRankProfiles = new HashMap<>(); + private final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); public RankProfileRegistry getRankProfileRegistry() { return rankProfileRegistry; @@ -105,7 +108,7 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) .compile(queryProfileRegistry, - new ImportedMlModels(applicationDir.toFile(), importers)); + new ImportedMlModels(applicationDir.toFile(), executor, importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index cf92cbc1e89..0152669ef78 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -13,7 +13,6 @@ import com.yahoo.tensor.TensorType; import java.io.File; import java.io.IOException; -import java.io.StringReader; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java deleted file mode 100644 index fc576df0f09..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer; - -import com.yahoo.path.Path; - -import java.io.File; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -// TODO: Remove this class after November 2018 -public class ImportedModels { - - /** All imported models, indexed by their names */ - private final Map importedModels; - - /** Create a null imported models */ - public ImportedModels() { - importedModels = Collections.emptyMap(); - } - - public ImportedModels(File modelsDirectory, Collection importers) { - Map models = new HashMap<>(); - - // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers); - importedModels = Collections.unmodifiableMap(models); - } - - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - /** Returns an immutable collection of all the imported models */ - public Collection all() { - return importedModels.values(); - } - - private static void importRecursively(File dir, - Map models, - Collection importers) { - if ( ! dir.isDirectory()) return; - - Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional importer = findImporterOf(child, importers); - if (importer.isPresent()) { - String name = toName(child); - ImportedModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - models.put(name, importer.get().importModel(name, child)); - } - else { - importRecursively(child, models, importers); - } - }); - } - - private static Optional findImporterOf(File path, Collection importers) { - return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); - } - - private static String toName(File modelFile) { - Path modelPath = Path.fromString(modelFile.toString()); - if (modelFile.isFile()) - modelPath = stripFileEnding(modelPath); - String localPath = concatenateAfterModelsDirectory(modelPath); - return localPath.replace('.', '_'); - } - - private static Path stripFileEnding(Path path) { - int dotIndex = path.last().lastIndexOf("."); - if (dotIndex <= 0) return path; - return path.withLast(path.last().substring(0, dotIndex)); - } - - private static String concatenateAfterModelsDirectory(Path path) { - boolean afterModels = false; - StringBuilder result = new StringBuilder(); - for (String element : path.elements()) { - if (afterModels) result.append(element).append("_"); - if (element.equals("models")) afterModels = true; - } - return result.substring(0, result.length()-1); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java index 4039de85e31..8ea3b00d423 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java @@ -10,6 +10,10 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; /** * All models imported from the models/ directory in the application package. @@ -24,18 +28,29 @@ public class ImportedMlModels { private final Map importedModels; /** Models that were not imported due to some error */ - private final Map skippedModels = new HashMap<>(); + private final Map skippedModels = new ConcurrentHashMap<>(); /** Create a null imported models */ public ImportedMlModels() { importedModels = Collections.emptyMap(); } - public ImportedMlModels(File modelsDirectory, Collection importers) { - Map models = new HashMap<>(); + public ImportedMlModels(File modelsDirectory, ExecutorService executor, Collection importers) { + Map> futureModels = new HashMap<>(); // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers, skippedModels); + importRecursively(modelsDirectory, executor, futureModels, importers, skippedModels); + Map models = new HashMap<>(); + futureModels.forEach((name, future) -> { + try { + ImportedMlModel model = future.get(); + if (model != null) { + models.put(name, model); + } + } catch (InterruptedException | ExecutionException e) { + skippedModels.put(name, e.getMessage()); + } + }); importedModels = Collections.unmodifiableMap(models); } @@ -61,7 +76,8 @@ public class ImportedMlModels { } private static void importRecursively(File dir, - Map models, + ExecutorService executor, + Map> models, Collection importers, Map skippedModels) { if ( ! dir.isDirectory()) return; @@ -70,19 +86,26 @@ public class ImportedMlModels { Optional importer = findImporterOf(child, importers); if (importer.isPresent()) { String name = toName(child); - ImportedMlModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - try { - ImportedMlModel importedModel = importer.get().importModel(name, child); - models.put(name, importedModel); - } catch (RuntimeException e) { - skippedModels.put(name, e.getMessage()); + Future existing = models.get(name); + if (existing != null) { + try { + throw new IllegalArgumentException("The models in " + child + " and " + existing.get().source() + + " both resolve to the model name '" + name + "'"); + } catch (InterruptedException | ExecutionException e) {} } + + Future future = executor.submit(() -> { + try { + return importer.get().importModel(name, child); + } catch (RuntimeException e) { + skippedModels.put(name, e.getMessage()); + } + return null; + }); + models.put(name, future); } else { - importRecursively(child, models, importers, skippedModels); + importRecursively(child, executor, models, importers, skippedModels); } }); } -- cgit v1.2.3