summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-05-16 12:35:32 +0200
committerJon Bratseth <bratseth@gmail.com>2022-05-16 12:35:32 +0200
commit1d63b5d81c057a8fe99812be22abac38c8195241 (patch)
tree97bb5db1fb81040c479cc160234948ea66a3100e /config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
parent640e8893fdb07b6f607d94de5dae24bdf305e705 (diff)
Add model support for Onnx models in rank profiles
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java48
1 files changed, 33 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index fd7e7e57f00..07f3048af04 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -118,6 +118,8 @@ public class RankProfile implements Cloneable {
private Map<Reference, Constant> constants = new HashMap<>();
+ private Map<String, OnnxModel> onnxModels = new HashMap<>();
+
private Set<String> filterFields = new HashSet<>();
private final RankProfileRegistry rankProfileRegistry;
@@ -128,8 +130,6 @@ public class RankProfile implements Cloneable {
private Boolean strict;
- /** Global onnx models not tied to a schema */
- private final OnnxModels onnxModels;
private final ApplicationPackage applicationPackage;
private final DeployLogger deployLogger;
@@ -144,7 +144,6 @@ public class RankProfile implements Cloneable {
public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = Objects.requireNonNull(schema, "schema cannot be null");
- this.onnxModels = null;
this.rankProfileRegistry = rankProfileRegistry;
this.applicationPackage = schema.applicationPackage();
this.deployLogger = schema.getDeployLogger();
@@ -156,11 +155,10 @@ public class RankProfile implements Cloneable {
* @param name the name of the new profile
*/
public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger,
- RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) {
+ RankProfileRegistry rankProfileRegistry) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.schema = null;
this.rankProfileRegistry = rankProfileRegistry;
- this.onnxModels = onnxModels;
this.applicationPackage = applicationPackage;
this.deployLogger = deployLogger;
}
@@ -175,10 +173,6 @@ public class RankProfile implements Cloneable {
return applicationPackage;
}
- public Map<String, OnnxModel> onnxModels() {
- return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (schema == null) return Stream.empty();
if (allFieldsList == null) {
@@ -411,6 +405,9 @@ public class RankProfile implements Cloneable {
constants.put(constant.name(), constant);
}
+ /** Returns an unmodifiable view of the constants declared in this */
+ public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+
/** Returns an unmodifiable view of the constants available in this */
public Map<Reference, Constant> constants() {
Map<Reference, Constant> allConstants = new HashMap<>();
@@ -430,8 +427,31 @@ public class RankProfile implements Cloneable {
return allConstants;
}
- /** Returns an unmodifiable view of the constants declared in this */
- public Map<Reference, Constant> declaredConstants() { return Collections.unmodifiableMap(constants); }
+ public void add(OnnxModel model) {
+ onnxModels.put(model.getName(), model);
+ }
+
+ /** Returns an unmodifiable map of the onnx models declared in this. */
+ public Map<String, OnnxModel> declaredOnnxModels() { return onnxModels; }
+
+ /** Returns an unmodifiable map of the onnx models available in this. */
+ public Map<String, OnnxModel> onnxModels() {
+ Map<String, OnnxModel> allModels = new HashMap<>();
+ for (var inheritedProfile : inherited()) {
+ for (var model : inheritedProfile.onnxModels().values()) {
+ if (allModels.containsKey(model.getName()))
+ throw new IllegalArgumentException(model + "' is present in " +
+ inheritedProfile + " inherited by " +
+ this + ", but is also present in another profile inherited by it");
+ allModels.put(model.getName(), model);
+ }
+ }
+
+ if (schema != null)
+ allModels.putAll(schema.onnxModels());
+ allModels.putAll(onnxModels);
+ return allModels;
+ }
public void addAttributeType(String attributeName, String attributeType) {
attributeTypes.addType(attributeName, attributeType);
@@ -1064,10 +1084,8 @@ public class RankProfile implements Cloneable {
}
// Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
+ for (var model : onnxModels().values()) {
+ Arguments args = new Arguments(new ReferenceNode(model.getName()));
Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context);
TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes);