diff options
author | Arne Juul <arnej@vespa.ai> | 2023-10-18 11:32:46 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-10-18 13:54:06 +0000 |
commit | 9af2a2c2b78510bec7b4f8017bcb98e1da7e3e2a (patch) | |
tree | 4bfd7bbbc999a0251ab55e6a189720cc74f61ebe /container-search/src/main/java | |
parent | 24f8d5ab63a193aeb6660eafb83a22956a679a75 (diff) |
add defaults extraction and unit test
Diffstat (limited to 'container-search/src/main/java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java | 69 |
1 files changed, 68 insertions, 1 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java index e5cd09d3a18..084c2c290eb 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -4,6 +4,7 @@ package com.yahoo.search.ranking; import ai.vespa.models.evaluation.FunctionEvaluator; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import java.util.*; @@ -30,6 +31,71 @@ class GlobalPhaseSetup { this.defaultValues = defaultValues; } + static class DefaultQueryFeatureExtractor { + final String baseName; + final String qfName; + TensorType type = null; + Tensor value = null; + DefaultQueryFeatureExtractor(String unwrappedQueryFeature) { + baseName = unwrappedQueryFeature; + qfName = "query(" + baseName + ")"; + } + List<String> lookingFor() { + return List.of(qfName, "vespa.type.query." + baseName); + } + void accept(String key, String propValue) { + if (key.equals(qfName)) { + this.value = Tensor.from(propValue); + } else { + this.type = TensorType.fromSpec(propValue); + } + } + Tensor extract() { + if (value != null) { + return value; + } + if (type != null) { + return Tensor.Builder.of(type).build(); + } + return Tensor.from(0.0); + } + } + + static private Map<String, Tensor> extraDefaultQueryFeatureValues(RankProfilesConfig.Rankprofile rp, + List<String> fromQuery, + List<NormalizerSetup> normalizers) + { + Map<String, DefaultQueryFeatureExtractor> extractors = new HashMap<>(); + for (String fn : fromQuery) { + extractors.put(fn, new DefaultQueryFeatureExtractor(fn)); + } + for (var n : normalizers) { + for (String fn : n.inputEvalSpec().fromQuery()) { + extractors.put(fn, new DefaultQueryFeatureExtractor(fn)); + } + } + Map<String, DefaultQueryFeatureExtractor> targets = new HashMap<>(); + for (var extractor : extractors.values()) { + for (String key : extractor.lookingFor()) { + var old = targets.put(key, extractor); + if (old != null) { + throw new IllegalStateException("Multiple targets for key: " + key); + } + } + } + for (var prop : rp.fef().property()) { + var extractor = targets.get(prop.name()); + if (extractor != null) { + extractor.accept(prop.name(), prop.value()); + } + } + Map<String, Tensor> defaultValues = new HashMap<>(); + for (var extractor : extractors.values()) { + defaultValues.put(extractor.qfName, extractor.extract()); + } + return defaultValues; + } + static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { var model = modelEvaluator.modelForRankProfile(rp.name()); Map<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<>(); @@ -104,7 +170,8 @@ class GlobalPhaseSetup { } Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource); var gfun = new FunEvalSpec(supplier, fromQuery, fromMF); - return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, Collections.emptyMap()); + var defaultValues = extraDefaultQueryFeatureValues(rp, fromQuery, normalizers); + return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, defaultValues); } return null; } |