summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarald Musum <musum@verizonmedia.com>2021-09-12 18:59:44 +0200
committerGitHub <noreply@github.com>2021-09-12 18:59:44 +0200
commitd83a5c6e35b106ccbfb7ea0a41cf1b8749bd28ac (patch)
tree94fe8a60c3ab4d4af84b7bda5de66fdcd5225e0c
parent302d4358bdb28ab1f2ebfff20f1cafb2d04f8835 (diff)
parenta92dba40abd00419801378d00812d27ea54838c3 (diff)
Merge pull request #19083 from vespa-engine/revert-19081-revert-19078-balder/wire-executor-to-ml-model-importing
Revert "Revert "Balder/wire executor to ml model importing""
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java42
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java99
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java58
6 files changed, 81 insertions, 131 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index 249ca71117a..13769be9ec1 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -145,7 +145,7 @@ public class DeployState implements ConfigDefinitionStore {
this.zone = zone;
this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated
this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated
- this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger);
+ this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger, executor);
this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml)
.orElse(ValidationOverrides.empty);
@@ -211,9 +211,10 @@ public class DeployState implements ConfigDefinitionStore {
private static ImportedMlModels importMlModels(ApplicationPackage applicationPackage,
Collection<MlModelImporter> modelImporters,
- DeployLogger deployLogger) {
+ DeployLogger deployLogger,
+ ExecutorService executor) {
File importFrom = applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR);
- ImportedMlModels importedModels = new ImportedMlModels(importFrom, modelImporters);
+ ImportedMlModels importedModels = new ImportedMlModels(importFrom, executor, modelImporters);
for (var entry : importedModels.getSkippedModels().entrySet()) {
deployLogger.logApplicationPackage(Level.WARNING, "Skipping import of model " + entry.getKey() + " as an exception " +
"occurred during import. Error: " + entry.getValue());
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 3f1cf130aff..c6c2fea5900 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -2,7 +2,6 @@
package com.yahoo.vespa.model;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
-import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.collections.Pair;
import com.yahoo.component.Version;
import com.yahoo.config.ConfigInstance;
@@ -78,6 +77,8 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -184,8 +185,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
VespaModelBuilder builder = new VespaDomBuilder();
root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this);
- createGlobalRankProfiles(deployState.getDeployLogger(), deployState.getImportedModels(),
- deployState.rankProfileRegistry(), deployState.getQueryProfiles());
+ createGlobalRankProfiles(deployState);
rankProfileList = new RankProfileList(null, // null search -> global
rankingConstants,
largeRankExpressions,
@@ -291,18 +291,24 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package,
* and adds it to the given rank profile registry.
*/
- private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels,
- RankProfileRegistry rankProfileRegistry,
- QueryProfiles queryProfiles) {
- if ( ! importedModels.all().isEmpty()) { // models/ directory is available
- for (ImportedMlModel model : importedModels.all()) {
+ private void createGlobalRankProfiles(DeployState deployState) {
+ var importedModels = deployState.getImportedModels().all();
+ DeployLogger deployLogger = deployState.getDeployLogger();
+ RankProfileRegistry rankProfileRegistry = deployState.rankProfileRegistry();
+ QueryProfiles queryProfiles = deployState.getQueryProfiles();
+ List <Future<ConvertedModel>> futureModels = new ArrayList<>();
+ if ( ! importedModels.isEmpty()) { // models/ directory is available
+ for (ImportedMlModel model : importedModels) {
// Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles.
OnnxModels onnxModels = onnxModelInfoFromSource(model);
RankProfile profile = new RankProfile(model.name(), this, deployLogger, rankProfileRegistry, onnxModels);
rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
- model.name(), profile, queryProfiles.getRegistry(), model);
- convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
+ futureModels.add(deployState.getExecutor().submit(() -> {
+ ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
+ model.name(), profile, queryProfiles.getRegistry(), model);
+ convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
+ return convertedModel;
+ }));
}
}
else { // generated and stored model information may be available instead
@@ -314,8 +320,18 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
OnnxModels onnxModels = onnxModelInfoFromStore(modelName);
RankProfile profile = new RankProfile(modelName, this, deployLogger, rankProfileRegistry, onnxModels);
rankProfileRegistry.add(profile);
- ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
- convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
+ futureModels.add(deployState.getExecutor().submit(() -> {
+ ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
+ convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false));
+ return convertedModel;
+ }));
+ }
+ }
+ for (var futureConvertedModel : futureModels) {
+ try {
+ futureConvertedModel.get();
+ } catch (ExecutionException |InterruptedException e) {
+ throw new RuntimeException(e);
}
}
new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 010b33597f3..9c363ea0628 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -24,6 +24,8 @@ import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import static org.junit.Assert.assertEquals;
@@ -43,6 +45,7 @@ class RankProfileSearchFixture {
private final QueryProfileRegistry queryProfileRegistry;
private final Search search;
private final Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ private final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
public RankProfileRegistry getRankProfileRegistry() {
return rankProfileRegistry;
@@ -105,7 +108,7 @@ class RankProfileSearchFixture {
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
.compile(queryProfileRegistry,
- new ImportedMlModels(applicationDir.toFile(), importers));
+ new ImportedMlModels(applicationDir.toFile(), executor, importers));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index cf92cbc1e89..0152669ef78 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -13,7 +13,6 @@ import com.yahoo.tensor.TensorType;
import java.io.File;
import java.io.IOException;
-import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
deleted file mode 100644
index fc576df0f09..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
+++ /dev/null
@@ -1,99 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer;
-
-import com.yahoo.path.Path;
-
-import java.io.File;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-
-// TODO: Remove this class after November 2018
-public class ImportedModels {
-
- /** All imported models, indexed by their names */
- private final Map<String, ImportedModel> importedModels;
-
- /** Create a null imported models */
- public ImportedModels() {
- importedModels = Collections.emptyMap();
- }
-
- public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
- Map<String, ImportedModel> models = new HashMap<>();
-
- // Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models, importers);
- importedModels = Collections.unmodifiableMap(models);
- }
-
- /**
- * Returns the model at the given location in the application package.
- *
- * @param modelPath the path to this model (file or directory, depending on model type)
- * under the application package, both from the root or relative to the
- * models directory works
- * @return the model at this path or null if none
- */
- public ImportedModel get(File modelPath) {
- return importedModels.get(toName(modelPath));
- }
-
- /** Returns an immutable collection of all the imported models */
- public Collection<ImportedModel> all() {
- return importedModels.values();
- }
-
- private static void importRecursively(File dir,
- Map<String, ImportedModel> models,
- Collection<ModelImporter> importers) {
- if ( ! dir.isDirectory()) return;
-
- Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child, importers);
- if (importer.isPresent()) {
- String name = toName(child);
- ImportedModel existing = models.get(name);
- if (existing != null)
- throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
- " both resolve to the model name '" + name + "'");
- models.put(name, importer.get().importModel(name, child));
- }
- else {
- importRecursively(child, models, importers);
- }
- });
- }
-
- private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
- return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
- }
-
- private static String toName(File modelFile) {
- Path modelPath = Path.fromString(modelFile.toString());
- if (modelFile.isFile())
- modelPath = stripFileEnding(modelPath);
- String localPath = concatenateAfterModelsDirectory(modelPath);
- return localPath.replace('.', '_');
- }
-
- private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf(".");
- if (dotIndex <= 0) return path;
- return path.withLast(path.last().substring(0, dotIndex));
- }
-
- private static String concatenateAfterModelsDirectory(Path path) {
- boolean afterModels = false;
- StringBuilder result = new StringBuilder();
- for (String element : path.elements()) {
- if (afterModels) result.append(element).append("_");
- if (element.equals("models")) afterModels = true;
- }
- return result.substring(0, result.length()-1);
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
index 4039de85e31..294a4782001 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/configmodelview/ImportedMlModels.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.configmodelview;
+import com.yahoo.concurrent.InThreadExecutorService;
import com.yahoo.path.Path;
import java.io.File;
@@ -10,6 +11,10 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
/**
* All models imported from the models/ directory in the application package.
@@ -24,18 +29,35 @@ public class ImportedMlModels {
private final Map<String, ImportedMlModel> importedModels;
/** Models that were not imported due to some error */
- private final Map<String, String> skippedModels = new HashMap<>();
+ private final Map<String, String> skippedModels = new ConcurrentHashMap<>();
/** Create a null imported models */
public ImportedMlModels() {
importedModels = Collections.emptyMap();
}
+ /** Will disappear shortly */
+ @Deprecated
public ImportedMlModels(File modelsDirectory, Collection<MlModelImporter> importers) {
- Map<String, ImportedMlModel> models = new HashMap<>();
+ this(modelsDirectory, new InThreadExecutorService(), importers);
+ }
+
+ public ImportedMlModels(File modelsDirectory, ExecutorService executor, Collection<MlModelImporter> importers) {
+ Map<String, Future<ImportedMlModel>> futureModels = new HashMap<>();
// Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models, importers, skippedModels);
+ importRecursively(modelsDirectory, executor, futureModels, importers, skippedModels);
+ Map<String, ImportedMlModel> models = new HashMap<>();
+ futureModels.forEach((name, future) -> {
+ try {
+ ImportedMlModel model = future.get();
+ if (model != null) {
+ models.put(name, model);
+ }
+ } catch (InterruptedException | ExecutionException e) {
+ skippedModels.put(name, e.getMessage());
+ }
+ });
importedModels = Collections.unmodifiableMap(models);
}
@@ -61,7 +83,8 @@ public class ImportedMlModels {
}
private static void importRecursively(File dir,
- Map<String, ImportedMlModel> models,
+ ExecutorService executor,
+ Map<String, Future<ImportedMlModel>> models,
Collection<MlModelImporter> importers,
Map<String, String> skippedModels) {
if ( ! dir.isDirectory()) return;
@@ -70,19 +93,26 @@ public class ImportedMlModels {
Optional<MlModelImporter> importer = findImporterOf(child, importers);
if (importer.isPresent()) {
String name = toName(child);
- ImportedMlModel existing = models.get(name);
- if (existing != null)
- throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
- " both resolve to the model name '" + name + "'");
- try {
- ImportedMlModel importedModel = importer.get().importModel(name, child);
- models.put(name, importedModel);
- } catch (RuntimeException e) {
- skippedModels.put(name, e.getMessage());
+ Future<ImportedMlModel> existing = models.get(name);
+ if (existing != null) {
+ try {
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.get().source() +
+ " both resolve to the model name '" + name + "'");
+ } catch (InterruptedException | ExecutionException e) {}
}
+
+ Future<ImportedMlModel> future = executor.submit(() -> {
+ try {
+ return importer.get().importModel(name, child);
+ } catch (RuntimeException e) {
+ skippedModels.put(name, e.getMessage());
+ }
+ return null;
+ });
+ models.put(name, future);
}
else {
- importRecursively(child, models, importers, skippedModels);
+ importRecursively(child, executor, models, importers, skippedModels);
}
});
}