diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 12:35:32 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-05-16 12:35:32 +0200 |
commit | 1d63b5d81c057a8fe99812be22abac38c8195241 (patch) | |
tree | 97bb5db1fb81040c479cc160234948ea66a3100e /config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | |
parent | 640e8893fdb07b6f607d94de5dae24bdf305e705 (diff) |
Add model support for Onnx models in rank profiles
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | 48 |
1 files changed, 33 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index fd7e7e57f00..07f3048af04 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -118,6 +118,8 @@ public class RankProfile implements Cloneable { private Map<Reference, Constant> constants = new HashMap<>(); + private Map<String, OnnxModel> onnxModels = new HashMap<>(); + private Set<String> filterFields = new HashSet<>(); private final RankProfileRegistry rankProfileRegistry; @@ -128,8 +130,6 @@ public class RankProfile implements Cloneable { private Boolean strict; - /** Global onnx models not tied to a schema */ - private final OnnxModels onnxModels; private final ApplicationPackage applicationPackage; private final DeployLogger deployLogger; @@ -144,7 +144,6 @@ public class RankProfile implements Cloneable { public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = Objects.requireNonNull(schema, "schema cannot be null"); - this.onnxModels = null; this.rankProfileRegistry = rankProfileRegistry; this.applicationPackage = schema.applicationPackage(); this.deployLogger = schema.getDeployLogger(); @@ -156,11 +155,10 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile */ public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { + RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = null; this.rankProfileRegistry = rankProfileRegistry; - this.onnxModels = onnxModels; this.applicationPackage = applicationPackage; this.deployLogger = deployLogger; } @@ -175,10 +173,6 @@ public class RankProfile implements Cloneable { return applicationPackage; } - public Map<String, OnnxModel> onnxModels() { - return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap(); - } - private Stream<ImmutableSDField> allFields() { if (schema == null) return Stream.empty(); if (allFieldsList == null) { @@ -411,6 +405,9 @@ public class RankProfile implements Cloneable { constants.put(constant.name(), constant); } + /** Returns an unmodifiable view of the constants declared in this */ + public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + /** Returns an unmodifiable view of the constants available in this */ public Map<Reference, Constant> constants() { Map<Reference, Constant> allConstants = new HashMap<>(); @@ -430,8 +427,31 @@ public class RankProfile implements Cloneable { return allConstants; } - /** Returns an unmodifiable view of the constants declared in this */ - public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); } + public void add(OnnxModel model) { + onnxModels.put(model.getName(), model); + } + + /** Returns an unmodifiable map of the onnx models declared in this. */ + public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; } + + /** Returns an unmodifiable map of the onnx models available in this. */ + public Map<String, OnnxModel> onnxModels() { + Map<String, OnnxModel> allModels = new HashMap<>(); + for (var inheritedProfile : inherited()) { + for (var model : inheritedProfile.onnxModels().values()) { + if (allModels.containsKey(model.getName())) + throw new IllegalArgumentException(model + "' is present in " + + inheritedProfile + " inherited by " + + this + ", but is also present in another profile inherited by it"); + allModels.put(model.getName(), model); + } + } + + if (schema != null) + allModels.putAll(schema.onnxModels()); + allModels.putAll(onnxModels); + return allModels; + } public void addAttributeType(String attributeName, String attributeType) { attributeTypes.addType(attributeName, attributeType); @@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable { } // Add output types for ONNX models - for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { - String modelName = entry.getKey(); - OnnxModel model = entry.getValue(); - Arguments args = new Arguments(new ReferenceNode(modelName)); + for (var model : onnxModels().values()) { + Arguments args = new Arguments(new ReferenceNode(model.getName())); Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); |