diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-09-10 23:41:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-10 23:41:55 +0200 |
commit | b439a3506cf9e93b8e572c14457fb1e952182ae7 (patch) | |
tree | 0c96859b9eb9b49670cea04456968fb21031cf2e /config-model | |
parent | 32dd2f430a08c9c310055a843f29676bba8bd184 (diff) |
Revert "Balder/wire executor to ml model importing"
Diffstat (limited to 'config-model')
3 files changed, 17 insertions, 37 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 13769be9ec1..249ca71117a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -145,7 +145,7 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger, executor); + this.importedModels = importMlModels(applicationPackage, modelImporters, deployLogger); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml) .orElse(ValidationOverrides.empty); @@ -211,10 +211,9 @@ public class DeployState implements ConfigDefinitionStore { private static ImportedMlModels importMlModels(ApplicationPackage applicationPackage, Collection<MlModelImporter> modelImporters, - DeployLogger deployLogger, - ExecutorService executor) { + DeployLogger deployLogger) { File importFrom = applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR); - ImportedMlModels importedModels = new ImportedMlModels(importFrom, executor, modelImporters); + ImportedMlModels importedModels = new ImportedMlModels(importFrom, modelImporters); for (var entry : importedModels.getSkippedModels().entrySet()) { deployLogger.logApplicationPackage(Level.WARNING, "Skipping import of model " + entry.getKey() + " as an exception " + "occurred during import. Error: " + entry.getValue()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index c6c2fea5900..3f1cf130aff 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.collections.Pair; import com.yahoo.component.Version; import com.yahoo.config.ConfigInstance; @@ -77,8 +78,6 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -185,7 +184,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri VespaModelBuilder builder = new VespaDomBuilder(); root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this); - createGlobalRankProfiles(deployState); + createGlobalRankProfiles(deployState.getDeployLogger(), deployState.getImportedModels(), + deployState.rankProfileRegistry(), deployState.getQueryProfiles()); rankProfileList = new RankProfileList(null, // null search -> global rankingConstants, largeRankExpressions, @@ -291,24 +291,18 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. */ - private void createGlobalRankProfiles(DeployState deployState) { - var importedModels = deployState.getImportedModels().all(); - DeployLogger deployLogger = deployState.getDeployLogger(); - RankProfileRegistry rankProfileRegistry = deployState.rankProfileRegistry(); - QueryProfiles queryProfiles = deployState.getQueryProfiles(); - List <Future<ConvertedModel>> futureModels = new ArrayList<>(); - if ( ! importedModels.isEmpty()) { // models/ directory is available - for (ImportedMlModel model : importedModels) { + private void createGlobalRankProfiles(DeployLogger deployLogger, ImportedMlModels importedModels, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { + if ( ! importedModels.all().isEmpty()) { // models/ directory is available + for (ImportedMlModel model : importedModels.all()) { // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. OnnxModels onnxModels = onnxModelInfoFromSource(model); RankProfile profile = new RankProfile(model.name(), this, deployLogger, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); - futureModels.add(deployState.getExecutor().submit(() -> { - ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), - model.name(), profile, queryProfiles.getRegistry(), model); - convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); - return convertedModel; - })); + ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), + model.name(), profile, queryProfiles.getRegistry(), model); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } else { // generated and stored model information may be available instead @@ -320,18 +314,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri OnnxModels onnxModels = onnxModelInfoFromStore(modelName); RankProfile profile = new RankProfile(modelName, this, deployLogger, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); - futureModels.add(deployState.getExecutor().submit(() -> { - ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); - return convertedModel; - })); - } - } - for (var futureConvertedModel : futureModels) { - try { - futureConvertedModel.get(); - } catch (ExecutionException |InterruptedException e) { - throw new RuntimeException(e); + ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); + convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); } } new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 9c363ea0628..010b33597f3 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -24,8 +24,6 @@ import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import static org.junit.Assert.assertEquals; @@ -45,7 +43,6 @@ class RankProfileSearchFixture { private final QueryProfileRegistry queryProfileRegistry; private final Search search; private final Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); - private final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); public RankProfileRegistry getRankProfileRegistry() { return rankProfileRegistry; @@ -108,7 +105,7 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) .compile(queryProfileRegistry, - new ImportedMlModels(applicationDir.toFile(), executor, importers)); + new ImportedMlModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } |