aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-10-18 11:32:46 +0000
committerArne Juul <arnej@vespa.ai>2023-10-18 13:54:06 +0000
commit9af2a2c2b78510bec7b4f8017bcb98e1da7e3e2a (patch)
tree4bfd7bbbc999a0251ab55e6a189720cc74f61ebe /container-search/src/main
parent24f8d5ab63a193aeb6660eafb83a22956a679a75 (diff)
add defaults extraction and unit test
Diffstat (limited to 'container-search/src/main')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java69
1 files changed, 68 insertions, 1 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
index e5cd09d3a18..084c2c290eb 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
@@ -4,6 +4,7 @@ package com.yahoo.search.ranking;
import ai.vespa.models.evaluation.FunctionEvaluator;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import java.util.*;
@@ -30,6 +31,71 @@ class GlobalPhaseSetup {
this.defaultValues = defaultValues;
}
+ static class DefaultQueryFeatureExtractor {
+ final String baseName;
+ final String qfName;
+ TensorType type = null;
+ Tensor value = null;
+ DefaultQueryFeatureExtractor(String unwrappedQueryFeature) {
+ baseName = unwrappedQueryFeature;
+ qfName = "query(" + baseName + ")";
+ }
+ List<String> lookingFor() {
+ return List.of(qfName, "vespa.type.query." + baseName);
+ }
+ void accept(String key, String propValue) {
+ if (key.equals(qfName)) {
+ this.value = Tensor.from(propValue);
+ } else {
+ this.type = TensorType.fromSpec(propValue);
+ }
+ }
+ Tensor extract() {
+ if (value != null) {
+ return value;
+ }
+ if (type != null) {
+ return Tensor.Builder.of(type).build();
+ }
+ return Tensor.from(0.0);
+ }
+ }
+
+ static private Map<String, Tensor> extraDefaultQueryFeatureValues(RankProfilesConfig.Rankprofile rp,
+ List<String> fromQuery,
+ List<NormalizerSetup> normalizers)
+ {
+ Map<String, DefaultQueryFeatureExtractor> extractors = new HashMap<>();
+ for (String fn : fromQuery) {
+ extractors.put(fn, new DefaultQueryFeatureExtractor(fn));
+ }
+ for (var n : normalizers) {
+ for (String fn : n.inputEvalSpec().fromQuery()) {
+ extractors.put(fn, new DefaultQueryFeatureExtractor(fn));
+ }
+ }
+ Map<String, DefaultQueryFeatureExtractor> targets = new HashMap<>();
+ for (var extractor : extractors.values()) {
+ for (String key : extractor.lookingFor()) {
+ var old = targets.put(key, extractor);
+ if (old != null) {
+ throw new IllegalStateException("Multiple targets for key: " + key);
+ }
+ }
+ }
+ for (var prop : rp.fef().property()) {
+ var extractor = targets.get(prop.name());
+ if (extractor != null) {
+ extractor.accept(prop.name(), prop.value());
+ }
+ }
+ Map<String, Tensor> defaultValues = new HashMap<>();
+ for (var extractor : extractors.values()) {
+ defaultValues.put(extractor.qfName, extractor.extract());
+ }
+ return defaultValues;
+ }
+
static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) {
var model = modelEvaluator.modelForRankProfile(rp.name());
Map<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<>();
@@ -104,7 +170,8 @@ class GlobalPhaseSetup {
}
Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource);
var gfun = new FunEvalSpec(supplier, fromQuery, fromMF);
- return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, Collections.emptyMap());
+ var defaultValues = extraDefaultQueryFeatureValues(rp, fromQuery, normalizers);
+ return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, defaultValues);
}
return null;
}