aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-21 08:40:32 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-23 06:54:10 +0000
commit0db383e464bb24c525ffc4b3df51950a8f10444f (patch)
tree73ba7fbba0c10953082df29d5e8755425b061db5 /config-model
parenta3cd7169ee674d4f8fcdbc3a2e7a87a42ab20f20 (diff)
allow rank profiles representing models to have declared inputs
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java24
1 files changed, 24 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 8bef4c39ba1..11f5438af4d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -103,6 +103,8 @@ public class RankProfile implements Cloneable {
private Map<String, RankingExpressionFunction> functions = new LinkedHashMap<>();
+ private Map<Reference, TensorType> inputParameters = new LinkedHashMap<>();
+
private Set<String> filterFields = new HashSet<>();
private final RankProfileRegistry rankProfileRegistry;
@@ -578,6 +580,23 @@ public class RankProfile implements Cloneable {
return rankingExpressionFunction;
}
+ /**
+ * Use for rank profiles representing a model evaluation; it will assume
+ * that a input is provided with the declared type (for the purpose of
+ * type resolving).
+ **/
+ public void addInputParameter(String name, TensorType declaredType) {
+ Reference ref = Reference.fromIdentifier(name);
+ if (inputParameters.containsKey(ref)) {
+ TensorType hadType = inputParameters.get(ref);
+ if (! declaredType.equals(hadType)) {
+ throw new IllegalArgumentException("Tried to replace input parameter "+name+" with different type: "+
+ hadType+" -> "+declaredType);
+ }
+ }
+ inputParameters.put(ref, declaredType);
+ }
+
public RankingExpressionFunction findFunction(String name) {
RankingExpressionFunction function = functions.get(name);
return ((function == null) && (getInherited() != null))
@@ -677,6 +696,7 @@ public class RankProfile implements Cloneable {
clone.summaryFeatures = summaryFeatures != null ? new LinkedHashSet<>(this.summaryFeatures) : null;
clone.rankFeatures = rankFeatures != null ? new LinkedHashSet<>(this.rankFeatures) : null;
clone.rankProperties = new LinkedHashMap<>(this.rankProperties);
+ clone.inputParameters = new LinkedHashMap<>(this.inputParameters);
clone.functions = new LinkedHashMap<>(this.functions);
clone.filterFields = new HashSet<>(this.filterFields);
clone.constants = new HashMap<>(this.constants);
@@ -790,8 +810,12 @@ public class RankProfile implements Cloneable {
return typeContext(queryProfiles, collectFeatureTypes());
}
+ public MapEvaluationTypeContext typeContext() { return typeContext(new QueryProfileRegistry()); }
+
private Map<Reference, TensorType> collectFeatureTypes() {
Map<Reference, TensorType> featureTypes = new HashMap<>();
+ // Add input parameters
+ inputParameters.forEach((k, v) -> featureTypes.put(k, v));
// Add attributes
allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes));
allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes));