diff options
Diffstat (limited to 'config-model')
3 files changed, 23 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 3c1f00f1252..b9e241a06e4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -120,6 +120,9 @@ public class RankProfile implements Cloneable { private List<ImmutableSDField> allFieldsList; + /** Global onnx models not tied to a search definition */ + private final OnnxModels onnxModels; + /** * Creates a new rank profile for a particular search definition * @@ -132,6 +135,7 @@ public class RankProfile implements Cloneable { this.name = Objects.requireNonNull(name, "name cannot be null"); this.search = Objects.requireNonNull(search, "search cannot be null"); this.model = null; + this.onnxModels = null; this.rankProfileRegistry = rankProfileRegistry; } @@ -141,11 +145,12 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile * @param model the model owning this profile */ - public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry) { + public RankProfile(String name, VespaModel model, RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.search = null; this.model = Objects.requireNonNull(model, "model cannot be null"); this.rankProfileRegistry = rankProfileRegistry; + this.onnxModels = onnxModels; } public String getName() { return name; } @@ -164,7 +169,7 @@ public class RankProfile implements Cloneable { } public Map<String, OnnxModel> onnxModels() { - return search != null ? search.onnxModels().asMap() : model.onnxModels().asMap(); + return search != null ? search.onnxModels().asMap() : onnxModels.asMap(); } private Stream<ImmutableSDField> allFields() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 623ddcc6f6c..69157cea050 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -12,6 +12,7 @@ import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.model.ApplicationConfigProducerRoot; @@ -129,8 +130,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri /** Large rank expression files of this */ private final LargeRankExpressions largeRankExpressions; - /** Large rank expression files of this */ - private final OnnxModels onnxModels; + private final FileRegistry fileRegistry; /** The validation overrides of this. This is never null. */ private final ValidationOverrides validationOverrides; @@ -175,9 +175,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri throws IOException, SAXException { super("vespamodel"); version = deployState.getVespaVersion(); + fileRegistry = deployState.getFileRegistry(); largeRankExpressions = new LargeRankExpressions(deployState.getFileRegistry()); rankingConstants = new RankingConstants(deployState.getFileRegistry()); - onnxModels = new OnnxModels(deployState.getFileRegistry()); validationOverrides = deployState.validationOverrides(); applicationPackage = deployState.getApplicationPackage(); provisioned = deployState.provisioned(); @@ -189,7 +189,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri rankProfileList = new RankProfileList(null, // null search -> global rankingConstants, largeRankExpressions, - onnxModels, + new OnnxModels(deployState.getFileRegistry()), AttributeFields.empty, deployState.rankProfileRegistry(), deployState.getQueryProfiles().getRegistry(), @@ -270,8 +270,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri public LargeRankExpressions rankExpressionFiles() { return largeRankExpressions; } - public OnnxModels onnxModels() { return onnxModels; } - /** Creates a mutable model with no services instantiated */ public static VespaModel createIncomplete(DeployState deployState) throws IOException, SAXException { return new VespaModel(new NullConfigModelRegistry(), deployState, false, @@ -302,8 +300,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri QueryProfiles queryProfiles) { if ( ! importedModels.all().isEmpty()) { // models/ directory is available for (ImportedMlModel model : importedModels.all()) { - onnxModelInfoFromSource(model); - RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry); + // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. + OnnxModels onnxModels = onnxModelInfoFromSource(model); + RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), model.name(), profile, queryProfiles.getRegistry(), model); @@ -315,8 +314,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) { String modelName = generatedModelDir.getPath().last(); if (modelName.contains(".")) continue; // Name space: Not a global profile - onnxModelInfoFromStore(modelName); - RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); + // Due to automatic naming not guaranteeing unique names, there must be a 1-1 between OnnxModels and global RankProfiles. + OnnxModels onnxModels = onnxModelInfoFromStore(modelName); + RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry, onnxModels); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); convertedModel.expressions().values().forEach(f -> profile.addFunction(f, false)); @@ -325,7 +325,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false); } - private void onnxModelInfoFromSource(ImportedMlModel model) { + private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) { + OnnxModels onnxModels = new OnnxModels(fileRegistry); if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) { String path = model.source(); String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString(); @@ -334,11 +335,14 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri } loadOnnxModelInfo(onnxModels, model.name(), path); } + return onnxModels; } - private void onnxModelInfoFromStore(String modelName) { + private OnnxModels onnxModelInfoFromStore(String modelName) { String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); + OnnxModels onnxModels = new OnnxModels(fileRegistry); loadOnnxModelInfo(onnxModels, modelName, path); + return onnxModels; } private void loadOnnxModelInfo(OnnxModels onnxModels, String name, String path) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java index 773929434cd..2133a9ba899 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/utils/FileSenderTest.java @@ -58,11 +58,6 @@ public class FileSenderTest { } @Override - public String fileSourceHost() { - return null; - } - - @Override public List<Entry> export() { return null; } |