diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-21 08:40:32 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-23 06:54:10 +0000 |
commit | 0db383e464bb24c525ffc4b3df51950a8f10444f (patch) | |
tree | 73ba7fbba0c10953082df29d5e8755425b061db5 /config-model | |
parent | a3cd7169ee674d4f8fcdbc3a2e7a87a42ab20f20 (diff) |
allow rank profiles representing models to have declared inputs
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 8bef4c39ba1..11f5438af4d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -103,6 +103,8 @@ public class RankProfile implements Cloneable { private Map<String, RankingExpressionFunction> functions = new LinkedHashMap<>(); + private Map<Reference, TensorType> inputParameters = new LinkedHashMap<>(); + private Set<String> filterFields = new HashSet<>(); private final RankProfileRegistry rankProfileRegistry; @@ -578,6 +580,23 @@ public class RankProfile implements Cloneable { return rankingExpressionFunction; } + /** + * Use for rank profiles representing a model evaluation; it will assume + * that a input is provided with the declared type (for the purpose of + * type resolving). + **/ + public void addInputParameter(String name, TensorType declaredType) { + Reference ref = Reference.fromIdentifier(name); + if (inputParameters.containsKey(ref)) { + TensorType hadType = inputParameters.get(ref); + if (! declaredType.equals(hadType)) { + throw new IllegalArgumentException("Tried to replace input parameter "+name+" with different type: "+ + hadType+" -> "+declaredType); + } + } + inputParameters.put(ref, declaredType); + } + public RankingExpressionFunction findFunction(String name) { RankingExpressionFunction function = functions.get(name); return ((function == null) && (getInherited() != null)) @@ -677,6 +696,7 @@ public class RankProfile implements Cloneable { clone.summaryFeatures = summaryFeatures != null ? new LinkedHashSet<>(this.summaryFeatures) : null; clone.rankFeatures = rankFeatures != null ? new LinkedHashSet<>(this.rankFeatures) : null; clone.rankProperties = new LinkedHashMap<>(this.rankProperties); + clone.inputParameters = new LinkedHashMap<>(this.inputParameters); clone.functions = new LinkedHashMap<>(this.functions); clone.filterFields = new HashSet<>(this.filterFields); clone.constants = new HashMap<>(this.constants); @@ -790,8 +810,12 @@ public class RankProfile implements Cloneable { return typeContext(queryProfiles, collectFeatureTypes()); } + public MapEvaluationTypeContext typeContext() { return typeContext(new QueryProfileRegistry()); } + private Map<Reference, TensorType> collectFeatureTypes() { Map<Reference, TensorType> featureTypes = new HashMap<>(); + // Add input parameters + inputParameters.forEach((k, v) -> featureTypes.put(k, v)); // Add attributes allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); |